Improve quote policy #1567

Merged
lunny merged 8 commits from lunny/quote_policy2 into master 2020-03-06 07:48:36 +00:00
18 changed files with 490 additions and 256 deletions

View File

@ -22,9 +22,10 @@ steps:
commands:
- make test-sqlite
- TEST_CACHE_ENABLE=true make test-sqlite
- TEST_QUOTE_POLICY=reserved make test-sqlite
- go test ./caches/... ./contexts/... ./convert/... ./core/... ./dialects/... \
./log/... ./migrate/... ./names/... ./schemas/... ./tags/... \
./internal/json/... ./internal/statements/... ./internal/utils/... \
./log/... ./migrate/... ./names/... ./schemas/... ./tags/...
when:
event:
@ -44,6 +45,7 @@ steps:
commands:
- make test-mysql
- TEST_CACHE_ENABLE=true make test-mysql
- TEST_QUOTE_POLICY=reserved make test-mysql
when:
event:
- push
@ -62,6 +64,7 @@ steps:
commands:
- make test-mysql
- TEST_CACHE_ENABLE=true make test-mysql
- TEST_QUOTE_POLICY=reserved make test-mysql
when:
event:
- push
@ -82,6 +85,7 @@ steps:
commands:
- make test-mysql
- TEST_CACHE_ENABLE=true make test-mysql
- TEST_QUOTE_POLICY=reserved make test-mysql
when:
event:
- push
@ -102,6 +106,7 @@ steps:
commands:
- make test-mymysql
- TEST_CACHE_ENABLE=true make test-mymysql
- TEST_QUOTE_POLICY=reserved make test-mymysql
when:
event:
- push
@ -120,6 +125,7 @@ steps:
commands:
- make test-postgres
- TEST_CACHE_ENABLE=true make test-postgres
- TEST_QUOTE_POLICY=reserved make test-postgres
when:
event:
- push
@ -141,6 +147,7 @@ steps:
commands:
- make test-postgres
- TEST_CACHE_ENABLE=true make test-postgres
- TEST_QUOTE_POLICY=reserved make test-postgres
when:
event:
- push
@ -159,6 +166,7 @@ steps:
commands:
- make test-mssql
- TEST_CACHE_ENABLE=true make test-mssql
- TEST_QUOTE_POLICY=reserved make test-mssql
when:
event:
- push
@ -177,6 +185,7 @@ steps:
commands:
- make test-tidb
- TEST_CACHE_ENABLE=true make test-tidb
- TEST_QUOTE_POLICY=reserved make test-tidb
when:
event:
- push

View File

@ -39,6 +39,7 @@ TEST_TIDB_USERNAME ?= root
TEST_TIDB_PASSWORD ?=
TEST_CACHE_ENABLE ?= false
TEST_QUOTE_POLICY ?= always
.PHONY: all
all: build
@ -135,73 +136,73 @@ test-cockroach\#%: go-check
.PNONY: test-mssql
test-mssql: go-check
$(GO) test -v -race -db=mssql -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -db=mssql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \
-conn_str="server=$(TEST_MSSQL_HOST);user id=$(TEST_MSSQL_USERNAME);password=$(TEST_MSSQL_PASSWORD);database=$(TEST_MSSQL_DBNAME)" \
-coverprofile=mssql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PNONY: test-mssql\#%
test-mssql\#%: go-check
$(GO) test -v -race -run $* -db=mssql -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -run $* -db=mssql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \
-conn_str="server=$(TEST_MSSQL_HOST);user id=$(TEST_MSSQL_USERNAME);password=$(TEST_MSSQL_PASSWORD);database=$(TEST_MSSQL_DBNAME)" \
-coverprofile=mssql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PNONY: test-mymysql
test-mymysql: go-check
$(GO) test -v -race -db=mymysql -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \
-conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \
-coverprofile=mymysql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PNONY: test-mymysql\#%
test-mymysql\#%: go-check
$(GO) test -v -race -run $* -db=mymysql -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -run $* -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \
-conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \
-coverprofile=mymysql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PNONY: test-mysql
test-mysql: go-check
$(GO) test -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \
-conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \
-coverprofile=mysql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-mysql\#%
test-mysql\#%: go-check
$(GO) test -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \
-conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \
-coverprofile=mysql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PNONY: test-postgres
test-postgres: go-check
$(GO) test -v -race -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \
-coverprofile=postgres.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-postgres\#%
test-postgres\#%: go-check
$(GO) test -v -race -run $* -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \
-coverprofile=postgres.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-sqlite
test-sqlite: go-check
$(GO) test -v -race -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \
-coverprofile=sqlite.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-sqlite\#%
test-sqlite\#%: go-check
$(GO) test -v -race -run $* -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \
-coverprofile=sqlite.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PNONY: test-tidb
test-tidb: go-check
$(GO) test -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \
-conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \
-coverprofile=tidb.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-tidb\#%
test-tidb\#%: go-check
$(GO) test -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \
-conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \
-coverprofile=tidb.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: vet
vet:

View File

@ -41,6 +41,7 @@ type Dialect interface {
IsReserved(string) bool
Quoter() schemas.Quoter
SetQuotePolicy(quotePolicy QuotePolicy)
AutoIncrStr() string
@ -79,6 +80,11 @@ type Base struct {
db *core.DB
dialect Dialect
uri *URI
quoter schemas.Quoter
}
func (b *Base) Quoter() schemas.Quoter {
return b.quoter
}
func (b *Base) DB() *core.DB {
@ -210,7 +216,7 @@ func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
idxName = index.XName(tableName)
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
quoter.Quote(idxName), quoter.Quote(tableName),
quoter.Quote(strings.Join(index.Cols, quoter.ReverseQuote(","))))
quoter.Join(index.Cols, ","))
}
func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
@ -258,7 +264,7 @@ func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, char
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(",")))
sql += quoter.Join(pkList, ",")
sql += " ), "
}

View File

@ -26,20 +26,36 @@ func (s *QuoteFilter) Do(sql string) string {
return sql
}
prefix, suffix := s.quoter[0][0], s.quoter[1][0]
raw := []byte(sql)
for i, cnt := 0, 0; i < len(raw); i = i + 1 {
if raw[i] == '`' {
if cnt%2 == 0 {
raw[i] = prefix
} else {
raw[i] = suffix
var buf strings.Builder
buf.Grow(len(sql))
var beginSingleQuote bool
for i := 0; i < len(sql); i++ {
if !beginSingleQuote && sql[i] == '`' {
var j = i + 1
for ; j < len(sql); j++ {
if sql[j] == '`' {
break
}
}
cnt++
word := sql[i+1 : j]
isReserved := s.quoter.IsReserved(word)
if isReserved {
buf.WriteByte(s.quoter.Prefix)
}
buf.WriteString(word)
if isReserved {
buf.WriteByte(s.quoter.Suffix)
}
i = j
} else {
if sql[i] == '\'' {
beginSingleQuote = !beginSingleQuote
}
buf.WriteByte(sql[i])
}
}
return string(raw)
return buf.String()
}
// SeqFilter filter SQL replace ?, ? ... to $1, $2 ...

View File

@ -9,13 +9,30 @@ import (
)
func TestQuoteFilter_Do(t *testing.T) {
f := QuoteFilter{schemas.Quoter{"[", "]"}}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
res := f.Do(sql)
assert.EqualValues(t,
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
res,
)
f := QuoteFilter{schemas.Quoter{'[', ']', schemas.AlwaysReserve}}
var kases = []struct {
source string
expected string
}{
{
"SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?",
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
},
{
"SELECT 'abc```test```''', `a` FROM b",
"SELECT 'abc```test```''', [a] FROM b",
},
{
"UPDATE table SET `a` = ~ `a`, `b`='abc`'",
"UPDATE table SET [a] = ~ [a], [b]='abc`'",
},
}
for _, kase := range kases {
t.Run(kase.source, func(t *testing.T) {
assert.EqualValues(t, kase.expected, f.Do(kase.source))
})
}
}
func TestSeqFilter(t *testing.T) {

View File

@ -204,6 +204,8 @@ var (
"EXIT": true,
"PROC": true,
}
mssqlQuoter = schemas.Quoter{'[', ']', schemas.AlwaysReserve}
)
type mssql struct {
@ -211,6 +213,7 @@ type mssql struct {
}
func (db *mssql) Init(d *core.DB, uri *URI) error {
db.quoter = mssqlQuoter
return db.Base.Init(d, db, uri)
}
@ -283,12 +286,25 @@ func (db *mssql) SupportInsertMany() bool {
}
func (db *mssql) IsReserved(name string) bool {
_, ok := mssqlReservedWords[name]
_, ok := mssqlReservedWords[strings.ToUpper(name)]
return ok
}
func (db *mssql) Quoter() schemas.Quoter {
return schemas.Quoter{"[", "]"}
func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy {
case QuotePolicyNone:
var q = mssqlQuoter
q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q
case QuotePolicyReserved:
var q = mssqlQuoter
q.IsReserved = db.IsReserved
db.quoter = q
case QuotePolicyAlways:
fallthrough
default:
db.quoter = mssqlQuoter
}
}
func (db *mssql) SupportEngine() bool {

View File

@ -161,6 +161,8 @@ var (
"YEAR_MONTH": true,
"ZEROFILL": true,
}
mysqlQuoter = schemas.Quoter{'`', '`', schemas.AlwaysReserve}
)
type mysql struct {
@ -178,6 +180,7 @@ type mysql struct {
}
func (db *mysql) Init(d *core.DB, uri *URI) error {
db.quoter = mysqlQuoter
return db.Base.Init(d, db, uri)
}
@ -272,14 +275,10 @@ func (db *mysql) SupportInsertMany() bool {
}
func (db *mysql) IsReserved(name string) bool {
_, ok := mysqlReservedWords[name]
_, ok := mysqlReservedWords[strings.ToUpper(name)]
return ok
}
func (db *mysql) Quoter() schemas.Quoter {
return schemas.Quoter{"`", "`"}
}
func (db *mysql) SupportEngine() bool {
return true
}
@ -458,6 +457,23 @@ func (db *mysql) GetTables(ctx context.Context) ([]*schemas.Table, error) {
return tables, nil
}
func (db *mysql) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy {
case QuotePolicyNone:
var q = mysqlQuoter
q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q
case QuotePolicyReserved:
var q = mysqlQuoter
q.IsReserved = db.IsReserved
db.quoter = q
case QuotePolicyAlways:
fallthrough
default:
db.quoter = mysqlQuoter
}
}
func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{db.uri.DBName, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
@ -538,7 +554,7 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(",")))
sql += quoter.Join(pkList, ",")
sql += " ), "
}

View File

@ -498,6 +498,8 @@ var (
"YEAR": true,
"ZONE": true,
}
oracleQuoter = schemas.Quoter{'[', ']', schemas.AlwaysReserve}
)
type oracle struct {
@ -505,6 +507,7 @@ type oracle struct {
}
func (db *oracle) Init(d *core.DB, uri *URI) error {
db.quoter = oracleQuoter
return db.Base.Init(d, db, uri)
}
@ -549,14 +552,10 @@ func (db *oracle) SupportInsertMany() bool {
}
func (db *oracle) IsReserved(name string) bool {
_, ok := oracleReservedWords[name]
_, ok := oracleReservedWords[strings.ToUpper(name)]
return ok
}
func (db *oracle) Quoter() schemas.Quoter {
return schemas.Quoter{"\"", "\""}
}
func (db *oracle) SupportEngine() bool {
return false
}
@ -601,7 +600,7 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c
if len(pkList) > 0 {
sql += "PRIMARY KEY ( "
sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(",")))
sql += quoter.Join(pkList, ",")
sql += " ), "
}
@ -620,6 +619,23 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c
return sql
}
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy {
case QuotePolicyNone:
var q = oracleQuoter
q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q
case QuotePolicyReserved:
var q = oracleQuoter
q.IsReserved = db.IsReserved
db.quoter = q
case QuotePolicyAlways:
fallthrough
default:
db.quoter = oracleQuoter
}
}
func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{tableName, idxName}
return `SELECT INDEX_NAME FROM USER_INDEXES ` +

View File

@ -766,6 +766,8 @@ var (
"YES": true,
"ZONE": true,
}
postgresQuoter = schemas.Quoter{'"', '"', schemas.AlwaysReserve}
)
const postgresPublicSchema = "public"
@ -775,6 +777,7 @@ type postgres struct {
}
func (db *postgres) Init(d *core.DB, uri *URI) error {
db.quoter = postgresQuoter
err := db.Base.Init(d, db, uri)
if err != nil {
return err
@ -785,6 +788,35 @@ func (db *postgres) Init(d *core.DB, uri *URI) error {
return nil
}
func (db *postgres) needQuote(name string) bool {
if db.IsReserved(name) {
return true
}
for _, c := range name {
if c >= 'A' && c <= 'Z' {
return true
}
}
return false
}
func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy {
case QuotePolicyNone:
var q = postgresQuoter
q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q
case QuotePolicyReserved:
var q = postgresQuoter
q.IsReserved = db.needQuote
db.quoter = q
case QuotePolicyAlways:
fallthrough
default:
db.quoter = postgresQuoter
}
}
func (db *postgres) DefaultSchema() string {
return postgresPublicSchema
}
@ -857,14 +889,10 @@ func (db *postgres) SupportInsertMany() bool {
}
func (db *postgres) IsReserved(name string) bool {
_, ok := postgresReservedWords[name]
_, ok := postgresReservedWords[strings.ToUpper(name)]
return ok
}
func (db *postgres) Quoter() schemas.Quoter {
return schemas.Quoter{`"`, `"`}
}
func (db *postgres) AutoIncrStr() string {
return ""
}

15
dialects/quote.go Normal file
View File

@ -0,0 +1,15 @@
// Copyright 2020 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dialects
// QuotePolicy describes quote handle policy
type QuotePolicy int
// All QuotePolicies
const (
QuotePolicyAlways QuotePolicy = iota
QuotePolicyNone
QuotePolicyReserved
)

View File

@ -143,6 +143,8 @@ var (
"WITH": true,
"WITHOUT": true,
}
sqlite3Quoter = schemas.Quoter{'`', '`', schemas.AlwaysReserve}
)
type sqlite3 struct {
@ -150,9 +152,27 @@ type sqlite3 struct {
}
func (db *sqlite3) Init(d *core.DB, uri *URI) error {
db.quoter = sqlite3Quoter
return db.Base.Init(d, db, uri)
}
func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy {
case QuotePolicyNone:
var q = sqlite3Quoter
q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q
case QuotePolicyReserved:
var q = sqlite3Quoter
q.IsReserved = db.IsReserved
db.quoter = q
case QuotePolicyAlways:
fallthrough
default:
db.quoter = sqlite3Quoter
}
}
func (db *sqlite3) SQLType(c *schemas.Column) string {
switch t := c.SQLType.Name; t {
case schemas.Bool:
@ -196,14 +216,10 @@ func (db *sqlite3) SupportInsertMany() bool {
}
func (db *sqlite3) IsReserved(name string) bool {
_, ok := sqlite3ReservedWords[name]
_, ok := sqlite3ReservedWords[strings.ToUpper(name)]
return ok
}
func (db *sqlite3) Quoter() schemas.Quoter {
return schemas.Quoter{"`", "`"}
}
func (db *sqlite3) AutoIncrStr() string {
return "AUTOINCREMENT"
}
@ -250,18 +266,24 @@ func (db *sqlite3) ForUpdateSQL(query string) string {
}
func (db *sqlite3) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) {
args := []interface{}{tableName}
query := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))"
rows, err := db.DB().QueryContext(ctx, query, args...)
query := "SELECT * FROM " + tableName + " LIMIT 0"
rows, err := db.DB().QueryContext(ctx, query)
if err != nil {
return false, err
}
defer rows.Close()
if rows.Next() {
return true, nil
cols, err := rows.Columns()
if err != nil {
return false, err
}
for _, col := range cols {
if strings.EqualFold(col, colName) {
return true, nil
}
}
return false, nil
}

View File

@ -54,6 +54,10 @@ func (engine *Engine) GetCacher(tableName string) caches.Cacher {
return engine.cacherMgr.GetCacher(tableName)
}
func (engine *Engine) SetQuotePolicy(quotePolicy dialects.QuotePolicy) {
engine.dialect.SetQuotePolicy(quotePolicy)
}
// BufferSize sets buffer size for iterate
func (engine *Engine) BufferSize(size int) *Session {
session := engine.NewSession()

View File

@ -9,6 +9,7 @@ import (
"time"
"xorm.io/xorm/caches"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
)
@ -180,6 +181,13 @@ func (eg *EngineGroup) SetPolicy(policy GroupPolicy) *EngineGroup {
return eg
}
func (eg *EngineGroup) SetQuotePolicy(quotePolicy dialects.QuotePolicy) {
eg.Engine.SetQuotePolicy(quotePolicy)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetQuotePolicy(quotePolicy)
}
}
// SetTableMapper set the table name mapping rule
func (eg *EngineGroup) SetTableMapper(mapper names.Mapper) {
eg.Engine.SetTableMapper(mapper)

View File

@ -104,6 +104,7 @@ type EngineInterface interface {
SetMapper(names.Mapper)
SetMaxOpenConns(int)
SetMaxIdleConns(int)
SetQuotePolicy(dialects.QuotePolicy)
SetSchema(string)
SetTableMapper(names.Mapper)
SetTZDatabase(tz *time.Location)

View File

@ -8,14 +8,29 @@ import (
"strings"
)
// Quoter represents two quote characters
type Quoter [2]string
// Quoter represents a quoter to the SQL table name and column name
type Quoter struct {
Prefix byte
Suffix byte
IsReserved func(string) bool
}
// CommonQuoter represetns a common quoter
var CommonQuoter = Quoter{"`", "`"}
var (
// AlwaysFalseReverse always think it's not a reverse word
AlwaysNoReserve = func(string) bool { return false }
// AlwaysReverse always reverse the word
AlwaysReserve = func(string) bool { return true }
// CommanQuoteMark represnets the common quote mark
CommanQuoteMark byte = '`'
// CommonQuoter represetns a common quoter
CommonQuoter = Quoter{CommanQuoteMark, CommanQuoteMark, AlwaysReserve}
)
func (q Quoter) IsEmpty() bool {
return q[0] == "" && q[1] == ""
return q.Prefix == 0 && q.Suffix == 0
}
func (q Quoter) Quote(s string) string {
@ -24,42 +39,6 @@ func (q Quoter) Quote(s string) string {
return buf.String()
}
func (q Quoter) Replace(sql string, newQuoter Quoter) string {
if q.IsEmpty() {
return sql
}
if newQuoter.IsEmpty() {
var buf strings.Builder
for i := 0; i < len(sql); i = i + 1 {
if sql[i] != q[0][0] && sql[i] != q[1][0] {
_ = buf.WriteByte(sql[i])
}
}
return buf.String()
}
prefix, suffix := newQuoter[0][0], newQuoter[1][0]
var buf strings.Builder
for i, cnt := 0, 0; i < len(sql); i = i + 1 {
if cnt == 0 && sql[i] == q[0][0] {
_ = buf.WriteByte(prefix)
cnt = 1
} else if cnt == 1 && sql[i] == q[1][0] {
_ = buf.WriteByte(suffix)
cnt = 0
} else {
_ = buf.WriteByte(sql[i])
}
}
return buf.String()
}
func (q Quoter) ReverseQuote(s string) string {
reverseQuoter := Quoter{q[1], q[0]}
return reverseQuoter.Quote(s)
}
// Trim removes quotes from s
func (q Quoter) Trim(s string) string {
if len(s) < 2 {
@ -69,10 +48,10 @@ func (q Quoter) Trim(s string) string {
var buf strings.Builder
for i := 0; i < len(s); i++ {
switch {
case i == 0 && s[i:i+1] == q[0]:
case i == len(s)-1 && s[i:i+1] == q[1]:
case s[i:i+1] == q[1] && s[i+1] == '.':
case s[i:i+1] == q[0] && s[i-1] == '.':
case i == 0 && s[i] == q.Prefix:
case i == len(s)-1 && s[i] == q.Suffix:
case s[i] == q.Suffix && s[i+1] == '.':
case s[i] == q.Prefix && s[i-1] == '.':
default:
buf.WriteByte(s[i])
}
@ -81,31 +60,8 @@ func (q Quoter) Trim(s string) string {
}
func (q Quoter) Join(a []string, sep string) string {
switch len(a) {
case 0:
return ""
case 1:
return a[0]
}
n := len(sep) * (len(a) - 1)
for i := 0; i < len(a); i++ {
n += len(a[i])
}
var b strings.Builder
b.Grow(n)
for i, s := range a {
if i > 0 {
b.WriteString(sep)
}
if q[0] != "" && s != "*" {
b.WriteString(q[0])
}
b.WriteString(strings.TrimSpace(s))
if q[1] != "" && s != "*" {
b.WriteString(q[1])
}
}
q.JoinWrite(&b, a, sep)
return b.String()
}
@ -126,23 +82,113 @@ func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error {
return err
}
}
if q[0] != "" && s != "*" && s[0] != '`' {
if _, err := b.WriteString(q[0]); err != nil {
return err
}
}
if _, err := b.WriteString(strings.TrimSpace(s)); err != nil {
return err
}
if q[1] != "" && s != "*" && s[0] != '`' {
if _, err := b.WriteString(q[1]); err != nil {
return err
}
if s != "*" {
q.QuoteTo(b, strings.TrimSpace(s))
}
}
return nil
}
func findWord(v string, start int) int {
for j := start; j < len(v); j++ {
switch v[j] {
case '.', ' ':
return j
}
}
return len(v)
}
func findStart(value string, start int) int {
if value[start] == '.' {
return start + 1
}
if value[start] != ' ' {
return start
}
var k int
for j := start; j < len(value); j++ {
if value[j] != ' ' {
k = j
break
}
}
if k-1 == len(value) {
return len(value)
}
if (value[k] == 'A' || value[k] == 'a') && (value[k+1] == 'S' || value[k+1] == 's') {
k = k + 2
}
for j := k; j < len(value); j++ {
if value[j] != ' ' {
return j
}
}
return len(value)
}
func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error {
var realWord = word
if (word[0] == CommanQuoteMark && word[len(word)-1] == CommanQuoteMark) ||
(word[0] == q.Prefix && word[len(word)-1] == q.Suffix) {
realWord = word[1 : len(word)-1]
}
if q.IsEmpty() {
_, err := buf.WriteString(realWord)
return err
}
isReserved := q.IsReserved(realWord)
if isReserved {
if err := buf.WriteByte(q.Prefix); err != nil {
return err
}
}
if _, err := buf.WriteString(realWord); err != nil {
return err
}
if isReserved {
return buf.WriteByte(q.Suffix)
}
return nil
}
// QuoteTo quotes the table or column names. i.e. if the quotes are [ and ]
// name -> [name]
// `name` -> [name]
// [name] -> [name]
// schema.name -> [schema].[name]
// `schema`.`name` -> [schema].[name]
// `schema`.name -> [schema].[name]
// schema.`name` -> [schema].[name]
// [schema].name -> [schema].[name]
// schema.[name] -> [schema].[name]
// name AS a -> [name] AS a
// schema.name AS a -> [schema].[name] AS a
func (q Quoter) QuoteTo(buf *strings.Builder, value string) error {
var i int
for i < len(value) {
start := findStart(value, i)
if start > i {
if _, err := buf.WriteString(value[i:start]); err != nil {
return err
}
}
var nextEnd = findWord(value, start)
if err := q.quoteWordTo(buf, value[start:nextEnd]); err != nil {
return err
}
i = nextEnd
}
return nil
}
// Strings quotes a slice of string
func (q Quoter) Strings(s []string) []string {
var res = make([]string, 0, len(s))
for _, a := range s {
@ -150,64 +196,3 @@ func (q Quoter) Strings(s []string) []string {
}
return res
}
func (q Quoter) QuoteTo(buf *strings.Builder, value string) {
if q.IsEmpty() {
buf.WriteString(value)
return
}
prefix, suffix := q[0][0], q[1][0]
lastCh := 0 // 0 prefix, 1 char, 2 suffix
i := 0
for i < len(value) {
// start of a token; might be already quoted
if value[i] == '.' {
_ = buf.WriteByte('.')
lastCh = 1
i++
} else if value[i] == prefix || value[i] == '`' {
// Has quotes; skip/normalize `name` to prefix+name+sufix
var ch byte
if value[i] == prefix {
ch = suffix
} else {
ch = '`'
}
_ = buf.WriteByte(prefix)
i++
lastCh = 0
for ; i < len(value) && value[i] != ch && value[i] != ' '; i++ {
_ = buf.WriteByte(value[i])
lastCh = 1
}
_ = buf.WriteByte(suffix)
lastCh = 2
i++
} else if value[i] == ' ' {
if lastCh != 2 {
_ = buf.WriteByte(suffix)
lastCh = 2
}
// a AS b or a b
for ; i < len(value); i++ {
if value[i] != ' ' && value[i-1] == ' ' && (len(value) > i+1 && !strings.EqualFold(value[i:i+2], "AS")) {
break
}
_ = buf.WriteByte(value[i])
lastCh = 1
}
} else {
// Requires quotes
_ = buf.WriteByte(prefix)
for ; i < len(value) && value[i] != '.' && value[i] != ' '; i++ {
_ = buf.WriteByte(value[i])
lastCh = 1
}
_ = buf.WriteByte(suffix)
lastCh = 2
}
}
}

View File

@ -11,54 +11,125 @@ import (
"github.com/stretchr/testify/assert"
)
func TestQuoteTo(t *testing.T) {
var quoter = Quoter{"[", "]"}
func TestAlwaysQuoteTo(t *testing.T) {
var (
quoter = Quoter{'[', ']', AlwaysReserve}
kases = []struct {
expected string
value string
}{
{"[mytable]", "mytable"},
{"[mytable]", "`mytable`"},
{"[mytable]", `[mytable]`},
{`["mytable"]`, `"mytable"`},
{"[myschema].[mytable]", "myschema.mytable"},
{"[myschema].[mytable]", "`myschema`.mytable"},
{"[myschema].[mytable]", "myschema.`mytable`"},
{"[myschema].[mytable]", "`myschema`.`mytable`"},
{"[myschema].[mytable]", `[myschema].mytable`},
{"[myschema].[mytable]", `myschema.[mytable]`},
{"[myschema].[mytable]", `[myschema].[mytable]`},
{`["myschema].[mytable"]`, `"myschema.mytable"`},
{"[message_user] AS [sender]", "`message_user` AS `sender`"},
{"[myschema].[mytable] AS [table]", "myschema.mytable AS table"},
}
)
test := func(t *testing.T, expected string, value string) {
buf := &strings.Builder{}
quoter.QuoteTo(buf, value)
assert.EqualValues(t, expected, buf.String())
for _, v := range kases {
t.Run(v.value, func(t *testing.T) {
buf := &strings.Builder{}
quoter.QuoteTo(buf, v.value)
assert.EqualValues(t, v.expected, buf.String())
})
}
}
test(t, "[mytable]", "mytable")
test(t, "[mytable]", "`mytable`")
test(t, "[mytable]", `[mytable]`)
func TestReversedQuoteTo(t *testing.T) {
var (
quoter = Quoter{'[', ']', func(s string) bool {
if s == "mytable" {
return true
}
return false
}}
kases = []struct {
expected string
value string
}{
{"[mytable]", "mytable"},
{"[mytable]", "`mytable`"},
{"[mytable]", `[mytable]`},
{`"mytable"`, `"mytable"`},
{"myschema.[mytable]", "myschema.mytable"},
{"myschema.[mytable]", "`myschema`.mytable"},
{"myschema.[mytable]", "myschema.`mytable`"},
{"myschema.[mytable]", "`myschema`.`mytable`"},
{"myschema.[mytable]", `[myschema].mytable`},
{"myschema.[mytable]", `myschema.[mytable]`},
{"myschema.[mytable]", `[myschema].[mytable]`},
{`"myschema.mytable"`, `"myschema.mytable"`},
{"message_user AS sender", "`message_user` AS `sender`"},
{"myschema.[mytable] AS table", "myschema.mytable AS table"},
}
)
test(t, `["mytable"]`, `"mytable"`)
for _, v := range kases {
t.Run(v.value, func(t *testing.T) {
buf := &strings.Builder{}
quoter.QuoteTo(buf, v.value)
assert.EqualValues(t, v.expected, buf.String())
})
}
}
test(t, "[myschema].[mytable]", "myschema.mytable")
test(t, "[myschema].[mytable]", "`myschema`.mytable")
test(t, "[myschema].[mytable]", "myschema.`mytable`")
test(t, "[myschema].[mytable]", "`myschema`.`mytable`")
test(t, "[myschema].[mytable]", `[myschema].mytable`)
test(t, "[myschema].[mytable]", `myschema.[mytable]`)
test(t, "[myschema].[mytable]", `[myschema].[mytable]`)
func TestNoQuoteTo(t *testing.T) {
var (
quoter = Quoter{'[', ']', AlwaysNoReserve}
kases = []struct {
expected string
value string
}{
{"mytable", "mytable"},
{"mytable", "`mytable`"},
{"mytable", `[mytable]`},
{`"mytable"`, `"mytable"`},
{"myschema.mytable", "myschema.mytable"},
{"myschema.mytable", "`myschema`.mytable"},
{"myschema.mytable", "myschema.`mytable`"},
{"myschema.mytable", "`myschema`.`mytable`"},
{"myschema.mytable", `[myschema].mytable`},
{"myschema.mytable", `myschema.[mytable]`},
{"myschema.mytable", `[myschema].[mytable]`},
{`"myschema.mytable"`, `"myschema.mytable"`},
{"message_user AS sender", "`message_user` AS `sender`"},
{"myschema.mytable AS table", "myschema.mytable AS table"},
}
)
test(t, `["myschema].[mytable"]`, `"myschema.mytable"`)
test(t, "[message_user] AS [sender]", "`message_user` AS `sender`")
assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ","))
buf := &strings.Builder{}
quoter = Quoter{"", ""}
quoter.QuoteTo(buf, "noquote")
assert.EqualValues(t, "noquote", buf.String())
for _, v := range kases {
t.Run(v.value, func(t *testing.T) {
buf := &strings.Builder{}
quoter.QuoteTo(buf, v.value)
assert.EqualValues(t, v.expected, buf.String())
})
}
}
func TestJoin(t *testing.T) {
cols := []string{"f1", "f2", "f3"}
quoter := Quoter{"[", "]"}
quoter := Quoter{'[', ']', AlwaysReserve}
assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ","))
assert.EqualValues(t, "[f1], [f2], [f3]", quoter.Join(cols, ", "))
quoter = Quoter{"", ""}
quoter.IsReserved = AlwaysNoReserve
assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", "))
}
func TestStrings(t *testing.T) {
cols := []string{"f1", "f2", "t3.f3"}
quoter := Quoter{"[", "]"}
quoter := Quoter{'[', ']', AlwaysReserve}
quotedCols := quoter.Strings(cols)
assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols)
@ -72,6 +143,6 @@ func TestTrim(t *testing.T) {
for src, dst := range kases {
assert.EqualValues(t, src, CommonQuoter.Trim(src))
assert.EqualValues(t, dst, Quoter{"[", "]"}.Trim(src))
assert.EqualValues(t, dst, Quoter{'[', ']', AlwaysReserve}.Trim(src))
}
}

View File

@ -22,7 +22,7 @@ type NullType struct {
Age sql.NullInt64
Height sql.NullFloat64
IsMan sql.NullBool `xorm:"null"`
CustomStruct CustomStruct `xorm:"valchar(64) null"`
CustomStruct CustomStruct `xorm:"varchar(64) null"`
}
type CustomStruct struct {
@ -58,14 +58,12 @@ func (m CustomStruct) Value() (driver.Value, error) {
func TestCreateNullStructTable(t *testing.T) {
assert.NoError(t, prepareEngine())
err := testEngine.CreateTables(new(NullType))
assert.NoError(t, err)
}
func TestDropNullStructTable(t *testing.T) {
assert.NoError(t, prepareEngine())
err := testEngine.DropTables(new(NullType))
assert.NoError(t, err)
}
@ -78,7 +76,7 @@ func TestNullStructInsert(t *testing.T) {
item := new(NullType)
_, err := testEngine.Insert(item)
assert.NoError(t, err)
assert.EqualValues(t, item.Id, 1)
assert.EqualValues(t, 1, item.Id)
}
if true {
@ -90,12 +88,11 @@ func TestNullStructInsert(t *testing.T) {
}
_, err := testEngine.Insert(&item)
assert.NoError(t, err)
assert.EqualValues(t, item.Id, 2)
assert.EqualValues(t, 2, item.Id)
}
if true {
items := []NullType{}
for i := 0; i < 5; i++ {
item := NullType{
Name: sql.NullString{String: "haolei_" + fmt.Sprint(i+1), Valid: true},
@ -152,7 +149,7 @@ func TestNullStructUpdate(t *testing.T) {
affected, err := testEngine.ID(2).Cols("age", "height", "is_man").Update(item)
assert.NoError(t, err)
assert.EqualValues(t, affected, 1)
assert.EqualValues(t, 1, affected)
}
if true { // 测试In update
@ -160,7 +157,7 @@ func TestNullStructUpdate(t *testing.T) {
item.Age = sql.NullInt64{Int64: 23, Valid: true}
affected, err := testEngine.In("id", 3, 4).Cols("age", "height", "is_man").Update(item)
assert.NoError(t, err)
assert.EqualValues(t, affected, 2)
assert.EqualValues(t, 2, affected)
}
if true { // 测试where
@ -183,9 +180,7 @@ func TestNullStructUpdate(t *testing.T) {
_, err := testEngine.AllCols().ID(6).Update(item)
assert.NoError(t, err)
fmt.Println(item)
}
}
func TestNullStructFind(t *testing.T) {
@ -274,9 +269,8 @@ func TestNullStructCount(t *testing.T) {
if true {
item := new(NullType)
total, err := testEngine.Where("age IS NOT NULL").Count(item)
_, err := testEngine.Where("age IS NOT NULL").Count(item)
assert.NoError(t, err)
fmt.Println(total)
}
}
@ -292,7 +286,6 @@ func TestNullStructRows(t *testing.T) {
for rows.Next() {
err = rows.Scan(item)
assert.NoError(t, err)
fmt.Println(item)
}
}

View File

@ -18,6 +18,7 @@ import (
_ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv"
"xorm.io/xorm/caches"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
@ -38,6 +39,7 @@ var (
schema = flag.String("schema", "", "specify the schema")
ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb")
ingoreUpdateLimit = flag.Bool("ignore_update_limit", false, "ignore update limit if implementation difference, only for cockroach")
quotePolicyStr = flag.String("quote", "always", "quote could be always, none, reversed")
tableMapper names.Mapper
colMapper names.Mapper
)
@ -131,6 +133,14 @@ func createEngine(dbType, connStr string) error {
testEngine.SetMapper(names.LintGonicMapper)
}
}
if *quotePolicyStr == "none" {
testEngine.SetQuotePolicy(dialects.QuotePolicyNone)
} else if *quotePolicyStr == "reserved" {
testEngine.SetQuotePolicy(dialects.QuotePolicyReserved)
} else {
testEngine.SetQuotePolicy(dialects.QuotePolicyAlways)
}
}
tableMapper = testEngine.GetTableMapper()