Support batch insert map #2019

Merged
lunny merged 1 commits from lunny/batch_insert_map into master 2021-07-29 08:12:09 +00:00
3 changed files with 239 additions and 52 deletions

View File

@ -1065,3 +1065,82 @@ func TestInsertDeleted(t *testing.T) {
assert.NoError(t, err)
assert.True(t, has)
}
func TestInsertMultipleMap(t *testing.T) {
type InsertMultipleMap struct {
Id int64
Width uint32
Height uint32
Name string
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(InsertMultipleMap))
cnt, err := testEngine.Table(new(InsertMultipleMap)).Insert([]map[string]interface{}{
{
"width": 20,
"height": 10,
"name": "lunny",
},
{
"width": 30,
"height": 20,
"name": "xiaolunwen",
},
})
assert.NoError(t, err)
assert.EqualValues(t, 2, cnt)
var res []InsertMultipleMap
err = testEngine.Find(&res)
assert.NoError(t, err)
assert.EqualValues(t, 2, len(res))
assert.EqualValues(t, InsertMultipleMap{
Id: 1,
Width: 20,
Height: 10,
Name: "lunny",
}, res[0])
assert.EqualValues(t, InsertMultipleMap{
Id: 2,
Width: 30,
Height: 20,
Name: "xiaolunwen",
}, res[1])
assert.NoError(t, PrepareEngine())
assertSync(t, new(InsertMultipleMap))
cnt, err = testEngine.Table(new(InsertMultipleMap)).Insert([]map[string]string{
{
"width": "20",
"height": "10",
"name": "lunny",
},
{
"width": "30",
"height": "20",
"name": "xiaolunwen",
},
})
assert.NoError(t, err)
assert.EqualValues(t, 2, cnt)
res = make([]InsertMultipleMap, 0, 2)
err = testEngine.Find(&res)
assert.NoError(t, err)
assert.EqualValues(t, 2, len(res))
assert.EqualValues(t, InsertMultipleMap{
Id: 1,
Width: 20,
Height: 10,
Name: "lunny",
}, res[0])
assert.EqualValues(t, InsertMultipleMap{
Id: 2,
Width: 30,
Height: 20,
Name: "xiaolunwen",
}, res[1])
}

View File

@ -5,6 +5,7 @@
package statements
import (
"errors"
"fmt"
"strings"
@ -205,3 +206,55 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{}
return buf.String(), buf.Args(), nil
}
func (statement *Statement) GenInsertMultipleMapSQL(columns []string, argss [][]interface{}) (string, []interface{}, error) {
var (
buf = builder.NewWriter()
exprs = statement.ExprColumns
tableName = statement.TableName()
)
if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s (", statement.quote(tableName))); err != nil {
return "", nil, err
}
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil {
return "", nil, err
}
// if insert where
if statement.Conds().IsValid() {
return "", nil, errors.New("batch insert don't support with where")
}
if _, err := buf.WriteString(") VALUES "); err != nil {
return "", nil, err
}
for i, args := range argss {
if _, err := buf.WriteString("("); err != nil {
return "", nil, err
}
if err := statement.WriteArgs(buf, args); err != nil {
return "", nil, err
}
if len(exprs) > 0 {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
}
if err := exprs.WriteArgs(buf); err != nil {
return "", nil, err
}
}
if _, err := buf.WriteString(")"); err != nil {
return "", nil, err
}
if i < len(argss)-1 {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
}
}
}
return buf.String(), buf.Args(), nil
}

View File

@ -18,7 +18,7 @@ import (
)
// ErrNoElementsOnSlice represents an error there is no element when insert
var ErrNoElementsOnSlice = errors.New("No element on slice when insert")
var ErrNoElementsOnSlice = errors.New("no element on slice when insert")
// Insert insert one or more beans
func (session *Session) Insert(beans ...interface{}) (int64, error) {
@ -36,71 +36,42 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
}()
for _, bean := range beans {
switch bean.(type) {
var cnt int64
var err error
switch v := bean.(type) {
case map[string]interface{}:
cnt, err := session.insertMapInterface(bean.(map[string]interface{}))
if err != nil {
return affected, err
}
affected += cnt
cnt, err = session.insertMapInterface(v)
case []map[string]interface{}:
s := bean.([]map[string]interface{})
for i := 0; i < len(s); i++ {
cnt, err := session.insertMapInterface(s[i])
if err != nil {
return affected, err
}
affected += cnt
}
cnt, err = session.insertMultipleMapInterface(v)
case map[string]string:
cnt, err := session.insertMapString(bean.(map[string]string))
if err != nil {
return affected, err
}
affected += cnt
cnt, err = session.insertMapString(v)
case []map[string]string:
s := bean.([]map[string]string)
for i := 0; i < len(s); i++ {
cnt, err := session.insertMapString(s[i])
if err != nil {
return affected, err
}
affected += cnt
}
cnt, err = session.insertMultipleMapString(v)
default:
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
if sliceValue.Kind() == reflect.Slice {
size := sliceValue.Len()
if size <= 0 {
return 0, ErrNoElementsOnSlice
}
cnt, err := session.innerInsertMulti(bean)
if err != nil {
return affected, err
}
affected += cnt
cnt, err = session.insertMultipleStruct(bean)
} else {
cnt, err := session.innerInsert(bean)
if err != nil {
return affected, err
}
affected += cnt
cnt, err = session.insertStruct(bean)
}
}
if err != nil {
return affected, err
}
affected += cnt
}
return affected, err
}
func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) {
func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, error) {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice {
return 0, errors.New("needs a pointer to a slice")
}
if sliceValue.Len() <= 0 {
return 0, errors.New("could not insert a empty slice")
return 0, ErrNoElementsOnSlice
}
if err := session.statement.SetRefBean(sliceValue.Index(0).Interface()); err != nil {
@ -269,14 +240,10 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
return 0, ErrPtrSliceType
}
if sliceValue.Len() <= 0 {
return 0, ErrNoElementsOnSlice
}
return session.innerInsertMulti(rowsSlicePtr)
return session.insertMultipleStruct(rowsSlicePtr)
}
func (session *Session) innerInsert(bean interface{}) (int64, error) {
func (session *Session) insertStruct(bean interface{}) (int64, error) {
if err := session.statement.SetRefBean(bean); err != nil {
return 0, err
}
@ -434,7 +401,7 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
defer session.Close()
}
return session.innerInsert(bean)
return session.insertStruct(bean)
}
func (session *Session) cacheInsert(table string) error {
@ -561,6 +528,37 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
return session.insertMap(columns, args)
}
func (session *Session) insertMultipleMapInterface(maps []map[string]interface{}) (int64, error) {
if len(maps) <= 0 {
return 0, ErrNoElementsOnSlice
}
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(maps[0]))
exprs := session.statement.ExprColumns
for k := range maps[0] {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var argss = make([][]interface{}, 0, len(maps))
for _, m := range maps {
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
argss = append(argss, args)
}
return session.insertMultipleMap(columns, argss)
}
func (session *Session) insertMapString(m map[string]string) (int64, error) {
if len(m) == 0 {
return 0, ErrParamsType
@ -589,6 +587,37 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
return session.insertMap(columns, args)
}
func (session *Session) insertMultipleMapString(maps []map[string]string) (int64, error) {
if len(maps) <= 0 {
return 0, ErrNoElementsOnSlice
}
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(maps[0]))
exprs := session.statement.ExprColumns
for k := range maps[0] {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var argss = make([][]interface{}, 0, len(maps))
for _, m := range maps {
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
argss = append(argss, args)
}
return session.insertMultipleMap(columns, argss)
}
func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) {
tableName := session.statement.TableName()
if len(tableName) <= 0 {
@ -614,3 +643,29 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64,
}
return affected, nil
}
func (session *Session) insertMultipleMap(columns []string, argss [][]interface{}) (int64, error) {
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
sql, args, err := session.statement.GenInsertMultipleMapSQL(columns, argss)
if err != nil {
return 0, err
}
if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return affected, nil
}