SetExpr support more go types #1499

Merged
lunny merged 2 commits from lunny/support_set_expr_gotypes into master 2020-01-19 09:36:08 +00:00
5 changed files with 153 additions and 64 deletions

View File

@ -729,66 +729,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
args = append(args, m[colName])
}
w := builder.NewWriter()
if session.statement.cond.IsValid() {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err
}
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
return 0, err
}
if _, err := w.WriteString(") SELECT "); err != nil {
return 0, err
}
if err := session.statement.writeArgs(w, args); err != nil {
return 0, err
}
if len(exprs.args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.writeArgs(w); err != nil {
return 0, err
}
}
if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil {
return 0, err
}
if err := session.statement.cond.WriteTo(w); err != nil {
return 0, err
}
} else {
qm := strings.Repeat("?,", len(columns))
qm = qm[:len(qm)-1]
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil {
return 0, err
}
w.Append(args...)
}
sql := w.String()
args = w.Args()
if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return affected, nil
return session.insertMap(columns, args)
}
func (session *Session) insertMapString(m map[string]string) (int64, error) {
@ -808,6 +749,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var args = make([]interface{}, 0, len(m))
@ -815,7 +757,18 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
args = append(args, m[colName])
}
return session.insertMap(columns, args)
}
func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) {
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
exprs := session.statement.exprColumns
w := builder.NewWriter()
// if insert where
if session.statement.cond.IsValid() {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err
@ -853,10 +806,29 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
qm := strings.Repeat("?,", len(columns))
qm = qm[:len(qm)-1]
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err
}
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
return 0, err
}
if _, err := w.WriteString(fmt.Sprintf(") VALUES (%s", qm)); err != nil {
return 0, err
}
w.Append(args...)
if len(exprs.args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.writeArgs(w); err != nil {
return 0, err
}
}
if _, err := w.WriteString(")"); err != nil {
return 0, err
}
}
sql := w.String()

View File

@ -928,6 +928,64 @@ func TestInsertWhere(t *testing.T) {
assert.EqualValues(t, 5, j5.Index)
}
func TestInsertExpr2(t *testing.T) {
assert.NoError(t, prepareEngine())
type InsertExprsRelease struct {
Id int64
RepoId int
IsTag bool
IsDraft bool
NumCommits int
Sha1 string
}
assertSync(t, new(InsertExprsRelease))
var ie = InsertExprsRelease{
RepoId: 1,
IsTag: true,
}
inserted, err := testEngine.
SetExpr("is_draft", true).
SetExpr("num_commits", 0).
SetExpr("sha1", "").
Insert(&ie)
assert.NoError(t, err)
assert.EqualValues(t, 1, inserted)
var ie2 InsertExprsRelease
has, err := testEngine.ID(ie.Id).Get(&ie2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, true, ie2.IsDraft)
assert.EqualValues(t, "", ie2.Sha1)
assert.EqualValues(t, 0, ie2.NumCommits)
assert.EqualValues(t, 1, ie2.RepoId)
assert.EqualValues(t, true, ie2.IsTag)

It would be nicer if the test would reread the inserted rows and check the column values.

It would be nicer if the test would reread the inserted rows and check the column values.
inserted, err = testEngine.Table(new(InsertExprsRelease)).
SetExpr("is_draft", true).
SetExpr("num_commits", 0).
SetExpr("sha1", "").
Insert(map[string]interface{}{
"repo_id": 1,
"is_tag": true,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, inserted)
var ie3 InsertExprsRelease
has, err = testEngine.ID(ie.Id + 1).Get(&ie3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, true, ie3.IsDraft)
assert.EqualValues(t, "", ie3.Sha1)
assert.EqualValues(t, 0, ie3.NumCommits)
assert.EqualValues(t, 1, ie3.RepoId)
assert.EqualValues(t, true, ie3.IsTag)
}
type NightlyRate struct {
ID int64 `xorm:"'id' not null pk BIGINT(20)" json:"id"`
}

View File

@ -239,14 +239,20 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
for i, colName := range exprColumns.colNames {
switch tp := exprColumns.args[i].(type) {
case string:
colNames = append(colNames, session.engine.Quote(colName)+" = "+tp)
if len(tp) == 0 {
tp = "''"
}
colNames = append(colNames, session.engine.Quote(colName)+"="+tp)

So this:

SetExpr("sha1", "''")

is equivalent to:

SetExpr("sha1", "")

It's a little awkward, but I don't see a way around it either.

So this: ``` SetExpr("sha1", "''") ``` is equivalent to: ``` SetExpr("sha1", "") ``` It's a little awkward, but I don't see a way around it either.
case *builder.Builder:
subQuery, subArgs, err := builder.ToSQL(tp)
if err != nil {
return 0, err
}
colNames = append(colNames, session.engine.Quote(colName)+" = ("+subQuery+")")
colNames = append(colNames, session.engine.Quote(colName)+"=("+subQuery+")")
args = append(args, subArgs...)
default:
colNames = append(colNames, session.engine.Quote(colName)+"=?")
args = append(args, exprColumns.args[i])
}
}

View File

@ -1359,3 +1359,48 @@ func TestUpdateAlias(t *testing.T) {
assert.EqualValues(t, 2, ue.NumIssues)
assert.EqualValues(t, "lunny xiao", ue.Name)
}
func TestUpdateExprs2(t *testing.T) {
assert.NoError(t, prepareEngine())
type UpdateExprsRelease struct {
Id int64
RepoId int
IsTag bool
IsDraft bool
NumCommits int
Sha1 string
}
assertSync(t, new(UpdateExprsRelease))
var uer = UpdateExprsRelease{
RepoId: 1,
IsTag: false,
IsDraft: false,
NumCommits: 1,
Sha1: "sha1",
}
inserted, err := testEngine.Insert(&uer)
assert.NoError(t, err)
assert.EqualValues(t, 1, inserted)
updated, err := testEngine.
Where("repo_id = ? AND is_tag = ?", 1, false).
SetExpr("is_draft", true).
SetExpr("num_commits", 0).
SetExpr("sha1", "").
Update(new(UpdateExprsRelease))
assert.NoError(t, err)
assert.EqualValues(t, 1, updated)
var uer2 UpdateExprsRelease
has, err := testEngine.ID(uer.Id).Get(&uer2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, uer2.RepoId)
assert.EqualValues(t, false, uer2.IsTag)
assert.EqualValues(t, true, uer2.IsDraft)
assert.EqualValues(t, 0, uer2.NumCommits)
assert.EqualValues(t, "", uer2.Sha1)
}

View File

@ -69,10 +69,18 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
if _, err := w.WriteString(")"); err != nil {
return err
}
default:
case string:
if arg == "" {
arg = "''"
}
if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil {
return err
}
default:
if _, err := w.WriteString("?"); err != nil {
return err
}
w.Append(arg)
}
if i != len(exprs.args)-1 {
if _, err := w.WriteString(","); err != nil {