xorm/session_insert.go
satorunooshie 57365108ae
All checks were successful
continuous-integration/drone/push Build is passing
Fix insertMultipleStruct to insert null value under certain circumstances (#2077)
The behavior of multi insertion of null differs from that of single insertion.
On the other hand, behavior between single and bulk insertion of null using pointer of struct field is identical.
Please see the example below.
```go
s := engineGroup.NewSession()

type User struct {
    ID   int    `xorm:"not null pk autoincr INT(10)"`
    Name string `xorm:"not null VARCHAR(191)"`
    Age  int64  `xorm:"null BIGINT(20)"`
}
list := []*User{
    {
        Name: "John",
    },
    {
    	Name: "Wick",
    },
}
s.Nullable("age")

// Single insertion works
_, err := s.Insert(list[0]) // [table] `user` INSERT INTO `user` (`name`,`age`) VALUES (?, ?) [John <nil>]

s.Nullable("age")
// Bulk insertion does not work
_, err := s.Insert(list) // [table] `user` INSERT INTO `user` (`name,`age`) VALUES (?, ?),(?, ?) [John, 0, Wick, 0]

//---------------------------------
//Using pointer, which is nullable, the generated sql has no difference.
//---------------------------------

type User struct {
    ID   int    `xorm:"not null pk autoincr INT(10)"`
    Name string `xorm:"not null VARCHAR(191)"`
    Age  *int64 `xorm:"null BIGINT(20)"`
}
list := []*User{
    {
        Name: "John",
    },
    {
    	Name: "Wick",
    },
}
s.Nullable("age")
// Single insertion works
_, err := s.Insert(list[0]) // [table] `user` INSERT INTO `user` (`name`,`age`) VALUES (?, ?) [John <nil>]

s.Nullable("age")
// Bulk insertion does not work
_, err := s.Insert(list) // [table] `user` INSERT INTO `user` (`name,`age`) VALUES (?, ?),(?, ?) [John, <nil>, Wick, <nil>]
```

Reviewed-on: #2077
Co-authored-by: satorunooshie <satorunooshie@noreply.gitea.io>
Co-committed-by: satorunooshie <satorunooshie@noreply.gitea.io>
2021-12-27 10:11:44 +08:00

710 lines
18 KiB
Go

// Copyright 2016 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"errors"
"fmt"
"reflect"
"sort"
"strings"
"time"
"xorm.io/xorm/convert"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
// ErrNoElementsOnSlice represents an error there is no element when insert
var ErrNoElementsOnSlice = errors.New("no element on slice when insert")
// Insert insert one or more beans
func (session *Session) Insert(beans ...interface{}) (int64, error) {
var affected int64
var err error
if session.isAutoClose {
defer session.Close()
}
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
for _, bean := range beans {
var cnt int64
var err error
switch v := bean.(type) {
case map[string]interface{}:
cnt, err = session.insertMapInterface(v)
case []map[string]interface{}:
cnt, err = session.insertMultipleMapInterface(v)
case map[string]string:
cnt, err = session.insertMapString(v)
case []map[string]string:
cnt, err = session.insertMultipleMapString(v)
default:
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
if sliceValue.Kind() == reflect.Slice {
cnt, err = session.insertMultipleStruct(bean)
} else {
cnt, err = session.insertStruct(bean)
}
}
if err != nil {
return affected, err
}
affected += cnt
}
return affected, err
}
func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, error) {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice {
return 0, errors.New("needs a pointer to a slice")
}
if sliceValue.Len() <= 0 {
return 0, ErrNoElementsOnSlice
}
if err := session.statement.SetRefBean(sliceValue.Index(0).Interface()); err != nil {
return 0, err
}
tableName := session.statement.TableName()
if len(tableName) == 0 {
return 0, ErrTableNotFound
}
var (
table = session.statement.RefTable
size = sliceValue.Len()
colNames []string
colMultiPlaces []string
args []interface{}
)
for i := 0; i < size; i++ {
v := sliceValue.Index(i)
var vv reflect.Value
switch v.Kind() {
case reflect.Interface:
vv = reflect.Indirect(v.Elem())
default:
vv = reflect.Indirect(v)
}
elemValue := v.Interface()
var colPlaces []string
// handle BeforeInsertProcessor
// !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
for _, closure := range session.beforeClosures {
closure(elemValue)
}
if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok {
processor.BeforeInsert()
}
// --
for _, col := range table.Columns() {
ptrFieldValue, err := col.ValueOfV(&vv)
if err != nil {
return 0, err
}
fieldValue := *ptrFieldValue
if col.IsAutoIncrement && utils.IsZero(fieldValue.Interface()) {
if session.engine.dialect.Features().AutoincrMode == dialects.SequenceAutoincrMode {
if i == 0 {
colNames = append(colNames, col.Name)
}
colPlaces = append(colPlaces, utils.SeqName(tableName)+".nextval")
}
continue
}
if col.MapType == schemas.ONLYFROMDB {
continue
}
if col.IsDeleted {
continue
}
if session.statement.OmitColumnMap.Contain(col.Name) {
continue
}
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
continue
}
// !satorunooshie! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok {
if col.Nullable && utils.IsValueZero(fieldValue) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t, err := session.engine.nowTime(col)
if err != nil {
return 0, err
}
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnInt(bean, col, 1)
})
} else {
arg, err := session.statement.Value2Interface(col, fieldValue)
if err != nil {
return 0, err
}
args = append(args, arg)
}
if i == 0 {
colNames = append(colNames, col.Name)
}
colPlaces = append(colPlaces, "?")
}
colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
}
cleanupProcessorsClosures(&session.beforeClosures)
quoter := session.engine.dialect.Quoter()
var sql string
colStr := quoter.Join(colNames, ",")
if session.engine.dialect.URI().DBType == schemas.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
quoter.Quote(tableName),
colStr)
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
quoter.Quote(tableName),
colStr,
strings.Join(colMultiPlaces, temp))
} else {
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
quoter.Quote(tableName),
colStr,
strings.Join(colMultiPlaces, "),("))
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
_ = session.cacheInsert(tableName)
lenAfterClosures := len(session.afterClosures)
for i := 0; i < size; i++ {
elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
// handle AfterInsertProcessor
if session.isAutoCommit {
// !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
for _, closure := range session.afterClosures {
closure(elemValue)
}
if processor, ok := elemValue.(AfterInsertProcessor); ok {
processor.AfterInsert()
}
} else {
if lenAfterClosures > 0 {
if value, has := session.afterInsertBeans[elemValue]; has && value != nil {
*value = append(*value, session.afterClosures...)
} else {
afterClosures := make([]func(interface{}), lenAfterClosures)
copy(afterClosures, session.afterClosures)
session.afterInsertBeans[elemValue] = &afterClosures
}
} else {
if _, ok := elemValue.(AfterInsertProcessor); ok {
session.afterInsertBeans[elemValue] = nil
}
}
}
}
cleanupProcessorsClosures(&session.afterClosures)
return res.RowsAffected()
}
// InsertMulti insert multiple records
func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
if session.isAutoClose {
defer session.Close()
}
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice {
return 0, ErrPtrSliceType
}
return session.insertMultipleStruct(rowsSlicePtr)
}
func (session *Session) insertStruct(bean interface{}) (int64, error) {
if err := session.statement.SetRefBean(bean); err != nil {
return 0, err
}
if len(session.statement.TableName()) == 0 {
return 0, ErrTableNotFound
}
// handle BeforeInsertProcessor
for _, closure := range session.beforeClosures {
closure(bean)
}
cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
processor.BeforeInsert()
}
var tableName = session.statement.TableName()
table := session.statement.RefTable
colNames, args, err := session.genInsertColumns(bean)
if err != nil {
return 0, err
}
sqlStr, args, err := session.statement.GenInsertSQL(colNames, args)
if err != nil {
return 0, err
}
sqlStr = session.engine.dialect.Quoter().Replace(sqlStr)
handleAfterInsertProcessorFunc := func(bean interface{}) {
if session.isAutoCommit {
for _, closure := range session.afterClosures {
closure(bean)
}
if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
processor.AfterInsert()
}
} else {
lenAfterClosures := len(session.afterClosures)
if lenAfterClosures > 0 {
if value, has := session.afterInsertBeans[bean]; has && value != nil {
*value = append(*value, session.afterClosures...)
} else {
afterClosures := make([]func(interface{}), lenAfterClosures)
copy(afterClosures, session.afterClosures)
session.afterInsertBeans[bean] = &afterClosures
}
} else {
if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
session.afterInsertBeans[bean] = nil
}
}
}
cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
}
// if there is auto increment column and driver don't support return it
if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID {
var sql string
var newArgs []interface{}
var needCommit bool
var id int64
if session.engine.dialect.URI().DBType == schemas.ORACLE || session.engine.dialect.URI().DBType == schemas.DAMENG {
if session.isAutoCommit { // if it's not in transaction
if err := session.Begin(); err != nil {
return 0, err
}
needCommit = true
}
_, err := session.exec(sqlStr, args...)
if err != nil {
return 0, err
}
i := utils.IndexSlice(colNames, table.AutoIncrement)
if i > -1 {
id, err = convert.AsInt64(args[i])
if err != nil {
return 0, err
}
} else {
sql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName))
}
} else {
sql = sqlStr
newArgs = args
}
if id == 0 {
err := session.queryRow(sql, newArgs...).Scan(&id)
if err != nil {
return 0, err
}
if needCommit {
if err := session.Commit(); err != nil {
return 0, err
}
}
if id == 0 {
return 0, errors.New("insert successfully but not returned id")
}
}
defer handleAfterInsertProcessorFunc(bean)
_ = session.cacheInsert(tableName)
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
return 1, nil
}
return 1, convert.AssignValue(*aiValue, id)
}
res, err := session.exec(sqlStr, args...)
if err != nil {
return 0, err
}
defer handleAfterInsertProcessorFunc(bean)
_ = session.cacheInsert(tableName)
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
}
if table.AutoIncrement == "" {
return res.RowsAffected()
}
var id int64
id, err = res.LastInsertId()
if err != nil || id <= 0 {
return res.RowsAffected()
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
return res.RowsAffected()
}
if err := convert.AssignValue(*aiValue, id); err != nil {
return 0, err
}
return res.RowsAffected()
}
// InsertOne insert only one struct into database as a record.
// The in parameter bean must a struct or a point to struct. The return
// parameter is inserted and error
// Deprecated: Please use Insert directly
func (session *Session) InsertOne(bean interface{}) (int64, error) {
if session.isAutoClose {
defer session.Close()
}
return session.insertStruct(bean)
}
func (session *Session) cacheInsert(table string) error {
if !session.statement.UseCache {
return nil
}
cacher := session.engine.cacherMgr.GetCacher(table)
if cacher == nil {
return nil
}
session.engine.logger.Debugf("[cache] clear SQL: %v", table)
cacher.ClearIds(table)
return nil
}
// genInsertColumns generates insert needed columns
func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) {
table := session.statement.RefTable
colNames := make([]string, 0, len(table.ColumnsSeq()))
args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() {
if col.MapType == schemas.ONLYFROMDB {
continue
}
if session.statement.OmitColumnMap.Contain(col.Name) {
continue
}
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
continue
}
if session.statement.IncrColumns.IsColExist(col.Name) {
continue
} else if session.statement.DecrColumns.IsColExist(col.Name) {
continue
} else if session.statement.ExprColumns.IsColExist(col.Name) {
continue
}
if col.IsDeleted {
arg, err := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, time.Time{})
if err != nil {
return nil, nil, err
}
args = append(args, arg)
colNames = append(colNames, col.Name)
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement && utils.IsValueZero(fieldValue) {
continue
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok {
if col.Nullable && utils.IsValueZero(fieldValue) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t, err := session.engine.nowTime(col)
if err != nil {
return nil, nil, err
}
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1)
} else {
arg, err := session.statement.Value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
colNames = append(colNames, col.Name)
}
return colNames, args, nil
}
func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) {
if len(m) == 0 {
return 0, ErrParamsType
}
tableName := session.statement.TableName()
if len(tableName) == 0 {
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(m))
exprs := session.statement.ExprColumns
for k := range m {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
return session.insertMap(columns, args)
}
func (session *Session) insertMultipleMapInterface(maps []map[string]interface{}) (int64, error) {
if len(maps) == 0 {
return 0, ErrNoElementsOnSlice
}
tableName := session.statement.TableName()
if len(tableName) == 0 {
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(maps[0]))
exprs := session.statement.ExprColumns
for k := range maps[0] {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var argss = make([][]interface{}, 0, len(maps))
for _, m := range maps {
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
argss = append(argss, args)
}
return session.insertMultipleMap(columns, argss)
}
func (session *Session) insertMapString(m map[string]string) (int64, error) {
if len(m) == 0 {
return 0, ErrParamsType
}
tableName := session.statement.TableName()
if len(tableName) == 0 {
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(m))
exprs := session.statement.ExprColumns
for k := range m {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
return session.insertMap(columns, args)
}
func (session *Session) insertMultipleMapString(maps []map[string]string) (int64, error) {
if len(maps) == 0 {
return 0, ErrNoElementsOnSlice
}
tableName := session.statement.TableName()
if len(tableName) == 0 {
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(maps[0]))
exprs := session.statement.ExprColumns
for k := range maps[0] {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var argss = make([][]interface{}, 0, len(maps))
for _, m := range maps {
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
argss = append(argss, args)
}
return session.insertMultipleMap(columns, argss)
}
func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) {
tableName := session.statement.TableName()
if len(tableName) == 0 {
return 0, ErrTableNotFound
}
sql, args, err := session.statement.GenInsertMapSQL(columns, args)
if err != nil {
return 0, err
}
sql = session.engine.dialect.Quoter().Replace(sql)
if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return affected, nil
}
func (session *Session) insertMultipleMap(columns []string, argss [][]interface{}) (int64, error) {
tableName := session.statement.TableName()
if len(tableName) == 0 {
return 0, ErrTableNotFound
}
sql, args, err := session.statement.GenInsertMultipleMapSQL(columns, argss)
if err != nil {
return 0, err
}
sql = session.engine.dialect.Quoter().Replace(sql)
if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return affected, nil
}