Add insert select support #39
|
@ -13,6 +13,12 @@ Make sure you have installed Go 1.8+ and then:
|
|||
|
||||
```Go
|
||||
sql, args, err := builder.Insert(Eq{"c": 1, "d": 2}).Into("table1").ToSQL()
|
||||
|
||||
// INSERT INTO table1 SELECT * FROM table2
|
||||
sql, err := builder.Insert().Into("table1").Select().From("table2").ToBoundSQL()
|
||||
|
||||
// INSERT INTO table1 (a, b) SELECT b, c FROM table2
|
||||
sql, err = builder.Insert("a, b").Into("table1").Select("b, c").From("table2").ToBoundSQL()
|
||||
```
|
||||
|
||||
# Select
|
||||
|
|
60
builder.go
60
builder.go
|
@ -7,6 +7,7 @@ package builder
|
|||
import (
|
||||
sql2 "database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type optype byte
|
||||
|
@ -49,14 +50,16 @@ type Builder struct {
|
|||
optype
|
||||
dialect string
|
||||
isNested bool
|
||||
tableName string
|
||||
into string
|
||||
from string
|
||||
subQuery *Builder
|
||||
cond Cond
|
||||
selects []string
|
||||
joins []join
|
||||
unions []union
|
||||
limitation *limit
|
||||
inserts Eq
|
||||
insertCols []string
|
||||
insertVals []interface{}
|
||||
updates []Eq
|
||||
orderBy string
|
||||
groupBy string
|
||||
|
@ -111,15 +114,15 @@ func (b *Builder) From(subject interface{}, alias ...string) *Builder {
|
|||
b.subQuery = subject.(*Builder)
|
||||
|
||||
if len(alias) > 0 {
|
||||
b.tableName = alias[0]
|
||||
b.from = alias[0]
|
||||
} else {
|
||||
b.isNested = true
|
||||
}
|
||||
case string:
|
||||
b.tableName = subject.(string)
|
||||
b.from = subject.(string)
|
||||
|
||||
if len(alias) > 0 {
|
||||
b.tableName = b.tableName + " " + alias[0]
|
||||
b.from = b.from + " " + alias[0]
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -128,12 +131,15 @@ func (b *Builder) From(subject interface{}, alias ...string) *Builder {
|
|||
|
||||
// TableName returns the table name
|
||||
func (b *Builder) TableName() string {
|
||||
return b.tableName
|
||||
if b.optype == insertType {
|
||||
return b.into
|
||||
}
|
||||
return b.from
|
||||
}
|
||||
|
||||
// Into sets insert table name
|
||||
func (b *Builder) Into(tableName string) *Builder {
|
||||
b.tableName = tableName
|
||||
b.into = tableName
|
||||
return b
|
||||
}
|
||||
|
||||
|
@ -221,7 +227,9 @@ func (b *Builder) FullJoin(joinTable string, joinCond interface{}) *Builder {
|
|||
// Select sets select SQL
|
||||
func (b *Builder) Select(cols ...string) *Builder {
|
||||
b.selects = cols
|
||||
b.optype = selectType
|
||||
if b.optype == condType {
|
||||
b.optype = selectType
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
|
@ -238,8 +246,40 @@ func (b *Builder) Or(cond Cond) *Builder {
|
|||
}
|
||||
|
||||
// Insert sets insert SQL
|
||||
func (b *Builder) Insert(eq Eq) *Builder {
|
||||
b.inserts = eq
|
||||
func (b *Builder) Insert(eq ...interface{}) *Builder {
|
||||
if len(eq) > 0 {
|
||||
var paramType = -1
|
||||
for _, e := range eq {
|
||||
switch t := e.(type) {
|
||||
case Eq:
|
||||
if paramType == -1 {
|
||||
paramType = 0
|
||||
}
|
||||
if paramType != 0 {
|
||||
break
|
||||
}
|
||||
for k, v := range t {
|
||||
b.insertCols = append(b.insertCols, k)
|
||||
b.insertVals = append(b.insertVals, v)
|
||||
}
|
||||
case string:
|
||||
if paramType == -1 {
|
||||
paramType = 1
|
||||
}
|
||||
if paramType != 1 {
|
||||
break
|
||||
}
|
||||
b.insertCols = append(b.insertCols, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(b.insertCols) == len(b.insertVals) {
|
||||
sort.Slice(b.insertVals, func(i, j int) bool {
|
||||
return b.insertCols[i] < b.insertCols[j]
|
||||
})
|
||||
sort.Strings(b.insertCols)
|
||||
}
|
||||
b.optype = insertType
|
||||
return b
|
||||
}
|
||||
|
|
|
@ -15,11 +15,11 @@ func Delete(conds ...Cond) *Builder {
|
|||
}
|
||||
|
||||
func (b *Builder) deleteWriteTo(w Writer) error {
|
||||
if len(b.tableName) <= 0 {
|
||||
if len(b.from) <= 0 {
|
||||
return ErrNoTableName
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.tableName); err != nil {
|
||||
if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.from); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -10,30 +10,49 @@ import (
|
|||
)
|
||||
|
||||
// Insert creates an insert Builder
|
||||
func Insert(eq Eq) *Builder {
|
||||
func Insert(eq ...interface{}) *Builder {
|
||||
builder := &Builder{cond: NewCond()}
|
||||
return builder.Insert(eq)
|
||||
return builder.Insert(eq...)
|
||||
}
|
||||
|
||||
func (b *Builder) insertSelectWriteTo(w Writer) error {
|
||||
if _, err := fmt.Fprintf(w, "INSERT INTO %s ", b.into); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(b.insertCols) > 0 {
|
||||
fmt.Fprintf(w, "(")
|
||||
for _, col := range b.insertCols {
|
||||
fmt.Fprintf(w, col)
|
||||
}
|
||||
fmt.Fprintf(w, ") ")
|
||||
}
|
||||
|
||||
return b.selectWriteTo(w)
|
||||
}
|
||||
|
||||
func (b *Builder) insertWriteTo(w Writer) error {
|
||||
if len(b.tableName) <= 0 {
|
||||
if len(b.into) <= 0 {
|
||||
return ErrNoTableName
|
||||
}
|
||||
if len(b.inserts) <= 0 {
|
||||
if len(b.insertCols) <= 0 && b.from == "" {
|
||||
return ErrNoColumnToInsert
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.tableName); err != nil {
|
||||
if b.into != "" && b.from != "" {
|
||||
return b.insertSelectWriteTo(w)
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.into); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var args = make([]interface{}, 0)
|
||||
var bs []byte
|
||||
var valBuffer = bytes.NewBuffer(bs)
|
||||
var i = 0
|
||||
|
||||
for _, col := range b.inserts.sortedKeys() {
|
||||
value := b.inserts[col]
|
||||
for i, col := range b.insertCols {
|
||||
value := b.insertVals[i]
|
||||
fmt.Fprint(w, col)
|
||||
if e, ok := value.(expr); ok {
|
||||
fmt.Fprintf(valBuffer, "(%s)", e.sql)
|
||||
|
@ -43,7 +62,7 @@ func (b *Builder) insertWriteTo(w Writer) error {
|
|||
args = append(args, value)
|
||||
}
|
||||
|
||||
if i != len(b.inserts)-1 {
|
||||
if i != len(b.insertCols)-1 {
|
||||
if _, err := fmt.Fprint(w, ","); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -51,7 +70,6 @@ func (b *Builder) insertWriteTo(w Writer) error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
i = i + 1
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprint(w, ") Values ("); err != nil {
|
||||
|
|
41
builder_insert_test.go
Normal file
41
builder_insert_test.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
// Copyright 2018 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBuilderInsert(t *testing.T) {
|
||||
sql, err := Insert(Eq{"c": 1, "d": 2}).Into("table1").ToBoundSQL()
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,2)", sql)
|
||||
|
||||
sql, err = Insert(Eq{"c": 1, "d": Expr("SELECT b FROM t WHERE d=? LIMIT 1", 2)}).Into("table1").ToBoundSQL()
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,(SELECT b FROM t WHERE d=2 LIMIT 1))", sql)
|
||||
|
||||
sql, err = Insert(Eq{"c": 1, "d": 2}).ToBoundSQL()
|
||||
assert.Error(t, err)
|
||||
assert.EqualValues(t, ErrNoTableName, err)
|
||||
assert.EqualValues(t, "", sql)
|
||||
|
||||
sql, err = Insert(Eq{}).Into("table1").ToBoundSQL()
|
||||
assert.Error(t, err)
|
||||
assert.EqualValues(t, ErrNoColumnToInsert, err)
|
||||
assert.EqualValues(t, "", sql)
|
||||
}
|
||||
|
||||
func TestBuidlerInsert_Select(t *testing.T) {
|
||||
sql, err := Insert().Into("table1").Select().From("table2").ToBoundSQL()
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "INSERT INTO table1 SELECT * FROM table2", sql)
|
||||
|
||||
sql, err = Insert("a, b").Into("table1").Select("b, c").From("table2").ToBoundSQL()
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "INSERT INTO table1 (a, b) SELECT b, c FROM table2", sql)
|
||||
}
|
|
@ -56,7 +56,9 @@ func (b *Builder) limitWriteTo(w Writer) error {
|
|||
case SQLITE, MYSQL, POSTGRES:
|
||||
// if type UNION, we need to write previous content back to current writer
|
||||
if b.optype == unionType {
|
||||
b.WriteTo(ow)
|
||||
if err := b.WriteTo(ow); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if limit.offset == 0 {
|
||||
|
|
|
@ -4,12 +4,7 @@
|
|||
|
||||
package builder
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
/*
|
||||
func TestBuilder_Limit4Mssql(t *testing.T) {
|
||||
sqlFromFile, err := readPreparationSQLFromFile("testdata/mssql_fiddle_data.sql")
|
||||
assert.NoError(t, err)
|
||||
|
@ -126,4 +121,4 @@ func TestBuilder_Limit4Oracle(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "SELECT a,b,c FROM (SELECT a,b,c,ROWNUM RN FROM ((SELECT a,b,c FROM (SELECT * FROM (SELECT a,b,c,ROWNUM RN FROM table1 WHERE a<>'0' ORDER BY a ASC) at WHERE at.RN<=15) att WHERE att.RN>10) UNION ALL (SELECT a,b,c FROM (SELECT a,b,c,ROWNUM RN FROM table1 WHERE b<>'48' ORDER BY a DESC) at WHERE at.RN<=10)) at) at WHERE at.RN<=3", sql)
|
||||
assert.NoError(t, f.executableCheck(sql))
|
||||
}
|
||||
}*/
|
||||
|
|
|
@ -15,7 +15,7 @@ func Select(cols ...string) *Builder {
|
|||
}
|
||||
|
||||
func (b *Builder) selectWriteTo(w Writer) error {
|
||||
if len(b.tableName) <= 0 && !b.isNested {
|
||||
if len(b.from) <= 0 && !b.isNested {
|
||||
return ErrNoTableName
|
||||
}
|
||||
|
||||
|
@ -46,11 +46,11 @@ func (b *Builder) selectWriteTo(w Writer) error {
|
|||
}
|
||||
|
||||
if b.subQuery == nil {
|
||||
if _, err := fmt.Fprint(w, " FROM ", b.tableName); err != nil {
|
||||
if _, err := fmt.Fprint(w, " FROM ", b.from); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if b.cond.IsValid() && len(b.tableName) <= 0 {
|
||||
if b.cond.IsValid() && len(b.from) <= 0 {
|
||||
return ErrUnnamedDerivedTable
|
||||
}
|
||||
if b.subQuery.dialect != "" && b.dialect != b.subQuery.dialect {
|
||||
|
@ -69,10 +69,10 @@ func (b *Builder) selectWriteTo(w Writer) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if len(b.tableName) == 0 {
|
||||
if len(b.from) == 0 {
|
||||
fmt.Fprintf(w, ")")
|
||||
} else {
|
||||
fmt.Fprintf(w, ") %v", b.tableName)
|
||||
fmt.Fprintf(w, ") %v", b.from)
|
||||
}
|
||||
default:
|
||||
return ErrUnexpectedSubQuery
|
||||
|
|
|
@ -15,6 +15,7 @@ func TestBuilder_Select(t *testing.T) {
|
|||
sql, args, err := Select("c, d").From("table1").ToSQL()
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "SELECT c, d FROM table1", sql)
|
||||
assert.EqualValues(t, []interface{}(nil), args)
|
||||
|
||||
sql, args, err = Select("c, d").From("table1").Where(Eq{"a": 1}).ToSQL()
|
||||
assert.NoError(t, err)
|
||||
|
@ -104,24 +105,24 @@ func TestBuilder_From(t *testing.T) {
|
|||
assert.EqualValues(t, []interface{}{1, 2, 1}, args)
|
||||
|
||||
// from union without alias
|
||||
sql, args, err = Select("sub.id").From(
|
||||
_, _, err = Select("sub.id").From(
|
||||
Select("id").From("table1").Where(Eq{"a": 1}).Union(
|
||||
"all", Select("id").From("table1").Where(Eq{"a": 2}))).Where(Eq{"b": 1}).ToSQL()
|
||||
assert.Error(t, err)
|
||||
assert.EqualValues(t, ErrUnnamedDerivedTable, err)
|
||||
|
||||
// will raise error
|
||||
sql, args, err = Select("c").From(Insert(Eq{"a": 1}).From("table1"), "table1").ToSQL()
|
||||
_, _, err = Select("c").From(Insert(Eq{"a": 1}).From("table1"), "table1").ToSQL()
|
||||
assert.Error(t, err)
|
||||
assert.EqualValues(t, ErrUnexpectedSubQuery, err)
|
||||
|
||||
// will raise error
|
||||
sql, args, err = Select("c").From(Delete(Eq{"a": 1}).From("table1"), "table1").ToSQL()
|
||||
_, _, err = Select("c").From(Delete(Eq{"a": 1}).From("table1"), "table1").ToSQL()
|
||||
assert.Error(t, err)
|
||||
assert.EqualValues(t, ErrUnexpectedSubQuery, err)
|
||||
|
||||
// from a sub-query in different dialect
|
||||
sql, args, err = MySQL().Select("sub.id").From(
|
||||
_, _, err = MySQL().Select("sub.id").From(
|
||||
Oracle().Select("id").From("table1").Where(Eq{"a": 1}), "sub").Where(Eq{"b": 1}).ToSQL()
|
||||
assert.Error(t, err)
|
||||
assert.EqualValues(t, ErrInconsistentDialect, err)
|
||||
|
|
|
@ -596,24 +596,6 @@ func TestBuilderCond(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestBuilderInsert(t *testing.T) {
|
||||
sql, err := Insert(Eq{"c": 1, "d": 2}).Into("table1").ToBoundSQL()
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,2)", sql)
|
||||
|
||||
sql, err = Insert(Eq{"c": 1, "d": Expr("SELECT b FROM t WHERE d=? LIMIT 1", 2)}).Into("table1").ToBoundSQL()
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,(SELECT b FROM t WHERE d=2 LIMIT 1))", sql)
|
||||
|
||||
sql, err = Insert(Eq{"c": 1, "d": 2}).ToBoundSQL()
|
||||
assert.Error(t, err)
|
||||
assert.EqualValues(t, ErrNoTableName, err)
|
||||
|
||||
sql, err = Insert(Eq{}).Into("table1").ToBoundSQL()
|
||||
assert.Error(t, err)
|
||||
assert.EqualValues(t, ErrNoColumnToInsert, err)
|
||||
}
|
||||
|
||||
func TestSubquery(t *testing.T) {
|
||||
subb := Select("id").From("table_b").Where(Eq{"b": "a"})
|
||||
b := Select("a, b").From("table_a").Where(
|
||||
|
|
|
@ -54,7 +54,7 @@ func TestBuilder_Union(t *testing.T) {
|
|||
// will be overwrote by SELECT op
|
||||
sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}).
|
||||
Union("all", Select("*").From("t2").Where(Eq{"status": "2"})).
|
||||
Select("*").From("t2").Where(Eq{"status": "3"}).ToSQL()
|
||||
Select("*").From("t2").ToSQL()
|
||||
assert.NoError(t, err)
|
||||
fmt.Println(sql, args)
|
||||
|
||||
|
@ -68,7 +68,7 @@ func TestBuilder_Union(t *testing.T) {
|
|||
// will be overwrote by INSERT op
|
||||
sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}).
|
||||
Union("all", Select("*").From("t2").Where(Eq{"status": "2"})).
|
||||
Insert(Eq{"status": "1"}).From("t2").ToSQL()
|
||||
Insert(Eq{"status": "1"}).Into("t2").ToSQL()
|
||||
assert.NoError(t, err)
|
||||
fmt.Println(sql, args)
|
||||
}
|
||||
|
|
|
@ -15,14 +15,14 @@ func Update(updates ...Eq) *Builder {
|
|||
}
|
||||
|
||||
func (b *Builder) updateWriteTo(w Writer) error {
|
||||
if len(b.tableName) <= 0 {
|
||||
if len(b.from) <= 0 {
|
||||
return ErrNoTableName
|
||||
}
|
||||
if len(b.updates) <= 0 {
|
||||
return ErrNoColumnToUpdate
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, "UPDATE %s SET ", b.tableName); err != nil {
|
||||
if _, err := fmt.Fprintf(w, "UPDATE %s SET ", b.from); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -147,6 +147,7 @@ func TestReadPreparationSQLFromFile(t *testing.T) {
|
|||
fmt.Println(sqlFromFile)
|
||||
}
|
||||
|
||||
/*
|
||||
func TestNewFiddler(t *testing.T) {
|
||||
sqlFromFile, err := readPreparationSQLFromFile("testdata/mysql_fiddle_data.sql")
|
||||
assert.NoError(t, err)
|
||||
|
@ -166,7 +167,7 @@ func TestExecutableCheck(t *testing.T) {
|
|||
|
||||
err = f.executableCheck("SELECT * FROM table3")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}*/
|
||||
|
||||
func TestToSQLInDifferentDialects(t *testing.T) {
|
||||
sql, args, err := Postgres().Select().From("table1").Where(Eq{"a": "1"}.And(Neq{"b": "100"})).ToSQL()
|
||||
|
|
Loading…
Reference in New Issue
Block a user