Browse Source

Add support subquery on SetExpr (#1428)

* add support subquery on SetExpr

* fix tests
tags/v0.7.8
Lunny Xiao GitHub 2 months ago
parent
commit
6d11913765
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 445 additions and 193 deletions
  1. +1
    -1
      engine.go
  2. +1
    -1
      go.mod
  3. +2
    -2
      go.sum
  4. +1
    -1
      interface.go
  5. +1
    -44
      session_cols.go
  6. +158
    -95
      session_insert.go
  7. +15
    -0
      session_insert_test.go
  8. +32
    -19
      session_update.go
  9. +12
    -30
      statement.go
  10. +68
    -0
      statement_args.go
  11. +35
    -0
      statement_columnmap.go
  12. +100
    -0
      statement_exprparam.go
  13. +19
    -0
      statement_quote.go

+ 1
- 1
engine.go View File

@@ -729,7 +729,7 @@ func (engine *Engine) Decr(column string, arg ...interface{}) *Session {
}

// SetExpr provides a update string like "column = {expression}"
func (engine *Engine) SetExpr(column string, expression string) *Session {
func (engine *Engine) SetExpr(column string, expression interface{}) *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.SetExpr(column, expression)


+ 1
- 1
go.mod View File

@@ -15,6 +15,6 @@ require (
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 // indirect
github.com/stretchr/testify v1.3.0
github.com/ziutek/mymysql v1.5.4
xorm.io/builder v0.3.6-0.20190906062455-b937eb46ecfb
xorm.io/builder v0.3.6
xorm.io/core v0.7.0
)

+ 2
- 2
go.sum View File

@@ -158,7 +158,7 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
xorm.io/builder v0.3.6-0.20190906062455-b937eb46ecfb h1:2idZcp79ldX5qLeQ6WKCdS7aEFNOMvQc9wrtt5hSRwM=
xorm.io/builder v0.3.6-0.20190906062455-b937eb46ecfb/go.mod h1:LEFAPISnRzG+zxaxj2vPicRwz67BdhFreKg8yv8/TgU=
xorm.io/builder v0.3.6 h1:ha28mQ2M+TFx96Hxo+iq6tQgnkC9IZkM6D8w9sKHHF8=
xorm.io/builder v0.3.6/go.mod h1:LEFAPISnRzG+zxaxj2vPicRwz67BdhFreKg8yv8/TgU=
xorm.io/core v0.7.0 h1:hKxuOKWZNeiFQsSuGet/KV8HZ788hclvAl+7azx3tkM=
xorm.io/core v0.7.0/go.mod h1:TuOJjIVa7e3w/rN8tDcAvuLBMtwzdHPbyOzE6Gk1EUI=

+ 1
- 1
interface.go View File

@@ -54,7 +54,7 @@ type Interface interface {
QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error)
QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error)
Rows(bean interface{}) (*Rows, error)
SetExpr(string, string) *Session
SetExpr(string, interface{}) *Session
SQL(interface{}, ...interface{}) *Session
Sum(bean interface{}, colName string) (float64, error)
SumInt(bean interface{}, colName string) (int64, error)


+ 1
- 44
session_cols.go View File

@@ -12,49 +12,6 @@ import (
"xorm.io/core"
)

type incrParam struct {
colName string
arg interface{}
}

type decrParam struct {
colName string
arg interface{}
}

type exprParam struct {
colName string
expr string
}

type columnMap []string

func (m columnMap) contain(colName string) bool {
if len(m) == 0 {
return false
}

n := len(colName)
for _, mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, colName) {
return true
}
}

return false
}

func (m *columnMap) add(colName string) bool {
if m.contain(colName) {
return false
}
*m = append(*m, colName)
return true
}

func setColumnInt(bean interface{}, col *core.Column, t int64) {
v, err := col.ValueOf(bean)
if err != nil {
@@ -132,7 +89,7 @@ func (session *Session) Decr(column string, arg ...interface{}) *Session {
}

// SetExpr provides a query string like "column = {expression}"
func (session *Session) SetExpr(column string, expression string) *Session {
func (session *Session) SetExpr(column string, expression interface{}) *Session {
session.statement.SetExpr(column, expression)
return session
}


+ 158
- 95
session_insert.go View File

@@ -340,74 +340,96 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err != nil {
return 0, err
}
// insert expr columns, override if exists
exprColumns := session.statement.getExpr()
exprColVals := make([]string, 0, len(exprColumns))
for _, v := range exprColumns {
// remove the expr columns
for i, colName := range colNames {
if colName == strings.Trim(v.colName, "`") {
colNames = append(colNames[:i], colNames[i+1:]...)
args = append(args[:i], args[i+1:]...)
}
}

// append expr column to the end
colNames = append(colNames, v.colName)
exprColVals = append(exprColVals, v.expr)
}

colPlaces := strings.Repeat("?, ", len(colNames)-len(exprColumns))
if len(exprColVals) > 0 {
colPlaces = colPlaces + strings.Join(exprColVals, ", ")
} else {
if len(colPlaces) > 0 {
colPlaces = colPlaces[0 : len(colPlaces)-2]
}
exprs := session.statement.exprColumns
colPlaces := strings.Repeat("?, ", len(colNames))
if exprs.Len() <= 0 && len(colPlaces) > 0 {
colPlaces = colPlaces[0 : len(colPlaces)-2]
}

var sqlStr string
var tableName = session.statement.TableName()
var output string
if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
}

if len(colPlaces) > 0 {
if session.statement.cond.IsValid() {
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
var buf = builder.NewWriter()
if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s", session.engine.Quote(tableName))); err != nil {
return 0, err
}

if len(colPlaces) <= 0 {
if session.engine.dialect.DBType() == core.MYSQL {
if _, err := buf.WriteString(" VALUES ()"); err != nil {
return 0, err
}

sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s SELECT %v FROM %v WHERE %v",
session.engine.Quote(tableName),
quoteColumns(colNames, session.engine.Quote, ","),
output,
colPlaces,
session.engine.Quote(tableName),
condSQL,
)
args = append(args, condArgs...)
} else {
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
session.engine.Quote(tableName),
quoteColumns(colNames, session.engine.Quote, ","),
output,
colPlaces)
if _, err := buf.WriteString(fmt.Sprintf("%s DEFAULT VALUES", output)); err != nil {
return 0, err
}
}
} else {
if session.engine.dialect.DBType() == core.MYSQL {
sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName))
if _, err := buf.WriteString(" ("); err != nil {
return 0, err
}

if err := writeStrings(buf, append(colNames, exprs.colNames...), "`", "`"); err != nil {
return 0, err
}

if session.statement.cond.IsValid() {
if _, err := buf.WriteString(fmt.Sprintf(")%s SELECT ", output)); err != nil {
return 0, err
}

if err := writeArgs(buf, args); err != nil {
return 0, err
}

if len(exprs.args) > 0 {
if _, err := buf.WriteString(","); err != nil {
return 0, err
}
}
if err := exprs.writeArgs(buf); err != nil {
return 0, err
}

if _, err := buf.WriteString(fmt.Sprintf(" FROM %v WHERE ", session.engine.Quote(tableName))); err != nil {
return 0, err
}

if err := session.statement.cond.WriteTo(buf); err != nil {
return 0, err
}
} else {
sqlStr = fmt.Sprintf("INSERT INTO %s%s DEFAULT VALUES", session.engine.Quote(tableName), output)
buf.Append(args...)

if _, err := buf.WriteString(fmt.Sprintf(")%s VALUES (%v",
output,
colPlaces)); err != nil {
return 0, err
}

if err := exprs.writeArgs(buf); err != nil {
return 0, err
}

if _, err := buf.WriteString(")"); err != nil {
return 0, err
}
}
}

if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES {
sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement)
if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil {
return 0, err
}
}

sqlStr := buf.String()
args = buf.Args()

handleAfterInsertProcessorFunc := func(bean interface{}) {
if session.isAutoCommit {
for _, closure := range session.afterClosures {
@@ -611,9 +633,11 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
continue
}

if _, ok := session.statement.incrColumns[col.Name]; ok {
if session.statement.incrColumns.isColExist(col.Name) {
continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok {
} else if session.statement.decrColumns.isColExist(col.Name) {
continue
} else if session.statement.exprColumns.isColExist(col.Name) {
continue
}

@@ -688,46 +712,66 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
}

var columns = make([]string, 0, len(m))
exprs := session.statement.exprColumns
for k := range m {
columns = append(columns, k)
if !exprs.isColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)

qm := strings.Repeat("?,", len(columns))

var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}

// insert expr columns, override if exists
exprColumns := session.statement.getExpr()
for _, col := range exprColumns {
columns = append(columns, strings.Trim(col.colName, "`"))
qm = qm + col.expr + ","
}
w := builder.NewWriter()
if session.statement.cond.IsValid() {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err
}

qm = qm[:len(qm)-1]
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
return 0, err
}

var sql string
if _, err := w.WriteString(") SELECT "); err != nil {
return 0, err
}

if session.statement.cond.IsValid() {
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
if err := writeArgs(w, args); err != nil {
return 0, err
}

if len(exprs.args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.writeArgs(w); err != nil {
return 0, err
}
}

if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil {
return 0, err
}

if err := session.statement.cond.WriteTo(w); err != nil {
return 0, err
}
sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
session.engine.Quote(tableName),
strings.Join(columns, "`,`"),
qm,
session.engine.Quote(tableName),
condSQL,
)
args = append(args, condArgs...)
} else {
sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
qm := strings.Repeat("?,", len(columns))
qm = qm[:len(qm)-1]

if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil {
return 0, err
}
w.Append(args...)
}

sql := w.String()
args = w.Args()

if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
@@ -754,8 +798,11 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
}

var columns = make([]string, 0, len(m))
exprs := session.statement.exprColumns
for k := range m {
columns = append(columns, k)
if !exprs.isColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)

@@ -764,37 +811,53 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
args = append(args, m[colName])
}

qm := strings.Repeat("?,", len(columns))
w := builder.NewWriter()
if session.statement.cond.IsValid() {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err
}

// insert expr columns, override if exists
exprColumns := session.statement.getExpr()
for _, col := range exprColumns {
columns = append(columns, strings.Trim(col.colName, "`"))
qm = qm + col.expr + ","
}
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
return 0, err
}

qm = qm[:len(qm)-1]
if _, err := w.WriteString(") SELECT "); err != nil {
return 0, err
}

var sql string
if err := writeArgs(w, args); err != nil {
return 0, err
}

if session.statement.cond.IsValid() {
qm = "(" + qm[:len(qm)-1] + ")"
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
if len(exprs.args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.writeArgs(w); err != nil {
return 0, err
}
}

if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil {
return 0, err
}

if err := session.statement.cond.WriteTo(w); err != nil {
return 0, err
}
sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
session.engine.Quote(tableName),
strings.Join(columns, "`,`"),
qm,
session.engine.Quote(tableName),
condSQL,
)
args = append(args, condArgs...)
} else {
sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
qm := strings.Repeat("?,", len(columns))
qm = qm[:len(qm)-1]

if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil {
return 0, err
}
w.Append(args...)
}

sql := w.String()
args = w.Args()

if err := session.cacheInsert(tableName); err != nil {
return 0, err
}


+ 15
- 0
session_insert_test.go View File

@@ -892,4 +892,19 @@ func TestInsertWhere(t *testing.T) {
assert.EqualValues(t, 40, j2.Height)
assert.EqualValues(t, "trest2", j2.Name)
assert.EqualValues(t, 2, j2.Index)

inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1).
SetExpr("`index`", "coalesce(MAX(`index`),0)+1").
Insert(map[string]string{
"name": "trest3",
})
assert.NoError(t, err)
assert.EqualValues(t, 1, inserted)

var j3 InsertWhere
has, err = testEngine.ID(3).Get(&j3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "trest3", j3.Name)
assert.EqualValues(t, 3, j3.Index)
}

+ 32
- 19
session_update.go View File

@@ -223,21 +223,31 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}

// for update action to like "column = column + ?"
incColumns := session.statement.getInc()
for _, v := range incColumns {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?")
args = append(args, v.arg)
incColumns := session.statement.incrColumns
for i, colName := range incColumns.colNames {
colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?")
args = append(args, incColumns.args[i])
}
// for update action to like "column = column - ?"
decColumns := session.statement.getDec()
for _, v := range decColumns {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?")
args = append(args, v.arg)
decColumns := session.statement.decrColumns
for i, colName := range decColumns.colNames {
colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?")
args = append(args, decColumns.args[i])
}
// for update action to like "column = expression"
exprColumns := session.statement.getExpr()
for _, v := range exprColumns {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr)
exprColumns := session.statement.exprColumns
for i, colName := range exprColumns.colNames {
switch tp := exprColumns.args[i].(type) {
case string:
colNames = append(colNames, session.engine.Quote(colName)+" = "+tp)
case *builder.Builder:
subQuery, subArgs, err := builder.ToSQL(tp)
if err != nil {
return 0, err
}
colNames = append(colNames, session.engine.Quote(colName)+" = "+subQuery)
args = append(args, subArgs...)
}
}

if err = session.statement.processIDParam(); err != nil {
@@ -468,14 +478,17 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
continue
}

if len(session.statement.columnMap) > 0 {
if !session.statement.columnMap.contain(col.Name) {
continue
} else if _, ok := session.statement.incrColumns[col.Name]; ok {
continue
} else if _, ok := session.statement.decrColumns[col.Name]; ok {
continue
}
// if only update specify columns
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
continue
}

if session.statement.incrColumns.isColExist(col.Name) {
continue
} else if session.statement.decrColumns.isColExist(col.Name) {
continue
} else if session.statement.exprColumns.isColExist(col.Name) {
continue
}

// !evalphobia! set fieldValue as nil when column is nullable and zero-value


+ 12
- 30
statement.go View File

@@ -52,9 +52,9 @@ type Statement struct {
omitColumnMap columnMap
mustColumnMap map[string]bool
nullableMap map[string]bool
incrColumns map[string]incrParam
decrColumns map[string]decrParam
exprColumns map[string]exprParam
incrColumns exprParams
decrColumns exprParams
exprColumns exprParams
cond builder.Cond
bufferSize int
context ContextCache
@@ -94,9 +94,9 @@ func (statement *Statement) Init() {
statement.nullableMap = make(map[string]bool)
statement.checkVersion = true
statement.unscoped = false
statement.incrColumns = make(map[string]incrParam)
statement.decrColumns = make(map[string]decrParam)
statement.exprColumns = make(map[string]exprParam)
statement.incrColumns = exprParams{}
statement.decrColumns = exprParams{}
statement.exprColumns = exprParams{}
statement.cond = builder.NewCond()
statement.bufferSize = 0
statement.context = nil
@@ -534,48 +534,30 @@ func (statement *Statement) ID(id interface{}) *Statement {

// Incr Generate "Update ... Set column = column + arg" statement
func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
k := strings.ToLower(column)
if len(arg) > 0 {
statement.incrColumns[k] = incrParam{column, arg[0]}
statement.incrColumns.addParam(column, arg[0])
} else {
statement.incrColumns[k] = incrParam{column, 1}
statement.incrColumns.addParam(column, 1)
}
return statement
}

// Decr Generate "Update ... Set column = column - arg" statement
func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
k := strings.ToLower(column)
if len(arg) > 0 {
statement.decrColumns[k] = decrParam{column, arg[0]}
statement.decrColumns.addParam(column, arg[0])
} else {
statement.decrColumns[k] = decrParam{column, 1}
statement.decrColumns.addParam(column, 1)
}
return statement
}

// SetExpr Generate "Update ... Set column = {expression}" statement
func (statement *Statement) SetExpr(column string, expression string) *Statement {
k := strings.ToLower(column)
statement.exprColumns[k] = exprParam{column, expression}
func (statement *Statement) SetExpr(column string, expression interface{}) *Statement {
statement.exprColumns.addParam(column, expression)
return statement
}

// Generate "Update ... Set column = column + arg" statement
func (statement *Statement) getInc() map[string]incrParam {
return statement.incrColumns
}

// Generate "Update ... Set column = column - arg" statement
func (statement *Statement) getDec() map[string]decrParam {
return statement.decrColumns
}

// Generate "Update ... Set column = {expression}" statement
func (statement *Statement) getExpr() map[string]exprParam {
return statement.exprColumns
}

func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
newColumns := make([]string, 0)
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")


+ 68
- 0
statement_args.go View File

@@ -0,0 +1,68 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm

import (
"fmt"

"xorm.io/builder"
)

func writeArg(w *builder.BytesWriter, arg interface{}) error {
switch argv := arg.(type) {
case string:
if _, err := w.WriteString("'" + argv + "'"); err != nil {
return err
}
case *builder.Builder:
if err := argv.WriteTo(w); err != nil {
return err
}
default:
if _, err := w.WriteString(fmt.Sprintf("%v", argv)); err != nil {
return err
}
}
return nil
}

func writeArgs(w *builder.BytesWriter, args []interface{}) error {
for i, arg := range args {
if err := writeArg(w, arg); err != nil {
return err
}

if i+1 != len(args) {
if _, err := w.WriteString(","); err != nil {
return err
}
}
}
return nil
}

func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error {
for i, colName := range cols {
if len(leftQuote) > 0 && colName[0] != '`' {
if _, err := w.WriteString(leftQuote); err != nil {
return err
}
}
if _, err := w.WriteString(colName); err != nil {
return err
}
if len(rightQuote) > 0 && colName[len(colName)-1] != '`' {
if _, err := w.WriteString(rightQuote); err != nil {
return err
}
}
if i+1 != len(cols) {
if _, err := w.WriteString(","); err != nil {
return err
}
}
}
return nil
}

+ 35
- 0
statement_columnmap.go View File

@@ -0,0 +1,35 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm

import "strings"

type columnMap []string

func (m columnMap) contain(colName string) bool {
if len(m) == 0 {
return false
}

n := len(colName)
for _, mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, colName) {
return true
}
}

return false
}

func (m *columnMap) add(colName string) bool {
if m.contain(colName) {
return false
}
*m = append(*m, colName)
return true
}

+ 100
- 0
statement_exprparam.go View File

@@ -0,0 +1,100 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm

import (
"fmt"
"strings"

"xorm.io/builder"
)

type ErrUnsupportedExprType struct {
tp string
}

func (err ErrUnsupportedExprType) Error() string {
return fmt.Sprintf("Unsupported expression type: %v", err.tp)
}

type exprParam struct {
colName string
arg interface{}
}

type exprParams struct {
colNames []string
args []interface{}
}

func (exprs *exprParams) Len() int {
return len(exprs.colNames)
}

func (exprs *exprParams) addParam(colName string, arg interface{}) {
exprs.colNames = append(exprs.colNames, colName)
exprs.args = append(exprs.args, arg)
}

func (exprs *exprParams) isColExist(colName string) bool {
for _, name := range exprs.colNames {
if strings.EqualFold(trimQuote(name), trimQuote(colName)) {
return true
}
}
return false
}

func (exprs *exprParams) getByName(colName string) (exprParam, bool) {
for i, name := range exprs.colNames {
if strings.EqualFold(name, colName) {
return exprParam{name, exprs.args[i]}, true
}
}
return exprParam{}, false
}

func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
for _, expr := range exprs.args {
switch arg := expr.(type) {
case *builder.Builder:
if err := arg.WriteTo(w); err != nil {
return err
}
default:
if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil {
return err
}
}
}
return nil
}

func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error {
for i, colName := range exprs.colNames {
if _, err := w.WriteString(colName); err != nil {
return err
}
if _, err := w.WriteString("="); err != nil {
return err
}

switch arg := exprs.args[i].(type) {
case *builder.Builder:
if err := arg.WriteTo(w); err != nil {
return err
}
default:
w.Append(exprs.args[i])
}

if i+1 != len(exprs.colNames) {
if _, err := w.WriteString(","); err != nil {
return err
}
}
}
return nil
}

+ 19
- 0
statement_quote.go View File

@@ -0,0 +1,19 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm

func trimQuote(s string) string {
if len(s) == 0 {
return s
}

if s[0] == '`' {
s = s[1:]
}
if len(s) > 0 && s[len(s)-1] == '`' {
return s[:len(s)-1]
}
return s
}

Loading…
Cancel
Save