builder/sql_test.go
2020-12-22 20:59:54 +08:00

258 lines
8.1 KiB
Go

// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
import (
sql2 "database/sql"
"fmt"
"io/ioutil"
"os"
"testing"
"gitea.com/xorm/sqlfiddle"
"github.com/stretchr/testify/assert"
)
const placeholderConverterSQL = "SELECT a, b FROM table_a WHERE b_id=(SELECT id FROM table_b WHERE b=?) AND id=? AND c=? AND d=? AND e=? AND f=?"
const placeholderConvertedSQL = "SELECT a, b FROM table_a WHERE b_id=(SELECT id FROM table_b WHERE b=$1) AND id=$2 AND c=$3 AND d=$4 AND e=$5 AND f=$6"
const placeholderBoundSQL = "SELECT a, b FROM table_a WHERE b_id=(SELECT id FROM table_b WHERE b=1) AND id=2.1 AND c='3' AND d=4 AND e='5' AND f=true"
func TestNoSQLQuoteNeeded(t *testing.T) {
assert.False(t, noSQLQuoteNeeded(nil))
}
func TestPlaceholderConverter(t *testing.T) {
var convertCases = []struct {
before, after string
mark string
}{
{
before: placeholderConverterSQL,
after: placeholderConvertedSQL,
mark: "$",
},
{
before: "SELECT a, b, 'a?b' FROM table_a WHERE id=?",
after: "SELECT a, b, 'a?b' FROM table_a WHERE id=:1",
mark: ":",
},
{
before: "SELECT a, b, 'a\\'?b' FROM table_a WHERE id=?",
after: "SELECT a, b, 'a\\'?b' FROM table_a WHERE id=$1",
mark: "$",
},
{
before: "SELECT a, b, 'a\\'b' FROM table_a WHERE id=?",
after: "SELECT a, b, 'a\\'b' FROM table_a WHERE id=$1",
mark: "$",
},
}
for _, kase := range convertCases {
t.Run(kase.before, func(t *testing.T) {
newSQL, err := ConvertPlaceholder(kase.before, kase.mark)
assert.NoError(t, err)
assert.EqualValues(t, kase.after, newSQL)
})
}
}
func BenchmarkPlaceholderConverter(b *testing.B) {
for i := 0; i < b.N; i++ {
ConvertPlaceholder(placeholderConverterSQL, "$")
}
}
func TestBoundSQLConverter(t *testing.T) {
newSQL, err := ConvertToBoundSQL(placeholderConverterSQL, []interface{}{1, 2.1, "3", uint(4), "5", true})
assert.NoError(t, err)
assert.EqualValues(t, placeholderBoundSQL, newSQL)
newSQL, err = ConvertToBoundSQL(placeholderConverterSQL, []interface{}{1, 2.1, sql2.Named("any", "3"), uint(4), "5", true})
assert.NoError(t, err)
assert.EqualValues(t, placeholderBoundSQL, newSQL)
newSQL, err = ConvertToBoundSQL(placeholderConverterSQL, []interface{}{1, 2.1, "3", 4, "5"})
assert.Error(t, err)
assert.EqualValues(t, ErrNeedMoreArguments, err)
newSQL, err = ToBoundSQL(Select("id").From("table").Where(In("a", 1, 2)))
assert.NoError(t, err)
assert.EqualValues(t, "SELECT id FROM table WHERE a IN (1,2)", newSQL)
newSQL, err = ToBoundSQL(Eq{"a": 1})
assert.NoError(t, err)
assert.EqualValues(t, "a=1", newSQL)
newSQL, err = ToBoundSQL(1)
assert.Error(t, err)
assert.EqualValues(t, ErrNotSupportType, err)
}
func TestSQL(t *testing.T) {
newSQL, args, err := ToSQL(In("a", 1, 2))
assert.NoError(t, err)
assert.EqualValues(t, "a IN (?,?)", newSQL)
assert.EqualValues(t, []interface{}{1, 2}, args)
newSQL, args, err = ToSQL(Select("id").From("table").Where(In("a", 1, 2)))
assert.NoError(t, err)
assert.EqualValues(t, "SELECT id FROM table WHERE a IN (?,?)", newSQL)
assert.EqualValues(t, []interface{}{1, 2}, args)
newSQL, args, err = ToSQL(1)
assert.Error(t, err)
assert.EqualValues(t, ErrNotSupportType, err)
}
type fiddler struct {
sessionCode string
dbType int
f *sqlfiddle.Fiddle
}
func readPreparationSQLFromFile(path string) (string, error) {
file, err := os.Open(path)
defer file.Close()
if err != nil {
return "", err
}
data, err := ioutil.ReadAll(file)
if err != nil {
return "", err
}
return string(data), nil
}
func newFiddler(fiddleServerAddr, dbDialect, preparationSQL string) (*fiddler, error) {
var dbType int
switch dbDialect {
case MYSQL:
dbType = sqlfiddle.Mysql5_6
case MSSQL:
dbType = sqlfiddle.MSSQL2017
case POSTGRES:
dbType = sqlfiddle.PostgreSQL96
case ORACLE:
dbType = sqlfiddle.Oracle11gR2
case SQLITE:
dbType = sqlfiddle.SQLite_WebSQL
default:
return nil, ErrNotSupportDialectType
}
f := sqlfiddle.NewFiddle(fiddleServerAddr)
response, err := f.CreateSchema(dbType, preparationSQL)
if err != nil {
return nil, err
}
return &fiddler{sessionCode: response.Code, f: f, dbType: dbType}, nil
}
func (f *fiddler) executableCheck(obj interface{}) error {
var sql string
var err error
switch obj.(type) {
case *Builder:
sql, err = obj.(*Builder).ToBoundSQL()
if err != nil {
return err
}
case string:
sql = obj.(string)
}
_, err = f.f.RunSQL(f.dbType, f.sessionCode, sql)
if err != nil {
return err
}
return nil
}
func TestReadPreparationSQLFromFile(t *testing.T) {
sqlFromFile, err := readPreparationSQLFromFile("testdata/mysql_fiddle_data.sql")
assert.NoError(t, err)
fmt.Println(sqlFromFile)
}
func TestNewFiddler(t *testing.T) {
sqlFromFile, err := readPreparationSQLFromFile("testdata/mysql_fiddle_data.sql")
assert.NoError(t, err)
f, err := newFiddler("", MYSQL, sqlFromFile)
assert.NoError(t, err)
assert.NotEmpty(t, f.sessionCode)
}
func TestExecutableCheck(t *testing.T) {
sqlFromFile, err := readPreparationSQLFromFile("testdata/mysql_fiddle_data.sql")
assert.NoError(t, err)
f, err := newFiddler("", MYSQL, sqlFromFile)
assert.NoError(t, err)
assert.NotEmpty(t, f.sessionCode)
assert.NoError(t, f.executableCheck("SELECT * FROM table1"))
err = f.executableCheck("SELECT * FROM table3")
assert.Error(t, err)
}
func TestToSQLInDifferentDialects(t *testing.T) {
sql, args, err := Postgres().Select().From("table1").Where(Eq{"a": "1"}.And(Neq{"b": "100"})).ToSQL()
assert.NoError(t, err)
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=$1 AND b<>$2", sql)
assert.EqualValues(t, []interface{}{"1", "100"}, args)
sql, args, err = MySQL().Select().From("table1").Where(Eq{"a": "1"}.And(Neq{"b": "100"})).ToSQL()
assert.NoError(t, err)
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=? AND b<>?", sql)
assert.EqualValues(t, []interface{}{"1", "100"}, args)
sql, args, err = MsSQL().Select().From("table1").Where(Eq{"a": "1"}.And(Neq{"b": "100"})).ToSQL()
assert.NoError(t, err)
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=@p1 AND b<>@p2", sql)
assert.EqualValues(t, []interface{}{sql2.Named("p1", "1"), sql2.Named("p2", "100")}, args)
// test sql.NamedArg in cond
sql, args, err = MsSQL().Select().From("table1").Where(Eq{"a": sql2.NamedArg{Name: "param", Value: "1"}}.And(Neq{"b": "100"})).ToSQL()
assert.NoError(t, err)
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=@p1 AND b<>@p2", sql)
assert.EqualValues(t, []interface{}{sql2.Named("p1", "1"), sql2.Named("p2", "100")}, args)
sql, args, err = Oracle().Select().From("table1").Where(Eq{"a": "1"}.And(Neq{"b": "100"})).ToSQL()
assert.NoError(t, err)
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=:p1 AND b<>:p2", sql)
assert.EqualValues(t, []interface{}{sql2.Named("p1", "1"), sql2.Named("p2", "100")}, args)
// test sql.NamedArg in cond
sql, args, err = Oracle().Select().From("table1").Where(Eq{"a": sql2.Named("a", "1")}.And(Neq{"b": "100"})).ToSQL()
assert.NoError(t, err)
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=:p1 AND b<>:p2", sql)
assert.EqualValues(t, []interface{}{sql2.Named("p1", "1"), sql2.Named("p2", "100")}, args)
sql, args, err = SQLite().Select().From("table1").Where(Eq{"a": "1"}.And(Neq{"b": "100"})).ToSQL()
assert.NoError(t, err)
assert.EqualValues(t, "SELECT * FROM table1 WHERE a=? AND b<>?", sql)
assert.EqualValues(t, []interface{}{"1", "100"}, args)
}
func TestToSQLInjectionHarmlessDisposal(t *testing.T) {
sql, err := MySQL().Select("*").From("table1").Where(Cond(Eq{"name": "cat';truncate table table1;"})).ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "SELECT * FROM table1 WHERE name='cat'';truncate table table1;'", sql)
sql, err = MySQL().Update(Eq{`a`: 1, `b`: nil}).From(`table1`).ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "UPDATE table1 SET a=1,b=null", sql)
sql, args, err := MySQL().Update(Eq{`a`: 1, `b`: nil}).From(`table1`).ToSQL()
assert.NoError(t, err)
assert.EqualValues(t, "UPDATE table1 SET a=?,b=null", sql)
assert.EqualValues(t, []interface{}{1}, args)
}