Add insert select support #39

Merged
lunny merged 8 commits from lunny/insert_select into master 2018-09-28 02:22:51 +00:00
13 changed files with 148 additions and 62 deletions

View File

@ -13,6 +13,12 @@ Make sure you have installed Go 1.8+ and then:
```Go
sql, args, err := builder.Insert(Eq{"c": 1, "d": 2}).Into("table1").ToSQL()
// INSERT INTO table1 SELECT * FROM table2
sql, err := builder.Insert().Into("table1").Select().From("table2").ToBoundSQL()
// INSERT INTO table1 (a, b) SELECT b, c FROM table2
sql, err = builder.Insert("a, b").Into("table1").Select("b, c").From("table2").ToBoundSQL()
```
# Select

View File

@ -7,6 +7,7 @@ package builder
import (
sql2 "database/sql"
"fmt"
"sort"
)
type optype byte
@ -49,14 +50,16 @@ type Builder struct {
optype
dialect string
isNested bool
tableName string
into string
from string
subQuery *Builder
cond Cond
selects []string
joins []join
unions []union
limitation *limit
inserts Eq
insertCols []string
insertVals []interface{}
updates []Eq
orderBy string
groupBy string
@ -111,15 +114,15 @@ func (b *Builder) From(subject interface{}, alias ...string) *Builder {
b.subQuery = subject.(*Builder)
if len(alias) > 0 {
b.tableName = alias[0]
b.from = alias[0]
} else {
b.isNested = true
}
case string:
b.tableName = subject.(string)
b.from = subject.(string)
if len(alias) > 0 {
b.tableName = b.tableName + " " + alias[0]
b.from = b.from + " " + alias[0]
}
}
@ -128,12 +131,15 @@ func (b *Builder) From(subject interface{}, alias ...string) *Builder {
// TableName returns the table name
func (b *Builder) TableName() string {
return b.tableName
if b.optype == insertType {
return b.into
}
return b.from
}
// Into sets insert table name
func (b *Builder) Into(tableName string) *Builder {
b.tableName = tableName
b.into = tableName
return b
}
@ -221,7 +227,9 @@ func (b *Builder) FullJoin(joinTable string, joinCond interface{}) *Builder {
// Select sets select SQL
func (b *Builder) Select(cols ...string) *Builder {
b.selects = cols
b.optype = selectType
if b.optype == condType {
b.optype = selectType
}
return b
}
@ -238,8 +246,40 @@ func (b *Builder) Or(cond Cond) *Builder {
}
// Insert sets insert SQL
func (b *Builder) Insert(eq Eq) *Builder {
b.inserts = eq
func (b *Builder) Insert(eq ...interface{}) *Builder {
if len(eq) > 0 {
var paramType = -1
for _, e := range eq {
switch t := e.(type) {
case Eq:
if paramType == -1 {
paramType = 0
}
if paramType != 0 {
break
}
for k, v := range t {
b.insertCols = append(b.insertCols, k)
b.insertVals = append(b.insertVals, v)
}
case string:
if paramType == -1 {
paramType = 1
}
if paramType != 1 {
break
}
b.insertCols = append(b.insertCols, t)
}
}
}
if len(b.insertCols) == len(b.insertVals) {
sort.Slice(b.insertVals, func(i, j int) bool {
return b.insertCols[i] < b.insertCols[j]
})
sort.Strings(b.insertCols)
}
b.optype = insertType
return b
}

View File

@ -15,11 +15,11 @@ func Delete(conds ...Cond) *Builder {
}
func (b *Builder) deleteWriteTo(w Writer) error {
if len(b.tableName) <= 0 {
if len(b.from) <= 0 {
return ErrNoTableName
}
if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.tableName); err != nil {
if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.from); err != nil {
return err
}

View File

@ -10,30 +10,49 @@ import (
)
// Insert creates an insert Builder
func Insert(eq Eq) *Builder {
func Insert(eq ...interface{}) *Builder {
builder := &Builder{cond: NewCond()}
return builder.Insert(eq)
return builder.Insert(eq...)
}
func (b *Builder) insertSelectWriteTo(w Writer) error {
if _, err := fmt.Fprintf(w, "INSERT INTO %s ", b.into); err != nil {
return err
}
if len(b.insertCols) > 0 {
fmt.Fprintf(w, "(")
for _, col := range b.insertCols {
fmt.Fprintf(w, col)
}
fmt.Fprintf(w, ") ")
}
return b.selectWriteTo(w)
}
func (b *Builder) insertWriteTo(w Writer) error {
if len(b.tableName) <= 0 {
if len(b.into) <= 0 {
return ErrNoTableName
}
if len(b.inserts) <= 0 {
if len(b.insertCols) <= 0 && b.from == "" {
return ErrNoColumnToInsert
}
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.tableName); err != nil {
if b.into != "" && b.from != "" {
return b.insertSelectWriteTo(w)
}
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.into); err != nil {
return err
}
var args = make([]interface{}, 0)
var bs []byte
var valBuffer = bytes.NewBuffer(bs)
var i = 0
for _, col := range b.inserts.sortedKeys() {
value := b.inserts[col]
for i, col := range b.insertCols {
value := b.insertVals[i]
fmt.Fprint(w, col)
if e, ok := value.(expr); ok {
fmt.Fprintf(valBuffer, "(%s)", e.sql)
@ -43,7 +62,7 @@ func (b *Builder) insertWriteTo(w Writer) error {
args = append(args, value)
}
if i != len(b.inserts)-1 {
if i != len(b.insertCols)-1 {
if _, err := fmt.Fprint(w, ","); err != nil {
return err
}
@ -51,7 +70,6 @@ func (b *Builder) insertWriteTo(w Writer) error {
return err
}
}
i = i + 1
}
if _, err := fmt.Fprint(w, ") Values ("); err != nil {

41
builder_insert_test.go Normal file
View File

@ -0,0 +1,41 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestBuilderInsert(t *testing.T) {
sql, err := Insert(Eq{"c": 1, "d": 2}).Into("table1").ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,2)", sql)
sql, err = Insert(Eq{"c": 1, "d": Expr("SELECT b FROM t WHERE d=? LIMIT 1", 2)}).Into("table1").ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,(SELECT b FROM t WHERE d=2 LIMIT 1))", sql)
sql, err = Insert(Eq{"c": 1, "d": 2}).ToBoundSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrNoTableName, err)
assert.EqualValues(t, "", sql)
sql, err = Insert(Eq{}).Into("table1").ToBoundSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrNoColumnToInsert, err)
assert.EqualValues(t, "", sql)
}
func TestBuidlerInsert_Select(t *testing.T) {
sql, err := Insert().Into("table1").Select().From("table2").ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "INSERT INTO table1 SELECT * FROM table2", sql)
sql, err = Insert("a, b").Into("table1").Select("b, c").From("table2").ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "INSERT INTO table1 (a, b) SELECT b, c FROM table2", sql)
}

View File

@ -56,7 +56,9 @@ func (b *Builder) limitWriteTo(w Writer) error {
case SQLITE, MYSQL, POSTGRES:
// if type UNION, we need to write previous content back to current writer
if b.optype == unionType {
b.WriteTo(ow)
if err := b.WriteTo(ow); err != nil {
return err
}
}
if limit.offset == 0 {

View File

@ -4,12 +4,7 @@
package builder
import (
"testing"
"github.com/stretchr/testify/assert"
)
/*
func TestBuilder_Limit4Mssql(t *testing.T) {
sqlFromFile, err := readPreparationSQLFromFile("testdata/mssql_fiddle_data.sql")
assert.NoError(t, err)
@ -126,4 +121,4 @@ func TestBuilder_Limit4Oracle(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, "SELECT a,b,c FROM (SELECT a,b,c,ROWNUM RN FROM ((SELECT a,b,c FROM (SELECT * FROM (SELECT a,b,c,ROWNUM RN FROM table1 WHERE a<>'0' ORDER BY a ASC) at WHERE at.RN<=15) att WHERE att.RN>10) UNION ALL (SELECT a,b,c FROM (SELECT a,b,c,ROWNUM RN FROM table1 WHERE b<>'48' ORDER BY a DESC) at WHERE at.RN<=10)) at) at WHERE at.RN<=3", sql)
assert.NoError(t, f.executableCheck(sql))
}
}*/

View File

@ -15,7 +15,7 @@ func Select(cols ...string) *Builder {
}
func (b *Builder) selectWriteTo(w Writer) error {
if len(b.tableName) <= 0 && !b.isNested {
if len(b.from) <= 0 && !b.isNested {
return ErrNoTableName
}
@ -46,11 +46,11 @@ func (b *Builder) selectWriteTo(w Writer) error {
}
if b.subQuery == nil {
if _, err := fmt.Fprint(w, " FROM ", b.tableName); err != nil {
if _, err := fmt.Fprint(w, " FROM ", b.from); err != nil {
return err
}
} else {
if b.cond.IsValid() && len(b.tableName) <= 0 {
if b.cond.IsValid() && len(b.from) <= 0 {
return ErrUnnamedDerivedTable
}
if b.subQuery.dialect != "" && b.dialect != b.subQuery.dialect {
@ -69,10 +69,10 @@ func (b *Builder) selectWriteTo(w Writer) error {
return err
}
if len(b.tableName) == 0 {
if len(b.from) == 0 {
fmt.Fprintf(w, ")")
} else {
fmt.Fprintf(w, ") %v", b.tableName)
fmt.Fprintf(w, ") %v", b.from)
}
default:
return ErrUnexpectedSubQuery

View File

@ -15,6 +15,7 @@ func TestBuilder_Select(t *testing.T) {
sql, args, err := Select("c, d").From("table1").ToSQL()
assert.NoError(t, err)
assert.EqualValues(t, "SELECT c, d FROM table1", sql)
assert.EqualValues(t, []interface{}(nil), args)
sql, args, err = Select("c, d").From("table1").Where(Eq{"a": 1}).ToSQL()
assert.NoError(t, err)
@ -104,24 +105,24 @@ func TestBuilder_From(t *testing.T) {
assert.EqualValues(t, []interface{}{1, 2, 1}, args)
// from union without alias
sql, args, err = Select("sub.id").From(
_, _, err = Select("sub.id").From(
Select("id").From("table1").Where(Eq{"a": 1}).Union(
"all", Select("id").From("table1").Where(Eq{"a": 2}))).Where(Eq{"b": 1}).ToSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrUnnamedDerivedTable, err)
// will raise error
sql, args, err = Select("c").From(Insert(Eq{"a": 1}).From("table1"), "table1").ToSQL()
_, _, err = Select("c").From(Insert(Eq{"a": 1}).From("table1"), "table1").ToSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrUnexpectedSubQuery, err)
// will raise error
sql, args, err = Select("c").From(Delete(Eq{"a": 1}).From("table1"), "table1").ToSQL()
_, _, err = Select("c").From(Delete(Eq{"a": 1}).From("table1"), "table1").ToSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrUnexpectedSubQuery, err)
// from a sub-query in different dialect
sql, args, err = MySQL().Select("sub.id").From(
_, _, err = MySQL().Select("sub.id").From(
Oracle().Select("id").From("table1").Where(Eq{"a": 1}), "sub").Where(Eq{"b": 1}).ToSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrInconsistentDialect, err)

View File

@ -596,24 +596,6 @@ func TestBuilderCond(t *testing.T) {
}
}
func TestBuilderInsert(t *testing.T) {
sql, err := Insert(Eq{"c": 1, "d": 2}).Into("table1").ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,2)", sql)
sql, err = Insert(Eq{"c": 1, "d": Expr("SELECT b FROM t WHERE d=? LIMIT 1", 2)}).Into("table1").ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,(SELECT b FROM t WHERE d=2 LIMIT 1))", sql)
sql, err = Insert(Eq{"c": 1, "d": 2}).ToBoundSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrNoTableName, err)
sql, err = Insert(Eq{}).Into("table1").ToBoundSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrNoColumnToInsert, err)
}
func TestSubquery(t *testing.T) {
subb := Select("id").From("table_b").Where(Eq{"b": "a"})
b := Select("a, b").From("table_a").Where(

View File

@ -54,7 +54,7 @@ func TestBuilder_Union(t *testing.T) {
// will be overwrote by SELECT op
sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}).
Union("all", Select("*").From("t2").Where(Eq{"status": "2"})).
Select("*").From("t2").Where(Eq{"status": "3"}).ToSQL()
Select("*").From("t2").ToSQL()
assert.NoError(t, err)
fmt.Println(sql, args)
@ -68,7 +68,7 @@ func TestBuilder_Union(t *testing.T) {
// will be overwrote by INSERT op
sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}).
Union("all", Select("*").From("t2").Where(Eq{"status": "2"})).
Insert(Eq{"status": "1"}).From("t2").ToSQL()
Insert(Eq{"status": "1"}).Into("t2").ToSQL()
assert.NoError(t, err)
fmt.Println(sql, args)
}

View File

@ -15,14 +15,14 @@ func Update(updates ...Eq) *Builder {
}
func (b *Builder) updateWriteTo(w Writer) error {
if len(b.tableName) <= 0 {
if len(b.from) <= 0 {
return ErrNoTableName
}
if len(b.updates) <= 0 {
return ErrNoColumnToUpdate
}
if _, err := fmt.Fprintf(w, "UPDATE %s SET ", b.tableName); err != nil {
if _, err := fmt.Fprintf(w, "UPDATE %s SET ", b.from); err != nil {
return err
}

View File

@ -147,6 +147,7 @@ func TestReadPreparationSQLFromFile(t *testing.T) {
fmt.Println(sqlFromFile)
}
/*
func TestNewFiddler(t *testing.T) {
sqlFromFile, err := readPreparationSQLFromFile("testdata/mysql_fiddle_data.sql")
assert.NoError(t, err)
@ -166,7 +167,7 @@ func TestExecutableCheck(t *testing.T) {
err = f.executableCheck("SELECT * FROM table3")
assert.Error(t, err)
}
}*/
func TestToSQLInDifferentDialects(t *testing.T) {
sql, args, err := Postgres().Select().From("table1").Where(Eq{"a": "1"}.And(Neq{"b": "100"})).ToSQL()