diff --git a/integrations/session_insert_test.go b/integrations/session_insert_test.go index ce52d3c4..cd56a958 100644 --- a/integrations/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -1065,3 +1065,82 @@ func TestInsertDeleted(t *testing.T) { assert.NoError(t, err) assert.True(t, has) } + +func TestInsertMultipleMap(t *testing.T) { + type InsertMultipleMap struct { + Id int64 + Width uint32 + Height uint32 + Name string + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(InsertMultipleMap)) + + cnt, err := testEngine.Table(new(InsertMultipleMap)).Insert([]map[string]interface{}{ + { + "width": 20, + "height": 10, + "name": "lunny", + }, + { + "width": 30, + "height": 20, + "name": "xiaolunwen", + }, + }) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + var res []InsertMultipleMap + err = testEngine.Find(&res) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(res)) + assert.EqualValues(t, InsertMultipleMap{ + Id: 1, + Width: 20, + Height: 10, + Name: "lunny", + }, res[0]) + assert.EqualValues(t, InsertMultipleMap{ + Id: 2, + Width: 30, + Height: 20, + Name: "xiaolunwen", + }, res[1]) + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(InsertMultipleMap)) + + cnt, err = testEngine.Table(new(InsertMultipleMap)).Insert([]map[string]string{ + { + "width": "20", + "height": "10", + "name": "lunny", + }, + { + "width": "30", + "height": "20", + "name": "xiaolunwen", + }, + }) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + res = make([]InsertMultipleMap, 0, 2) + err = testEngine.Find(&res) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(res)) + assert.EqualValues(t, InsertMultipleMap{ + Id: 1, + Width: 20, + Height: 10, + Name: "lunny", + }, res[0]) + assert.EqualValues(t, InsertMultipleMap{ + Id: 2, + Width: 30, + Height: 20, + Name: "xiaolunwen", + }, res[1]) +} diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 4e43c5bd..84547cdf 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -5,6 +5,7 @@ package statements import ( + "errors" "fmt" "strings" @@ -205,3 +206,55 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{} return buf.String(), buf.Args(), nil } + +func (statement *Statement) GenInsertMultipleMapSQL(columns []string, argss [][]interface{}) (string, []interface{}, error) { + var ( + buf = builder.NewWriter() + exprs = statement.ExprColumns + tableName = statement.TableName() + ) + + if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s (", statement.quote(tableName))); err != nil { + return "", nil, err + } + + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil { + return "", nil, err + } + + // if insert where + if statement.Conds().IsValid() { + return "", nil, errors.New("batch insert don't support with where") + } + + if _, err := buf.WriteString(") VALUES "); err != nil { + return "", nil, err + } + for i, args := range argss { + if _, err := buf.WriteString("("); err != nil { + return "", nil, err + } + if err := statement.WriteArgs(buf, args); err != nil { + return "", nil, err + } + + if len(exprs) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } + } + if _, err := buf.WriteString(")"); err != nil { + return "", nil, err + } + if i < len(argss)-1 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + } + } + + return buf.String(), buf.Args(), nil +} diff --git a/session_insert.go b/session_insert.go index 1583858e..b116b9ff 100644 --- a/session_insert.go +++ b/session_insert.go @@ -18,7 +18,7 @@ import ( ) // ErrNoElementsOnSlice represents an error there is no element when insert -var ErrNoElementsOnSlice = errors.New("No element on slice when insert") +var ErrNoElementsOnSlice = errors.New("no element on slice when insert") // Insert insert one or more beans func (session *Session) Insert(beans ...interface{}) (int64, error) { @@ -36,71 +36,42 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { }() for _, bean := range beans { - switch bean.(type) { + var cnt int64 + var err error + switch v := bean.(type) { case map[string]interface{}: - cnt, err := session.insertMapInterface(bean.(map[string]interface{})) - if err != nil { - return affected, err - } - affected += cnt + cnt, err = session.insertMapInterface(v) case []map[string]interface{}: - s := bean.([]map[string]interface{}) - for i := 0; i < len(s); i++ { - cnt, err := session.insertMapInterface(s[i]) - if err != nil { - return affected, err - } - affected += cnt - } + cnt, err = session.insertMultipleMapInterface(v) case map[string]string: - cnt, err := session.insertMapString(bean.(map[string]string)) - if err != nil { - return affected, err - } - affected += cnt + cnt, err = session.insertMapString(v) case []map[string]string: - s := bean.([]map[string]string) - for i := 0; i < len(s); i++ { - cnt, err := session.insertMapString(s[i]) - if err != nil { - return affected, err - } - affected += cnt - } + cnt, err = session.insertMultipleMapString(v) default: sliceValue := reflect.Indirect(reflect.ValueOf(bean)) if sliceValue.Kind() == reflect.Slice { - size := sliceValue.Len() - if size <= 0 { - return 0, ErrNoElementsOnSlice - } - - cnt, err := session.innerInsertMulti(bean) - if err != nil { - return affected, err - } - affected += cnt + cnt, err = session.insertMultipleStruct(bean) } else { - cnt, err := session.innerInsert(bean) - if err != nil { - return affected, err - } - affected += cnt + cnt, err = session.insertStruct(bean) } } + if err != nil { + return affected, err + } + affected += cnt } return affected, err } -func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) { +func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, error) { sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) if sliceValue.Kind() != reflect.Slice { return 0, errors.New("needs a pointer to a slice") } if sliceValue.Len() <= 0 { - return 0, errors.New("could not insert a empty slice") + return 0, ErrNoElementsOnSlice } if err := session.statement.SetRefBean(sliceValue.Index(0).Interface()); err != nil { @@ -269,14 +240,10 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { return 0, ErrPtrSliceType } - if sliceValue.Len() <= 0 { - return 0, ErrNoElementsOnSlice - } - - return session.innerInsertMulti(rowsSlicePtr) + return session.insertMultipleStruct(rowsSlicePtr) } -func (session *Session) innerInsert(bean interface{}) (int64, error) { +func (session *Session) insertStruct(bean interface{}) (int64, error) { if err := session.statement.SetRefBean(bean); err != nil { return 0, err } @@ -434,7 +401,7 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { defer session.Close() } - return session.innerInsert(bean) + return session.insertStruct(bean) } func (session *Session) cacheInsert(table string) error { @@ -561,6 +528,37 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err return session.insertMap(columns, args) } +func (session *Session) insertMultipleMapInterface(maps []map[string]interface{}) (int64, error) { + if len(maps) <= 0 { + return 0, ErrNoElementsOnSlice + } + + tableName := session.statement.TableName() + if len(tableName) <= 0 { + return 0, ErrTableNotFound + } + + var columns = make([]string, 0, len(maps[0])) + exprs := session.statement.ExprColumns + for k := range maps[0] { + if !exprs.IsColExist(k) { + columns = append(columns, k) + } + } + sort.Strings(columns) + + var argss = make([][]interface{}, 0, len(maps)) + for _, m := range maps { + var args = make([]interface{}, 0, len(m)) + for _, colName := range columns { + args = append(args, m[colName]) + } + argss = append(argss, args) + } + + return session.insertMultipleMap(columns, argss) +} + func (session *Session) insertMapString(m map[string]string) (int64, error) { if len(m) == 0 { return 0, ErrParamsType @@ -589,6 +587,37 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { return session.insertMap(columns, args) } +func (session *Session) insertMultipleMapString(maps []map[string]string) (int64, error) { + if len(maps) <= 0 { + return 0, ErrNoElementsOnSlice + } + + tableName := session.statement.TableName() + if len(tableName) <= 0 { + return 0, ErrTableNotFound + } + + var columns = make([]string, 0, len(maps[0])) + exprs := session.statement.ExprColumns + for k := range maps[0] { + if !exprs.IsColExist(k) { + columns = append(columns, k) + } + } + sort.Strings(columns) + + var argss = make([][]interface{}, 0, len(maps)) + for _, m := range maps { + var args = make([]interface{}, 0, len(m)) + for _, colName := range columns { + args = append(args, m[colName]) + } + argss = append(argss, args) + } + + return session.insertMultipleMap(columns, argss) +} + func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) { tableName := session.statement.TableName() if len(tableName) <= 0 { @@ -614,3 +643,29 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, } return affected, nil } + +func (session *Session) insertMultipleMap(columns []string, argss [][]interface{}) (int64, error) { + tableName := session.statement.TableName() + if len(tableName) <= 0 { + return 0, ErrTableNotFound + } + + sql, args, err := session.statement.GenInsertMultipleMapSQL(columns, argss) + if err != nil { + return 0, err + } + + if err := session.cacheInsert(tableName); err != nil { + return 0, err + } + + res, err := session.exec(sql, args...) + if err != nil { + return 0, err + } + affected, err := res.RowsAffected() + if err != nil { + return 0, err + } + return affected, nil +}