Browse Source

Improve quote policy (#1567)

names with upper charactor on postgres will need quotes

Fix bug

Add new quote parameter on tests

Fix bug

Fix tests

Fix quotes

fix test

Improve quote policy

Reviewed-on: #1567
tags/v1.0.0
Lunny Xiao 3 weeks ago
parent
commit
7455014823
18 changed files with 488 additions and 254 deletions
  1. +10
    -1
      .drone.yml
  2. +19
    -18
      Makefile
  3. +8
    -2
      dialects/dialect.go
  4. +27
    -11
      dialects/filter.go
  5. +24
    -7
      dialects/filter_test.go
  6. +19
    -3
      dialects/mssql.go
  7. +22
    -6
      dialects/mysql.go
  8. +22
    -6
      dialects/oracle.go
  9. +33
    -5
      dialects/postgres.go
  10. +15
    -0
      dialects/quote.go
  11. +33
    -11
      dialects/sqlite3.go
  12. +4
    -0
      engine.go
  13. +8
    -0
      engine_group.go
  14. +1
    -0
      interface.go
  15. +122
    -137
      schemas/quote.go
  16. +105
    -34
      schemas/quote_test.go
  17. +6
    -13
      types_null_test.go
  18. +10
    -0
      xorm_test.go

+ 10
- 1
.drone.yml View File

@@ -22,9 +22,10 @@ steps:
commands:
- make test-sqlite
- TEST_CACHE_ENABLE=true make test-sqlite
- TEST_QUOTE_POLICY=reserved make test-sqlite
- go test ./caches/... ./contexts/... ./convert/... ./core/... ./dialects/... \
./log/... ./migrate/... ./names/... ./schemas/... ./tags/... \
./internal/json/... ./internal/statements/... ./internal/utils/... \
./log/... ./migrate/... ./names/... ./schemas/... ./tags/...

when:
event:
@@ -44,6 +45,7 @@ steps:
commands:
- make test-mysql
- TEST_CACHE_ENABLE=true make test-mysql
- TEST_QUOTE_POLICY=reserved make test-mysql
when:
event:
- push
@@ -62,6 +64,7 @@ steps:
commands:
- make test-mysql
- TEST_CACHE_ENABLE=true make test-mysql
- TEST_QUOTE_POLICY=reserved make test-mysql
when:
event:
- push
@@ -82,6 +85,7 @@ steps:
commands:
- make test-mysql
- TEST_CACHE_ENABLE=true make test-mysql
- TEST_QUOTE_POLICY=reserved make test-mysql
when:
event:
- push
@@ -102,6 +106,7 @@ steps:
commands:
- make test-mymysql
- TEST_CACHE_ENABLE=true make test-mymysql
- TEST_QUOTE_POLICY=reserved make test-mymysql
when:
event:
- push
@@ -120,6 +125,7 @@ steps:
commands:
- make test-postgres
- TEST_CACHE_ENABLE=true make test-postgres
- TEST_QUOTE_POLICY=reserved make test-postgres
when:
event:
- push
@@ -141,6 +147,7 @@ steps:
commands:
- make test-postgres
- TEST_CACHE_ENABLE=true make test-postgres
- TEST_QUOTE_POLICY=reserved make test-postgres
when:
event:
- push
@@ -159,6 +166,7 @@ steps:
commands:
- make test-mssql
- TEST_CACHE_ENABLE=true make test-mssql
- TEST_QUOTE_POLICY=reserved make test-mssql
when:
event:
- push
@@ -177,6 +185,7 @@ steps:
commands:
- make test-tidb
- TEST_CACHE_ENABLE=true make test-tidb
- TEST_QUOTE_POLICY=reserved make test-tidb
when:
event:
- push


+ 19
- 18
Makefile View File

@@ -39,6 +39,7 @@ TEST_TIDB_USERNAME ?= root
TEST_TIDB_PASSWORD ?=

TEST_CACHE_ENABLE ?= false
TEST_QUOTE_POLICY ?= always

.PHONY: all
all: build
@@ -135,73 +136,73 @@ test-cockroach\#%: go-check

.PNONY: test-mssql
test-mssql: go-check
$(GO) test -v -race -db=mssql -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -db=mssql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \
-conn_str="server=$(TEST_MSSQL_HOST);user id=$(TEST_MSSQL_USERNAME);password=$(TEST_MSSQL_PASSWORD);database=$(TEST_MSSQL_DBNAME)" \
-coverprofile=mssql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

.PNONY: test-mssql\#%
test-mssql\#%: go-check
$(GO) test -v -race -run $* -db=mssql -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -run $* -db=mssql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \
-conn_str="server=$(TEST_MSSQL_HOST);user id=$(TEST_MSSQL_USERNAME);password=$(TEST_MSSQL_PASSWORD);database=$(TEST_MSSQL_DBNAME)" \
-coverprofile=mssql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

.PNONY: test-mymysql
test-mymysql: go-check
$(GO) test -v -race -db=mymysql -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \
-conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \
-coverprofile=mymysql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

.PNONY: test-mymysql\#%
test-mymysql\#%: go-check
$(GO) test -v -race -run $* -db=mymysql -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -run $* -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \
-conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \
-coverprofile=mymysql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

.PNONY: test-mysql
test-mysql: go-check
$(GO) test -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \
-conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \
-coverprofile=mysql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

.PHONY: test-mysql\#%
test-mysql\#%: go-check
$(GO) test -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \
-conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \
-coverprofile=mysql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

.PNONY: test-postgres
test-postgres: go-check
$(GO) test -v -race -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \
-coverprofile=postgres.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

.PHONY: test-postgres\#%
test-postgres\#%: go-check
$(GO) test -v -race -run $* -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \
-coverprofile=postgres.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

.PHONY: test-sqlite
test-sqlite: go-check
$(GO) test -v -race -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \
-coverprofile=sqlite.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

.PHONY: test-sqlite\#%
test-sqlite\#%: go-check
$(GO) test -v -race -run $* -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \
-coverprofile=sqlite.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

.PNONY: test-tidb
test-tidb: go-check
$(GO) test -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \
-conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \
-coverprofile=tidb.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

.PHONY: test-tidb\#%
test-tidb\#%: go-check
$(GO) test -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \
-conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \
-coverprofile=tidb.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
-quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

.PHONY: vet
vet:

+ 8
- 2
dialects/dialect.go View File

@@ -41,6 +41,7 @@ type Dialect interface {

IsReserved(string) bool
Quoter() schemas.Quoter
SetQuotePolicy(quotePolicy QuotePolicy)

AutoIncrStr() string

@@ -79,6 +80,11 @@ type Base struct {
db *core.DB
dialect Dialect
uri *URI
quoter schemas.Quoter
}

func (b *Base) Quoter() schemas.Quoter {
return b.quoter
}

func (b *Base) DB() *core.DB {
@@ -210,7 +216,7 @@ func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
idxName = index.XName(tableName)
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
quoter.Quote(idxName), quoter.Quote(tableName),
quoter.Quote(strings.Join(index.Cols, quoter.ReverseQuote(","))))
quoter.Join(index.Cols, ","))
}

func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
@@ -258,7 +264,7 @@ func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, char

if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(",")))
sql += quoter.Join(pkList, ",")
sql += " ), "
}



+ 27
- 11
dialects/filter.go View File

@@ -26,20 +26,36 @@ func (s *QuoteFilter) Do(sql string) string {
return sql
}

prefix, suffix := s.quoter[0][0], s.quoter[1][0]
raw := []byte(sql)
for i, cnt := 0, 0; i < len(raw); i = i + 1 {
if raw[i] == '`' {
if cnt%2 == 0 {
raw[i] = prefix
} else {
raw[i] = suffix
var buf strings.Builder
buf.Grow(len(sql))

var beginSingleQuote bool
for i := 0; i < len(sql); i++ {
if !beginSingleQuote && sql[i] == '`' {
var j = i + 1
for ; j < len(sql); j++ {
if sql[j] == '`' {
break
}
}
word := sql[i+1 : j]
isReserved := s.quoter.IsReserved(word)
if isReserved {
buf.WriteByte(s.quoter.Prefix)
}
buf.WriteString(word)
if isReserved {
buf.WriteByte(s.quoter.Suffix)
}
cnt++
i = j
} else {
if sql[i] == '\'' {
beginSingleQuote = !beginSingleQuote
}
buf.WriteByte(sql[i])
}
}
return string(raw)

return buf.String()
}

// SeqFilter filter SQL replace ?, ? ... to $1, $2 ...


+ 24
- 7
dialects/filter_test.go View File

@@ -9,13 +9,30 @@ import (
)

func TestQuoteFilter_Do(t *testing.T) {
f := QuoteFilter{schemas.Quoter{"[", "]"}}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
res := f.Do(sql)
assert.EqualValues(t,
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
res,
)
f := QuoteFilter{schemas.Quoter{'[', ']', schemas.AlwaysReserve}}
var kases = []struct {
source string
expected string
}{
{
"SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?",
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
},
{
"SELECT 'abc```test```''', `a` FROM b",
"SELECT 'abc```test```''', [a] FROM b",
},
{
"UPDATE table SET `a` = ~ `a`, `b`='abc`'",
"UPDATE table SET [a] = ~ [a], [b]='abc`'",
},
}

for _, kase := range kases {
t.Run(kase.source, func(t *testing.T) {
assert.EqualValues(t, kase.expected, f.Do(kase.source))
})
}
}

func TestSeqFilter(t *testing.T) {


+ 19
- 3
dialects/mssql.go View File

@@ -204,6 +204,8 @@ var (
"EXIT": true,
"PROC": true,
}

mssqlQuoter = schemas.Quoter{'[', ']', schemas.AlwaysReserve}
)

type mssql struct {
@@ -211,6 +213,7 @@ type mssql struct {
}

func (db *mssql) Init(d *core.DB, uri *URI) error {
db.quoter = mssqlQuoter
return db.Base.Init(d, db, uri)
}

@@ -283,12 +286,25 @@ func (db *mssql) SupportInsertMany() bool {
}

func (db *mssql) IsReserved(name string) bool {
_, ok := mssqlReservedWords[name]
_, ok := mssqlReservedWords[strings.ToUpper(name)]
return ok
}

func (db *mssql) Quoter() schemas.Quoter {
return schemas.Quoter{"[", "]"}
func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy {
case QuotePolicyNone:
var q = mssqlQuoter
q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q
case QuotePolicyReserved:
var q = mssqlQuoter
q.IsReserved = db.IsReserved
db.quoter = q
case QuotePolicyAlways:
fallthrough
default:
db.quoter = mssqlQuoter
}
}

func (db *mssql) SupportEngine() bool {


+ 22
- 6
dialects/mysql.go View File

@@ -161,6 +161,8 @@ var (
"YEAR_MONTH": true,
"ZEROFILL": true,
}

mysqlQuoter = schemas.Quoter{'`', '`', schemas.AlwaysReserve}
)

type mysql struct {
@@ -178,6 +180,7 @@ type mysql struct {
}

func (db *mysql) Init(d *core.DB, uri *URI) error {
db.quoter = mysqlQuoter
return db.Base.Init(d, db, uri)
}

@@ -272,14 +275,10 @@ func (db *mysql) SupportInsertMany() bool {
}

func (db *mysql) IsReserved(name string) bool {
_, ok := mysqlReservedWords[name]
_, ok := mysqlReservedWords[strings.ToUpper(name)]
return ok
}

func (db *mysql) Quoter() schemas.Quoter {
return schemas.Quoter{"`", "`"}
}

func (db *mysql) SupportEngine() bool {
return true
}
@@ -458,6 +457,23 @@ func (db *mysql) GetTables(ctx context.Context) ([]*schemas.Table, error) {
return tables, nil
}

func (db *mysql) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy {
case QuotePolicyNone:
var q = mysqlQuoter
q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q
case QuotePolicyReserved:
var q = mysqlQuoter
q.IsReserved = db.IsReserved
db.quoter = q
case QuotePolicyAlways:
fallthrough
default:
db.quoter = mysqlQuoter
}
}

func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{db.uri.DBName, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
@@ -538,7 +554,7 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch

if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(",")))
sql += quoter.Join(pkList, ",")
sql += " ), "
}



+ 22
- 6
dialects/oracle.go View File

@@ -498,6 +498,8 @@ var (
"YEAR": true,
"ZONE": true,
}

oracleQuoter = schemas.Quoter{'[', ']', schemas.AlwaysReserve}
)

type oracle struct {
@@ -505,6 +507,7 @@ type oracle struct {
}

func (db *oracle) Init(d *core.DB, uri *URI) error {
db.quoter = oracleQuoter
return db.Base.Init(d, db, uri)
}

@@ -549,14 +552,10 @@ func (db *oracle) SupportInsertMany() bool {
}

func (db *oracle) IsReserved(name string) bool {
_, ok := oracleReservedWords[name]
_, ok := oracleReservedWords[strings.ToUpper(name)]
return ok
}

func (db *oracle) Quoter() schemas.Quoter {
return schemas.Quoter{"\"", "\""}
}

func (db *oracle) SupportEngine() bool {
return false
}
@@ -601,7 +600,7 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c

if len(pkList) > 0 {
sql += "PRIMARY KEY ( "
sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(",")))
sql += quoter.Join(pkList, ",")
sql += " ), "
}

@@ -620,6 +619,23 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c
return sql
}

func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy {
case QuotePolicyNone:
var q = oracleQuoter
q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q
case QuotePolicyReserved:
var q = oracleQuoter
q.IsReserved = db.IsReserved
db.quoter = q
case QuotePolicyAlways:
fallthrough
default:
db.quoter = oracleQuoter
}
}

func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{tableName, idxName}
return `SELECT INDEX_NAME FROM USER_INDEXES ` +


+ 33
- 5
dialects/postgres.go View File

@@ -766,6 +766,8 @@ var (
"YES": true,
"ZONE": true,
}

postgresQuoter = schemas.Quoter{'"', '"', schemas.AlwaysReserve}
)

const postgresPublicSchema = "public"
@@ -775,6 +777,7 @@ type postgres struct {
}

func (db *postgres) Init(d *core.DB, uri *URI) error {
db.quoter = postgresQuoter
err := db.Base.Init(d, db, uri)
if err != nil {
return err
@@ -785,6 +788,35 @@ func (db *postgres) Init(d *core.DB, uri *URI) error {
return nil
}

func (db *postgres) needQuote(name string) bool {
if db.IsReserved(name) {
return true
}
for _, c := range name {
if c >= 'A' && c <= 'Z' {
return true
}
}
return false
}

func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy {
case QuotePolicyNone:
var q = postgresQuoter
q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q
case QuotePolicyReserved:
var q = postgresQuoter
q.IsReserved = db.needQuote
db.quoter = q
case QuotePolicyAlways:
fallthrough
default:
db.quoter = postgresQuoter
}
}

func (db *postgres) DefaultSchema() string {
return postgresPublicSchema
}
@@ -857,14 +889,10 @@ func (db *postgres) SupportInsertMany() bool {
}

func (db *postgres) IsReserved(name string) bool {
_, ok := postgresReservedWords[name]
_, ok := postgresReservedWords[strings.ToUpper(name)]
return ok
}

func (db *postgres) Quoter() schemas.Quoter {
return schemas.Quoter{`"`, `"`}
}

func (db *postgres) AutoIncrStr() string {
return ""
}


+ 15
- 0
dialects/quote.go View File

@@ -0,0 +1,15 @@
// Copyright 2020 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package dialects

// QuotePolicy describes quote handle policy
type QuotePolicy int

// All QuotePolicies
const (
QuotePolicyAlways QuotePolicy = iota
QuotePolicyNone
QuotePolicyReserved
)

+ 33
- 11
dialects/sqlite3.go View File

@@ -143,6 +143,8 @@ var (
"WITH": true,
"WITHOUT": true,
}

sqlite3Quoter = schemas.Quoter{'`', '`', schemas.AlwaysReserve}
)

type sqlite3 struct {
@@ -150,9 +152,27 @@ type sqlite3 struct {
}

func (db *sqlite3) Init(d *core.DB, uri *URI) error {
db.quoter = sqlite3Quoter
return db.Base.Init(d, db, uri)
}

func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy {
case QuotePolicyNone:
var q = sqlite3Quoter
q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q
case QuotePolicyReserved:
var q = sqlite3Quoter
q.IsReserved = db.IsReserved
db.quoter = q
case QuotePolicyAlways:
fallthrough
default:
db.quoter = sqlite3Quoter
}
}

func (db *sqlite3) SQLType(c *schemas.Column) string {
switch t := c.SQLType.Name; t {
case schemas.Bool:
@@ -196,14 +216,10 @@ func (db *sqlite3) SupportInsertMany() bool {
}

func (db *sqlite3) IsReserved(name string) bool {
_, ok := sqlite3ReservedWords[name]
_, ok := sqlite3ReservedWords[strings.ToUpper(name)]
return ok
}

func (db *sqlite3) Quoter() schemas.Quoter {
return schemas.Quoter{"`", "`"}
}

func (db *sqlite3) AutoIncrStr() string {
return "AUTOINCREMENT"
}
@@ -250,18 +266,24 @@ func (db *sqlite3) ForUpdateSQL(query string) string {
}

func (db *sqlite3) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) {
args := []interface{}{tableName}
query := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))"

rows, err := db.DB().QueryContext(ctx, query, args...)
query := "SELECT * FROM " + tableName + " LIMIT 0"
rows, err := db.DB().QueryContext(ctx, query)
if err != nil {
return false, err
}
defer rows.Close()

if rows.Next() {
return true, nil
cols, err := rows.Columns()
if err != nil {
return false, err
}

for _, col := range cols {
if strings.EqualFold(col, colName) {
return true, nil
}
}

return false, nil
}



+ 4
- 0
engine.go View File

@@ -54,6 +54,10 @@ func (engine *Engine) GetCacher(tableName string) caches.Cacher {
return engine.cacherMgr.GetCacher(tableName)
}

func (engine *Engine) SetQuotePolicy(quotePolicy dialects.QuotePolicy) {
engine.dialect.SetQuotePolicy(quotePolicy)
}

// BufferSize sets buffer size for iterate
func (engine *Engine) BufferSize(size int) *Session {
session := engine.NewSession()


+ 8
- 0
engine_group.go View File

@@ -9,6 +9,7 @@ import (
"time"

"xorm.io/xorm/caches"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
)
@@ -180,6 +181,13 @@ func (eg *EngineGroup) SetPolicy(policy GroupPolicy) *EngineGroup {
return eg
}

func (eg *EngineGroup) SetQuotePolicy(quotePolicy dialects.QuotePolicy) {
eg.Engine.SetQuotePolicy(quotePolicy)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetQuotePolicy(quotePolicy)
}
}

// SetTableMapper set the table name mapping rule
func (eg *EngineGroup) SetTableMapper(mapper names.Mapper) {
eg.Engine.SetTableMapper(mapper)


+ 1
- 0
interface.go View File

@@ -104,6 +104,7 @@ type EngineInterface interface {
SetMapper(names.Mapper)
SetMaxOpenConns(int)
SetMaxIdleConns(int)
SetQuotePolicy(dialects.QuotePolicy)
SetSchema(string)
SetTableMapper(names.Mapper)
SetTZDatabase(tz *time.Location)


+ 122
- 137
schemas/quote.go View File

@@ -8,14 +8,29 @@ import (
"strings"
)

// Quoter represents two quote characters
type Quoter [2]string
// Quoter represents a quoter to the SQL table name and column name
type Quoter struct {
Prefix byte
Suffix byte
IsReserved func(string) bool
}

var (
// AlwaysFalseReverse always think it's not a reverse word
AlwaysNoReserve = func(string) bool { return false }

// AlwaysReverse always reverse the word
AlwaysReserve = func(string) bool { return true }

// CommonQuoter represetns a common quoter
var CommonQuoter = Quoter{"`", "`"}
// CommanQuoteMark represnets the common quote mark
CommanQuoteMark byte = '`'

// CommonQuoter represetns a common quoter
CommonQuoter = Quoter{CommanQuoteMark, CommanQuoteMark, AlwaysReserve}
)

func (q Quoter) IsEmpty() bool {
return q[0] == "" && q[1] == ""
return q.Prefix == 0 && q.Suffix == 0
}

func (q Quoter) Quote(s string) string {
@@ -24,42 +39,6 @@ func (q Quoter) Quote(s string) string {
return buf.String()
}

func (q Quoter) Replace(sql string, newQuoter Quoter) string {
if q.IsEmpty() {
return sql
}

if newQuoter.IsEmpty() {
var buf strings.Builder
for i := 0; i < len(sql); i = i + 1 {
if sql[i] != q[0][0] && sql[i] != q[1][0] {
_ = buf.WriteByte(sql[i])
}
}
return buf.String()
}

prefix, suffix := newQuoter[0][0], newQuoter[1][0]
var buf strings.Builder
for i, cnt := 0, 0; i < len(sql); i = i + 1 {
if cnt == 0 && sql[i] == q[0][0] {
_ = buf.WriteByte(prefix)
cnt = 1
} else if cnt == 1 && sql[i] == q[1][0] {
_ = buf.WriteByte(suffix)
cnt = 0
} else {
_ = buf.WriteByte(sql[i])
}
}
return buf.String()
}

func (q Quoter) ReverseQuote(s string) string {
reverseQuoter := Quoter{q[1], q[0]}
return reverseQuoter.Quote(s)
}

// Trim removes quotes from s
func (q Quoter) Trim(s string) string {
if len(s) < 2 {
@@ -69,10 +48,10 @@ func (q Quoter) Trim(s string) string {
var buf strings.Builder
for i := 0; i < len(s); i++ {
switch {
case i == 0 && s[i:i+1] == q[0]:
case i == len(s)-1 && s[i:i+1] == q[1]:
case s[i:i+1] == q[1] && s[i+1] == '.':
case s[i:i+1] == q[0] && s[i-1] == '.':
case i == 0 && s[i] == q.Prefix:
case i == len(s)-1 && s[i] == q.Suffix:
case s[i] == q.Suffix && s[i+1] == '.':
case s[i] == q.Prefix && s[i-1] == '.':
default:
buf.WriteByte(s[i])
}
@@ -81,31 +60,8 @@ func (q Quoter) Trim(s string) string {
}

func (q Quoter) Join(a []string, sep string) string {
switch len(a) {
case 0:
return ""
case 1:
return a[0]
}
n := len(sep) * (len(a) - 1)
for i := 0; i < len(a); i++ {
n += len(a[i])
}

var b strings.Builder
b.Grow(n)
for i, s := range a {
if i > 0 {
b.WriteString(sep)
}
if q[0] != "" && s != "*" {
b.WriteString(q[0])
}
b.WriteString(strings.TrimSpace(s))
if q[1] != "" && s != "*" {
b.WriteString(q[1])
}
}
q.JoinWrite(&b, a, sep)
return b.String()
}

@@ -126,88 +82,117 @@ func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error {
return err
}
}
if q[0] != "" && s != "*" && s[0] != '`' {
if _, err := b.WriteString(q[0]); err != nil {
return err
}
}
if _, err := b.WriteString(strings.TrimSpace(s)); err != nil {
return err
}
if q[1] != "" && s != "*" && s[0] != '`' {
if _, err := b.WriteString(q[1]); err != nil {
return err
}
if s != "*" {
q.QuoteTo(b, strings.TrimSpace(s))
}
}
return nil
}

func (q Quoter) Strings(s []string) []string {
var res = make([]string, 0, len(s))
for _, a := range s {
res = append(res, q.Quote(a))
func findWord(v string, start int) int {
for j := start; j < len(v); j++ {
switch v[j] {
case '.', ' ':
return j
}
}
return res
return len(v)
}

func findStart(value string, start int) int {
if value[start] == '.' {
return start + 1
}
if value[start] != ' ' {
return start
}

var k int
for j := start; j < len(value); j++ {
if value[j] != ' ' {
k = j
break
}
}
if k-1 == len(value) {
return len(value)
}
if (value[k] == 'A' || value[k] == 'a') && (value[k+1] == 'S' || value[k+1] == 's') {
k = k + 2
}

for j := k; j < len(value); j++ {
if value[j] != ' ' {
return j
}
}
return len(value)
}

func (q Quoter) QuoteTo(buf *strings.Builder, value string) {
func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error {
var realWord = word
if (word[0] == CommanQuoteMark && word[len(word)-1] == CommanQuoteMark) ||
(word[0] == q.Prefix && word[len(word)-1] == q.Suffix) {
realWord = word[1 : len(word)-1]
}

if q.IsEmpty() {
buf.WriteString(value)
return
_, err := buf.WriteString(realWord)
return err
}

prefix, suffix := q[0][0], q[1][0]
lastCh := 0 // 0 prefix, 1 char, 2 suffix
i := 0
for i < len(value) {
// start of a token; might be already quoted
if value[i] == '.' {
_ = buf.WriteByte('.')
lastCh = 1
i++
} else if value[i] == prefix || value[i] == '`' {
// Has quotes; skip/normalize `name` to prefix+name+sufix
var ch byte
if value[i] == prefix {
ch = suffix
} else {
ch = '`'
}
_ = buf.WriteByte(prefix)
i++
lastCh = 0
for ; i < len(value) && value[i] != ch && value[i] != ' '; i++ {
_ = buf.WriteByte(value[i])
lastCh = 1
}
_ = buf.WriteByte(suffix)
lastCh = 2
i++
} else if value[i] == ' ' {
if lastCh != 2 {
_ = buf.WriteByte(suffix)
lastCh = 2
}
isReserved := q.IsReserved(realWord)
if isReserved {
if err := buf.WriteByte(q.Prefix); err != nil {
return err
}
}
if _, err := buf.WriteString(realWord); err != nil {
return err
}
if isReserved {
return buf.WriteByte(q.Suffix)
}

// a AS b or a b
for ; i < len(value); i++ {
if value[i] != ' ' && value[i-1] == ' ' && (len(value) > i+1 && !strings.EqualFold(value[i:i+2], "AS")) {
break
}
return nil
}

_ = buf.WriteByte(value[i])
lastCh = 1
}
} else {
// Requires quotes
_ = buf.WriteByte(prefix)
for ; i < len(value) && value[i] != '.' && value[i] != ' '; i++ {
_ = buf.WriteByte(value[i])
lastCh = 1
// QuoteTo quotes the table or column names. i.e. if the quotes are [ and ]
// name -> [name]
// `name` -> [name]
// [name] -> [name]
// schema.name -> [schema].[name]
// `schema`.`name` -> [schema].[name]
// `schema`.name -> [schema].[name]
// schema.`name` -> [schema].[name]
// [schema].name -> [schema].[name]
// schema.[name] -> [schema].[name]
// name AS a -> [name] AS a
// schema.name AS a -> [schema].[name] AS a
func (q Quoter) QuoteTo(buf *strings.Builder, value string) error {
var i int
for i < len(value) {
start := findStart(value, i)
if start > i {
if _, err := buf.WriteString(value[i:start]); err != nil {
return err
}
_ = buf.WriteByte(suffix)
lastCh = 2
}
var nextEnd = findWord(value, start)

if err := q.quoteWordTo(buf, value[start:nextEnd]); err != nil {
return err
}
i = nextEnd
}
return nil
}

// Strings quotes a slice of string
func (q Quoter) Strings(s []string) []string {
var res = make([]string, 0, len(s))
for _, a := range s {
res = append(res, q.Quote(a))
}
return res
}

+ 105
- 34
schemas/quote_test.go View File

@@ -11,54 +11,125 @@ import (
"github.com/stretchr/testify/assert"
)

func TestQuoteTo(t *testing.T) {
var quoter = Quoter{"[", "]"}

test := func(t *testing.T, expected string, value string) {
buf := &strings.Builder{}
quoter.QuoteTo(buf, value)
assert.EqualValues(t, expected, buf.String())
func TestAlwaysQuoteTo(t *testing.T) {
var (
quoter = Quoter{'[', ']', AlwaysReserve}
kases = []struct {
expected string
value string
}{
{"[mytable]", "mytable"},
{"[mytable]", "`mytable`"},
{"[mytable]", `[mytable]`},
{`["mytable"]`, `"mytable"`},
{"[myschema].[mytable]", "myschema.mytable"},
{"[myschema].[mytable]", "`myschema`.mytable"},
{"[myschema].[mytable]", "myschema.`mytable`"},
{"[myschema].[mytable]", "`myschema`.`mytable`"},
{"[myschema].[mytable]", `[myschema].mytable`},
{"[myschema].[mytable]", `myschema.[mytable]`},
{"[myschema].[mytable]", `[myschema].[mytable]`},
{`["myschema].[mytable"]`, `"myschema.mytable"`},
{"[message_user] AS [sender]", "`message_user` AS `sender`"},
{"[myschema].[mytable] AS [table]", "myschema.mytable AS table"},
}
)

for _, v := range kases {
t.Run(v.value, func(t *testing.T) {
buf := &strings.Builder{}
quoter.QuoteTo(buf, v.value)
assert.EqualValues(t, v.expected, buf.String())
})
}
}

test(t, "[mytable]", "mytable")
test(t, "[mytable]", "`mytable`")
test(t, "[mytable]", `[mytable]`)

test(t, `["mytable"]`, `"mytable"`)

test(t, "[myschema].[mytable]", "myschema.mytable")
test(t, "[myschema].[mytable]", "`myschema`.mytable")
test(t, "[myschema].[mytable]", "myschema.`mytable`")
test(t, "[myschema].[mytable]", "`myschema`.`mytable`")
test(t, "[myschema].[mytable]", `[myschema].mytable`)
test(t, "[myschema].[mytable]", `myschema.[mytable]`)
test(t, "[myschema].[mytable]", `[myschema].[mytable]`)

test(t, `["myschema].[mytable"]`, `"myschema.mytable"`)

test(t, "[message_user] AS [sender]", "`message_user` AS `sender`")

assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ","))
func TestReversedQuoteTo(t *testing.T) {
var (
quoter = Quoter{'[', ']', func(s string) bool {
if s == "mytable" {
return true
}
return false
}}
kases = []struct {
expected string
value string
}{
{"[mytable]", "mytable"},
{"[mytable]", "`mytable`"},
{"[mytable]", `[mytable]`},
{`"mytable"`, `"mytable"`},
{"myschema.[mytable]", "myschema.mytable"},
{"myschema.[mytable]", "`myschema`.mytable"},
{"myschema.[mytable]", "myschema.`mytable`"},
{"myschema.[mytable]", "`myschema`.`mytable`"},
{"myschema.[mytable]", `[myschema].mytable`},
{"myschema.[mytable]", `myschema.[mytable]`},
{"myschema.[mytable]", `[myschema].[mytable]`},
{`"myschema.mytable"`, `"myschema.mytable"`},
{"message_user AS sender", "`message_user` AS `sender`"},
{"myschema.[mytable] AS table", "myschema.mytable AS table"},
}
)

for _, v := range kases {
t.Run(v.value, func(t *testing.T) {
buf := &strings.Builder{}
quoter.QuoteTo(buf, v.value)
assert.EqualValues(t, v.expected, buf.String())
})
}
}

buf := &strings.Builder{}
quoter = Quoter{"", ""}
quoter.QuoteTo(buf, "noquote")
assert.EqualValues(t, "noquote", buf.String())
func TestNoQuoteTo(t *testing.T) {
var (
quoter = Quoter{'[', ']', AlwaysNoReserve}
kases = []struct {
expected string
value string
}{
{"mytable", "mytable"},
{"mytable", "`mytable`"},
{"mytable", `[mytable]`},
{`"mytable"`, `"mytable"`},
{"myschema.mytable", "myschema.mytable"},
{"myschema.mytable", "`myschema`.mytable"},
{"myschema.mytable", "myschema.`mytable`"},
{"myschema.mytable", "`myschema`.`mytable`"},
{"myschema.mytable", `[myschema].mytable`},
{"myschema.mytable", `myschema.[mytable]`},
{"myschema.mytable", `[myschema].[mytable]`},
{`"myschema.mytable"`, `"myschema.mytable"`},
{"message_user AS sender", "`message_user` AS `sender`"},
{"myschema.mytable AS table", "myschema.mytable AS table"},
}
)

for _, v := range kases {
t.Run(v.value, func(t *testing.T) {
buf := &strings.Builder{}
quoter.QuoteTo(buf, v.value)
assert.EqualValues(t, v.expected, buf.String())
})
}
}

func TestJoin(t *testing.T) {
cols := []string{"f1", "f2", "f3"}
quoter := Quoter{"[", "]"}
quoter := Quoter{'[', ']', AlwaysReserve}

assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ","))

assert.EqualValues(t, "[f1], [f2], [f3]", quoter.Join(cols, ", "))

quoter = Quoter{"", ""}
quoter.IsReserved = AlwaysNoReserve
assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", "))
}

func TestStrings(t *testing.T) {
cols := []string{"f1", "f2", "t3.f3"}
quoter := Quoter{"[", "]"}
quoter := Quoter{'[', ']', AlwaysReserve}

quotedCols := quoter.Strings(cols)
assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols)
@@ -72,6 +143,6 @@ func TestTrim(t *testing.T) {

for src, dst := range kases {
assert.EqualValues(t, src, CommonQuoter.Trim(src))
assert.EqualValues(t, dst, Quoter{"[", "]"}.Trim(src))
assert.EqualValues(t, dst, Quoter{'[', ']', AlwaysReserve}.Trim(src))
}
}

+ 6
- 13
types_null_test.go View File

@@ -22,7 +22,7 @@ type NullType struct {
Age sql.NullInt64
Height sql.NullFloat64
IsMan sql.NullBool `xorm:"null"`
CustomStruct CustomStruct `xorm:"valchar(64) null"`
CustomStruct CustomStruct `xorm:"varchar(64) null"`
}

type CustomStruct struct {
@@ -58,14 +58,12 @@ func (m CustomStruct) Value() (driver.Value, error) {

func TestCreateNullStructTable(t *testing.T) {
assert.NoError(t, prepareEngine())

err := testEngine.CreateTables(new(NullType))
assert.NoError(t, err)
}

func TestDropNullStructTable(t *testing.T) {
assert.NoError(t, prepareEngine())

err := testEngine.DropTables(new(NullType))
assert.NoError(t, err)
}
@@ -78,7 +76,7 @@ func TestNullStructInsert(t *testing.T) {
item := new(NullType)
_, err := testEngine.Insert(item)
assert.NoError(t, err)
assert.EqualValues(t, item.Id, 1)
assert.EqualValues(t, 1, item.Id)
}

if true {
@@ -90,12 +88,11 @@ func TestNullStructInsert(t *testing.T) {
}
_, err := testEngine.Insert(&item)
assert.NoError(t, err)
assert.EqualValues(t, item.Id, 2)
assert.EqualValues(t, 2, item.Id)
}

if true {
items := []NullType{}

for i := 0; i < 5; i++ {
item := NullType{
Name: sql.NullString{String: "haolei_" + fmt.Sprint(i+1), Valid: true},
@@ -152,7 +149,7 @@ func TestNullStructUpdate(t *testing.T) {

affected, err := testEngine.ID(2).Cols("age", "height", "is_man").Update(item)
assert.NoError(t, err)
assert.EqualValues(t, affected, 1)
assert.EqualValues(t, 1, affected)
}

if true { // 测试In update
@@ -160,7 +157,7 @@ func TestNullStructUpdate(t *testing.T) {
item.Age = sql.NullInt64{Int64: 23, Valid: true}
affected, err := testEngine.In("id", 3, 4).Cols("age", "height", "is_man").Update(item)
assert.NoError(t, err)
assert.EqualValues(t, affected, 2)
assert.EqualValues(t, 2, affected)
}

if true { // 测试where
@@ -183,9 +180,7 @@ func TestNullStructUpdate(t *testing.T) {

_, err := testEngine.AllCols().ID(6).Update(item)
assert.NoError(t, err)
fmt.Println(item)
}

}

func TestNullStructFind(t *testing.T) {
@@ -274,9 +269,8 @@ func TestNullStructCount(t *testing.T) {

if true {
item := new(NullType)
total, err := testEngine.Where("age IS NOT NULL").Count(item)
_, err := testEngine.Where("age IS NOT NULL").Count(item)
assert.NoError(t, err)
fmt.Println(total)
}
}

@@ -292,7 +286,6 @@ func TestNullStructRows(t *testing.T) {
for rows.Next() {
err = rows.Scan(item)
assert.NoError(t, err)
fmt.Println(item)
}
}



+ 10
- 0
xorm_test.go View File

@@ -18,6 +18,7 @@ import (
_ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv"
"xorm.io/xorm/caches"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
@@ -38,6 +39,7 @@ var (
schema = flag.String("schema", "", "specify the schema")
ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb")
ingoreUpdateLimit = flag.Bool("ignore_update_limit", false, "ignore update limit if implementation difference, only for cockroach")
quotePolicyStr = flag.String("quote", "always", "quote could be always, none, reversed")
tableMapper names.Mapper
colMapper names.Mapper
)
@@ -131,6 +133,14 @@ func createEngine(dbType, connStr string) error {
testEngine.SetMapper(names.LintGonicMapper)
}
}

if *quotePolicyStr == "none" {
testEngine.SetQuotePolicy(dialects.QuotePolicyNone)
} else if *quotePolicyStr == "reserved" {
testEngine.SetQuotePolicy(dialects.QuotePolicyReserved)
} else {
testEngine.SetQuotePolicy(dialects.QuotePolicyAlways)
}
}

tableMapper = testEngine.GetTableMapper()


Loading…
Cancel
Save