Simple and Powerful ORM for Go, support mysql,postgres,tidb,sqlite3,mssql,oracle https://xorm.io
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

644 lines
16 KiB

  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "errors"
  7. "fmt"
  8. "reflect"
  9. "sort"
  10. "strconv"
  11. "strings"
  12. "xorm.io/xorm/internal/utils"
  13. "xorm.io/xorm/schemas"
  14. )
  15. // ErrNoElementsOnSlice represents an error there is no element when insert
  16. var ErrNoElementsOnSlice = errors.New("No element on slice when insert")
  17. // Insert insert one or more beans
  18. func (session *Session) Insert(beans ...interface{}) (int64, error) {
  19. var affected int64
  20. var err error
  21. if session.isAutoClose {
  22. defer session.Close()
  23. }
  24. session.autoResetStatement = false
  25. defer func() {
  26. session.autoResetStatement = true
  27. session.resetStatement()
  28. }()
  29. for _, bean := range beans {
  30. switch bean.(type) {
  31. case map[string]interface{}:
  32. cnt, err := session.insertMapInterface(bean.(map[string]interface{}))
  33. if err != nil {
  34. return affected, err
  35. }
  36. affected += cnt
  37. case []map[string]interface{}:
  38. s := bean.([]map[string]interface{})
  39. for i := 0; i < len(s); i++ {
  40. cnt, err := session.insertMapInterface(s[i])
  41. if err != nil {
  42. return affected, err
  43. }
  44. affected += cnt
  45. }
  46. case map[string]string:
  47. cnt, err := session.insertMapString(bean.(map[string]string))
  48. if err != nil {
  49. return affected, err
  50. }
  51. affected += cnt
  52. case []map[string]string:
  53. s := bean.([]map[string]string)
  54. for i := 0; i < len(s); i++ {
  55. cnt, err := session.insertMapString(s[i])
  56. if err != nil {
  57. return affected, err
  58. }
  59. affected += cnt
  60. }
  61. default:
  62. sliceValue := reflect.Indirect(reflect.ValueOf(bean))
  63. if sliceValue.Kind() == reflect.Slice {
  64. size := sliceValue.Len()
  65. if size <= 0 {
  66. return 0, ErrNoElementsOnSlice
  67. }
  68. cnt, err := session.innerInsertMulti(bean)
  69. if err != nil {
  70. return affected, err
  71. }
  72. affected += cnt
  73. } else {
  74. cnt, err := session.innerInsert(bean)
  75. if err != nil {
  76. return affected, err
  77. }
  78. affected += cnt
  79. }
  80. }
  81. }
  82. return affected, err
  83. }
  84. func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) {
  85. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  86. if sliceValue.Kind() != reflect.Slice {
  87. return 0, errors.New("needs a pointer to a slice")
  88. }
  89. if sliceValue.Len() <= 0 {
  90. return 0, errors.New("could not insert a empty slice")
  91. }
  92. if err := session.statement.SetRefBean(sliceValue.Index(0).Interface()); err != nil {
  93. return 0, err
  94. }
  95. tableName := session.statement.TableName()
  96. if len(tableName) <= 0 {
  97. return 0, ErrTableNotFound
  98. }
  99. var (
  100. table = session.statement.RefTable
  101. size = sliceValue.Len()
  102. colNames []string
  103. colMultiPlaces []string
  104. args []interface{}
  105. cols []*schemas.Column
  106. )
  107. for i := 0; i < size; i++ {
  108. v := sliceValue.Index(i)
  109. var vv reflect.Value
  110. switch v.Kind() {
  111. case reflect.Interface:
  112. vv = reflect.Indirect(v.Elem())
  113. default:
  114. vv = reflect.Indirect(v)
  115. }
  116. elemValue := v.Interface()
  117. var colPlaces []string
  118. // handle BeforeInsertProcessor
  119. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  120. for _, closure := range session.beforeClosures {
  121. closure(elemValue)
  122. }
  123. if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok {
  124. processor.BeforeInsert()
  125. }
  126. // --
  127. for _, col := range table.Columns() {
  128. ptrFieldValue, err := col.ValueOfV(&vv)
  129. if err != nil {
  130. return 0, err
  131. }
  132. fieldValue := *ptrFieldValue
  133. if col.IsAutoIncrement && utils.IsZero(fieldValue.Interface()) {
  134. continue
  135. }
  136. if col.MapType == schemas.ONLYFROMDB {
  137. continue
  138. }
  139. if col.IsDeleted {
  140. continue
  141. }
  142. if session.statement.OmitColumnMap.Contain(col.Name) {
  143. continue
  144. }
  145. if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
  146. continue
  147. }
  148. if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
  149. val, t := session.engine.nowTime(col)
  150. args = append(args, val)
  151. var colName = col.Name
  152. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  153. col := table.GetColumn(colName)
  154. setColumnTime(bean, col, t)
  155. })
  156. } else if col.IsVersion && session.statement.CheckVersion {
  157. args = append(args, 1)
  158. var colName = col.Name
  159. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  160. col := table.GetColumn(colName)
  161. setColumnInt(bean, col, 1)
  162. })
  163. } else {
  164. arg, err := session.statement.Value2Interface(col, fieldValue)
  165. if err != nil {
  166. return 0, err
  167. }
  168. args = append(args, arg)
  169. }
  170. if i == 0 {
  171. colNames = append(colNames, col.Name)
  172. cols = append(cols, col)
  173. }
  174. colPlaces = append(colPlaces, "?")
  175. }
  176. colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
  177. }
  178. cleanupProcessorsClosures(&session.beforeClosures)
  179. quoter := session.engine.dialect.Quoter()
  180. var sql string
  181. colStr := quoter.Join(colNames, ",")
  182. if session.engine.dialect.URI().DBType == schemas.ORACLE {
  183. temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
  184. quoter.Quote(tableName),
  185. colStr)
  186. sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
  187. quoter.Quote(tableName),
  188. colStr,
  189. strings.Join(colMultiPlaces, temp))
  190. } else {
  191. sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
  192. quoter.Quote(tableName),
  193. colStr,
  194. strings.Join(colMultiPlaces, "),("))
  195. }
  196. res, err := session.exec(sql, args...)
  197. if err != nil {
  198. return 0, err
  199. }
  200. session.cacheInsert(tableName)
  201. lenAfterClosures := len(session.afterClosures)
  202. for i := 0; i < size; i++ {
  203. elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
  204. // handle AfterInsertProcessor
  205. if session.isAutoCommit {
  206. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  207. for _, closure := range session.afterClosures {
  208. closure(elemValue)
  209. }
  210. if processor, ok := elemValue.(AfterInsertProcessor); ok {
  211. processor.AfterInsert()
  212. }
  213. } else {
  214. if lenAfterClosures > 0 {
  215. if value, has := session.afterInsertBeans[elemValue]; has && value != nil {
  216. *value = append(*value, session.afterClosures...)
  217. } else {
  218. afterClosures := make([]func(interface{}), lenAfterClosures)
  219. copy(afterClosures, session.afterClosures)
  220. session.afterInsertBeans[elemValue] = &afterClosures
  221. }
  222. } else {
  223. if _, ok := elemValue.(AfterInsertProcessor); ok {
  224. session.afterInsertBeans[elemValue] = nil
  225. }
  226. }
  227. }
  228. }
  229. cleanupProcessorsClosures(&session.afterClosures)
  230. return res.RowsAffected()
  231. }
  232. // InsertMulti insert multiple records
  233. func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
  234. if session.isAutoClose {
  235. defer session.Close()
  236. }
  237. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  238. if sliceValue.Kind() != reflect.Slice {
  239. return 0, ErrPtrSliceType
  240. }
  241. if sliceValue.Len() <= 0 {
  242. return 0, ErrNoElementsOnSlice
  243. }
  244. return session.innerInsertMulti(rowsSlicePtr)
  245. }
  246. func (session *Session) innerInsert(bean interface{}) (int64, error) {
  247. if err := session.statement.SetRefBean(bean); err != nil {
  248. return 0, err
  249. }
  250. if len(session.statement.TableName()) <= 0 {
  251. return 0, ErrTableNotFound
  252. }
  253. // handle BeforeInsertProcessor
  254. for _, closure := range session.beforeClosures {
  255. closure(bean)
  256. }
  257. cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
  258. if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
  259. processor.BeforeInsert()
  260. }
  261. var tableName = session.statement.TableName()
  262. table := session.statement.RefTable
  263. colNames, args, err := session.genInsertColumns(bean)
  264. if err != nil {
  265. return 0, err
  266. }
  267. sqlStr, args, err := session.statement.GenInsertSQL(colNames, args)
  268. if err != nil {
  269. return 0, err
  270. }
  271. handleAfterInsertProcessorFunc := func(bean interface{}) {
  272. if session.isAutoCommit {
  273. for _, closure := range session.afterClosures {
  274. closure(bean)
  275. }
  276. if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
  277. processor.AfterInsert()
  278. }
  279. } else {
  280. lenAfterClosures := len(session.afterClosures)
  281. if lenAfterClosures > 0 {
  282. if value, has := session.afterInsertBeans[bean]; has && value != nil {
  283. *value = append(*value, session.afterClosures...)
  284. } else {
  285. afterClosures := make([]func(interface{}), lenAfterClosures)
  286. copy(afterClosures, session.afterClosures)
  287. session.afterInsertBeans[bean] = &afterClosures
  288. }
  289. } else {
  290. if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
  291. session.afterInsertBeans[bean] = nil
  292. }
  293. }
  294. }
  295. cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
  296. }
  297. // for postgres, many of them didn't implement lastInsertId, so we should
  298. // implemented it ourself.
  299. if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 {
  300. res, err := session.queryBytes("select seq_atable.currval from dual", args...)
  301. if err != nil {
  302. return 0, err
  303. }
  304. defer handleAfterInsertProcessorFunc(bean)
  305. session.cacheInsert(tableName)
  306. if table.Version != "" && session.statement.CheckVersion {
  307. verValue, err := table.VersionColumn().ValueOf(bean)
  308. if err != nil {
  309. session.engine.logger.Errorf("%v", err)
  310. } else if verValue.IsValid() && verValue.CanSet() {
  311. session.incrVersionFieldValue(verValue)
  312. }
  313. }
  314. if len(res) < 1 {
  315. return 0, errors.New("insert no error but not returned id")
  316. }
  317. idByte := res[0][table.AutoIncrement]
  318. id, err := strconv.ParseInt(string(idByte), 10, 64)
  319. if err != nil || id <= 0 {
  320. return 1, err
  321. }
  322. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  323. if err != nil {
  324. session.engine.logger.Errorf("%v", err)
  325. }
  326. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  327. return 1, nil
  328. }
  329. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  330. return 1, nil
  331. } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
  332. session.engine.dialect.URI().DBType == schemas.MSSQL) {
  333. res, err := session.queryBytes(sqlStr, args...)
  334. if err != nil {
  335. return 0, err
  336. }
  337. defer handleAfterInsertProcessorFunc(bean)
  338. session.cacheInsert(tableName)
  339. if table.Version != "" && session.statement.CheckVersion {
  340. verValue, err := table.VersionColumn().ValueOf(bean)
  341. if err != nil {
  342. session.engine.logger.Errorf("%v", err)
  343. } else if verValue.IsValid() && verValue.CanSet() {
  344. session.incrVersionFieldValue(verValue)
  345. }
  346. }
  347. if len(res) < 1 {
  348. return 0, errors.New("insert successfully but not returned id")
  349. }
  350. idByte := res[0][table.AutoIncrement]
  351. id, err := strconv.ParseInt(string(idByte), 10, 64)
  352. if err != nil || id <= 0 {
  353. return 1, err
  354. }
  355. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  356. if err != nil {
  357. session.engine.logger.Errorf("%v", err)
  358. }
  359. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  360. return 1, nil
  361. }
  362. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  363. return 1, nil
  364. }
  365. res, err := session.exec(sqlStr, args...)
  366. if err != nil {
  367. return 0, err
  368. }
  369. defer handleAfterInsertProcessorFunc(bean)
  370. session.cacheInsert(tableName)
  371. if table.Version != "" && session.statement.CheckVersion {
  372. verValue, err := table.VersionColumn().ValueOf(bean)
  373. if err != nil {
  374. session.engine.logger.Errorf("%v", err)
  375. } else if verValue.IsValid() && verValue.CanSet() {
  376. session.incrVersionFieldValue(verValue)
  377. }
  378. }
  379. if table.AutoIncrement == "" {
  380. return res.RowsAffected()
  381. }
  382. var id int64
  383. id, err = res.LastInsertId()
  384. if err != nil || id <= 0 {
  385. return res.RowsAffected()
  386. }
  387. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  388. if err != nil {
  389. session.engine.logger.Errorf("%v", err)
  390. }
  391. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  392. return res.RowsAffected()
  393. }
  394. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  395. return res.RowsAffected()
  396. }
  397. // InsertOne insert only one struct into database as a record.
  398. // The in parameter bean must a struct or a point to struct. The return
  399. // parameter is inserted and error
  400. func (session *Session) InsertOne(bean interface{}) (int64, error) {
  401. if session.isAutoClose {
  402. defer session.Close()
  403. }
  404. return session.innerInsert(bean)
  405. }
  406. func (session *Session) cacheInsert(table string) error {
  407. if !session.statement.UseCache {
  408. return nil
  409. }
  410. cacher := session.engine.cacherMgr.GetCacher(table)
  411. if cacher == nil {
  412. return nil
  413. }
  414. session.engine.logger.Debugf("[cache] clear SQL: %v", table)
  415. cacher.ClearIds(table)
  416. return nil
  417. }
  418. // genInsertColumns generates insert needed columns
  419. func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) {
  420. table := session.statement.RefTable
  421. colNames := make([]string, 0, len(table.ColumnsSeq()))
  422. args := make([]interface{}, 0, len(table.ColumnsSeq()))
  423. for _, col := range table.Columns() {
  424. if col.MapType == schemas.ONLYFROMDB {
  425. continue
  426. }
  427. if col.IsDeleted {
  428. continue
  429. }
  430. if session.statement.OmitColumnMap.Contain(col.Name) {
  431. continue
  432. }
  433. if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
  434. continue
  435. }
  436. if session.statement.IncrColumns.IsColExist(col.Name) {
  437. continue
  438. } else if session.statement.DecrColumns.IsColExist(col.Name) {
  439. continue
  440. } else if session.statement.ExprColumns.IsColExist(col.Name) {
  441. continue
  442. }
  443. fieldValuePtr, err := col.ValueOf(bean)
  444. if err != nil {
  445. return nil, nil, err
  446. }
  447. fieldValue := *fieldValuePtr
  448. if col.IsAutoIncrement && utils.IsValueZero(fieldValue) {
  449. continue
  450. }
  451. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  452. if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok {
  453. if col.Nullable && utils.IsValueZero(fieldValue) {
  454. var nilValue *int
  455. fieldValue = reflect.ValueOf(nilValue)
  456. }
  457. }
  458. if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
  459. // if time is non-empty, then set to auto time
  460. val, t := session.engine.nowTime(col)
  461. args = append(args, val)
  462. var colName = col.Name
  463. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  464. col := table.GetColumn(colName)
  465. setColumnTime(bean, col, t)
  466. })
  467. } else if col.IsVersion && session.statement.CheckVersion {
  468. args = append(args, 1)
  469. } else {
  470. arg, err := session.statement.Value2Interface(col, fieldValue)
  471. if err != nil {
  472. return colNames, args, err
  473. }
  474. args = append(args, arg)
  475. }
  476. colNames = append(colNames, col.Name)
  477. }
  478. return colNames, args, nil
  479. }
  480. func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) {
  481. if len(m) == 0 {
  482. return 0, ErrParamsType
  483. }
  484. tableName := session.statement.TableName()
  485. if len(tableName) <= 0 {
  486. return 0, ErrTableNotFound
  487. }
  488. var columns = make([]string, 0, len(m))
  489. exprs := session.statement.ExprColumns
  490. for k := range m {
  491. if !exprs.IsColExist(k) {
  492. columns = append(columns, k)
  493. }
  494. }
  495. sort.Strings(columns)
  496. var args = make([]interface{}, 0, len(m))
  497. for _, colName := range columns {
  498. args = append(args, m[colName])
  499. }
  500. return session.insertMap(columns, args)
  501. }
  502. func (session *Session) insertMapString(m map[string]string) (int64, error) {
  503. if len(m) == 0 {
  504. return 0, ErrParamsType
  505. }
  506. tableName := session.statement.TableName()
  507. if len(tableName) <= 0 {
  508. return 0, ErrTableNotFound
  509. }
  510. var columns = make([]string, 0, len(m))
  511. exprs := session.statement.ExprColumns
  512. for k := range m {
  513. if !exprs.IsColExist(k) {
  514. columns = append(columns, k)
  515. }
  516. }
  517. sort.Strings(columns)
  518. var args = make([]interface{}, 0, len(m))
  519. for _, colName := range columns {
  520. args = append(args, m[colName])
  521. }
  522. return session.insertMap(columns, args)
  523. }
  524. func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) {
  525. tableName := session.statement.TableName()
  526. if len(tableName) <= 0 {
  527. return 0, ErrTableNotFound
  528. }
  529. sql, args, err := session.statement.GenInsertMapSQL(columns, args)
  530. if err != nil {
  531. return 0, err
  532. }
  533. if err := session.cacheInsert(tableName); err != nil {
  534. return 0, err
  535. }
  536. res, err := session.exec(sql, args...)
  537. if err != nil {
  538. return 0, err
  539. }
  540. affected, err := res.RowsAffected()
  541. if err != nil {
  542. return 0, err
  543. }
  544. return affected, nil
  545. }