Refactor orderby and support arguments #2150

Merged
lunny merged 14 commits from lunny/orderby into master 2022-05-31 03:00:29 +00:00
18 changed files with 813 additions and 632 deletions

View File

@ -400,7 +400,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
" `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " + " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " +
alreadyQuoted + " AS NEEDS_QUOTE " + alreadyQuoted + " AS NEEDS_QUOTE " +
"FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + "FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" +
" ORDER BY `COLUMNS`.ORDINAL_POSITION" " ORDER BY `COLUMNS`.ORDINAL_POSITION ASC"
rows, err := queryer.QueryContext(ctx, s, args...) rows, err := queryer.QueryContext(ctx, s, args...)
if err != nil { if err != nil {

View File

@ -380,7 +380,7 @@ func (engine *Engine) loadTableInfo(table *schemas.Table) error {
seq = 0 seq = 0
} }
} }
var colName = strings.Trim(parts[0], `"`) colName := strings.Trim(parts[0], `"`)
if col := table.GetColumn(colName); col != nil { if col := table.GetColumn(colName); col != nil {
col.Indexes[index.Name] = index.Type col.Indexes[index.Name] = index.Type
} else { } else {
@ -502,9 +502,9 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
} }
} }
var dstTableName = dstTable.Name dstTableName := dstTable.Name
var quoter = dstDialect.Quoter().Quote quoter := dstDialect.Quoter().Quote
var quotedDstTableName = quoter(dstTable.Name) quotedDstTableName := quoter(dstTable.Name)
if dstDialect.URI().Schema != "" { if dstDialect.URI().Schema != "" {
dstTableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, dstTable.Name) dstTableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, dstTable.Name)
quotedDstTableName = fmt.Sprintf("%s.%s", quoter(dstDialect.URI().Schema), quoter(dstTable.Name)) quotedDstTableName = fmt.Sprintf("%s.%s", quoter(dstDialect.URI().Schema), quoter(dstTable.Name))
@ -1006,10 +1006,10 @@ func (engine *Engine) Asc(colNames ...string) *Session {
} }
// OrderBy will generate "ORDER BY order" // OrderBy will generate "ORDER BY order"
func (engine *Engine) OrderBy(order string) *Session { func (engine *Engine) OrderBy(order interface{}, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.isAutoClose = true session.isAutoClose = true
return session.OrderBy(order) return session.OrderBy(order, args...)
} }
// Prepare enables prepare statement // Prepare enables prepare statement

2
go.mod
View File

@ -17,5 +17,5 @@ require (
github.com/syndtr/goleveldb v1.0.0 github.com/syndtr/goleveldb v1.0.0
github.com/ziutek/mymysql v1.5.4 github.com/ziutek/mymysql v1.5.4
modernc.org/sqlite v1.14.2 modernc.org/sqlite v1.14.2
xorm.io/builder v0.3.9 xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978
) )

4
go.sum
View File

@ -659,5 +659,5 @@ modernc.org/z v1.2.19 h1:BGyRFWhDVn5LFS5OcX4Yd/MlpRTOc7hOPTdcIpCiUao=
modernc.org/z v1.2.19/go.mod h1:+ZpP0pc4zz97eukOzW3xagV/lS82IpPN9NGG5pNF9vY= modernc.org/z v1.2.19/go.mod h1:+ZpP0pc4zz97eukOzW3xagV/lS82IpPN9NGG5pNF9vY=
sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o=
sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU= sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU=
xorm.io/builder v0.3.9 h1:Sd65/LdWyO7LR8+Cbd+e7mm3sK/7U9k0jS3999IDHMc= xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978 h1:bvLlAPW1ZMTWA32LuZMBEGHAUOcATZjzHcotf3SWweM=
xorm.io/builder v0.3.9/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE=

View File

@ -247,6 +247,10 @@ func TestOrder(t *testing.T) {
users2 := make([]Userinfo, 0) users2 := make([]Userinfo, 0)
err = testEngine.Asc("id", "username").Desc("height").Find(&users2) err = testEngine.Asc("id", "username").Desc("height").Find(&users2)
assert.NoError(t, err) assert.NoError(t, err)
users = make([]Userinfo, 0)
err = testEngine.OrderBy("CASE WHEN username LIKE ? THEN 0 ELSE 1 END DESC", "a").Find(&users)
assert.NoError(t, err)
} }
func TestGroupBy(t *testing.T) { func TestGroupBy(t *testing.T) {

View File

@ -54,7 +54,7 @@ type Interface interface {
Nullable(...string) *Session Nullable(...string) *Session
Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session
Omit(columns ...string) *Session Omit(columns ...string) *Session
OrderBy(order string) *Session OrderBy(order interface{}, args ...interface{}) *Session
Ping() error Ping() error
Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error) Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error)
QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error)

111
internal/statements/cond.go Normal file
View File

@ -0,0 +1,111 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"xorm.io/builder"
"xorm.io/xorm/schemas"
)
type QuoteReplacer struct {
*builder.BytesWriter
quoter schemas.Quoter
}
func (q *QuoteReplacer) Write(p []byte) (n int, err error) {
c := q.quoter.Replace(string(p))
return q.BytesWriter.Builder.WriteString(c)
}
func (statement *Statement) QuoteReplacer(w *builder.BytesWriter) *QuoteReplacer {
return &QuoteReplacer{
BytesWriter: w,
quoter: statement.dialect.Quoter(),
}
}
// Where add Where statement
func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
return statement.And(query, args...)
}
// And add Where & and statement
func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
switch qr := query.(type) {
case string:
cond := builder.Expr(qr, args...)
statement.cond = statement.cond.And(cond)
case map[string]interface{}:
cond := make(builder.Eq)
for k, v := range qr {
cond[statement.quote(k)] = v
}
statement.cond = statement.cond.And(cond)
case builder.Cond:
statement.cond = statement.cond.And(qr)
for _, v := range args {
if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.And(vv)
}
}
default:
statement.LastError = ErrConditionType
}
return statement
}
// Or add Where & Or statement
func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
switch qr := query.(type) {
case string:
cond := builder.Expr(qr, args...)
statement.cond = statement.cond.Or(cond)
case map[string]interface{}:
cond := make(builder.Eq)
for k, v := range qr {
cond[statement.quote(k)] = v
}
statement.cond = statement.cond.Or(cond)
case builder.Cond:
statement.cond = statement.cond.Or(qr)
for _, v := range args {
if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.Or(vv)
}
}
default:
statement.LastError = ErrConditionType
}
return statement
}
// In generate "Where column IN (?) " statement
func (statement *Statement) In(column string, args ...interface{}) *Statement {
in := builder.In(statement.quote(column), args...)
statement.cond = statement.cond.And(in)
return statement
}
// NotIn generate "Where column NOT IN (?) " statement
func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
notIn := builder.NotIn(statement.quote(column), args...)
statement.cond = statement.cond.And(notIn)
return statement
}
// SetNoAutoCondition if you do not want convert bean's field as query condition, then use this function
func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement {
statement.NoAutoCondition = true
if len(no) > 0 {
statement.NoAutoCondition = no[0]
}
return statement
}
// Conds returns condtions
func (statement *Statement) Conds() builder.Cond {
return statement.cond
}

View File

@ -0,0 +1,78 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"strings"
"xorm.io/builder"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
var buf strings.Builder
if len(statement.JoinStr) > 0 {
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
} else {
fmt.Fprintf(&buf, "%v JOIN ", joinOP)
}
switch tp := tablename.(type) {
case builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition))
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition))
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
if !utils.IsSubQuery(tbName) {
var buf strings.Builder
_ = statement.dialect.Quoter().QuoteTo(&buf, tbName)
tbName = buf.String()
} else {
tbName = statement.ReplaceQuote(tbName)
}
fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition))
}
statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...)
return statement
}
func (statement *Statement) writeJoin(w builder.Writer) error {
if statement.JoinStr != "" {
if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil {
return err
}
w.Append(statement.joinArgs...)
}
return nil
}

View File

@ -0,0 +1,90 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"strings"
"xorm.io/builder"
)
func (statement *Statement) HasOrderBy() bool {
return statement.orderStr != ""
}
// ResetOrderBy reset ordery conditions
func (statement *Statement) ResetOrderBy() {
statement.orderStr = ""
statement.orderArgs = nil
}
// WriteOrderBy write order by to writer
func (statement *Statement) WriteOrderBy(w builder.Writer) error {
if len(statement.orderStr) > 0 {
if _, err := fmt.Fprintf(w, " ORDER BY %s", statement.orderStr); err != nil {
return err
}
w.Append(statement.orderArgs...)
}
return nil
}
// OrderBy generate "Order By order" statement
func (statement *Statement) OrderBy(order interface{}, args ...interface{}) *Statement {
if len(statement.orderStr) > 0 {
statement.orderStr += ", "
}
var rawOrder string
switch t := order.(type) {
case (*builder.Expression):
rawOrder = t.Content()
args = t.Args()
case string:
rawOrder = t
default:
statement.LastError = ErrUnSupportedSQLType
return statement
}
statement.orderStr += statement.ReplaceQuote(rawOrder)
if len(args) > 0 {
statement.orderArgs = append(statement.orderArgs, args...)
}
return statement
}
// Desc generate `ORDER BY xx DESC`
func (statement *Statement) Desc(colNames ...string) *Statement {
var buf strings.Builder
if len(statement.orderStr) > 0 {
fmt.Fprint(&buf, statement.orderStr, ", ")
}
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
_ = statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " DESC")
}
statement.orderStr = buf.String()
return statement
}
// Asc provide asc order by query condition, the input parameters are columns.
func (statement *Statement) Asc(colNames ...string) *Statement {
var buf strings.Builder
if len(statement.orderStr) > 0 {
fmt.Fprint(&buf, statement.orderStr, ", ")
}
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
_ = statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " ASC")
}
statement.orderStr = buf.String()
return statement
}

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -28,7 +29,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
return "", nil, ErrTableNotFound return "", nil, ErrTableNotFound
} }
var columnStr = statement.ColumnStr() columnStr := statement.ColumnStr()
if len(statement.SelectStr) > 0 { if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr columnStr = statement.SelectStr
} else { } else {
@ -58,19 +59,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
return "", nil, err return "", nil, err
} }
sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) return statement.genSelectSQL(columnStr, true, true)
if err != nil {
return "", nil, err
}
args := append(statement.joinArgs, condArgs...)
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
return sqlStr, args, nil
} }
// GenSumSQL generates sum SQL // GenSumSQL generates sum SQL
@ -83,7 +72,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
return "", nil, err return "", nil, err
} }
var sumStrs = make([]string, 0, len(columns)) sumStrs := make([]string, 0, len(columns))
for _, colName := range columns { for _, colName := range columns {
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
colName = statement.quote(colName) colName = statement.quote(colName)
@ -94,16 +83,11 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
} }
sumSelect := strings.Join(sumStrs, ", ") sumSelect := strings.Join(sumStrs, ", ")
if err := statement.mergeConds(bean); err != nil { if err := statement.MergeConds(bean); err != nil {
return "", nil, err return "", nil, err
} }
sqlStr, condArgs, err := statement.genSelectSQL(sumSelect, true, true) return statement.genSelectSQL(sumSelect, true, true)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
// GenGetSQL generates Get SQL // GenGetSQL generates Get SQL
@ -119,7 +103,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
} }
} }
var columnStr = statement.ColumnStr() columnStr := statement.ColumnStr()
if len(statement.SelectStr) > 0 { if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr columnStr = statement.SelectStr
} else { } else {
@ -146,7 +130,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
} }
if isStruct { if isStruct {
if err := statement.mergeConds(bean); err != nil { if err := statement.MergeConds(bean); err != nil {
return "", nil, err return "", nil, err
} }
} else { } else {
@ -155,12 +139,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
} }
} }
sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) return statement.genSelectSQL(columnStr, true, true)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
// GenCountSQL generates the SQL for counting // GenCountSQL generates the SQL for counting
@ -175,12 +154,12 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
if err := statement.SetRefBean(beans[0]); err != nil { if err := statement.SetRefBean(beans[0]); err != nil {
return "", nil, err return "", nil, err
} }
if err := statement.mergeConds(beans[0]); err != nil { if err := statement.MergeConds(beans[0]); err != nil {
return "", nil, err return "", nil, err
} }
} }
var selectSQL = statement.SelectStr selectSQL := statement.SelectStr
if len(selectSQL) <= 0 { if len(selectSQL) <= 0 {
if statement.IsDistinct { if statement.IsDistinct {
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr()) selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr())
@ -206,55 +185,58 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
sqlStr = fmt.Sprintf("SELECT %s FROM (%s) sub", selectSQL, sqlStr) sqlStr = fmt.Sprintf("SELECT %s FROM (%s) sub", selectSQL, sqlStr)
} }
return sqlStr, append(statement.joinArgs, condArgs...), nil return sqlStr, condArgs, nil
} }
func (statement *Statement) fromBuilder() *strings.Builder { func (statement *Statement) writeFrom(w builder.Writer) error {
var builder strings.Builder if _, err := fmt.Fprint(w, " FROM "); err != nil {
var quote = statement.quote return err
var dialect = statement.dialect
builder.WriteString(" FROM ")
if dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
builder.WriteString(statement.TableName())
} else {
builder.WriteString(quote(statement.TableName()))
} }
if err := statement.writeTableName(w); err != nil {
return err
}
if err := statement.writeAlias(w); err != nil {
return err
}
return statement.writeJoin(w)
}
if statement.TableAlias != "" { func (statement *Statement) writeLimitOffset(w builder.Writer) error {
if dialect.URI().DBType == schemas.ORACLE { if statement.Start > 0 {
builder.WriteString(" ") if statement.LimitN != nil {
} else { _, err := fmt.Fprintf(w, " LIMIT %v OFFSET %v", *statement.LimitN, statement.Start)
builder.WriteString(" AS ") return err
} }
builder.WriteString(quote(statement.TableAlias)) _, err := fmt.Fprintf(w, " LIMIT 0 OFFSET %v", statement.Start)
return err
} }
if statement.JoinStr != "" { if statement.LimitN != nil {
builder.WriteString(" ") _, err := fmt.Fprint(w, " LIMIT ", *statement.LimitN)
builder.WriteString(statement.JoinStr) return err
} }
return &builder // no limit statement
return nil
} }
func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) { func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) {
var ( var (
distinct string distinct string
dialect = statement.dialect dialect = statement.dialect
fromStr = statement.fromBuilder().String() top, whereStr string
top, mssqlCondi, whereStr string mssqlCondi = builder.NewWriter()
) )
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT " distinct = "DISTINCT "
} }
condSQL, condArgs, err := statement.GenCondSQL(statement.cond) condWriter := builder.NewWriter()
if err != nil { if err := statement.cond.WriteTo(statement.QuoteReplacer(condWriter)); err != nil {
return "", nil, err return "", nil, err
} }
if len(condSQL) > 0 {
whereStr = fmt.Sprintf(" WHERE %s", condSQL) if condWriter.Len() > 0 {
whereStr = " WHERE "
} }
pLimitN := statement.LimitN pLimitN := statement.LimitN
@ -289,49 +271,81 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} }
} }
var orderStr string if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s",
if needOrderBy && len(statement.OrderStr) > 0 { column, statement.Start, column); err != nil {
orderStr = fmt.Sprintf(" ORDER BY %s", statement.OrderStr) return "", nil, err
} }
if err := statement.writeFrom(mssqlCondi); err != nil {
var groupStr string return "", nil, err
if len(statement.GroupByStr) > 0 { }
groupStr = fmt.Sprintf(" GROUP BY %s", statement.GroupByStr) if whereStr != "" {
if _, err := fmt.Fprint(mssqlCondi, whereStr); err != nil {
return "", nil, err
}
if err := utils.WriteBuilder(mssqlCondi, statement.QuoteReplacer(condWriter)); err != nil {
return "", nil, err
}
}
if needOrderBy {
if err := statement.WriteOrderBy(mssqlCondi); err != nil {
return "", nil, err
}
}
if err := statement.WriteGroupBy(mssqlCondi); err != nil {
return "", nil, err
}
if _, err := fmt.Fprint(mssqlCondi, "))"); err != nil {
return "", nil, err
} }
mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
} }
} }
var buf strings.Builder buf := builder.NewWriter()
fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) if _, err := fmt.Fprintf(buf, "SELECT %v%v%v", distinct, top, columnStr); err != nil {
if len(mssqlCondi) > 0 { return "", nil, err
}
if err := statement.writeFrom(buf); err != nil {
return "", nil, err
}
if whereStr != "" {
if _, err := fmt.Fprint(buf, whereStr); err != nil {
return "", nil, err
}
if err := utils.WriteBuilder(buf, statement.QuoteReplacer(condWriter)); err != nil {
return "", nil, err
}
}
if mssqlCondi.Len() > 0 {
if len(whereStr) > 0 { if len(whereStr) > 0 {
fmt.Fprint(&buf, " AND ", mssqlCondi) if _, err := fmt.Fprint(buf, " AND "); err != nil {
return "", nil, err
}
} else { } else {
fmt.Fprint(&buf, " WHERE ", mssqlCondi) if _, err := fmt.Fprint(buf, " WHERE "); err != nil {
return "", nil, err
}
}
if err := utils.WriteBuilder(buf, mssqlCondi); err != nil {
return "", nil, err
} }
} }
if statement.GroupByStr != "" { if err := statement.WriteGroupBy(buf); err != nil {
fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr) return "", nil, err
} }
if statement.HavingStr != "" { if err := statement.writeHaving(buf); err != nil {
fmt.Fprint(&buf, " ", statement.HavingStr) return "", nil, err
} }
if needOrderBy && statement.OrderStr != "" { if needOrderBy {
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) if err := statement.WriteOrderBy(buf); err != nil {
return "", nil, err
}
} }
if needLimit { if needLimit {
if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE { if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE {
if statement.Start > 0 { if err := statement.writeLimitOffset(buf); err != nil {
if pLimitN != nil { return "", nil, err
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
} else {
fmt.Fprintf(&buf, " LIMIT 0 OFFSET %v", statement.Start)
}
} else if pLimitN != nil {
fmt.Fprint(&buf, " LIMIT ", *pLimitN)
} }
} else if dialect.URI().DBType == schemas.ORACLE { } else if dialect.URI().DBType == schemas.ORACLE {
if pLimitN != nil { if pLimitN != nil {
@ -341,16 +355,16 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
if rawColStr == "*" { if rawColStr == "*" {
rawColStr = "at.*" rawColStr = "at.*"
} }
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", fmt.Fprintf(buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start) columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start)
} }
} }
} }
if statement.IsForUpdate { if statement.IsForUpdate {
return dialect.ForUpdateSQL(buf.String()), condArgs, nil return dialect.ForUpdateSQL(buf.String()), buf.Args(), nil
} }
return buf.String(), condArgs, nil return buf.String(), buf.Args(), nil
} }
// GenExistSQL generates Exist SQL // GenExistSQL generates Exist SQL
@ -359,10 +373,6 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
return statement.GenRawSQL(), statement.RawParams, nil return statement.GenRawSQL(), statement.RawParams, nil
} }
var sqlStr string
var args []interface{}
var joinStr string
var err error
var b interface{} var b interface{}
if len(bean) > 0 { if len(bean) > 0 {
b = bean[0] b = bean[0]
@ -381,45 +391,70 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
if len(tableName) <= 0 { if len(tableName) <= 0 {
return "", nil, ErrTableNotFound return "", nil, ErrTableNotFound
} }
if statement.RefTable == nil { if statement.RefTable != nil {
tableName = statement.quote(tableName) return statement.Limit(1).GenGetSQL(b)
if len(statement.JoinStr) > 0 { }
joinStr = statement.JoinStr
}
tableName = statement.quote(tableName)
buf := builder.NewWriter()
if statement.dialect.URI().DBType == schemas.MSSQL {
if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil {
return "", nil, err
}
if err := statement.writeJoin(buf); err != nil {
return "", nil, err
}
if statement.Conds().IsValid() { if statement.Conds().IsValid() {
condSQL, condArgs, err := statement.GenCondSQL(statement.Conds()) if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
if err != nil {
return "", nil, err return "", nil, err
} }
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
if statement.dialect.URI().DBType == schemas.MSSQL { return "", nil, err
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL)
} else if statement.dialect.URI().DBType == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL)
} else {
sqlStr = fmt.Sprintf("SELECT 1 FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL)
} }
args = condArgs }
} else { } else if statement.dialect.URI().DBType == schemas.ORACLE {
if statement.dialect.URI().DBType == schemas.MSSQL { if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr) return "", nil, err
} else if statement.dialect.URI().DBType == schemas.ORACLE { }
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr) if err := statement.writeJoin(buf); err != nil {
} else { return "", nil, err
sqlStr = fmt.Sprintf("SELECT 1 FROM %s %s LIMIT 1", tableName, joinStr) }
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
return "", nil, err
}
if statement.Conds().IsValid() {
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
return "", nil, err
} }
args = []interface{}{} if _, err := fmt.Fprintf(buf, " AND "); err != nil {
return "", nil, err
}
}
if _, err := fmt.Fprintf(buf, "ROWNUM=1"); err != nil {
return "", nil, err
} }
} else { } else {
statement.Limit(1) if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil {
sqlStr, args, err = statement.GenGetSQL(b) return "", nil, err
if err != nil { }
if err := statement.writeJoin(buf); err != nil {
return "", nil, err
}
if statement.Conds().IsValid() {
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
return "", nil, err
}
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
return "", nil, err
}
}
if _, err := fmt.Fprintf(buf, " LIMIT 1"); err != nil {
return "", nil, err return "", nil, err
} }
} }
return sqlStr, args, nil return buf.String(), buf.Args(), nil
} }
// GenFindSQL generates Find SQL // GenFindSQL generates Find SQL
@ -428,15 +463,11 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
return statement.GenRawSQL(), statement.RawParams, nil return statement.GenRawSQL(), statement.RawParams, nil
} }
var sqlStr string
var args []interface{}
var err error
if len(statement.TableName()) <= 0 { if len(statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound return "", nil, ErrTableNotFound
} }
var columnStr = statement.ColumnStr() columnStr := statement.ColumnStr()
if len(statement.SelectStr) > 0 { if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr columnStr = statement.SelectStr
} else { } else {
@ -464,16 +495,5 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
statement.cond = statement.cond.And(autoCond) statement.cond = statement.cond.And(autoCond)
sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) return statement.genSelectSQL(columnStr, true, true)
if err != nil {
return "", nil, err
}
args = append(statement.joinArgs, condArgs...)
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
return sqlStr, args, nil
} }

View File

@ -0,0 +1,137 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"strings"
"xorm.io/xorm/schemas"
)
// Select replace select
func (statement *Statement) Select(str string) *Statement {
statement.SelectStr = statement.ReplaceQuote(str)
return statement
}
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0, len(columns))
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
// Cols generate "col1, col2" statement
func (statement *Statement) Cols(columns ...string) *Statement {
cols := col2NewCols(columns...)
for _, nc := range cols {
statement.ColumnMap.Add(nc)
}
return statement
}
// ColumnStr returns column string
func (statement *Statement) ColumnStr() string {
return statement.dialect.Quoter().Join(statement.ColumnMap, ", ")
}
// AllCols update use only: update all columns
func (statement *Statement) AllCols() *Statement {
statement.useAllCols = true
return statement
}
// MustCols update use only: must update columns
func (statement *Statement) MustCols(columns ...string) *Statement {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.MustColumnMap[strings.ToLower(nc)] = true
}
return statement
}
// UseBool indicates that use bool fields as update contents and query contiditions
func (statement *Statement) UseBool(columns ...string) *Statement {
if len(columns) > 0 {
statement.MustCols(columns...)
} else {
statement.allUseBool = true
}
return statement
}
// Omit do not use the columns
func (statement *Statement) Omit(columns ...string) {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.OmitColumnMap = append(statement.OmitColumnMap, nc)
}
}
func (statement *Statement) genColumnStr() string {
if statement.RefTable == nil {
return ""
}
var buf strings.Builder
columns := statement.RefTable.Columns()
for _, col := range columns {
if statement.OmitColumnMap.Contain(col.Name) {
continue
}
if len(statement.ColumnMap) > 0 && !statement.ColumnMap.Contain(col.Name) {
continue
}
if col.MapType == schemas.ONLYTODB {
continue
}
if buf.Len() != 0 {
buf.WriteString(", ")
}
if statement.JoinStr != "" {
if statement.TableAlias != "" {
buf.WriteString(statement.TableAlias)
} else {
buf.WriteString(statement.TableName())
}
buf.WriteString(".")
}
statement.dialect.Quoter().QuoteTo(&buf, col.Name)
}
return buf.String()
}
func (statement *Statement) colName(col *schemas.Column, tableName string) string {
if statement.needTableName() {
nm := tableName
if len(statement.TableAlias) > 0 {
nm = statement.TableAlias
}
return fmt.Sprintf("%s.%s", statement.quote(nm), statement.quote(col.Name))
}
return statement.quote(col.Name)
}
// Distinct generates "DISTINCT col1, col2 " statement
func (statement *Statement) Distinct(columns ...string) *Statement {
statement.IsDistinct = true
statement.Cols(columns...)
return statement
}

View File

@ -43,7 +43,8 @@ type Statement struct {
Start int Start int
LimitN *int LimitN *int
idParam schemas.PK idParam schemas.PK
OrderStr string orderStr string
orderArgs []interface{}
JoinStr string JoinStr string
joinArgs []interface{} joinArgs []interface{}
GroupByStr string GroupByStr string
@ -101,15 +102,6 @@ func (statement *Statement) GenRawSQL() string {
return statement.ReplaceQuote(statement.RawSQL) return statement.ReplaceQuote(statement.RawSQL)
} }
// GenCondSQL generates condition SQL
func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []interface{}, error) {
condSQL, condArgs, err := builder.ToSQL(condOrBuilder)
if err != nil {
return "", nil, err
}
return statement.ReplaceQuote(condSQL), condArgs, nil
}
// ReplaceQuote replace sql key words with quote // ReplaceQuote replace sql key words with quote
func (statement *Statement) ReplaceQuote(sql string) string { func (statement *Statement) ReplaceQuote(sql string) string {
if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL || if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL ||
@ -129,7 +121,7 @@ func (statement *Statement) Reset() {
statement.RefTable = nil statement.RefTable = nil
statement.Start = 0 statement.Start = 0
statement.LimitN = nil statement.LimitN = nil
statement.OrderStr = "" statement.ResetOrderBy()
statement.UseCascade = true statement.UseCascade = true
statement.JoinStr = "" statement.JoinStr = ""
statement.joinArgs = make([]interface{}, 0) statement.joinArgs = make([]interface{}, 0)
@ -164,21 +156,6 @@ func (statement *Statement) Reset() {
statement.LastError = nil statement.LastError = nil
} }
// SetNoAutoCondition if you do not want convert bean's field as query condition, then use this function
func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement {
statement.NoAutoCondition = true
if len(no) > 0 {
statement.NoAutoCondition = no[0]
}
return statement
}
// Alias set the table alias
func (statement *Statement) Alias(alias string) *Statement {
statement.TableAlias = alias
return statement
}
// SQL adds raw sql statement // SQL adds raw sql statement
func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement { func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
switch query.(type) { switch query.(type) {
@ -198,80 +175,10 @@ func (statement *Statement) SQL(query interface{}, args ...interface{}) *Stateme
return statement return statement
} }
// Where add Where statement
func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
return statement.And(query, args...)
}
func (statement *Statement) quote(s string) string { func (statement *Statement) quote(s string) string {
return statement.dialect.Quoter().Quote(s) return statement.dialect.Quoter().Quote(s)
} }
// And add Where & and statement
func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
switch qr := query.(type) {
case string:
cond := builder.Expr(qr, args...)
statement.cond = statement.cond.And(cond)
case map[string]interface{}:
cond := make(builder.Eq)
for k, v := range qr {
cond[statement.quote(k)] = v
}
statement.cond = statement.cond.And(cond)
case builder.Cond:
statement.cond = statement.cond.And(qr)
for _, v := range args {
if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.And(vv)
}
}
default:
statement.LastError = ErrConditionType
}
return statement
}
// Or add Where & Or statement
func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
switch qr := query.(type) {
case string:
cond := builder.Expr(qr, args...)
statement.cond = statement.cond.Or(cond)
case map[string]interface{}:
cond := make(builder.Eq)
for k, v := range qr {
cond[statement.quote(k)] = v
}
statement.cond = statement.cond.Or(cond)
case builder.Cond:
statement.cond = statement.cond.Or(qr)
for _, v := range args {
if vv, ok := v.(builder.Cond); ok {
statement.cond = statement.cond.Or(vv)
}
}
default:
statement.LastError = ErrConditionType
}
return statement
}
// In generate "Where column IN (?) " statement
func (statement *Statement) In(column string, args ...interface{}) *Statement {
in := builder.In(statement.quote(column), args...)
statement.cond = statement.cond.And(in)
return statement
}
// NotIn generate "Where column NOT IN (?) " statement
func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
notIn := builder.NotIn(statement.quote(column), args...)
statement.cond = statement.cond.And(notIn)
return statement
}
// SetRefValue set ref value // SetRefValue set ref value
func (statement *Statement) SetRefValue(v reflect.Value) error { func (statement *Statement) SetRefValue(v reflect.Value) error {
var err error var err error
@ -302,26 +209,6 @@ func (statement *Statement) needTableName() bool {
return len(statement.JoinStr) > 0 return len(statement.JoinStr) > 0
} }
func (statement *Statement) colName(col *schemas.Column, tableName string) string {
if statement.needTableName() {
nm := tableName
if len(statement.TableAlias) > 0 {
nm = statement.TableAlias
}
return fmt.Sprintf("%s.%s", statement.quote(nm), statement.quote(col.Name))
}
return statement.quote(col.Name)
}
// TableName return current tableName
func (statement *Statement) TableName() string {
if statement.AltTableName != "" {
return statement.AltTableName
}
return statement.tableName
}
// Incr Generate "Update ... Set column = column + arg" statement // Incr Generate "Update ... Set column = column + arg" statement
func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
if len(arg) > 0 { if len(arg) > 0 {
@ -352,85 +239,12 @@ func (statement *Statement) SetExpr(column string, expression interface{}) *Stat
return statement return statement
} }
// Distinct generates "DISTINCT col1, col2 " statement
func (statement *Statement) Distinct(columns ...string) *Statement {
statement.IsDistinct = true
statement.Cols(columns...)
return statement
}
// ForUpdate generates "SELECT ... FOR UPDATE" statement // ForUpdate generates "SELECT ... FOR UPDATE" statement
func (statement *Statement) ForUpdate() *Statement { func (statement *Statement) ForUpdate() *Statement {
statement.IsForUpdate = true statement.IsForUpdate = true
return statement return statement
} }
// Select replace select
func (statement *Statement) Select(str string) *Statement {
statement.SelectStr = statement.ReplaceQuote(str)
return statement
}
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0, len(columns))
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
// Cols generate "col1, col2" statement
func (statement *Statement) Cols(columns ...string) *Statement {
cols := col2NewCols(columns...)
for _, nc := range cols {
statement.ColumnMap.Add(nc)
}
return statement
}
// ColumnStr returns column string
func (statement *Statement) ColumnStr() string {
return statement.dialect.Quoter().Join(statement.ColumnMap, ", ")
}
// AllCols update use only: update all columns
func (statement *Statement) AllCols() *Statement {
statement.useAllCols = true
return statement
}
// MustCols update use only: must update columns
func (statement *Statement) MustCols(columns ...string) *Statement {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.MustColumnMap[strings.ToLower(nc)] = true
}
return statement
}
// UseBool indicates that use bool fields as update contents and query contiditions
func (statement *Statement) UseBool(columns ...string) *Statement {
if len(columns) > 0 {
statement.MustCols(columns...)
} else {
statement.allUseBool = true
}
return statement
}
// Omit do not use the columns
func (statement *Statement) Omit(columns ...string) {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.OmitColumnMap = append(statement.OmitColumnMap, nc)
}
}
// Nullable Update use only: update columns to null when value is nullable and zero-value // Nullable Update use only: update columns to null when value is nullable and zero-value
func (statement *Statement) Nullable(columns ...string) { func (statement *Statement) Nullable(columns ...string) {
newColumns := col2NewCols(columns...) newColumns := col2NewCols(columns...)
@ -454,54 +268,6 @@ func (statement *Statement) Limit(limit int, start ...int) *Statement {
return statement return statement
} }
// OrderBy generate "Order By order" statement
func (statement *Statement) OrderBy(order string) *Statement {
if len(statement.OrderStr) > 0 {
statement.OrderStr += ", "
}
statement.OrderStr += statement.ReplaceQuote(order)
return statement
}
// Desc generate `ORDER BY xx DESC`
func (statement *Statement) Desc(colNames ...string) *Statement {
var buf strings.Builder
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, statement.OrderStr, ", ")
}
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
_ = statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " DESC")
}
statement.OrderStr = buf.String()
return statement
}
// Asc provide asc order by query condition, the input parameters are columns.
func (statement *Statement) Asc(colNames ...string) *Statement {
var buf strings.Builder
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, statement.OrderStr, ", ")
}
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
_ = statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " ASC")
}
statement.OrderStr = buf.String()
return statement
}
// Conds returns condtions
func (statement *Statement) Conds() builder.Cond {
return statement.cond
}
// SetTable tempororily set table name, the parameter could be a string or a pointer of struct // SetTable tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) SetTable(tableNameOrBean interface{}) error { func (statement *Statement) SetTable(tableNameOrBean interface{}) error {
v := rValue(tableNameOrBean) v := rValue(tableNameOrBean)
@ -518,71 +284,34 @@ func (statement *Statement) SetTable(tableNameOrBean interface{}) error {
return nil return nil
} }
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
var buf strings.Builder
if len(statement.JoinStr) > 0 {
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
} else {
fmt.Fprintf(&buf, "%v JOIN ", joinOP)
}
switch tp := tablename.(type) {
case builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition))
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition))
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
if !utils.IsSubQuery(tbName) {
var buf strings.Builder
_ = statement.dialect.Quoter().QuoteTo(&buf, tbName)
tbName = buf.String()
} else {
tbName = statement.ReplaceQuote(tbName)
}
fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition))
}
statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...)
return statement
}
// GroupBy generate "Group By keys" statement // GroupBy generate "Group By keys" statement
func (statement *Statement) GroupBy(keys string) *Statement { func (statement *Statement) GroupBy(keys string) *Statement {
statement.GroupByStr = statement.ReplaceQuote(keys) statement.GroupByStr = statement.ReplaceQuote(keys)
return statement return statement
} }
func (statement *Statement) WriteGroupBy(w builder.Writer) error {
if statement.GroupByStr == "" {
return nil
}
_, err := fmt.Fprintf(w, " GROUP BY %s", statement.GroupByStr)
return err
}
// Having generate "Having conditions" statement // Having generate "Having conditions" statement
func (statement *Statement) Having(conditions string) *Statement { func (statement *Statement) Having(conditions string) *Statement {
statement.HavingStr = fmt.Sprintf("HAVING %v", statement.ReplaceQuote(conditions)) statement.HavingStr = fmt.Sprintf("HAVING %v", statement.ReplaceQuote(conditions))
return statement return statement
} }
func (statement *Statement) writeHaving(w builder.Writer) error {
if statement.HavingStr == "" {
return nil
}
_, err := fmt.Fprint(w, " ", statement.HavingStr)
return err
}
// SetUnscoped always disable struct tag "deleted" // SetUnscoped always disable struct tag "deleted"
func (statement *Statement) SetUnscoped() *Statement { func (statement *Statement) SetUnscoped() *Statement {
statement.unscoped = true statement.unscoped = true
@ -594,47 +323,6 @@ func (statement *Statement) GetUnscoped() bool {
return statement.unscoped return statement.unscoped
} }
func (statement *Statement) genColumnStr() string {
if statement.RefTable == nil {
return ""
}
var buf strings.Builder
columns := statement.RefTable.Columns()
for _, col := range columns {
if statement.OmitColumnMap.Contain(col.Name) {
continue
}
if len(statement.ColumnMap) > 0 && !statement.ColumnMap.Contain(col.Name) {
continue
}
if col.MapType == schemas.ONLYTODB {
continue
}
if buf.Len() != 0 {
buf.WriteString(", ")
}
if statement.JoinStr != "" {
if statement.TableAlias != "" {
buf.WriteString(statement.TableAlias)
} else {
buf.WriteString(statement.TableName())
}
buf.WriteString(".")
}
statement.dialect.Quoter().QuoteTo(&buf, col.Name)
}
return buf.String()
}
// GenIndexSQL generated create index SQL // GenIndexSQL generated create index SQL
func (statement *Statement) GenIndexSQL() []string { func (statement *Statement) GenIndexSQL() []string {
var sqls []string var sqls []string
@ -914,7 +602,8 @@ func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, i
statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
} }
func (statement *Statement) mergeConds(bean interface{}) error { // MergeConds merge conditions from bean and id
func (statement *Statement) MergeConds(bean interface{}) error {
if !statement.NoAutoCondition && statement.RefTable != nil { if !statement.NoAutoCondition && statement.RefTable != nil {
addedTableName := (len(statement.JoinStr) > 0) addedTableName := (len(statement.JoinStr) > 0)
autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName) autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
@ -927,15 +616,6 @@ func (statement *Statement) mergeConds(bean interface{}) error {
return statement.ProcessIDParam() return statement.ProcessIDParam()
} }
// GenConds generates conditions
func (statement *Statement) GenConds(bean interface{}) (string, []interface{}, error) {
if err := statement.mergeConds(bean); err != nil {
return "", nil, err
}
return statement.GenCondSQL(statement.cond)
}
func (statement *Statement) quoteColumnStr(columnStr string) string { func (statement *Statement) quoteColumnStr(columnStr string) string {
columns := strings.Split(columnStr, ",") columns := strings.Split(columnStr, ",")
return statement.dialect.Quoter().Join(columns, ",") return statement.dialect.Quoter().Join(columns, ",")

View File

@ -0,0 +1,56 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"strings"
"xorm.io/builder"
"xorm.io/xorm/schemas"
)
// TableName return current tableName
func (statement *Statement) TableName() string {
if statement.AltTableName != "" {
return statement.AltTableName
}
return statement.tableName
}
// Alias set the table alias
func (statement *Statement) Alias(alias string) *Statement {
statement.TableAlias = alias
return statement
}
func (statement *Statement) writeAlias(w builder.Writer) error {
if statement.TableAlias != "" {
if statement.dialect.URI().DBType == schemas.ORACLE {
if _, err := fmt.Fprint(w, " ", statement.quote(statement.TableAlias)); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, " AS ", statement.quote(statement.TableAlias)); err != nil {
return err
}
}
}
return nil
}
func (statement *Statement) writeTableName(w builder.Writer) error {
if statement.dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
if _, err := fmt.Fprint(w, statement.TableName()); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, statement.quote(statement.TableName())); err != nil {
return err
}
}
return nil
}

27
internal/utils/builder.go Normal file
View File

@ -0,0 +1,27 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package utils
import (
"fmt"
"xorm.io/builder"
)
type BuildReader interface {
String() string
Args() []interface{}
}
// WriteBuilder writes writers to one
func WriteBuilder(w *builder.BytesWriter, inputs ...BuildReader) error {
for _, input := range inputs {
if _, err := fmt.Fprint(w, input.String()); err != nil {
return err
}
w.Append(input.Args()...)
}
return nil
}

View File

@ -275,8 +275,8 @@ func (session *Session) Limit(limit int, start ...int) *Session {
// OrderBy provide order by query condition, the input parameter is the content // OrderBy provide order by query condition, the input parameter is the content
// after order by on a sql statement. // after order by on a sql statement.
func (session *Session) OrderBy(order string) *Session { func (session *Session) OrderBy(order interface{}, args ...interface{}) *Session {
session.statement.OrderBy(order) session.statement.OrderBy(order, args...)
return session return session
} }

View File

@ -9,7 +9,9 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"xorm.io/builder"
"xorm.io/xorm/caches" "xorm.io/xorm/caches"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -99,10 +101,9 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
} }
var ( var (
condSQL string condWriter = builder.NewWriter()
condArgs []interface{} err error
err error bean interface{}
bean interface{}
) )
if len(beans) > 0 { if len(beans) > 0 {
bean = beans[0] bean = beans[0]
@ -116,115 +117,97 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
processor.BeforeDelete() processor.BeforeDelete()
} }
condSQL, condArgs, err = session.statement.GenConds(bean) if err = session.statement.MergeConds(bean); err != nil {
} else { return 0, err
condSQL, condArgs, err = session.statement.GenCondSQL(session.statement.Conds()) }
} }
if err != nil {
if err = session.statement.Conds().WriteTo(session.statement.QuoteReplacer(condWriter)); err != nil {
return 0, err return 0, err
} }
pLimitN := session.statement.LimitN pLimitN := session.statement.LimitN
if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { if condWriter.Len() == 0 && (pLimitN == nil || *pLimitN == 0) {
return 0, ErrNeedDeletedCond return 0, ErrNeedDeletedCond
} }
var tableNameNoQuote = session.statement.TableName() tableNameNoQuote := session.statement.TableName()
var tableName = session.engine.Quote(tableNameNoQuote) tableName := session.engine.Quote(tableNameNoQuote)
var table = session.statement.RefTable table := session.statement.RefTable
var deleteSQL string deleteSQLWriter := builder.NewWriter()
if len(condSQL) > 0 { fmt.Fprintf(deleteSQLWriter, "DELETE FROM %v", tableName)
deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL) if condWriter.Len() > 0 {
} else { fmt.Fprintf(deleteSQLWriter, " WHERE %v", condWriter.String())
deleteSQL = fmt.Sprintf("DELETE FROM %v", tableName) deleteSQLWriter.Append(condWriter.Args()...)
} }
var orderSQL string orderSQLWriter := builder.NewWriter()
if len(session.statement.OrderStr) > 0 { if err := session.statement.WriteOrderBy(orderSQLWriter); err != nil {
orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr) return 0, err
} }
if pLimitN != nil && *pLimitN > 0 { if pLimitN != nil && *pLimitN > 0 {
limitNValue := *pLimitN limitNValue := *pLimitN
orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue) if _, err := fmt.Fprintf(orderSQLWriter, " LIMIT %d", limitNValue); err != nil {
return 0, err
}
} }
if len(orderSQL) > 0 { orderCondWriter := builder.NewWriter()
if orderSQLWriter.Len() > 0 {
switch session.engine.dialect.URI().DBType { switch session.engine.dialect.URI().DBType {
case schemas.POSTGRES: case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) if condWriter.Len() > 0 {
if len(condSQL) > 0 { fmt.Fprintf(orderCondWriter, " AND ")
deleteSQL += " AND " + inSQL
} else { } else {
deleteSQL += " WHERE " + inSQL fmt.Fprintf(orderCondWriter, " WHERE ")
} }
fmt.Fprintf(orderCondWriter, "ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQLWriter.String())
orderCondWriter.Append(orderSQLWriter.Args()...)
case schemas.SQLITE: case schemas.SQLITE:
inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) if condWriter.Len() > 0 {
if len(condSQL) > 0 { fmt.Fprintf(orderCondWriter, " AND ")
deleteSQL += " AND " + inSQL
} else { } else {
deleteSQL += " WHERE " + inSQL fmt.Fprintf(orderCondWriter, " WHERE ")
} }
fmt.Fprintf(orderCondWriter, "rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQLWriter.String())
// TODO: how to handle delete limit on mssql? // TODO: how to handle delete limit on mssql?
case schemas.MSSQL: case schemas.MSSQL:
return 0, ErrNotImplemented return 0, ErrNotImplemented
default: default:
deleteSQL += orderSQL fmt.Fprint(orderCondWriter, orderSQLWriter.String())
orderCondWriter.Append(orderSQLWriter.Args()...)
} }
} }
var realSQL string realSQLWriter := builder.NewWriter()
argsForCache := make([]interface{}, 0, len(condArgs)*2) argsForCache := make([]interface{}, 0, len(deleteSQLWriter.Args())*2)
copy(argsForCache, deleteSQLWriter.Args())
argsForCache = append(deleteSQLWriter.Args(), argsForCache...)
if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled
realSQL = deleteSQL if err := utils.WriteBuilder(realSQLWriter, deleteSQLWriter, orderCondWriter); err != nil {
copy(argsForCache, condArgs) return 0, err
argsForCache = append(condArgs, argsForCache...) }
} else { } else {
// !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches.
copy(argsForCache, condArgs)
argsForCache = append(condArgs, argsForCache...)
deletedColumn := table.DeletedColumn() deletedColumn := table.DeletedColumn()
realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", if _, err := fmt.Fprintf(realSQLWriter, "UPDATE %v SET %v = ? WHERE %v",
session.engine.Quote(session.statement.TableName()), session.engine.Quote(session.statement.TableName()),
session.engine.Quote(deletedColumn.Name), session.engine.Quote(deletedColumn.Name),
condSQL) condWriter.String()); err != nil {
return 0, err
if len(orderSQL) > 0 {
switch session.engine.dialect.URI().DBType {
case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
realSQL += " AND " + inSQL
} else {
realSQL += " WHERE " + inSQL
}
case schemas.SQLITE:
inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
realSQL += " AND " + inSQL
} else {
realSQL += " WHERE " + inSQL
}
// TODO: how to handle delete limit on mssql?
case schemas.MSSQL:
return 0, ErrNotImplemented
default:
realSQL += orderSQL
}
} }
// !oinume! Insert nowTime to the head of session.statement.Params
condArgs = append(condArgs, "")
paramsLen := len(condArgs)
copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1])
val, t, err := session.engine.nowTime(deletedColumn) val, t, err := session.engine.nowTime(deletedColumn)
if err != nil { if err != nil {
return 0, err return 0, err
} }
condArgs[0] = val realSQLWriter.Append(val)
realSQLWriter.Append(condWriter.Args()...)
var colName = deletedColumn.Name if err := utils.WriteBuilder(realSQLWriter, orderCondWriter); err != nil {
return 0, err
}
colName := deletedColumn.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)
@ -232,11 +215,11 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
} }
if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache {
_ = session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) _ = session.cacheDelete(table, tableNameNoQuote, deleteSQLWriter.String(), argsForCache...)
} }
session.statement.RefTable = table session.statement.RefTable = table
res, err := session.exec(realSQL, condArgs...) res, err := session.exec(realSQLWriter.String(), realSQLWriter.Args()...)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -60,9 +60,7 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte
if len(session.statement.ColumnMap) > 0 && !session.statement.IsDistinct { if len(session.statement.ColumnMap) > 0 && !session.statement.IsDistinct {
session.statement.ColumnMap = []string{} session.statement.ColumnMap = []string{}
} }
if session.statement.OrderStr != "" { session.statement.ResetOrderBy()
session.statement.OrderStr = ""
}
if session.statement.LimitN != nil { if session.statement.LimitN != nil {
session.statement.LimitN = nil session.statement.LimitN = nil
} }
@ -85,15 +83,15 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
var isSlice = sliceValue.Kind() == reflect.Slice isSlice := sliceValue.Kind() == reflect.Slice
var isMap = sliceValue.Kind() == reflect.Map isMap := sliceValue.Kind() == reflect.Map
if !isSlice && !isMap { if !isSlice && !isMap {
return errors.New("needs a pointer to a slice or a map") return errors.New("needs a pointer to a slice or a map")
} }
sliceElementType := sliceValue.Type().Elem() sliceElementType := sliceValue.Type().Elem()
var tp = tpStruct tp := tpStruct
if session.statement.RefTable == nil { if session.statement.RefTable == nil {
if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct { if sliceElementType.Elem().Kind() == reflect.Struct {
@ -190,7 +188,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
return err return err
} }
var newElemFunc = func(fields []string) reflect.Value { newElemFunc := func(fields []string) reflect.Value {
return utils.New(elemType, len(fields), len(fields)) return utils.New(elemType, len(fields), len(fields))
} }
@ -235,7 +233,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
} }
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
var newValue = newElemFunc(fields) newValue := newElemFunc(fields)
tb, err := session.engine.tagParser.ParseWithCache(newValue) tb, err := session.engine.tagParser.ParseWithCache(newValue)
if err != nil { if err != nil {
return err return err
@ -249,7 +247,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
} }
for rows.Next() { for rows.Next() {
var newValue = newElemFunc(fields) newValue := newElemFunc(fields)
bean := newValue.Interface() bean := newValue.Interface()
switch elemType.Kind() { switch elemType.Kind() {
@ -310,7 +308,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache")
return ErrCacheFailed return ErrCacheFailed
} }
var res = make([]string, len(table.PrimaryKeys)) res := make([]string, len(table.PrimaryKeys))
err = rows.ScanSlice(&res) err = rows.ScanSlice(&res)
if err != nil { if err != nil {
return err return err
@ -342,7 +340,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
ididxes := make(map[string]int) ididxes := make(map[string]int)
var ides []schemas.PK var ides []schemas.PK
var temps = make([]interface{}, len(ids)) temps := make([]interface{}, len(ids))
for idx, id := range ids { for idx, id := range ids {
sid, err := id.ToString() sid, err := id.ToString()
@ -457,7 +455,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean))))
} }
} else if sliceValue.Kind() == reflect.Map { } else if sliceValue.Kind() == reflect.Map {
var key = ids[j] key := ids[j]
keyType := sliceValue.Type().Key() keyType := sliceValue.Type().Key()
keyValue := reflect.New(keyType) keyValue := reflect.New(keyType)
var ikey interface{} var ikey interface{}

View File

@ -60,7 +60,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
ids = make([]schemas.PK, 0) ids = make([]schemas.PK, 0)
for rows.Next() { for rows.Next() {
var res = make([]string, len(table.PrimaryKeys)) res := make([]string, len(table.PrimaryKeys))
err = rows.ScanSlice(&res) err = rows.ScanSlice(&res)
if err != nil { if err != nil {
return err return err
@ -176,8 +176,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
// -- // --
var err error var err error
var isMap = t.Kind() == reflect.Map isMap := t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct isStruct := t.Kind() == reflect.Struct
if isStruct { if isStruct {
if err := session.statement.SetRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return 0, err return 0, err
@ -226,7 +226,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
args = append(args, val) args = append(args, val)
} }
var colName = col.Name colName := col.Name
if isStruct { if isStruct {
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
@ -258,10 +258,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp) colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp)
case *builder.Builder: case *builder.Builder:
subQuery, subArgs, err := session.statement.GenCondSQL(tp) subQuery, subArgs, err := builder.ToSQL(tp)
if err != nil { if err != nil {
return 0, err return 0, err
} }
subQuery = session.statement.ReplaceQuote(subQuery)
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")") colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")")
args = append(args, subArgs...) args = append(args, subArgs...)
default: default:
@ -279,7 +280,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
condBeanIsStruct := false condBeanIsStruct := false
if len(condiBean) > 0 { if len(condiBean) > 0 {
if c, ok := condiBean[0].(map[string]interface{}); ok { if c, ok := condiBean[0].(map[string]interface{}); ok {
var eq = make(builder.Eq) eq := make(builder.Eq)
for k, v := range c { for k, v := range c {
eq[session.engine.Quote(k)] = v eq[session.engine.Quote(k)] = v
} }
@ -323,11 +324,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
st := session.statement st := session.statement
var ( var (
sqlStr string
condArgs []interface{}
condSQL string
cond = session.statement.Conds().And(autoCond) cond = session.statement.Conds().And(autoCond)
doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion) doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion)
verValue *reflect.Value verValue *reflect.Value
) )
@ -347,70 +344,65 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, ErrNoColumnsTobeUpdated return 0, ErrNoColumnsTobeUpdated
} }
condSQL, condArgs, err = session.statement.GenCondSQL(cond) whereWriter := builder.NewWriter()
if err != nil { if cond.IsValid() {
fmt.Fprint(whereWriter, "WHERE ")
}
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
return 0, err
}
if err := st.WriteOrderBy(whereWriter); err != nil {
return 0, err return 0, err
} }
if len(condSQL) > 0 { tableName := session.statement.TableName()
condSQL = "WHERE " + condSQL
}
if st.OrderStr != "" {
condSQL += fmt.Sprintf(" ORDER BY %v", st.OrderStr)
}
var tableName = session.statement.TableName()
// TODO: Oracle support needed // TODO: Oracle support needed
var top string var top string
if st.LimitN != nil { if st.LimitN != nil {
limitValue := *st.LimitN limitValue := *st.LimitN
switch session.engine.dialect.URI().DBType { switch session.engine.dialect.URI().DBType {
case schemas.MYSQL: case schemas.MYSQL:
condSQL += fmt.Sprintf(" LIMIT %d", limitValue) fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
case schemas.SQLITE: case schemas.SQLITE:
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil { whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
return 0, err return 0, err
} }
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
case schemas.POSTGRES: case schemas.POSTGRES:
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil { whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
return 0, err return 0, err
} }
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
case schemas.MSSQL: case schemas.MSSQL:
if st.OrderStr != "" && table != nil && len(table.PrimaryKeys) == 1 { if st.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
session.engine.Quote(tableName), condSQL), condArgs...) session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)
condSQL, condArgs, err = session.statement.GenCondSQL(cond) whereWriter = builder.NewWriter()
if err != nil { fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(whereWriter); err != nil {
return 0, err return 0, err
} }
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
} else { } else {
top = fmt.Sprintf("TOP (%d) ", limitValue) top = fmt.Sprintf("TOP (%d) ", limitValue)
} }
} }
} }
var tableAlias = session.engine.Quote(tableName) tableAlias := session.engine.Quote(tableName)
var fromSQL string var fromSQL string
if session.statement.TableAlias != "" { if session.statement.TableAlias != "" {
switch session.engine.dialect.URI().DBType { switch session.engine.dialect.URI().DBType {
@ -422,14 +414,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v%v", updateWriter := builder.NewWriter()
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v",
top, top,
tableAlias, tableAlias,
strings.Join(colNames, ", "), strings.Join(colNames, ", "),
fromSQL, fromSQL); err != nil {
condSQL) return 0, err
}
if err := utils.WriteBuilder(updateWriter, whereWriter); err != nil {
return 0, err
}
res, err := session.exec(sqlStr, append(args, condArgs...)...) res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...)
if err != nil { if err != nil {
return 0, err return 0, err
} else if doIncVer { } else if doIncVer {
@ -535,7 +532,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
} }
args = append(args, val) args = append(args, val)
var colName = col.Name colName := col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)