From 18ce3866630e9753571d39ff6921678d49a60337 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 22 Sep 2019 17:25:15 +0800 Subject: [PATCH 1/2] add support subquery on SetExpr --- engine.go | 2 +- go.mod | 2 +- go.sum | 4 +- interface.go | 2 +- session_cols.go | 45 +------ session_insert.go | 266 +++++++++++++++++++++++++---------------- session_insert_test.go | 15 +++ session_update.go | 51 +++++--- statement.go | 42 ++----- statement_args.go | 69 +++++++++++ statement_columnmap.go | 35 ++++++ statement_exprparam.go | 101 ++++++++++++++++ statement_quote.go | 19 +++ 13 files changed, 455 insertions(+), 198 deletions(-) create mode 100644 statement_args.go create mode 100644 statement_columnmap.go create mode 100644 statement_exprparam.go create mode 100644 statement_quote.go diff --git a/engine.go b/engine.go index f04c702e..649fd1e3 100644 --- a/engine.go +++ b/engine.go @@ -729,7 +729,7 @@ func (engine *Engine) Decr(column string, arg ...interface{}) *Session { } // SetExpr provides a update string like "column = {expression}" -func (engine *Engine) SetExpr(column string, expression string) *Session { +func (engine *Engine) SetExpr(column string, expression interface{}) *Session { session := engine.NewSession() session.isAutoClose = true return session.SetExpr(column, expression) diff --git a/go.mod b/go.mod index a3e78cae..ac982b05 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,6 @@ require ( github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 // indirect github.com/stretchr/testify v1.3.0 github.com/ziutek/mymysql v1.5.4 - xorm.io/builder v0.3.6-0.20190906062455-b937eb46ecfb + xorm.io/builder v0.3.6 xorm.io/core v0.7.0 ) diff --git a/go.sum b/go.sum index 0f2baf17..2fecfbd0 100644 --- a/go.sum +++ b/go.sum @@ -162,7 +162,7 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -xorm.io/builder v0.3.6-0.20190906062455-b937eb46ecfb h1:2idZcp79ldX5qLeQ6WKCdS7aEFNOMvQc9wrtt5hSRwM= -xorm.io/builder v0.3.6-0.20190906062455-b937eb46ecfb/go.mod h1:LEFAPISnRzG+zxaxj2vPicRwz67BdhFreKg8yv8/TgU= +xorm.io/builder v0.3.6 h1:ha28mQ2M+TFx96Hxo+iq6tQgnkC9IZkM6D8w9sKHHF8= +xorm.io/builder v0.3.6/go.mod h1:LEFAPISnRzG+zxaxj2vPicRwz67BdhFreKg8yv8/TgU= xorm.io/core v0.7.0 h1:hKxuOKWZNeiFQsSuGet/KV8HZ788hclvAl+7azx3tkM= xorm.io/core v0.7.0/go.mod h1:TuOJjIVa7e3w/rN8tDcAvuLBMtwzdHPbyOzE6Gk1EUI= diff --git a/interface.go b/interface.go index 0928f66a..a564db12 100644 --- a/interface.go +++ b/interface.go @@ -54,7 +54,7 @@ type Interface interface { QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error) Rows(bean interface{}) (*Rows, error) - SetExpr(string, string) *Session + SetExpr(string, interface{}) *Session SQL(interface{}, ...interface{}) *Session Sum(bean interface{}, colName string) (float64, error) SumInt(bean interface{}, colName string) (int64, error) diff --git a/session_cols.go b/session_cols.go index dc3befcf..1558074f 100644 --- a/session_cols.go +++ b/session_cols.go @@ -12,49 +12,6 @@ import ( "xorm.io/core" ) -type incrParam struct { - colName string - arg interface{} -} - -type decrParam struct { - colName string - arg interface{} -} - -type exprParam struct { - colName string - expr string -} - -type columnMap []string - -func (m columnMap) contain(colName string) bool { - if len(m) == 0 { - return false - } - - n := len(colName) - for _, mk := range m { - if len(mk) != n { - continue - } - if strings.EqualFold(mk, colName) { - return true - } - } - - return false -} - -func (m *columnMap) add(colName string) bool { - if m.contain(colName) { - return false - } - *m = append(*m, colName) - return true -} - func setColumnInt(bean interface{}, col *core.Column, t int64) { v, err := col.ValueOf(bean) if err != nil { @@ -132,7 +89,7 @@ func (session *Session) Decr(column string, arg ...interface{}) *Session { } // SetExpr provides a query string like "column = {expression}" -func (session *Session) SetExpr(column string, expression string) *Session { +func (session *Session) SetExpr(column string, expression interface{}) *Session { session.statement.SetExpr(column, expression) return session } diff --git a/session_insert.go b/session_insert.go index 24b32831..c943ae11 100644 --- a/session_insert.go +++ b/session_insert.go @@ -7,6 +7,7 @@ package xorm import ( "errors" "fmt" + "io" "reflect" "sort" "strconv" @@ -340,72 +341,96 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if err != nil { return 0, err } - // insert expr columns, override if exists - exprColumns := session.statement.getExpr() - exprColVals := make([]string, 0, len(exprColumns)) - for _, v := range exprColumns { - // remove the expr columns - for i, colName := range colNames { - if colName == strings.Trim(v.colName, "`") { - colNames = append(colNames[:i], colNames[i+1:]...) - args = append(args[:i], args[i+1:]...) - } - } - // append expr column to the end - colNames = append(colNames, v.colName) - exprColVals = append(exprColVals, v.expr) + exprs := session.statement.exprColumns + colPlaces := strings.Repeat("?, ", len(colNames)) + if exprs.Len() <= 0 && len(colPlaces) > 0 { + colPlaces = colPlaces[0 : len(colPlaces)-2] } - colPlaces := strings.Repeat("?, ", len(colNames)-len(exprColumns)) - if len(exprColVals) > 0 { - colPlaces = colPlaces + strings.Join(exprColVals, ", ") - } else { - if len(colPlaces) > 0 { - colPlaces = colPlaces[0 : len(colPlaces)-2] - } - } - - var sqlStr string var tableName = session.statement.TableName() var output string if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 { output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) } - if len(colPlaces) > 0 { + var buf = builder.NewWriter() + if _, err := io.WriteString(buf, fmt.Sprintf("INSERT INTO %s", session.engine.Quote(tableName))); err != nil { + return 0, err + } + + if len(colPlaces) <= 0 { + if session.engine.dialect.DBType() == core.MYSQL { + if _, err := io.WriteString(buf, " VALUES ()"); err != nil { + return 0, err + } + } else { + if _, err := io.WriteString(buf, fmt.Sprintf("%s DEFAULT VALUES", output)); err != nil { + return 0, err + } + } + } else { + if _, err := io.WriteString(buf, " ("); err != nil { + return 0, err + } + + if err := writeStrings(buf, append(colNames, exprs.colNames...), "`", "`"); err != nil { + return 0, err + } + if session.statement.cond.IsValid() { - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { + if _, err := io.WriteString(buf, fmt.Sprintf(")%s SELECT ", + output, + )); err != nil { return 0, err } - sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s SELECT %v FROM %v WHERE %v", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ","), - output, - colPlaces, - session.engine.Quote(tableName), - condSQL, - ) - args = append(args, condArgs...) + if err := writeArgs(buf, args); err != nil { + return 0, err + } + + if len(exprs.args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return 0, err + } + } + if err := exprs.writeArgs(buf); err != nil { + return 0, err + } + + if _, err := io.WriteString(buf, fmt.Sprintf(" FROM %v WHERE ", session.engine.Quote(tableName))); err != nil { + return 0, err + } + + if err := session.statement.cond.WriteTo(buf); err != nil { + return 0, err + } } else { - sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ","), + buf.Append(args...) + + if _, err := io.WriteString(buf, fmt.Sprintf(")%s VALUES (%v", output, - colPlaces) - } - } else { - if session.engine.dialect.DBType() == core.MYSQL { - sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName)) - } else { - sqlStr = fmt.Sprintf("INSERT INTO %s%s DEFAULT VALUES", session.engine.Quote(tableName), output) + colPlaces)); err != nil { + return 0, err + } + + if err := exprs.writeArgs(buf); err != nil { + return 0, err + } + + if _, err := io.WriteString(buf, ")"); err != nil { + return 0, err + } } } + sqlStr := buf.String() + args = buf.Args() + if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES { - sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement) + if _, err := io.WriteString(buf, " RETURNING "+session.engine.Quote(table.AutoIncrement)); err != nil { + return 0, err + } } handleAfterInsertProcessorFunc := func(bean interface{}) { @@ -611,9 +636,11 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac continue } - if _, ok := session.statement.incrColumns[col.Name]; ok { + if session.statement.incrColumns.isColExist(col.Name) { continue - } else if _, ok := session.statement.decrColumns[col.Name]; ok { + } else if session.statement.decrColumns.isColExist(col.Name) { + continue + } else if session.statement.exprColumns.isColExist(col.Name) { continue } @@ -688,46 +715,66 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err } var columns = make([]string, 0, len(m)) + exprs := session.statement.exprColumns for k := range m { - columns = append(columns, k) + if !exprs.isColExist(k) { + columns = append(columns, k) + } } sort.Strings(columns) - qm := strings.Repeat("?,", len(columns)) - var args = make([]interface{}, 0, len(m)) for _, colName := range columns { args = append(args, m[colName]) } - // insert expr columns, override if exists - exprColumns := session.statement.getExpr() - for _, col := range exprColumns { - columns = append(columns, strings.Trim(col.colName, "`")) - qm = qm + col.expr + "," - } - - qm = qm[:len(qm)-1] - - var sql string - + w := builder.NewWriter() if session.statement.cond.IsValid() { - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { + if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { + return 0, err + } + + if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { + return 0, err + } + + if _, err := w.WriteString(") SELECT "); err != nil { + return 0, err + } + + if err := writeArgs(w, args); err != nil { + return 0, err + } + + if len(exprs.args) > 0 { + if _, err := w.WriteString(","); err != nil { + return 0, err + } + if err := exprs.writeArgs(w); err != nil { + return 0, err + } + } + + if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil { + return 0, err + } + + if err := session.statement.cond.WriteTo(w); err != nil { return 0, err } - sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s", - session.engine.Quote(tableName), - strings.Join(columns, "`,`"), - qm, - session.engine.Quote(tableName), - condSQL, - ) - args = append(args, condArgs...) } else { - sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm) + qm := strings.Repeat("?,", len(columns)) + qm = qm[:len(qm)-1] + + if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil { + return 0, err + } + w.Append(args...) } + sql := w.String() + args = w.Args() + if err := session.cacheInsert(tableName); err != nil { return 0, err } @@ -754,8 +801,11 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { } var columns = make([]string, 0, len(m)) + exprs := session.statement.exprColumns for k := range m { - columns = append(columns, k) + if !exprs.isColExist(k) { + columns = append(columns, k) + } } sort.Strings(columns) @@ -764,37 +814,53 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { args = append(args, m[colName]) } - qm := strings.Repeat("?,", len(columns)) - - // insert expr columns, override if exists - exprColumns := session.statement.getExpr() - for _, col := range exprColumns { - columns = append(columns, strings.Trim(col.colName, "`")) - qm = qm + col.expr + "," - } - - qm = qm[:len(qm)-1] - - var sql string - + w := builder.NewWriter() if session.statement.cond.IsValid() { - qm = "(" + qm[:len(qm)-1] + ")" - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { + if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { + return 0, err + } + + if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { + return 0, err + } + + if _, err := w.WriteString(") SELECT "); err != nil { + return 0, err + } + + if err := writeArgs(w, args); err != nil { + return 0, err + } + + if len(exprs.args) > 0 { + if _, err := w.WriteString(","); err != nil { + return 0, err + } + if err := exprs.writeArgs(w); err != nil { + return 0, err + } + } + + if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil { + return 0, err + } + + if err := session.statement.cond.WriteTo(w); err != nil { return 0, err } - sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s", - session.engine.Quote(tableName), - strings.Join(columns, "`,`"), - qm, - session.engine.Quote(tableName), - condSQL, - ) - args = append(args, condArgs...) } else { - sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm) + qm := strings.Repeat("?,", len(columns)) + qm = qm[:len(qm)-1] + + if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil { + return 0, err + } + w.Append(args...) } + sql := w.String() + args = w.Args() + if err := session.cacheInsert(tableName); err != nil { return 0, err } diff --git a/session_insert_test.go b/session_insert_test.go index daf08e7f..f9c99071 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -892,4 +892,19 @@ func TestInsertWhere(t *testing.T) { assert.EqualValues(t, 40, j2.Height) assert.EqualValues(t, "trest2", j2.Name) assert.EqualValues(t, 2, j2.Index) + + inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). + Insert(map[string]string{ + "name": "trest3", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) + + var j3 InsertWhere + has, err = testEngine.ID(3).Get(&j3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "trest3", j3.Name) + assert.EqualValues(t, 3, j3.Index) } diff --git a/session_update.go b/session_update.go index 85b0bb0b..402470e5 100644 --- a/session_update.go +++ b/session_update.go @@ -223,21 +223,31 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } // for update action to like "column = column + ?" - incColumns := session.statement.getInc() - for _, v := range incColumns { - colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?") - args = append(args, v.arg) + incColumns := session.statement.incrColumns + for i, colName := range incColumns.colNames { + colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?") + args = append(args, incColumns.args[i]) } // for update action to like "column = column - ?" - decColumns := session.statement.getDec() - for _, v := range decColumns { - colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?") - args = append(args, v.arg) + decColumns := session.statement.decrColumns + for i, colName := range decColumns.colNames { + colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?") + args = append(args, decColumns.args[i]) } // for update action to like "column = expression" - exprColumns := session.statement.getExpr() - for _, v := range exprColumns { - colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr) + exprColumns := session.statement.exprColumns + for i, colName := range exprColumns.colNames { + switch tp := exprColumns.args[i].(type) { + case string: + colNames = append(colNames, session.engine.Quote(colName)+" = "+tp) + case *builder.Builder: + subQuery, subArgs, err := builder.ToSQL(tp) + if err != nil { + return 0, err + } + colNames = append(colNames, session.engine.Quote(colName)+" = "+subQuery) + args = append(args, subArgs...) + } } if err = session.statement.processIDParam(); err != nil { @@ -468,14 +478,17 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac continue } - if len(session.statement.columnMap) > 0 { - if !session.statement.columnMap.contain(col.Name) { - continue - } else if _, ok := session.statement.incrColumns[col.Name]; ok { - continue - } else if _, ok := session.statement.decrColumns[col.Name]; ok { - continue - } + // if only update specify columns + if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + continue + } + + if session.statement.incrColumns.isColExist(col.Name) { + continue + } else if session.statement.decrColumns.isColExist(col.Name) { + continue + } else if session.statement.exprColumns.isColExist(col.Name) { + continue } // !evalphobia! set fieldValue as nil when column is nullable and zero-value diff --git a/statement.go b/statement.go index 6cdbad7d..3cc0831e 100644 --- a/statement.go +++ b/statement.go @@ -52,9 +52,9 @@ type Statement struct { omitColumnMap columnMap mustColumnMap map[string]bool nullableMap map[string]bool - incrColumns map[string]incrParam - decrColumns map[string]decrParam - exprColumns map[string]exprParam + incrColumns exprParams + decrColumns exprParams + exprColumns exprParams cond builder.Cond bufferSize int context ContextCache @@ -94,9 +94,9 @@ func (statement *Statement) Init() { statement.nullableMap = make(map[string]bool) statement.checkVersion = true statement.unscoped = false - statement.incrColumns = make(map[string]incrParam) - statement.decrColumns = make(map[string]decrParam) - statement.exprColumns = make(map[string]exprParam) + statement.incrColumns = exprParams{} + statement.decrColumns = exprParams{} + statement.exprColumns = exprParams{} statement.cond = builder.NewCond() statement.bufferSize = 0 statement.context = nil @@ -534,48 +534,30 @@ func (statement *Statement) ID(id interface{}) *Statement { // Incr Generate "Update ... Set column = column + arg" statement func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { - k := strings.ToLower(column) if len(arg) > 0 { - statement.incrColumns[k] = incrParam{column, arg[0]} + statement.incrColumns.addParam(column, arg[0]) } else { - statement.incrColumns[k] = incrParam{column, 1} + statement.incrColumns.addParam(column, 1) } return statement } // Decr Generate "Update ... Set column = column - arg" statement func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { - k := strings.ToLower(column) if len(arg) > 0 { - statement.decrColumns[k] = decrParam{column, arg[0]} + statement.decrColumns.addParam(column, arg[0]) } else { - statement.decrColumns[k] = decrParam{column, 1} + statement.decrColumns.addParam(column, 1) } return statement } // SetExpr Generate "Update ... Set column = {expression}" statement -func (statement *Statement) SetExpr(column string, expression string) *Statement { - k := strings.ToLower(column) - statement.exprColumns[k] = exprParam{column, expression} +func (statement *Statement) SetExpr(column string, expression interface{}) *Statement { + statement.exprColumns.addParam(column, expression) return statement } -// Generate "Update ... Set column = column + arg" statement -func (statement *Statement) getInc() map[string]incrParam { - return statement.incrColumns -} - -// Generate "Update ... Set column = column - arg" statement -func (statement *Statement) getDec() map[string]decrParam { - return statement.decrColumns -} - -// Generate "Update ... Set column = {expression}" statement -func (statement *Statement) getExpr() map[string]exprParam { - return statement.exprColumns -} - func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { newColumns := make([]string, 0) quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") diff --git a/statement_args.go b/statement_args.go new file mode 100644 index 00000000..aec12240 --- /dev/null +++ b/statement_args.go @@ -0,0 +1,69 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "io" + + "xorm.io/builder" +) + +func writeArg(w *builder.BytesWriter, arg interface{}) error { + switch argv := arg.(type) { + case string: + if _, err := w.WriteString("'" + argv + "'"); err != nil { + return err + } + case *builder.Builder: + if err := argv.WriteTo(w); err != nil { + return err + } + default: + if _, err := w.WriteString(fmt.Sprintf("%v", argv)); err != nil { + return err + } + } + return nil +} + +func writeArgs(w *builder.BytesWriter, args []interface{}) error { + for i, arg := range args { + if err := writeArg(w, arg); err != nil { + return err + } + + if i+1 != len(args) { + if _, err := w.WriteString(","); err != nil { + return err + } + } + } + return nil +} + +func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error { + for i, colName := range cols { + if len(leftQuote) > 0 && colName[0] != '`' { + if _, err := io.WriteString(w, leftQuote); err != nil { + return err + } + } + if _, err := io.WriteString(w, colName); err != nil { + return err + } + if len(rightQuote) > 0 && colName[len(colName)-1] != '`' { + if _, err := io.WriteString(w, rightQuote); err != nil { + return err + } + } + if i+1 != len(cols) { + if _, err := io.WriteString(w, ","); err != nil { + return err + } + } + } + return nil +} diff --git a/statement_columnmap.go b/statement_columnmap.go new file mode 100644 index 00000000..b6523b1e --- /dev/null +++ b/statement_columnmap.go @@ -0,0 +1,35 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import "strings" + +type columnMap []string + +func (m columnMap) contain(colName string) bool { + if len(m) == 0 { + return false + } + + n := len(colName) + for _, mk := range m { + if len(mk) != n { + continue + } + if strings.EqualFold(mk, colName) { + return true + } + } + + return false +} + +func (m *columnMap) add(colName string) bool { + if m.contain(colName) { + return false + } + *m = append(*m, colName) + return true +} diff --git a/statement_exprparam.go b/statement_exprparam.go new file mode 100644 index 00000000..dfe3730c --- /dev/null +++ b/statement_exprparam.go @@ -0,0 +1,101 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "io" + "strings" + + "xorm.io/builder" +) + +type ErrUnsupportedExprType struct { + tp string +} + +func (err ErrUnsupportedExprType) Error() string { + return fmt.Sprintf("Unsupported expression type: %v", err.tp) +} + +type exprParam struct { + colName string + arg interface{} +} + +type exprParams struct { + colNames []string + args []interface{} +} + +func (exprs *exprParams) Len() int { + return len(exprs.colNames) +} + +func (exprs *exprParams) addParam(colName string, arg interface{}) { + exprs.colNames = append(exprs.colNames, colName) + exprs.args = append(exprs.args, arg) +} + +func (exprs *exprParams) isColExist(colName string) bool { + for _, name := range exprs.colNames { + if strings.EqualFold(trimQuote(name), trimQuote(colName)) { + return true + } + } + return false +} + +func (exprs *exprParams) getByName(colName string) (exprParam, bool) { + for i, name := range exprs.colNames { + if strings.EqualFold(name, colName) { + return exprParam{name, exprs.args[i]}, true + } + } + return exprParam{}, false +} + +func (exprs *exprParams) writeArgs(w builder.Writer) error { + for _, expr := range exprs.args { + switch arg := expr.(type) { + case *builder.Builder: + if err := arg.WriteTo(w); err != nil { + return err + } + default: + if _, err := io.WriteString(w, fmt.Sprintf("%v", arg)); err != nil { + return err + } + } + } + return nil +} + +func (exprs *exprParams) writeNameArgs(w builder.Writer) error { + for i, colName := range exprs.colNames { + if _, err := io.WriteString(w, colName); err != nil { + return err + } + if _, err := io.WriteString(w, "="); err != nil { + return err + } + + switch arg := exprs.args[i].(type) { + case *builder.Builder: + if err := arg.WriteTo(w); err != nil { + return err + } + default: + w.Append(exprs.args[i]) + } + + if i+1 != len(exprs.colNames) { + if _, err := io.WriteString(w, ","); err != nil { + return err + } + } + } + return nil +} diff --git a/statement_quote.go b/statement_quote.go new file mode 100644 index 00000000..e22e0d14 --- /dev/null +++ b/statement_quote.go @@ -0,0 +1,19 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +func trimQuote(s string) string { + if len(s) == 0 { + return s + } + + if s[0] == '`' { + s = s[1:] + } + if len(s) > 0 && s[len(s)-1] == '`' { + return s[:len(s)-1] + } + return s +} -- 2.40.1 From e7221d4aff0393601fe7ab1342c08be80947e02f Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 22 Sep 2019 20:50:00 +0800 Subject: [PATCH 2/2] fix tests --- session_insert.go | 27 ++++++++++++--------------- statement_args.go | 9 ++++----- statement_exprparam.go | 13 ++++++------- 3 files changed, 22 insertions(+), 27 deletions(-) diff --git a/session_insert.go b/session_insert.go index c943ae11..44cae7c4 100644 --- a/session_insert.go +++ b/session_insert.go @@ -7,7 +7,6 @@ package xorm import ( "errors" "fmt" - "io" "reflect" "sort" "strconv" @@ -355,22 +354,22 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } var buf = builder.NewWriter() - if _, err := io.WriteString(buf, fmt.Sprintf("INSERT INTO %s", session.engine.Quote(tableName))); err != nil { + if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s", session.engine.Quote(tableName))); err != nil { return 0, err } if len(colPlaces) <= 0 { if session.engine.dialect.DBType() == core.MYSQL { - if _, err := io.WriteString(buf, " VALUES ()"); err != nil { + if _, err := buf.WriteString(" VALUES ()"); err != nil { return 0, err } } else { - if _, err := io.WriteString(buf, fmt.Sprintf("%s DEFAULT VALUES", output)); err != nil { + if _, err := buf.WriteString(fmt.Sprintf("%s DEFAULT VALUES", output)); err != nil { return 0, err } } } else { - if _, err := io.WriteString(buf, " ("); err != nil { + if _, err := buf.WriteString(" ("); err != nil { return 0, err } @@ -379,9 +378,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } if session.statement.cond.IsValid() { - if _, err := io.WriteString(buf, fmt.Sprintf(")%s SELECT ", - output, - )); err != nil { + if _, err := buf.WriteString(fmt.Sprintf(")%s SELECT ", output)); err != nil { return 0, err } @@ -398,7 +395,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 0, err } - if _, err := io.WriteString(buf, fmt.Sprintf(" FROM %v WHERE ", session.engine.Quote(tableName))); err != nil { + if _, err := buf.WriteString(fmt.Sprintf(" FROM %v WHERE ", session.engine.Quote(tableName))); err != nil { return 0, err } @@ -408,7 +405,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } else { buf.Append(args...) - if _, err := io.WriteString(buf, fmt.Sprintf(")%s VALUES (%v", + if _, err := buf.WriteString(fmt.Sprintf(")%s VALUES (%v", output, colPlaces)); err != nil { return 0, err @@ -418,21 +415,21 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 0, err } - if _, err := io.WriteString(buf, ")"); err != nil { + if _, err := buf.WriteString(")"); err != nil { return 0, err } } } - sqlStr := buf.String() - args = buf.Args() - if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES { - if _, err := io.WriteString(buf, " RETURNING "+session.engine.Quote(table.AutoIncrement)); err != nil { + if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil { return 0, err } } + sqlStr := buf.String() + args = buf.Args() + handleAfterInsertProcessorFunc := func(bean interface{}) { if session.isAutoCommit { for _, closure := range session.afterClosures { diff --git a/statement_args.go b/statement_args.go index aec12240..c6168db1 100644 --- a/statement_args.go +++ b/statement_args.go @@ -6,7 +6,6 @@ package xorm import ( "fmt" - "io" "xorm.io/builder" ) @@ -47,20 +46,20 @@ func writeArgs(w *builder.BytesWriter, args []interface{}) error { func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error { for i, colName := range cols { if len(leftQuote) > 0 && colName[0] != '`' { - if _, err := io.WriteString(w, leftQuote); err != nil { + if _, err := w.WriteString(leftQuote); err != nil { return err } } - if _, err := io.WriteString(w, colName); err != nil { + if _, err := w.WriteString(colName); err != nil { return err } if len(rightQuote) > 0 && colName[len(colName)-1] != '`' { - if _, err := io.WriteString(w, rightQuote); err != nil { + if _, err := w.WriteString(rightQuote); err != nil { return err } } if i+1 != len(cols) { - if _, err := io.WriteString(w, ","); err != nil { + if _, err := w.WriteString(","); err != nil { return err } } diff --git a/statement_exprparam.go b/statement_exprparam.go index dfe3730c..a72f0aea 100644 --- a/statement_exprparam.go +++ b/statement_exprparam.go @@ -6,7 +6,6 @@ package xorm import ( "fmt" - "io" "strings" "xorm.io/builder" @@ -57,7 +56,7 @@ func (exprs *exprParams) getByName(colName string) (exprParam, bool) { return exprParam{}, false } -func (exprs *exprParams) writeArgs(w builder.Writer) error { +func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { for _, expr := range exprs.args { switch arg := expr.(type) { case *builder.Builder: @@ -65,7 +64,7 @@ func (exprs *exprParams) writeArgs(w builder.Writer) error { return err } default: - if _, err := io.WriteString(w, fmt.Sprintf("%v", arg)); err != nil { + if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil { return err } } @@ -73,12 +72,12 @@ func (exprs *exprParams) writeArgs(w builder.Writer) error { return nil } -func (exprs *exprParams) writeNameArgs(w builder.Writer) error { +func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { for i, colName := range exprs.colNames { - if _, err := io.WriteString(w, colName); err != nil { + if _, err := w.WriteString(colName); err != nil { return err } - if _, err := io.WriteString(w, "="); err != nil { + if _, err := w.WriteString("="); err != nil { return err } @@ -92,7 +91,7 @@ func (exprs *exprParams) writeNameArgs(w builder.Writer) error { } if i+1 != len(exprs.colNames) { - if _, err := io.WriteString(w, ","); err != nil { + if _, err := w.WriteString(","); err != nil { return err } } -- 2.40.1