Browse Source

Improve insert map generating SQL (#1634)

Fix writeArg

Improve insert map generating SQL

Reviewed-on: #1634
pull/1637/head
Lunny Xiao 2 months ago
parent
commit
2ac051f075
3 changed files with 98 additions and 112 deletions
  1. +86
    -22
      internal/statements/insert.go
  2. +9
    -23
      internal/statements/statement_args.go
  3. +3
    -67
      session_insert.go

+ 86
- 22
internal/statements/insert.go View File

@@ -5,6 +5,7 @@
package statements

import (
"fmt"
"strings"

"xorm.io/builder"
@@ -23,18 +24,15 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem
return nil
}

// GenInsertSQL generates insert beans SQL
func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (string, []interface{}, error) {
var (
buf = builder.NewWriter()
exprs = statement.ExprColumns
table = statement.RefTable
tableName = statement.TableName()
exprs = statement.ExprColumns
colPlaces = strings.Repeat("?, ", len(colNames))
)
if exprs.Len() <= 0 && len(colPlaces) > 0 {
colPlaces = colPlaces[0 : len(colPlaces)-2]
}

var buf = builder.NewWriter()
if _, err := buf.WriteString("INSERT INTO "); err != nil {
return "", nil, err
}
@@ -43,7 +41,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
return "", nil, err
}

if len(colPlaces) <= 0 {
if len(colNames) <= 0 {
if statement.dialect.URI().DBType == schemas.MYSQL {
if _, err := buf.WriteString(" VALUES ()"); err != nil {
return "", nil, err
@@ -65,13 +63,14 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
return "", nil, err
}

if _, err := buf.WriteString(")"); err != nil {
return "", nil, err
}
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
return "", nil, err
}

if statement.Conds().IsValid() {
if _, err := buf.WriteString(")"); err != nil {
return "", nil, err
}
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
return "", nil, err
}
if _, err := buf.WriteString(" SELECT "); err != nil {
return "", nil, err
}
@@ -105,21 +104,20 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
return "", nil, err
}
} else {
buf.Append(args...)

if _, err := buf.WriteString(")"); err != nil {
return "", nil, err
}
if err := statement.writeInsertOutput(buf.Builder, table); err != nil {
return "", nil, err
}
if _, err := buf.WriteString(" VALUES ("); err != nil {
return "", nil, err
}
if _, err := buf.WriteString(colPlaces); err != nil {

if err := statement.WriteArgs(buf, args); err != nil {
return "", nil, err
}

if len(exprs.Args) > 0 {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
}
}

if err := exprs.WriteArgs(buf); err != nil {
return "", nil, err
}
@@ -141,3 +139,69 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})

return buf.String(), buf.Args(), nil
}

// GenInsertMapSQL generates insert map SQL
func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{}) (string, []interface{}, error) {
var (
buf = builder.NewWriter()
exprs = statement.ExprColumns
tableName = statement.TableName()
)

if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s (", statement.quote(tableName))); err != nil {
return "", nil, err
}

if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames...), ","); err != nil {
return "", nil, err
}

// if insert where
if statement.Conds().IsValid() {
if _, err := buf.WriteString(") SELECT "); err != nil {
return "", nil, err
}

if err := statement.WriteArgs(buf, args); err != nil {
return "", nil, err
}

if len(exprs.Args) > 0 {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
}
if err := exprs.WriteArgs(buf); err != nil {
return "", nil, err
}
}

if _, err := buf.WriteString(fmt.Sprintf(" FROM %s WHERE ", statement.quote(tableName))); err != nil {
return "", nil, err
}

if err := statement.Conds().WriteTo(buf); err != nil {
return "", nil, err
}
} else {
if _, err := buf.WriteString(") VALUES ("); err != nil {
return "", nil, err
}
if err := statement.WriteArgs(buf, args); err != nil {
return "", nil, err
}

if len(exprs.Args) > 0 {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
}
if err := exprs.WriteArgs(buf); err != nil {
return "", nil, err
}
}
if _, err := buf.WriteString(")"); err != nil {
return "", nil, err
}
}

return buf.String(), buf.Args(), nil
}

+ 9
- 23
internal/statements/statement_args.go View File

@@ -79,28 +79,6 @@ const insertSelectPlaceHolder = true

func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error {
switch argv := arg.(type) {
case bool:
if statement.dialect.URI().DBType == schemas.MSSQL {
if argv {
if _, err := w.WriteString("1"); err != nil {
return err
}
} else {
if _, err := w.WriteString("0"); err != nil {
return err
}
}
} else {
if argv {
if _, err := w.WriteString("true"); err != nil {
return err
}
} else {
if _, err := w.WriteString("false"); err != nil {
return err
}
}
}
case *builder.Builder:
if _, err := w.WriteString("("); err != nil {
return err
@@ -116,7 +94,15 @@ func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) er
if err := w.WriteByte('?'); err != nil {
return err
}
w.Append(arg)
if v, ok := arg.(bool); ok && statement.dialect.URI().DBType == schemas.MSSQL {
if v {
w.Append(1)
} else {
w.Append(0)
}
} else {
w.Append(arg)
}
} else {
var convertFunc = convertStringSingleQuote
if statement.dialect.URI().DBType == schemas.MYSQL {


+ 3
- 67
session_insert.go View File

@@ -12,7 +12,6 @@ import (
"strconv"
"strings"

"xorm.io/builder"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
@@ -623,74 +622,11 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64,
return 0, ErrTableNotFound
}

exprs := session.statement.ExprColumns
w := builder.NewWriter()
// if insert where
if session.statement.Conds().IsValid() {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err
}

if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil {
return 0, err
}

if _, err := w.WriteString(") SELECT "); err != nil {
return 0, err
}

if err := session.statement.WriteArgs(w, args); err != nil {
return 0, err
}

if len(exprs.Args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.WriteArgs(w); err != nil {
return 0, err
}
}

if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil {
return 0, err
}

if err := session.statement.Conds().WriteTo(w); err != nil {
return 0, err
}
} else {
qm := strings.Repeat("?,", len(columns))
qm = qm[:len(qm)-1]

if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err
}

if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil {
return 0, err
}
if _, err := w.WriteString(fmt.Sprintf(") VALUES (%s", qm)); err != nil {
return 0, err
}

w.Append(args...)
if len(exprs.Args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.WriteArgs(w); err != nil {
return 0, err
}
}
if _, err := w.WriteString(")"); err != nil {
return 0, err
}
sql, args, err := session.statement.GenInsertMapSQL(columns, args)
if err != nil {
return 0, err
}

sql := w.String()
args = w.Args()

if err := session.cacheInsert(tableName); err != nil {
return 0, err
}


Loading…
Cancel
Save