Add support subquery on SetExpr #1428

Merged
lunny merged 3 commits from lunny/support_set_expr_builder into master 2019-09-23 15:34:27 +00:00
13 changed files with 450 additions and 198 deletions

View File

@ -729,7 +729,7 @@ func (engine *Engine) Decr(column string, arg ...interface{}) *Session {
} }
// SetExpr provides a update string like "column = {expression}" // SetExpr provides a update string like "column = {expression}"
func (engine *Engine) SetExpr(column string, expression string) *Session { func (engine *Engine) SetExpr(column string, expression interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.isAutoClose = true session.isAutoClose = true
return session.SetExpr(column, expression) return session.SetExpr(column, expression)

2
go.mod
View File

@ -15,6 +15,6 @@ require (
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 // indirect github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 // indirect
github.com/stretchr/testify v1.3.0 github.com/stretchr/testify v1.3.0
github.com/ziutek/mymysql v1.5.4 github.com/ziutek/mymysql v1.5.4
xorm.io/builder v0.3.6-0.20190906062455-b937eb46ecfb xorm.io/builder v0.3.6
xorm.io/core v0.7.0 xorm.io/core v0.7.0
) )

4
go.sum
View File

@ -158,7 +158,7 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
xorm.io/builder v0.3.6-0.20190906062455-b937eb46ecfb h1:2idZcp79ldX5qLeQ6WKCdS7aEFNOMvQc9wrtt5hSRwM= xorm.io/builder v0.3.6 h1:ha28mQ2M+TFx96Hxo+iq6tQgnkC9IZkM6D8w9sKHHF8=
xorm.io/builder v0.3.6-0.20190906062455-b937eb46ecfb/go.mod h1:LEFAPISnRzG+zxaxj2vPicRwz67BdhFreKg8yv8/TgU= xorm.io/builder v0.3.6/go.mod h1:LEFAPISnRzG+zxaxj2vPicRwz67BdhFreKg8yv8/TgU=
xorm.io/core v0.7.0 h1:hKxuOKWZNeiFQsSuGet/KV8HZ788hclvAl+7azx3tkM= xorm.io/core v0.7.0 h1:hKxuOKWZNeiFQsSuGet/KV8HZ788hclvAl+7azx3tkM=
xorm.io/core v0.7.0/go.mod h1:TuOJjIVa7e3w/rN8tDcAvuLBMtwzdHPbyOzE6Gk1EUI= xorm.io/core v0.7.0/go.mod h1:TuOJjIVa7e3w/rN8tDcAvuLBMtwzdHPbyOzE6Gk1EUI=

View File

@ -54,7 +54,7 @@ type Interface interface {
QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error)
QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error) QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error)
Rows(bean interface{}) (*Rows, error) Rows(bean interface{}) (*Rows, error)
SetExpr(string, string) *Session SetExpr(string, interface{}) *Session
SQL(interface{}, ...interface{}) *Session SQL(interface{}, ...interface{}) *Session
Sum(bean interface{}, colName string) (float64, error) Sum(bean interface{}, colName string) (float64, error)
SumInt(bean interface{}, colName string) (int64, error) SumInt(bean interface{}, colName string) (int64, error)

View File

@ -12,49 +12,6 @@ import (
"xorm.io/core" "xorm.io/core"
) )
type incrParam struct {
colName string
arg interface{}
}
type decrParam struct {
colName string
arg interface{}
}
type exprParam struct {
colName string
expr string
}
type columnMap []string
func (m columnMap) contain(colName string) bool {
if len(m) == 0 {
return false
}
n := len(colName)
for _, mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, colName) {
return true
}
}
return false
}
func (m *columnMap) add(colName string) bool {
if m.contain(colName) {
return false
}
*m = append(*m, colName)
return true
}
func setColumnInt(bean interface{}, col *core.Column, t int64) { func setColumnInt(bean interface{}, col *core.Column, t int64) {
v, err := col.ValueOf(bean) v, err := col.ValueOf(bean)
if err != nil { if err != nil {
@ -132,7 +89,7 @@ func (session *Session) Decr(column string, arg ...interface{}) *Session {
} }
// SetExpr provides a query string like "column = {expression}" // SetExpr provides a query string like "column = {expression}"
func (session *Session) SetExpr(column string, expression string) *Session { func (session *Session) SetExpr(column string, expression interface{}) *Session {
session.statement.SetExpr(column, expression) session.statement.SetExpr(column, expression)
return session return session
} }

View File

@ -340,74 +340,96 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
// insert expr columns, override if exists
exprColumns := session.statement.getExpr()
exprColVals := make([]string, 0, len(exprColumns))
for _, v := range exprColumns {
// remove the expr columns
for i, colName := range colNames {
if colName == strings.Trim(v.colName, "`") {
colNames = append(colNames[:i], colNames[i+1:]...)
args = append(args[:i], args[i+1:]...)
}
}
// append expr column to the end exprs := session.statement.exprColumns
colNames = append(colNames, v.colName) colPlaces := strings.Repeat("?, ", len(colNames))
exprColVals = append(exprColVals, v.expr) if exprs.Len() <= 0 && len(colPlaces) > 0 {
colPlaces = colPlaces[0 : len(colPlaces)-2]
} }
colPlaces := strings.Repeat("?, ", len(colNames)-len(exprColumns))
if len(exprColVals) > 0 {
colPlaces = colPlaces + strings.Join(exprColVals, ", ")
} else {
if len(colPlaces) > 0 {
colPlaces = colPlaces[0 : len(colPlaces)-2]
}
}
var sqlStr string
var tableName = session.statement.TableName() var tableName = session.statement.TableName()
var output string var output string
if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 { if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
} }
if len(colPlaces) > 0 { var buf = builder.NewWriter()
if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s", session.engine.Quote(tableName))); err != nil {
return 0, err
}
if len(colPlaces) <= 0 {
if session.engine.dialect.DBType() == core.MYSQL {
if _, err := buf.WriteString(" VALUES ()"); err != nil {
return 0, err
}
} else {
if _, err := buf.WriteString(fmt.Sprintf("%s DEFAULT VALUES", output)); err != nil {
return 0, err
}
}
} else {
if _, err := buf.WriteString(" ("); err != nil {
return 0, err
}
if err := writeStrings(buf, append(colNames, exprs.colNames...), "`", "`"); err != nil {
return 0, err
}
if session.statement.cond.IsValid() { if session.statement.cond.IsValid() {
condSQL, condArgs, err := builder.ToSQL(session.statement.cond) if _, err := buf.WriteString(fmt.Sprintf(")%s SELECT ", output)); err != nil {
if err != nil {
return 0, err return 0, err
} }
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s SELECT %v FROM %v WHERE %v", if err := writeArgs(buf, args); err != nil {
session.engine.Quote(tableName), return 0, err
quoteColumns(colNames, session.engine.Quote, ","), }
output,
colPlaces, if len(exprs.args) > 0 {
session.engine.Quote(tableName), if _, err := buf.WriteString(","); err != nil {
condSQL, return 0, err
) }
args = append(args, condArgs...) }
if err := exprs.writeArgs(buf); err != nil {
return 0, err
}
if _, err := buf.WriteString(fmt.Sprintf(" FROM %v WHERE ", session.engine.Quote(tableName))); err != nil {
return 0, err
}
if err := session.statement.cond.WriteTo(buf); err != nil {
return 0, err
}
} else { } else {
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)", buf.Append(args...)
session.engine.Quote(tableName),
quoteColumns(colNames, session.engine.Quote, ","), if _, err := buf.WriteString(fmt.Sprintf(")%s VALUES (%v",
output, output,
colPlaces) colPlaces)); err != nil {
} return 0, err
} else { }
if session.engine.dialect.DBType() == core.MYSQL {
sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName)) if err := exprs.writeArgs(buf); err != nil {
} else { return 0, err
sqlStr = fmt.Sprintf("INSERT INTO %s%s DEFAULT VALUES", session.engine.Quote(tableName), output) }
if _, err := buf.WriteString(")"); err != nil {
return 0, err
}
} }
} }
if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES { if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES {
sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement) if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil {
return 0, err
}
} }
sqlStr := buf.String()
args = buf.Args()
handleAfterInsertProcessorFunc := func(bean interface{}) { handleAfterInsertProcessorFunc := func(bean interface{}) {
if session.isAutoCommit { if session.isAutoCommit {
for _, closure := range session.afterClosures { for _, closure := range session.afterClosures {
@ -611,9 +633,11 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
continue continue
} }
if _, ok := session.statement.incrColumns[col.Name]; ok { if session.statement.incrColumns.isColExist(col.Name) {
continue continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok { } else if session.statement.decrColumns.isColExist(col.Name) {
continue
} else if session.statement.exprColumns.isColExist(col.Name) {
continue continue
} }
@ -688,46 +712,66 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
} }
var columns = make([]string, 0, len(m)) var columns = make([]string, 0, len(m))
exprs := session.statement.exprColumns
for k := range m { for k := range m {
columns = append(columns, k) if !exprs.isColExist(k) {
columns = append(columns, k)
}
} }
sort.Strings(columns) sort.Strings(columns)
qm := strings.Repeat("?,", len(columns))
var args = make([]interface{}, 0, len(m)) var args = make([]interface{}, 0, len(m))
for _, colName := range columns { for _, colName := range columns {
args = append(args, m[colName]) args = append(args, m[colName])
} }
// insert expr columns, override if exists w := builder.NewWriter()
exprColumns := session.statement.getExpr()
for _, col := range exprColumns {
columns = append(columns, strings.Trim(col.colName, "`"))
qm = qm + col.expr + ","
}
qm = qm[:len(qm)-1]
var sql string
if session.statement.cond.IsValid() { if session.statement.cond.IsValid() {
condSQL, condArgs, err := builder.ToSQL(session.statement.cond) if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
if err != nil { return 0, err
}
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
return 0, err
}
if _, err := w.WriteString(") SELECT "); err != nil {
return 0, err
}
if err := writeArgs(w, args); err != nil {
return 0, err
}
if len(exprs.args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.writeArgs(w); err != nil {
return 0, err
}
}
if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil {
return 0, err
}
if err := session.statement.cond.WriteTo(w); err != nil {
return 0, err return 0, err
} }
sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
session.engine.Quote(tableName),
strings.Join(columns, "`,`"),
qm,
session.engine.Quote(tableName),
condSQL,
)
args = append(args, condArgs...)
} else { } else {
sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm) qm := strings.Repeat("?,", len(columns))
qm = qm[:len(qm)-1]
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil {
return 0, err
}
w.Append(args...)
} }
sql := w.String()
args = w.Args()
if err := session.cacheInsert(tableName); err != nil { if err := session.cacheInsert(tableName); err != nil {
return 0, err return 0, err
} }
@ -754,8 +798,11 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
} }
var columns = make([]string, 0, len(m)) var columns = make([]string, 0, len(m))
exprs := session.statement.exprColumns
for k := range m { for k := range m {
columns = append(columns, k) if !exprs.isColExist(k) {
columns = append(columns, k)
}
} }
sort.Strings(columns) sort.Strings(columns)
@ -764,37 +811,53 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
args = append(args, m[colName]) args = append(args, m[colName])
} }
qm := strings.Repeat("?,", len(columns)) w := builder.NewWriter()
// insert expr columns, override if exists
exprColumns := session.statement.getExpr()
for _, col := range exprColumns {
columns = append(columns, strings.Trim(col.colName, "`"))
qm = qm + col.expr + ","
}
qm = qm[:len(qm)-1]
var sql string
if session.statement.cond.IsValid() { if session.statement.cond.IsValid() {
qm = "(" + qm[:len(qm)-1] + ")" if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
condSQL, condArgs, err := builder.ToSQL(session.statement.cond) return 0, err
if err != nil { }
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
return 0, err
}
if _, err := w.WriteString(") SELECT "); err != nil {
return 0, err
}
if err := writeArgs(w, args); err != nil {
return 0, err
}
if len(exprs.args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.writeArgs(w); err != nil {
return 0, err
}
}
if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil {
return 0, err
}
if err := session.statement.cond.WriteTo(w); err != nil {
return 0, err return 0, err
} }
sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
session.engine.Quote(tableName),
strings.Join(columns, "`,`"),
qm,
session.engine.Quote(tableName),
condSQL,
)
args = append(args, condArgs...)
} else { } else {
sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm) qm := strings.Repeat("?,", len(columns))
qm = qm[:len(qm)-1]
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil {
return 0, err
}
w.Append(args...)
} }
sql := w.String()
args = w.Args()
if err := session.cacheInsert(tableName); err != nil { if err := session.cacheInsert(tableName); err != nil {
return 0, err return 0, err
} }

View File

@ -892,4 +892,19 @@ func TestInsertWhere(t *testing.T) {
assert.EqualValues(t, 40, j2.Height) assert.EqualValues(t, 40, j2.Height)
assert.EqualValues(t, "trest2", j2.Name) assert.EqualValues(t, "trest2", j2.Name)
assert.EqualValues(t, 2, j2.Index) assert.EqualValues(t, 2, j2.Index)
inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1).
SetExpr("`index`", "coalesce(MAX(`index`),0)+1").
Insert(map[string]string{
"name": "trest3",
})
assert.NoError(t, err)
assert.EqualValues(t, 1, inserted)
var j3 InsertWhere
has, err = testEngine.ID(3).Get(&j3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "trest3", j3.Name)
assert.EqualValues(t, 3, j3.Index)
} }

View File

@ -223,21 +223,31 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
// for update action to like "column = column + ?" // for update action to like "column = column + ?"
incColumns := session.statement.getInc() incColumns := session.statement.incrColumns
for _, v := range incColumns { for i, colName := range incColumns.colNames {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?") colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?")
args = append(args, v.arg) args = append(args, incColumns.args[i])
} }
// for update action to like "column = column - ?" // for update action to like "column = column - ?"
decColumns := session.statement.getDec() decColumns := session.statement.decrColumns
for _, v := range decColumns { for i, colName := range decColumns.colNames {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?") colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?")
args = append(args, v.arg) args = append(args, decColumns.args[i])
} }
// for update action to like "column = expression" // for update action to like "column = expression"
exprColumns := session.statement.getExpr() exprColumns := session.statement.exprColumns
for _, v := range exprColumns { for i, colName := range exprColumns.colNames {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr) switch tp := exprColumns.args[i].(type) {
case string:
colNames = append(colNames, session.engine.Quote(colName)+" = "+tp)
case *builder.Builder:
subQuery, subArgs, err := builder.ToSQL(tp)
if err != nil {
return 0, err
}
colNames = append(colNames, session.engine.Quote(colName)+" = "+subQuery)
args = append(args, subArgs...)
}
} }
if err = session.statement.processIDParam(); err != nil { if err = session.statement.processIDParam(); err != nil {
@ -468,14 +478,17 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
continue continue
} }
if len(session.statement.columnMap) > 0 { // if only update specify columns
if !session.statement.columnMap.contain(col.Name) { if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
continue continue
} else if _, ok := session.statement.incrColumns[col.Name]; ok { }
continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok { if session.statement.incrColumns.isColExist(col.Name) {
continue continue
} } else if session.statement.decrColumns.isColExist(col.Name) {
continue
} else if session.statement.exprColumns.isColExist(col.Name) {
continue
} }
// !evalphobia! set fieldValue as nil when column is nullable and zero-value // !evalphobia! set fieldValue as nil when column is nullable and zero-value

View File

@ -52,9 +52,9 @@ type Statement struct {
omitColumnMap columnMap omitColumnMap columnMap
mustColumnMap map[string]bool mustColumnMap map[string]bool
nullableMap map[string]bool nullableMap map[string]bool
incrColumns map[string]incrParam incrColumns exprParams
decrColumns map[string]decrParam decrColumns exprParams
exprColumns map[string]exprParam exprColumns exprParams
cond builder.Cond cond builder.Cond
bufferSize int bufferSize int
context ContextCache context ContextCache
@ -94,9 +94,9 @@ func (statement *Statement) Init() {
statement.nullableMap = make(map[string]bool) statement.nullableMap = make(map[string]bool)
statement.checkVersion = true statement.checkVersion = true
statement.unscoped = false statement.unscoped = false
statement.incrColumns = make(map[string]incrParam) statement.incrColumns = exprParams{}
statement.decrColumns = make(map[string]decrParam) statement.decrColumns = exprParams{}
statement.exprColumns = make(map[string]exprParam) statement.exprColumns = exprParams{}
statement.cond = builder.NewCond() statement.cond = builder.NewCond()
statement.bufferSize = 0 statement.bufferSize = 0
statement.context = nil statement.context = nil
@ -534,48 +534,30 @@ func (statement *Statement) ID(id interface{}) *Statement {
// Incr Generate "Update ... Set column = column + arg" statement // Incr Generate "Update ... Set column = column + arg" statement
func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
k := strings.ToLower(column)
if len(arg) > 0 { if len(arg) > 0 {
statement.incrColumns[k] = incrParam{column, arg[0]} statement.incrColumns.addParam(column, arg[0])
} else { } else {
statement.incrColumns[k] = incrParam{column, 1} statement.incrColumns.addParam(column, 1)
} }
return statement return statement
} }
// Decr Generate "Update ... Set column = column - arg" statement // Decr Generate "Update ... Set column = column - arg" statement
func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
k := strings.ToLower(column)
if len(arg) > 0 { if len(arg) > 0 {
statement.decrColumns[k] = decrParam{column, arg[0]} statement.decrColumns.addParam(column, arg[0])
} else { } else {
statement.decrColumns[k] = decrParam{column, 1} statement.decrColumns.addParam(column, 1)
} }
return statement return statement
} }
// SetExpr Generate "Update ... Set column = {expression}" statement // SetExpr Generate "Update ... Set column = {expression}" statement
func (statement *Statement) SetExpr(column string, expression string) *Statement { func (statement *Statement) SetExpr(column string, expression interface{}) *Statement {
k := strings.ToLower(column) statement.exprColumns.addParam(column, expression)
statement.exprColumns[k] = exprParam{column, expression}
return statement return statement
} }
// Generate "Update ... Set column = column + arg" statement
func (statement *Statement) getInc() map[string]incrParam {
return statement.incrColumns
}
// Generate "Update ... Set column = column - arg" statement
func (statement *Statement) getDec() map[string]decrParam {
return statement.decrColumns
}
// Generate "Update ... Set column = {expression}" statement
func (statement *Statement) getExpr() map[string]exprParam {
return statement.exprColumns
}
func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
newColumns := make([]string, 0) newColumns := make([]string, 0)
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")

68
statement_args.go Normal file
View File

@ -0,0 +1,68 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"fmt"
"xorm.io/builder"
)
func writeArg(w *builder.BytesWriter, arg interface{}) error {
switch argv := arg.(type) {
case string:
if _, err := w.WriteString("'" + argv + "'"); err != nil {
return err
}
case *builder.Builder:
if err := argv.WriteTo(w); err != nil {
return err
}
default:
if _, err := w.WriteString(fmt.Sprintf("%v", argv)); err != nil {
return err
}
}
return nil
}
func writeArgs(w *builder.BytesWriter, args []interface{}) error {
for i, arg := range args {
if err := writeArg(w, arg); err != nil {
return err
}
if i+1 != len(args) {
if _, err := w.WriteString(","); err != nil {
return err
}
}
}
return nil
}
func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error {
for i, colName := range cols {
if len(leftQuote) > 0 && colName[0] != '`' {
if _, err := w.WriteString(leftQuote); err != nil {
return err
}
}
if _, err := w.WriteString(colName); err != nil {
return err
}
if len(rightQuote) > 0 && colName[len(colName)-1] != '`' {
if _, err := w.WriteString(rightQuote); err != nil {
return err
}
}
if i+1 != len(cols) {
if _, err := w.WriteString(","); err != nil {
return err
}
}
}
return nil
}

35
statement_columnmap.go Normal file
View File

@ -0,0 +1,35 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import "strings"
type columnMap []string
func (m columnMap) contain(colName string) bool {
if len(m) == 0 {
return false
}
n := len(colName)
for _, mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, colName) {
return true
}
}
return false
}
func (m *columnMap) add(colName string) bool {
if m.contain(colName) {
return false
}
*m = append(*m, colName)
return true
}

100
statement_exprparam.go Normal file
View File

@ -0,0 +1,100 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"fmt"
"strings"
"xorm.io/builder"
)
type ErrUnsupportedExprType struct {
tp string
}
func (err ErrUnsupportedExprType) Error() string {
return fmt.Sprintf("Unsupported expression type: %v", err.tp)
}
type exprParam struct {
colName string
arg interface{}
}
type exprParams struct {
colNames []string
args []interface{}
}
func (exprs *exprParams) Len() int {
return len(exprs.colNames)
}
func (exprs *exprParams) addParam(colName string, arg interface{}) {
exprs.colNames = append(exprs.colNames, colName)
exprs.args = append(exprs.args, arg)
}
func (exprs *exprParams) isColExist(colName string) bool {
for _, name := range exprs.colNames {
if strings.EqualFold(trimQuote(name), trimQuote(colName)) {
return true
}
}
return false
}
func (exprs *exprParams) getByName(colName string) (exprParam, bool) {
for i, name := range exprs.colNames {
if strings.EqualFold(name, colName) {
return exprParam{name, exprs.args[i]}, true
}
}
return exprParam{}, false
}
func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
for _, expr := range exprs.args {
switch arg := expr.(type) {
case *builder.Builder:
if err := arg.WriteTo(w); err != nil {
return err
}
default:
if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil {
return err
}
}
}
return nil
}
func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error {
for i, colName := range exprs.colNames {
if _, err := w.WriteString(colName); err != nil {
return err
}
if _, err := w.WriteString("="); err != nil {
return err
}
switch arg := exprs.args[i].(type) {
case *builder.Builder:
if err := arg.WriteTo(w); err != nil {
return err
}
default:
w.Append(exprs.args[i])
}
if i+1 != len(exprs.colNames) {
if _, err := w.WriteString(","); err != nil {
return err
}
}
}
return nil
}

19
statement_quote.go Normal file
View File

@ -0,0 +1,19 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
func trimQuote(s string) string {
if len(s) == 0 {
return s
}
if s[0] == '`' {
s = s[1:]
}
if len(s) > 0 && s[len(s)-1] == '`' {
return s[:len(s)-1]
}
return s
}