Browse Source

SetExpr support more go types (#1499)

Improve tests

SetExpr support more go types

fix vet

fix drone lint

remove go1.10 test on drone

Reviewed-on: #1499
tags/v0.8.2
Lunny Xiao 1 month ago
parent
commit
a18e35f7f5
5 changed files with 153 additions and 64 deletions
  1. +33
    -61
      session_insert.go
  2. +58
    -0
      session_insert_test.go
  3. +8
    -2
      session_update.go
  4. +45
    -0
      session_update_test.go
  5. +9
    -1
      statement_exprparam.go

+ 33
- 61
session_insert.go View File

@@ -729,66 +729,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
args = append(args, m[colName])
}

w := builder.NewWriter()
if session.statement.cond.IsValid() {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err
}

if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
return 0, err
}

if _, err := w.WriteString(") SELECT "); err != nil {
return 0, err
}

if err := session.statement.writeArgs(w, args); err != nil {
return 0, err
}

if len(exprs.args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.writeArgs(w); err != nil {
return 0, err
}
}

if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil {
return 0, err
}

if err := session.statement.cond.WriteTo(w); err != nil {
return 0, err
}
} else {
qm := strings.Repeat("?,", len(columns))
qm = qm[:len(qm)-1]

if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil {
return 0, err
}
w.Append(args...)
}

sql := w.String()
args = w.Args()

if err := session.cacheInsert(tableName); err != nil {
return 0, err
}

res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return affected, nil
return session.insertMap(columns, args)
}

func (session *Session) insertMapString(m map[string]string) (int64, error) {
@@ -808,6 +749,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
columns = append(columns, k)
}
}

sort.Strings(columns)

var args = make([]interface{}, 0, len(m))
@@ -815,7 +757,18 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
args = append(args, m[colName])
}

return session.insertMap(columns, args)
}

func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) {
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}

exprs := session.statement.exprColumns
w := builder.NewWriter()
// if insert where
if session.statement.cond.IsValid() {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err
@@ -853,10 +806,29 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
qm := strings.Repeat("?,", len(columns))
qm = qm[:len(qm)-1]

if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err
}

if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
return 0, err
}
if _, err := w.WriteString(fmt.Sprintf(") VALUES (%s", qm)); err != nil {
return 0, err
}

w.Append(args...)
if len(exprs.args) > 0 {
if _, err := w.WriteString(","); err != nil {
return 0, err
}
if err := exprs.writeArgs(w); err != nil {
return 0, err
}
}
if _, err := w.WriteString(")"); err != nil {
return 0, err
}
}

sql := w.String()


+ 58
- 0
session_insert_test.go View File

@@ -928,6 +928,64 @@ func TestInsertWhere(t *testing.T) {
assert.EqualValues(t, 5, j5.Index)
}

func TestInsertExpr2(t *testing.T) {
assert.NoError(t, prepareEngine())

type InsertExprsRelease struct {
Id int64
RepoId int
IsTag bool
IsDraft bool
NumCommits int
Sha1 string
}

assertSync(t, new(InsertExprsRelease))

var ie = InsertExprsRelease{
RepoId: 1,
IsTag: true,
}
inserted, err := testEngine.
SetExpr("is_draft", true).
SetExpr("num_commits", 0).
SetExpr("sha1", "").
Insert(&ie)
assert.NoError(t, err)
assert.EqualValues(t, 1, inserted)

var ie2 InsertExprsRelease
has, err := testEngine.ID(ie.Id).Get(&ie2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, true, ie2.IsDraft)
assert.EqualValues(t, "", ie2.Sha1)
assert.EqualValues(t, 0, ie2.NumCommits)
assert.EqualValues(t, 1, ie2.RepoId)
assert.EqualValues(t, true, ie2.IsTag)

inserted, err = testEngine.Table(new(InsertExprsRelease)).
SetExpr("is_draft", true).
SetExpr("num_commits", 0).
SetExpr("sha1", "").
Insert(map[string]interface{}{
"repo_id": 1,
"is_tag": true,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, inserted)

var ie3 InsertExprsRelease
has, err = testEngine.ID(ie.Id + 1).Get(&ie3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, true, ie3.IsDraft)
assert.EqualValues(t, "", ie3.Sha1)
assert.EqualValues(t, 0, ie3.NumCommits)
assert.EqualValues(t, 1, ie3.RepoId)
assert.EqualValues(t, true, ie3.IsTag)
}

type NightlyRate struct {
ID int64 `xorm:"'id' not null pk BIGINT(20)" json:"id"`
}


+ 8
- 2
session_update.go View File

@@ -239,14 +239,20 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
for i, colName := range exprColumns.colNames {
switch tp := exprColumns.args[i].(type) {
case string:
colNames = append(colNames, session.engine.Quote(colName)+" = "+tp)
if len(tp) == 0 {
tp = "''"
}
colNames = append(colNames, session.engine.Quote(colName)+"="+tp)
case *builder.Builder:
subQuery, subArgs, err := builder.ToSQL(tp)
if err != nil {
return 0, err
}
colNames = append(colNames, session.engine.Quote(colName)+" = ("+subQuery+")")
colNames = append(colNames, session.engine.Quote(colName)+"=("+subQuery+")")
args = append(args, subArgs...)
default:
colNames = append(colNames, session.engine.Quote(colName)+"=?")
args = append(args, exprColumns.args[i])
}
}



+ 45
- 0
session_update_test.go View File

@@ -1359,3 +1359,48 @@ func TestUpdateAlias(t *testing.T) {
assert.EqualValues(t, 2, ue.NumIssues)
assert.EqualValues(t, "lunny xiao", ue.Name)
}

func TestUpdateExprs2(t *testing.T) {
assert.NoError(t, prepareEngine())

type UpdateExprsRelease struct {
Id int64
RepoId int
IsTag bool
IsDraft bool
NumCommits int
Sha1 string
}

assertSync(t, new(UpdateExprsRelease))

var uer = UpdateExprsRelease{
RepoId: 1,
IsTag: false,
IsDraft: false,
NumCommits: 1,
Sha1: "sha1",
}
inserted, err := testEngine.Insert(&uer)
assert.NoError(t, err)
assert.EqualValues(t, 1, inserted)

updated, err := testEngine.
Where("repo_id = ? AND is_tag = ?", 1, false).
SetExpr("is_draft", true).
SetExpr("num_commits", 0).
SetExpr("sha1", "").
Update(new(UpdateExprsRelease))
assert.NoError(t, err)
assert.EqualValues(t, 1, updated)

var uer2 UpdateExprsRelease
has, err := testEngine.ID(uer.Id).Get(&uer2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, 1, uer2.RepoId)
assert.EqualValues(t, false, uer2.IsTag)
assert.EqualValues(t, true, uer2.IsDraft)
assert.EqualValues(t, 0, uer2.NumCommits)
assert.EqualValues(t, "", uer2.Sha1)
}

+ 9
- 1
statement_exprparam.go View File

@@ -69,10 +69,18 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
if _, err := w.WriteString(")"); err != nil {
return err
}
default:
case string:
if arg == "" {
arg = "''"
}
if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil {
return err
}
default:
if _, err := w.WriteString("?"); err != nil {
return err
}
w.Append(arg)
}
if i != len(exprs.args)-1 {
if _, err := w.WriteString(","); err != nil {


Loading…
Cancel
Save