SetExpr support more go types #1499

Merged
lunny merged 2 commits from lunny/support_set_expr_gotypes into master 2020-01-19 09:36:08 +00:00
5 changed files with 153 additions and 64 deletions

View File

@ -729,66 +729,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
args = append(args, m[colName]) args = append(args, m[colName])
} }
w := builder.NewWriter() return session.insertMap(columns, args)
if session.statement.cond.IsValid() {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err
}
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
return 0, err
}
if _, err := w.WriteString(") SELECT "); err != nil {
return 0, err
}
if err := session.statement.writeArgs(w, args); err != nil {
return 0, err
}
if len(exprs.args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.writeArgs(w); err != nil {
return 0, err
}
}
if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil {
return 0, err
}
if err := session.statement.cond.WriteTo(w); err != nil {
return 0, err
}
} else {
qm := strings.Repeat("?,", len(columns))
qm = qm[:len(qm)-1]
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil {
return 0, err
}
w.Append(args...)
}
sql := w.String()
args = w.Args()
if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return affected, nil
} }
func (session *Session) insertMapString(m map[string]string) (int64, error) { func (session *Session) insertMapString(m map[string]string) (int64, error) {
@ -808,6 +749,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
columns = append(columns, k) columns = append(columns, k)
} }
} }
sort.Strings(columns) sort.Strings(columns)
var args = make([]interface{}, 0, len(m)) var args = make([]interface{}, 0, len(m))
@ -815,7 +757,18 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
args = append(args, m[colName]) args = append(args, m[colName])
} }
return session.insertMap(columns, args)
}
func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) {
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
exprs := session.statement.exprColumns
w := builder.NewWriter() w := builder.NewWriter()
// if insert where
if session.statement.cond.IsValid() { if session.statement.cond.IsValid() {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err return 0, err
@ -853,10 +806,29 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
qm := strings.Repeat("?,", len(columns)) qm := strings.Repeat("?,", len(columns))
qm = qm[:len(qm)-1] qm = qm[:len(qm)-1]
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil { if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err return 0, err
} }
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
return 0, err
}
if _, err := w.WriteString(fmt.Sprintf(") VALUES (%s", qm)); err != nil {
return 0, err
}
w.Append(args...) w.Append(args...)
if len(exprs.args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.writeArgs(w); err != nil {
return 0, err
}
}
if _, err := w.WriteString(")"); err != nil {
return 0, err
}
} }
sql := w.String() sql := w.String()

View File

@ -928,6 +928,64 @@ func TestInsertWhere(t *testing.T) {
assert.EqualValues(t, 5, j5.Index) assert.EqualValues(t, 5, j5.Index)
} }
func TestInsertExpr2(t *testing.T) {
assert.NoError(t, prepareEngine())
type InsertExprsRelease struct {
Id int64
RepoId int
IsTag bool
IsDraft bool
NumCommits int
Sha1 string
}
assertSync(t, new(InsertExprsRelease))
var ie = InsertExprsRelease{
RepoId: 1,
IsTag: true,
}
inserted, err := testEngine.
SetExpr("is_draft", true).
SetExpr("num_commits", 0).
SetExpr("sha1", "").
Insert(&ie)
assert.NoError(t, err)
assert.EqualValues(t, 1, inserted)
var ie2 InsertExprsRelease
has, err := testEngine.ID(ie.Id).Get(&ie2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, true, ie2.IsDraft)
assert.EqualValues(t, "", ie2.Sha1)
assert.EqualValues(t, 0, ie2.NumCommits)
assert.EqualValues(t, 1, ie2.RepoId)
assert.EqualValues(t, true, ie2.IsTag)

It would be nicer if the test would reread the inserted rows and check the column values.

It would be nicer if the test would reread the inserted rows and check the column values.
inserted, err = testEngine.Table(new(InsertExprsRelease)).
SetExpr("is_draft", true).
SetExpr("num_commits", 0).
SetExpr("sha1", "").
Insert(map[string]interface{}{
"repo_id": 1,
"is_tag": true,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, inserted)
var ie3 InsertExprsRelease
has, err = testEngine.ID(ie.Id + 1).Get(&ie3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, true, ie3.IsDraft)
assert.EqualValues(t, "", ie3.Sha1)
assert.EqualValues(t, 0, ie3.NumCommits)
assert.EqualValues(t, 1, ie3.RepoId)
assert.EqualValues(t, true, ie3.IsTag)
}
type NightlyRate struct { type NightlyRate struct {
ID int64 `xorm:"'id' not null pk BIGINT(20)" json:"id"` ID int64 `xorm:"'id' not null pk BIGINT(20)" json:"id"`
} }

View File

@ -239,14 +239,20 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
for i, colName := range exprColumns.colNames { for i, colName := range exprColumns.colNames {
switch tp := exprColumns.args[i].(type) { switch tp := exprColumns.args[i].(type) {
case string: case string:
colNames = append(colNames, session.engine.Quote(colName)+" = "+tp) if len(tp) == 0 {
tp = "''"
}
colNames = append(colNames, session.engine.Quote(colName)+"="+tp)

So this:

SetExpr("sha1", "''")

is equivalent to:

SetExpr("sha1", "")

It's a little awkward, but I don't see a way around it either.

So this: ``` SetExpr("sha1", "''") ``` is equivalent to: ``` SetExpr("sha1", "") ``` It's a little awkward, but I don't see a way around it either.
case *builder.Builder: case *builder.Builder:
subQuery, subArgs, err := builder.ToSQL(tp) subQuery, subArgs, err := builder.ToSQL(tp)
if err != nil { if err != nil {
return 0, err return 0, err
} }
colNames = append(colNames, session.engine.Quote(colName)+" = ("+subQuery+")") colNames = append(colNames, session.engine.Quote(colName)+"=("+subQuery+")")
args = append(args, subArgs...) args = append(args, subArgs...)
default:
colNames = append(colNames, session.engine.Quote(colName)+"=?")
args = append(args, exprColumns.args[i])
} }
} }

View File

@ -1359,3 +1359,48 @@ func TestUpdateAlias(t *testing.T) {
assert.EqualValues(t, 2, ue.NumIssues) assert.EqualValues(t, 2, ue.NumIssues)
assert.EqualValues(t, "lunny xiao", ue.Name) assert.EqualValues(t, "lunny xiao", ue.Name)
} }
func TestUpdateExprs2(t *testing.T) {
assert.NoError(t, prepareEngine())
type UpdateExprsRelease struct {
Id int64
RepoId int
IsTag bool
IsDraft bool
NumCommits int
Sha1 string
}
assertSync(t, new(UpdateExprsRelease))
var uer = UpdateExprsRelease{
RepoId: 1,
IsTag: false,
IsDraft: false,
NumCommits: 1,
Sha1: "sha1",
}
inserted, err := testEngine.Insert(&uer)
assert.NoError(t, err)
assert.EqualValues(t, 1, inserted)
updated, err := testEngine.
Where("repo_id = ? AND is_tag = ?", 1, false).
SetExpr("is_draft", true).
SetExpr("num_commits", 0).
SetExpr("sha1", "").
Update(new(UpdateExprsRelease))
assert.NoError(t, err)
assert.EqualValues(t, 1, updated)
var uer2 UpdateExprsRelease
has, err := testEngine.ID(uer.Id).Get(&uer2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, uer2.RepoId)
assert.EqualValues(t, false, uer2.IsTag)
assert.EqualValues(t, true, uer2.IsDraft)
assert.EqualValues(t, 0, uer2.NumCommits)
assert.EqualValues(t, "", uer2.Sha1)
}

View File

@ -69,10 +69,18 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
if _, err := w.WriteString(")"); err != nil { if _, err := w.WriteString(")"); err != nil {
return err return err
} }
default: case string:
if arg == "" {
arg = "''"
}
if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil { if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil {
return err return err
} }
default:
if _, err := w.WriteString("?"); err != nil {
return err
}
w.Append(arg)
} }
if i != len(exprs.args)-1 { if i != len(exprs.args)-1 {
if _, err := w.WriteString(","); err != nil { if _, err := w.WriteString(","); err != nil {