diff --git a/internal/statements/expr.go b/internal/statements/expr.go new file mode 100644 index 00000000..b44c96ca --- /dev/null +++ b/internal/statements/expr.go @@ -0,0 +1,93 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "xorm.io/builder" + "xorm.io/xorm/schemas" +) + +// ErrUnsupportedExprType represents an error with unsupported express type +type ErrUnsupportedExprType struct { + tp string +} + +func (err ErrUnsupportedExprType) Error() string { + return fmt.Sprintf("Unsupported expression type: %v", err.tp) +} + +// Expr represents an SQL express +type Expr struct { + ColName string + Arg interface{} +} + +func (expr *Expr) WriteArgs(w *builder.BytesWriter) error { + switch arg := expr.Arg.(type) { + case *builder.Builder: + if _, err := w.WriteString("("); err != nil { + return err + } + if err := arg.WriteTo(w); err != nil { + return err + } + if _, err := w.WriteString(")"); err != nil { + return err + } + case string: + if arg == "" { + arg = "''" + } + if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil { + return err + } + default: + if _, err := w.WriteString("?"); err != nil { + return err + } + w.Append(arg) + } + return nil +} + +type exprParams []Expr + +func (exprs exprParams) ColNames() []string { + var cols = make([]string, 0, len(exprs)) + for _, expr := range exprs { + cols = append(cols, expr.ColName) + } + return cols +} + +func (exprs *exprParams) Add(name string, arg interface{}) { + *exprs = append(*exprs, Expr{name, arg}) +} + +func (exprs exprParams) IsColExist(colName string) bool { + for _, expr := range exprs { + if strings.EqualFold(schemas.CommonQuoter.Trim(expr.ColName), schemas.CommonQuoter.Trim(colName)) { + return true + } + } + return false +} + +func (exprs exprParams) WriteArgs(w *builder.BytesWriter) error { + for i, expr := range exprs { + if err := expr.WriteArgs(w); err != nil { + return err + } + if i != len(exprs)-1 { + if _, err := w.WriteString(","); err != nil { + return err + } + } + } + return nil +} diff --git a/internal/statements/expr_param.go b/internal/statements/expr_param.go deleted file mode 100644 index d0c355d3..00000000 --- a/internal/statements/expr_param.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2019 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package statements - -import ( - "fmt" - "strings" - - "xorm.io/builder" - "xorm.io/xorm/schemas" -) - -// ErrUnsupportedExprType represents an error with unsupported express type -type ErrUnsupportedExprType struct { - tp string -} - -func (err ErrUnsupportedExprType) Error() string { - return fmt.Sprintf("Unsupported expression type: %v", err.tp) -} - -type exprParam struct { - colName string - arg interface{} -} - -type exprParams struct { - ColNames []string - Args []interface{} -} - -func (exprs *exprParams) Len() int { - return len(exprs.ColNames) -} - -func (exprs *exprParams) addParam(colName string, arg interface{}) { - exprs.ColNames = append(exprs.ColNames, colName) - exprs.Args = append(exprs.Args, arg) -} - -func (exprs *exprParams) IsColExist(colName string) bool { - for _, name := range exprs.ColNames { - if strings.EqualFold(schemas.CommonQuoter.Trim(name), schemas.CommonQuoter.Trim(colName)) { - return true - } - } - return false -} - -func (exprs *exprParams) getByName(colName string) (exprParam, bool) { - for i, name := range exprs.ColNames { - if strings.EqualFold(name, colName) { - return exprParam{name, exprs.Args[i]}, true - } - } - return exprParam{}, false -} - -func (exprs *exprParams) WriteArgs(w *builder.BytesWriter) error { - for i, expr := range exprs.Args { - switch arg := expr.(type) { - case *builder.Builder: - if _, err := w.WriteString("("); err != nil { - return err - } - if err := arg.WriteTo(w); err != nil { - return err - } - if _, err := w.WriteString(")"); err != nil { - return err - } - case string: - if arg == "" { - arg = "''" - } - if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil { - return err - } - default: - if _, err := w.WriteString("?"); err != nil { - return err - } - w.Append(arg) - } - if i != len(exprs.Args)-1 { - if _, err := w.WriteString(","); err != nil { - return err - } - } - } - return nil -} - -func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { - for i, colName := range exprs.ColNames { - if _, err := w.WriteString(colName); err != nil { - return err - } - if _, err := w.WriteString("="); err != nil { - return err - } - - switch arg := exprs.Args[i].(type) { - case *builder.Builder: - if _, err := w.WriteString("("); err != nil { - return err - } - if err := arg.WriteTo(w); err != nil { - return err - } - if _, err := w.WriteString("("); err != nil { - return err - } - default: - w.Append(exprs.Args[i]) - } - - if i+1 != len(exprs.ColNames) { - if _, err := w.WriteString(","); err != nil { - return err - } - } - } - return nil -} diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 6cbbbeda..367dbdc9 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -59,7 +59,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames...), ","); err != nil { + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames()...), ","); err != nil { return "", nil, err } @@ -79,7 +79,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if len(exprs.Args) > 0 { + if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err } @@ -112,7 +112,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if len(exprs.Args) > 0 { + if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err } @@ -152,7 +152,7 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{} return "", nil, err } - if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames...), ","); err != nil { + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil { return "", nil, err } @@ -166,7 +166,7 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{} return "", nil, err } - if len(exprs.Args) > 0 { + if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err } @@ -190,7 +190,7 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{} return "", nil, err } - if len(exprs.Args) > 0 { + if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err } diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 3dd036a6..a52c6ca2 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -324,9 +324,9 @@ func (statement *Statement) TableName() string { // Incr Generate "Update ... Set column = column + arg" statement func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { if len(arg) > 0 { - statement.IncrColumns.addParam(column, arg[0]) + statement.IncrColumns.Add(column, arg[0]) } else { - statement.IncrColumns.addParam(column, 1) + statement.IncrColumns.Add(column, 1) } return statement } @@ -334,9 +334,9 @@ func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { // Decr Generate "Update ... Set column = column - arg" statement func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { if len(arg) > 0 { - statement.DecrColumns.addParam(column, arg[0]) + statement.DecrColumns.Add(column, arg[0]) } else { - statement.DecrColumns.addParam(column, 1) + statement.DecrColumns.Add(column, 1) } return statement } @@ -344,9 +344,9 @@ func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { // SetExpr Generate "Update ... Set column = {expression}" statement func (statement *Statement) SetExpr(column string, expression interface{}) *Statement { if e, ok := expression.(string); ok { - statement.ExprColumns.addParam(column, statement.dialect.Quoter().Replace(e)) + statement.ExprColumns.Add(column, statement.dialect.Quoter().Replace(e)) } else { - statement.ExprColumns.addParam(column, expression) + statement.ExprColumns.Add(column, expression) } return statement } diff --git a/session_update.go b/session_update.go index 0adac25e..9e4cddb1 100644 --- a/session_update.go +++ b/session_update.go @@ -224,35 +224,35 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 // for update action to like "column = column + ?" incColumns := session.statement.IncrColumns - for i, colName := range incColumns.ColNames { - colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?") - args = append(args, incColumns.Args[i]) + for _, expr := range incColumns { + colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" + ?") + args = append(args, expr.Arg) } // for update action to like "column = column - ?" decColumns := session.statement.DecrColumns - for i, colName := range decColumns.ColNames { - colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?") - args = append(args, decColumns.Args[i]) + for _, expr := range decColumns { + colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" - ?") + args = append(args, expr.Arg) } // for update action to like "column = expression" exprColumns := session.statement.ExprColumns - for i, colName := range exprColumns.ColNames { - switch tp := exprColumns.Args[i].(type) { + for _, expr := range exprColumns { + switch tp := expr.Arg.(type) { case string: if len(tp) == 0 { tp = "''" } - colNames = append(colNames, session.engine.Quote(colName)+"="+tp) + colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp) case *builder.Builder: subQuery, subArgs, err := session.statement.GenCondSQL(tp) if err != nil { return 0, err } - colNames = append(colNames, session.engine.Quote(colName)+"=("+subQuery+")") + colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")") args = append(args, subArgs...) default: - colNames = append(colNames, session.engine.Quote(colName)+"=?") - args = append(args, exprColumns.Args[i]) + colNames = append(colNames, session.engine.Quote(expr.ColName)+"=?") + args = append(args, expr.Arg) } }