Refactor orderby and support arguments #2150

Merged
lunny merged 14 commits from lunny/orderby into master 2022-05-31 03:00:29 +00:00
6 changed files with 93 additions and 72 deletions
Showing only changes of commit a0f42c421a - Show all commits

View File

@ -249,7 +249,7 @@ func TestOrder(t *testing.T) {
assert.NoError(t, err)
users = make([]Userinfo, 0)
err = testEngine.OrderBy("case username like ? desc", "a").Find(&users)
err = testEngine.OrderBy("CASE WHEN username LIKE ? THEN 0 ELSE 1 END DESC", "a").Find(&users)
assert.NoError(t, err)
}

View File

@ -11,6 +11,7 @@ import (
"strings"
"xorm.io/builder"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
@ -250,12 +251,13 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
distinct = "DISTINCT "
}
condSQL, condArgs, err := statement.GenCondSQL(statement.cond)
if err != nil {
condWriter := builder.NewWriter()
if err := statement.cond.WriteTo(condWriter); err != nil {
return "", nil, err
}
if len(condSQL) > 0 {
whereStr = fmt.Sprintf(" WHERE %s", condSQL)
if condWriter.Len() > 0 {
whereStr = " WHERE "
}
pLimitN := statement.LimitN
@ -297,11 +299,13 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
}
}
if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s%s%s%s",
column, statement.Start, column, fromStr, whereStr, orderByWriter.String()); err != nil {
if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s%s%s",
column, statement.Start, column, fromStr, whereStr); err != nil {
return "", nil, err
}
if err := utils.WriteBuilder(mssqlCondi, condWriter, orderByWriter); err != nil {
return "", nil, err
}
mssqlCondi.Append(orderByWriter.Args()...)
if err := statement.WriteGroupBy(mssqlCondi); err != nil {
return "", nil, err
@ -315,14 +319,19 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
buf := builder.NewWriter()
fmt.Fprintf(buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
if err := utils.WriteBuilder(buf, condWriter); err != nil {
return "", nil, err
}
if mssqlCondi.Len() > 0 {
if len(whereStr) > 0 {
fmt.Fprint(buf, " AND ")
} else {
fmt.Fprint(buf, " WHERE ")
}
fmt.Fprint(buf, mssqlCondi.String())
buf.Append(mssqlCondi.Args()...)
if err := utils.WriteBuilder(buf, mssqlCondi); err != nil {
return "", nil, err
}
}
if err := statement.WriteGroupBy(buf); err != nil {
@ -361,10 +370,10 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
}
}
if statement.IsForUpdate {
return dialect.ForUpdateSQL(buf.String()), condArgs, nil
return dialect.ForUpdateSQL(buf.String()), buf.Args(), nil
}
return buf.String(), condArgs, nil
return buf.String(), buf.Args(), nil
}
// GenExistSQL generates Exist SQL

View File

@ -455,6 +455,10 @@ func (statement *Statement) Limit(limit int, start ...int) *Statement {
return statement
}
func (statement *Statement) HasOrderBy() bool {
return statement.OrderStr != ""
}
// ResetOrderBy reset ordery conditions
func (statement *Statement) ResetOrderBy() {
statement.OrderStr = ""

22
internal/utils/builder.go Normal file
View File

@ -0,0 +1,22 @@
// Copyright 2022 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package utils
import (
"fmt"
"xorm.io/builder"
)
// WriteBuilder writes writers to one
func WriteBuilder(w *builder.BytesWriter, inputs ...*builder.BytesWriter) error {
for _, input := range inputs {
if _, err := fmt.Fprint(w, input.String()); err != nil {
return err
}
w.Append(input.Args()...)
}
return nil
}

View File

@ -11,6 +11,7 @@ import (
"xorm.io/builder"
"xorm.io/xorm/caches"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
@ -89,16 +90,6 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri
return nil
}
func writeBuilder(w *builder.BytesWriter, inputs ...*builder.BytesWriter) error {
for _, input := range inputs {
if _, err := fmt.Fprint(w, input.String()); err != nil {
return err
}
w.Append(input.Args()...)
}
return nil
}
// Delete records, bean's non-empty fields are conditions
func (session *Session) Delete(beans ...interface{}) (int64, error) {
if session.isAutoClose {
@ -194,7 +185,7 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
copy(argsForCache, deleteSQLWriter.Args())
argsForCache = append(deleteSQLWriter.Args(), argsForCache...)
if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled
if err := writeBuilder(realSQLWriter, deleteSQLWriter, orderCondWriter); err != nil {
if err := utils.WriteBuilder(realSQLWriter, deleteSQLWriter, orderCondWriter); err != nil {
return 0, err
}
} else {
@ -212,7 +203,7 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
realSQLWriter.Append(val)
realSQLWriter.Append(condWriter.Args()...)
if err := writeBuilder(realSQLWriter, orderCondWriter); err != nil {
if err := utils.WriteBuilder(realSQLWriter, orderCondWriter); err != nil {
return 0, err
}

View File

@ -60,7 +60,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
ids = make([]schemas.PK, 0)
for rows.Next() {
var res = make([]string, len(table.PrimaryKeys))
res := make([]string, len(table.PrimaryKeys))
err = rows.ScanSlice(&res)
if err != nil {
return err
@ -176,8 +176,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
// --
var err error
var isMap = t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct
isMap := t.Kind() == reflect.Map
isStruct := t.Kind() == reflect.Struct
if isStruct {
if err := session.statement.SetRefBean(bean); err != nil {
return 0, err
@ -226,7 +226,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
args = append(args, val)
}
var colName = col.Name
colName := col.Name
if isStruct {
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
@ -279,7 +279,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
condBeanIsStruct := false
if len(condiBean) > 0 {
if c, ok := condiBean[0].(map[string]interface{}); ok {
var eq = make(builder.Eq)
eq := make(builder.Eq)
for k, v := range c {
eq[session.engine.Quote(k)] = v
}
@ -323,11 +323,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
st := session.statement
var (
sqlStr string
condArgs []interface{}
condSQL string
cond = session.statement.Conds().And(autoCond)
doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion)
verValue *reflect.Value
)
@ -347,70 +343,65 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, ErrNoColumnsTobeUpdated
}
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil {
whereWriter := builder.NewWriter()
if cond.IsValid() {
fmt.Fprint(whereWriter, "WHERE ")
}
if err := cond.WriteTo(whereWriter); err != nil {
return 0, err
}
if err := st.WriteOrderBy(whereWriter); err != nil {
return 0, err
}
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
if st.OrderStr != "" {
condSQL += fmt.Sprintf(" ORDER BY %v", st.OrderStr)
}
var tableName = session.statement.TableName()
tableName := session.statement.TableName()
// TODO: Oracle support needed
var top string
if st.LimitN != nil {
limitValue := *st.LimitN
switch session.engine.dialect.URI().DBType {
case schemas.MYSQL:
condSQL += fmt.Sprintf(" LIMIT %d", limitValue)
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
case schemas.SQLITE:
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil {
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(whereWriter); err != nil {
return 0, err
}
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
case schemas.POSTGRES:
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil {
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(whereWriter); err != nil {
return 0, err
}
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
case schemas.MSSQL:
if st.OrderStr != "" && table != nil && len(table.PrimaryKeys) == 1 {
if st.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
session.engine.Quote(tableName), condSQL), condArgs...)
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)
condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil {
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(whereWriter); err != nil {
return 0, err
}
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
} else {
top = fmt.Sprintf("TOP (%d) ", limitValue)
}
}
}
var tableAlias = session.engine.Quote(tableName)
tableAlias := session.engine.Quote(tableName)
var fromSQL string
if session.statement.TableAlias != "" {
switch session.engine.dialect.URI().DBType {
@ -422,14 +413,18 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}
sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v%v",
updateWriter := builder.NewWriter()
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v%v",
top,
tableAlias,
strings.Join(colNames, ", "),
fromSQL,
condSQL)
whereWriter.String()); err != nil {
return 0, err
}
updateWriter.Append(whereWriter.Args()...)
res, err := session.exec(sqlStr, append(args, condArgs...)...)
res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...)
if err != nil {
return 0, err
} else if doIncVer {
@ -535,7 +530,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
}
args = append(args, val)
var colName = col.Name
colName := col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)