fix setexpr missing big quotes #1431

Merged
lunny merged 3 commits from lunny/fix_setexpr into master 2019-09-24 04:58:26 +00:00
4 changed files with 30 additions and 2 deletions
Showing only changes of commit 9daf957e5f - Show all commits

View File

@ -7,8 +7,9 @@ package xorm
import ( import (
"testing" "testing"
"xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/builder"
"xorm.io/core"
) )
func TestSetExpr(t *testing.T) { func TestSetExpr(t *testing.T) {
@ -34,6 +35,15 @@ func TestSetExpr(t *testing.T) {
cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr)) cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr))
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
cnt, err = testEngine.SetExpr("show",
builder.Select("NOT show").
From("user_expr").
Where(builder.Eq{"id": 1})).
ID(1).
Update(new(UserExpr))
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
} }
func TestCols(t *testing.T) { func TestCols(t *testing.T) {

View File

@ -245,7 +245,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if err != nil { if err != nil {
return 0, err return 0, err
} }
colNames = append(colNames, session.engine.Quote(colName)+" = "+subQuery) colNames = append(colNames, session.engine.Quote(colName)+" = ("+subQuery+")")
args = append(args, subArgs...) args = append(args, subArgs...)
} }
} }

View File

@ -17,9 +17,15 @@ func writeArg(w *builder.BytesWriter, arg interface{}) error {
return err return err
} }
case *builder.Builder: case *builder.Builder:
if _, err := w.WriteString("("); err != nil {
return err
}
if err := argv.WriteTo(w); err != nil { if err := argv.WriteTo(w); err != nil {
return err return err
} }
if _, err := w.WriteString(")"); err != nil {
return err
}
default: default:
if _, err := w.WriteString(fmt.Sprintf("%v", argv)); err != nil { if _, err := w.WriteString(fmt.Sprintf("%v", argv)); err != nil {
return err return err

View File

@ -60,9 +60,15 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
for _, expr := range exprs.args { for _, expr := range exprs.args {
switch arg := expr.(type) { switch arg := expr.(type) {
case *builder.Builder: case *builder.Builder:
if _, err := w.WriteString("("); err != nil {
return err
}
if err := arg.WriteTo(w); err != nil { if err := arg.WriteTo(w); err != nil {
return err return err
} }
if _, err := w.WriteString(")"); err != nil {
return err
}
default: default:
if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil { if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil {
return err return err
@ -83,9 +89,15 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error {
switch arg := exprs.args[i].(type) { switch arg := exprs.args[i].(type) {
case *builder.Builder: case *builder.Builder:
if _, err := w.WriteString("("); err != nil {
return err
}
if err := arg.WriteTo(w); err != nil { if err := arg.WriteTo(w); err != nil {
return err return err
} }
if _, err := w.WriteString("("); err != nil {
return err
}
default: default:
w.Append(exprs.args[i]) w.Append(exprs.args[i])
} }