diff --git a/.drone.yml b/.drone.yml index 9a62c6bd..0863cce2 100644 --- a/.drone.yml +++ b/.drone.yml @@ -22,9 +22,10 @@ steps: commands: - make test-sqlite - TEST_CACHE_ENABLE=true make test-sqlite + - TEST_QUOTE_POLICY=reserved make test-sqlite - go test ./caches/... ./contexts/... ./convert/... ./core/... ./dialects/... \ + ./log/... ./migrate/... ./names/... ./schemas/... ./tags/... \ ./internal/json/... ./internal/statements/... ./internal/utils/... \ - ./log/... ./migrate/... ./names/... ./schemas/... ./tags/... when: event: @@ -44,6 +45,7 @@ steps: commands: - make test-mysql - TEST_CACHE_ENABLE=true make test-mysql + - TEST_QUOTE_POLICY=reserved make test-mysql when: event: - push @@ -62,6 +64,7 @@ steps: commands: - make test-mysql - TEST_CACHE_ENABLE=true make test-mysql + - TEST_QUOTE_POLICY=reserved make test-mysql when: event: - push @@ -82,6 +85,7 @@ steps: commands: - make test-mysql - TEST_CACHE_ENABLE=true make test-mysql + - TEST_QUOTE_POLICY=reserved make test-mysql when: event: - push @@ -102,6 +106,7 @@ steps: commands: - make test-mymysql - TEST_CACHE_ENABLE=true make test-mymysql + - TEST_QUOTE_POLICY=reserved make test-mymysql when: event: - push @@ -120,6 +125,7 @@ steps: commands: - make test-postgres - TEST_CACHE_ENABLE=true make test-postgres + - TEST_QUOTE_POLICY=reserved make test-postgres when: event: - push @@ -141,6 +147,7 @@ steps: commands: - make test-postgres - TEST_CACHE_ENABLE=true make test-postgres + - TEST_QUOTE_POLICY=reserved make test-postgres when: event: - push @@ -159,6 +166,7 @@ steps: commands: - make test-mssql - TEST_CACHE_ENABLE=true make test-mssql + - TEST_QUOTE_POLICY=reserved make test-mssql when: event: - push @@ -177,6 +185,7 @@ steps: commands: - make test-tidb - TEST_CACHE_ENABLE=true make test-tidb + - TEST_QUOTE_POLICY=reserved make test-tidb when: event: - push diff --git a/Makefile b/Makefile index faad978f..4444ebd0 100644 --- a/Makefile +++ b/Makefile @@ -39,6 +39,7 @@ TEST_TIDB_USERNAME ?= root TEST_TIDB_PASSWORD ?= TEST_CACHE_ENABLE ?= false +TEST_QUOTE_POLICY ?= always .PHONY: all all: build @@ -135,73 +136,73 @@ test-cockroach\#%: go-check .PNONY: test-mssql test-mssql: go-check - $(GO) test -v -race -db=mssql -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -db=mssql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ -conn_str="server=$(TEST_MSSQL_HOST);user id=$(TEST_MSSQL_USERNAME);password=$(TEST_MSSQL_PASSWORD);database=$(TEST_MSSQL_DBNAME)" \ - -coverprofile=mssql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PNONY: test-mssql\#% test-mssql\#%: go-check - $(GO) test -v -race -run $* -db=mssql -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -run $* -db=mssql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ -conn_str="server=$(TEST_MSSQL_HOST);user id=$(TEST_MSSQL_USERNAME);password=$(TEST_MSSQL_PASSWORD);database=$(TEST_MSSQL_DBNAME)" \ - -coverprofile=mssql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PNONY: test-mymysql test-mymysql: go-check - $(GO) test -v -race -db=mymysql -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ -conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \ - -coverprofile=mymysql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PNONY: test-mymysql\#% test-mymysql\#%: go-check - $(GO) test -v -race -run $* -db=mymysql -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -run $* -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ -conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \ - -coverprofile=mymysql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PNONY: test-mysql test-mysql: go-check - $(GO) test -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ -conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \ - -coverprofile=mysql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-mysql\#% test-mysql\#%: go-check - $(GO) test -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ -conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \ - -coverprofile=mysql.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PNONY: test-postgres test-postgres: go-check $(GO) test -v -race -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \ - -coverprofile=postgres.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-postgres\#% test-postgres\#%: go-check $(GO) test -v -race -run $* -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \ - -coverprofile=postgres.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-sqlite test-sqlite: go-check $(GO) test -v -race -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ - -coverprofile=sqlite.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-sqlite\#% test-sqlite\#%: go-check $(GO) test -v -race -run $* -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ - -coverprofile=sqlite.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PNONY: test-tidb test-tidb: go-check $(GO) test -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \ -conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \ - -coverprofile=tidb.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-tidb\#% test-tidb\#%: go-check $(GO) test -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \ -conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \ - -coverprofile=tidb.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: vet vet: diff --git a/dialects/dialect.go b/dialects/dialect.go index c591cc7b..d89f1ebe 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -41,6 +41,7 @@ type Dialect interface { IsReserved(string) bool Quoter() schemas.Quoter + SetQuotePolicy(quotePolicy QuotePolicy) AutoIncrStr() string @@ -79,6 +80,11 @@ type Base struct { db *core.DB dialect Dialect uri *URI + quoter schemas.Quoter +} + +func (b *Base) Quoter() schemas.Quoter { + return b.quoter } func (b *Base) DB() *core.DB { @@ -210,7 +216,7 @@ func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { idxName = index.XName(tableName) return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique, quoter.Quote(idxName), quoter.Quote(tableName), - quoter.Quote(strings.Join(index.Cols, quoter.ReverseQuote(",")))) + quoter.Join(index.Cols, ",")) } func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { @@ -258,7 +264,7 @@ func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, char if len(pkList) > 1 { sql += "PRIMARY KEY ( " - sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(","))) + sql += quoter.Join(pkList, ",") sql += " ), " } diff --git a/dialects/filter.go b/dialects/filter.go index 0f9b4107..add8cc7d 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -26,20 +26,36 @@ func (s *QuoteFilter) Do(sql string) string { return sql } - prefix, suffix := s.quoter[0][0], s.quoter[1][0] - raw := []byte(sql) - for i, cnt := 0, 0; i < len(raw); i = i + 1 { - if raw[i] == '`' { - if cnt%2 == 0 { - raw[i] = prefix - } else { - raw[i] = suffix + var buf strings.Builder + buf.Grow(len(sql)) + + var beginSingleQuote bool + for i := 0; i < len(sql); i++ { + if !beginSingleQuote && sql[i] == '`' { + var j = i + 1 + for ; j < len(sql); j++ { + if sql[j] == '`' { + break + } } - cnt++ + word := sql[i+1 : j] + isReserved := s.quoter.IsReserved(word) + if isReserved { + buf.WriteByte(s.quoter.Prefix) + } + buf.WriteString(word) + if isReserved { + buf.WriteByte(s.quoter.Suffix) + } + i = j + } else { + if sql[i] == '\'' { + beginSingleQuote = !beginSingleQuote + } + buf.WriteByte(sql[i]) } } - return string(raw) - + return buf.String() } // SeqFilter filter SQL replace ?, ? ... to $1, $2 ... diff --git a/dialects/filter_test.go b/dialects/filter_test.go index ac110a69..e8395156 100644 --- a/dialects/filter_test.go +++ b/dialects/filter_test.go @@ -9,13 +9,30 @@ import ( ) func TestQuoteFilter_Do(t *testing.T) { - f := QuoteFilter{schemas.Quoter{"[", "]"}} - sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" - res := f.Do(sql) - assert.EqualValues(t, - "SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?", - res, - ) + f := QuoteFilter{schemas.Quoter{'[', ']', schemas.AlwaysReserve}} + var kases = []struct { + source string + expected string + }{ + { + "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?", + "SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?", + }, + { + "SELECT 'abc```test```''', `a` FROM b", + "SELECT 'abc```test```''', [a] FROM b", + }, + { + "UPDATE table SET `a` = ~ `a`, `b`='abc`'", + "UPDATE table SET [a] = ~ [a], [b]='abc`'", + }, + } + + for _, kase := range kases { + t.Run(kase.source, func(t *testing.T) { + assert.EqualValues(t, kase.expected, f.Do(kase.source)) + }) + } } func TestSeqFilter(t *testing.T) { diff --git a/dialects/mssql.go b/dialects/mssql.go index 558abdfc..a2cbb361 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -204,6 +204,8 @@ var ( "EXIT": true, "PROC": true, } + + mssqlQuoter = schemas.Quoter{'[', ']', schemas.AlwaysReserve} ) type mssql struct { @@ -211,6 +213,7 @@ type mssql struct { } func (db *mssql) Init(d *core.DB, uri *URI) error { + db.quoter = mssqlQuoter return db.Base.Init(d, db, uri) } @@ -283,12 +286,25 @@ func (db *mssql) SupportInsertMany() bool { } func (db *mssql) IsReserved(name string) bool { - _, ok := mssqlReservedWords[name] + _, ok := mssqlReservedWords[strings.ToUpper(name)] return ok } -func (db *mssql) Quoter() schemas.Quoter { - return schemas.Quoter{"[", "]"} +func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = mssqlQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = mssqlQuoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = mssqlQuoter + } } func (db *mssql) SupportEngine() bool { diff --git a/dialects/mysql.go b/dialects/mysql.go index 939a7cf1..5f36ed31 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -161,6 +161,8 @@ var ( "YEAR_MONTH": true, "ZEROFILL": true, } + + mysqlQuoter = schemas.Quoter{'`', '`', schemas.AlwaysReserve} ) type mysql struct { @@ -178,6 +180,7 @@ type mysql struct { } func (db *mysql) Init(d *core.DB, uri *URI) error { + db.quoter = mysqlQuoter return db.Base.Init(d, db, uri) } @@ -272,14 +275,10 @@ func (db *mysql) SupportInsertMany() bool { } func (db *mysql) IsReserved(name string) bool { - _, ok := mysqlReservedWords[name] + _, ok := mysqlReservedWords[strings.ToUpper(name)] return ok } -func (db *mysql) Quoter() schemas.Quoter { - return schemas.Quoter{"`", "`"} -} - func (db *mysql) SupportEngine() bool { return true } @@ -458,6 +457,23 @@ func (db *mysql) GetTables(ctx context.Context) ([]*schemas.Table, error) { return tables, nil } +func (db *mysql) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = mysqlQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = mysqlQuoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = mysqlQuoter + } +} + func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{db.uri.DBName, tableName} s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" @@ -538,7 +554,7 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch if len(pkList) > 1 { sql += "PRIMARY KEY ( " - sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(","))) + sql += quoter.Join(pkList, ",") sql += " ), " } diff --git a/dialects/oracle.go b/dialects/oracle.go index 4a8162ac..d54ca80c 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -498,6 +498,8 @@ var ( "YEAR": true, "ZONE": true, } + + oracleQuoter = schemas.Quoter{'[', ']', schemas.AlwaysReserve} ) type oracle struct { @@ -505,6 +507,7 @@ type oracle struct { } func (db *oracle) Init(d *core.DB, uri *URI) error { + db.quoter = oracleQuoter return db.Base.Init(d, db, uri) } @@ -549,14 +552,10 @@ func (db *oracle) SupportInsertMany() bool { } func (db *oracle) IsReserved(name string) bool { - _, ok := oracleReservedWords[name] + _, ok := oracleReservedWords[strings.ToUpper(name)] return ok } -func (db *oracle) Quoter() schemas.Quoter { - return schemas.Quoter{"\"", "\""} -} - func (db *oracle) SupportEngine() bool { return false } @@ -601,7 +600,7 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c if len(pkList) > 0 { sql += "PRIMARY KEY ( " - sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(","))) + sql += quoter.Join(pkList, ",") sql += " ), " } @@ -620,6 +619,23 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c return sql } +func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = oracleQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = oracleQuoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = oracleQuoter + } +} + func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { args := []interface{}{tableName, idxName} return `SELECT INDEX_NAME FROM USER_INDEXES ` + diff --git a/dialects/postgres.go b/dialects/postgres.go index f92202cd..0049cee6 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -766,6 +766,8 @@ var ( "YES": true, "ZONE": true, } + + postgresQuoter = schemas.Quoter{'"', '"', schemas.AlwaysReserve} ) const postgresPublicSchema = "public" @@ -775,6 +777,7 @@ type postgres struct { } func (db *postgres) Init(d *core.DB, uri *URI) error { + db.quoter = postgresQuoter err := db.Base.Init(d, db, uri) if err != nil { return err @@ -785,6 +788,35 @@ func (db *postgres) Init(d *core.DB, uri *URI) error { return nil } +func (db *postgres) needQuote(name string) bool { + if db.IsReserved(name) { + return true + } + for _, c := range name { + if c >= 'A' && c <= 'Z' { + return true + } + } + return false +} + +func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = postgresQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = postgresQuoter + q.IsReserved = db.needQuote + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = postgresQuoter + } +} + func (db *postgres) DefaultSchema() string { return postgresPublicSchema } @@ -857,14 +889,10 @@ func (db *postgres) SupportInsertMany() bool { } func (db *postgres) IsReserved(name string) bool { - _, ok := postgresReservedWords[name] + _, ok := postgresReservedWords[strings.ToUpper(name)] return ok } -func (db *postgres) Quoter() schemas.Quoter { - return schemas.Quoter{`"`, `"`} -} - func (db *postgres) AutoIncrStr() string { return "" } diff --git a/dialects/quote.go b/dialects/quote.go new file mode 100644 index 00000000..da4e0dd6 --- /dev/null +++ b/dialects/quote.go @@ -0,0 +1,15 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +// QuotePolicy describes quote handle policy +type QuotePolicy int + +// All QuotePolicies +const ( + QuotePolicyAlways QuotePolicy = iota + QuotePolicyNone + QuotePolicyReserved +) diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 39138b13..4af9b27e 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -143,6 +143,8 @@ var ( "WITH": true, "WITHOUT": true, } + + sqlite3Quoter = schemas.Quoter{'`', '`', schemas.AlwaysReserve} ) type sqlite3 struct { @@ -150,9 +152,27 @@ type sqlite3 struct { } func (db *sqlite3) Init(d *core.DB, uri *URI) error { + db.quoter = sqlite3Quoter return db.Base.Init(d, db, uri) } +func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = sqlite3Quoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = sqlite3Quoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = sqlite3Quoter + } +} + func (db *sqlite3) SQLType(c *schemas.Column) string { switch t := c.SQLType.Name; t { case schemas.Bool: @@ -196,14 +216,10 @@ func (db *sqlite3) SupportInsertMany() bool { } func (db *sqlite3) IsReserved(name string) bool { - _, ok := sqlite3ReservedWords[name] + _, ok := sqlite3ReservedWords[strings.ToUpper(name)] return ok } -func (db *sqlite3) Quoter() schemas.Quoter { - return schemas.Quoter{"`", "`"} -} - func (db *sqlite3) AutoIncrStr() string { return "AUTOINCREMENT" } @@ -250,18 +266,24 @@ func (db *sqlite3) ForUpdateSQL(query string) string { } func (db *sqlite3) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { - args := []interface{}{tableName} - query := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" - - rows, err := db.DB().QueryContext(ctx, query, args...) + query := "SELECT * FROM " + tableName + " LIMIT 0" + rows, err := db.DB().QueryContext(ctx, query) if err != nil { return false, err } defer rows.Close() - if rows.Next() { - return true, nil + cols, err := rows.Columns() + if err != nil { + return false, err } + + for _, col := range cols { + if strings.EqualFold(col, colName) { + return true, nil + } + } + return false, nil } diff --git a/engine.go b/engine.go index cc8a74a0..c657cd1f 100644 --- a/engine.go +++ b/engine.go @@ -54,6 +54,10 @@ func (engine *Engine) GetCacher(tableName string) caches.Cacher { return engine.cacherMgr.GetCacher(tableName) } +func (engine *Engine) SetQuotePolicy(quotePolicy dialects.QuotePolicy) { + engine.dialect.SetQuotePolicy(quotePolicy) +} + // BufferSize sets buffer size for iterate func (engine *Engine) BufferSize(size int) *Session { session := engine.NewSession() diff --git a/engine_group.go b/engine_group.go index 8177697e..38b64ca2 100644 --- a/engine_group.go +++ b/engine_group.go @@ -9,6 +9,7 @@ import ( "time" "xorm.io/xorm/caches" + "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" ) @@ -180,6 +181,13 @@ func (eg *EngineGroup) SetPolicy(policy GroupPolicy) *EngineGroup { return eg } +func (eg *EngineGroup) SetQuotePolicy(quotePolicy dialects.QuotePolicy) { + eg.Engine.SetQuotePolicy(quotePolicy) + for i := 0; i < len(eg.slaves); i++ { + eg.slaves[i].SetQuotePolicy(quotePolicy) + } +} + // SetTableMapper set the table name mapping rule func (eg *EngineGroup) SetTableMapper(mapper names.Mapper) { eg.Engine.SetTableMapper(mapper) diff --git a/interface.go b/interface.go index 8d2402f0..67b8d4b1 100644 --- a/interface.go +++ b/interface.go @@ -104,6 +104,7 @@ type EngineInterface interface { SetMapper(names.Mapper) SetMaxOpenConns(int) SetMaxIdleConns(int) + SetQuotePolicy(dialects.QuotePolicy) SetSchema(string) SetTableMapper(names.Mapper) SetTZDatabase(tz *time.Location) diff --git a/schemas/quote.go b/schemas/quote.go index 736b774a..10436270 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -8,14 +8,29 @@ import ( "strings" ) -// Quoter represents two quote characters -type Quoter [2]string +// Quoter represents a quoter to the SQL table name and column name +type Quoter struct { + Prefix byte + Suffix byte + IsReserved func(string) bool +} -// CommonQuoter represetns a common quoter -var CommonQuoter = Quoter{"`", "`"} +var ( + // AlwaysFalseReverse always think it's not a reverse word + AlwaysNoReserve = func(string) bool { return false } + + // AlwaysReverse always reverse the word + AlwaysReserve = func(string) bool { return true } + + // CommanQuoteMark represnets the common quote mark + CommanQuoteMark byte = '`' + + // CommonQuoter represetns a common quoter + CommonQuoter = Quoter{CommanQuoteMark, CommanQuoteMark, AlwaysReserve} +) func (q Quoter) IsEmpty() bool { - return q[0] == "" && q[1] == "" + return q.Prefix == 0 && q.Suffix == 0 } func (q Quoter) Quote(s string) string { @@ -24,42 +39,6 @@ func (q Quoter) Quote(s string) string { return buf.String() } -func (q Quoter) Replace(sql string, newQuoter Quoter) string { - if q.IsEmpty() { - return sql - } - - if newQuoter.IsEmpty() { - var buf strings.Builder - for i := 0; i < len(sql); i = i + 1 { - if sql[i] != q[0][0] && sql[i] != q[1][0] { - _ = buf.WriteByte(sql[i]) - } - } - return buf.String() - } - - prefix, suffix := newQuoter[0][0], newQuoter[1][0] - var buf strings.Builder - for i, cnt := 0, 0; i < len(sql); i = i + 1 { - if cnt == 0 && sql[i] == q[0][0] { - _ = buf.WriteByte(prefix) - cnt = 1 - } else if cnt == 1 && sql[i] == q[1][0] { - _ = buf.WriteByte(suffix) - cnt = 0 - } else { - _ = buf.WriteByte(sql[i]) - } - } - return buf.String() -} - -func (q Quoter) ReverseQuote(s string) string { - reverseQuoter := Quoter{q[1], q[0]} - return reverseQuoter.Quote(s) -} - // Trim removes quotes from s func (q Quoter) Trim(s string) string { if len(s) < 2 { @@ -69,10 +48,10 @@ func (q Quoter) Trim(s string) string { var buf strings.Builder for i := 0; i < len(s); i++ { switch { - case i == 0 && s[i:i+1] == q[0]: - case i == len(s)-1 && s[i:i+1] == q[1]: - case s[i:i+1] == q[1] && s[i+1] == '.': - case s[i:i+1] == q[0] && s[i-1] == '.': + case i == 0 && s[i] == q.Prefix: + case i == len(s)-1 && s[i] == q.Suffix: + case s[i] == q.Suffix && s[i+1] == '.': + case s[i] == q.Prefix && s[i-1] == '.': default: buf.WriteByte(s[i]) } @@ -81,31 +60,8 @@ func (q Quoter) Trim(s string) string { } func (q Quoter) Join(a []string, sep string) string { - switch len(a) { - case 0: - return "" - case 1: - return a[0] - } - n := len(sep) * (len(a) - 1) - for i := 0; i < len(a); i++ { - n += len(a[i]) - } - var b strings.Builder - b.Grow(n) - for i, s := range a { - if i > 0 { - b.WriteString(sep) - } - if q[0] != "" && s != "*" { - b.WriteString(q[0]) - } - b.WriteString(strings.TrimSpace(s)) - if q[1] != "" && s != "*" { - b.WriteString(q[1]) - } - } + q.JoinWrite(&b, a, sep) return b.String() } @@ -126,23 +82,113 @@ func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error { return err } } - if q[0] != "" && s != "*" && s[0] != '`' { - if _, err := b.WriteString(q[0]); err != nil { - return err - } - } - if _, err := b.WriteString(strings.TrimSpace(s)); err != nil { - return err - } - if q[1] != "" && s != "*" && s[0] != '`' { - if _, err := b.WriteString(q[1]); err != nil { - return err - } + if s != "*" { + q.QuoteTo(b, strings.TrimSpace(s)) } } return nil } +func findWord(v string, start int) int { + for j := start; j < len(v); j++ { + switch v[j] { + case '.', ' ': + return j + } + } + return len(v) +} + +func findStart(value string, start int) int { + if value[start] == '.' { + return start + 1 + } + if value[start] != ' ' { + return start + } + + var k int + for j := start; j < len(value); j++ { + if value[j] != ' ' { + k = j + break + } + } + if k-1 == len(value) { + return len(value) + } + if (value[k] == 'A' || value[k] == 'a') && (value[k+1] == 'S' || value[k+1] == 's') { + k = k + 2 + } + + for j := k; j < len(value); j++ { + if value[j] != ' ' { + return j + } + } + return len(value) +} + +func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error { + var realWord = word + if (word[0] == CommanQuoteMark && word[len(word)-1] == CommanQuoteMark) || + (word[0] == q.Prefix && word[len(word)-1] == q.Suffix) { + realWord = word[1 : len(word)-1] + } + + if q.IsEmpty() { + _, err := buf.WriteString(realWord) + return err + } + + isReserved := q.IsReserved(realWord) + if isReserved { + if err := buf.WriteByte(q.Prefix); err != nil { + return err + } + } + if _, err := buf.WriteString(realWord); err != nil { + return err + } + if isReserved { + return buf.WriteByte(q.Suffix) + } + + return nil +} + +// QuoteTo quotes the table or column names. i.e. if the quotes are [ and ] +// name -> [name] +// `name` -> [name] +// [name] -> [name] +// schema.name -> [schema].[name] +// `schema`.`name` -> [schema].[name] +// `schema`.name -> [schema].[name] +// schema.`name` -> [schema].[name] +// [schema].name -> [schema].[name] +// schema.[name] -> [schema].[name] +// name AS a -> [name] AS a +// schema.name AS a -> [schema].[name] AS a +func (q Quoter) QuoteTo(buf *strings.Builder, value string) error { + var i int + for i < len(value) { + start := findStart(value, i) + if start > i { + if _, err := buf.WriteString(value[i:start]); err != nil { + return err + } + } + var nextEnd = findWord(value, start) + + if err := q.quoteWordTo(buf, value[start:nextEnd]); err != nil { + return err + } + i = nextEnd + } + return nil +} + +// Strings quotes a slice of string func (q Quoter) Strings(s []string) []string { var res = make([]string, 0, len(s)) for _, a := range s { @@ -150,64 +196,3 @@ func (q Quoter) Strings(s []string) []string { } return res } - -func (q Quoter) QuoteTo(buf *strings.Builder, value string) { - if q.IsEmpty() { - buf.WriteString(value) - return - } - - prefix, suffix := q[0][0], q[1][0] - lastCh := 0 // 0 prefix, 1 char, 2 suffix - i := 0 - for i < len(value) { - // start of a token; might be already quoted - if value[i] == '.' { - _ = buf.WriteByte('.') - lastCh = 1 - i++ - } else if value[i] == prefix || value[i] == '`' { - // Has quotes; skip/normalize `name` to prefix+name+sufix - var ch byte - if value[i] == prefix { - ch = suffix - } else { - ch = '`' - } - _ = buf.WriteByte(prefix) - i++ - lastCh = 0 - for ; i < len(value) && value[i] != ch && value[i] != ' '; i++ { - _ = buf.WriteByte(value[i]) - lastCh = 1 - } - _ = buf.WriteByte(suffix) - lastCh = 2 - i++ - } else if value[i] == ' ' { - if lastCh != 2 { - _ = buf.WriteByte(suffix) - lastCh = 2 - } - - // a AS b or a b - for ; i < len(value); i++ { - if value[i] != ' ' && value[i-1] == ' ' && (len(value) > i+1 && !strings.EqualFold(value[i:i+2], "AS")) { - break - } - - _ = buf.WriteByte(value[i]) - lastCh = 1 - } - } else { - // Requires quotes - _ = buf.WriteByte(prefix) - for ; i < len(value) && value[i] != '.' && value[i] != ' '; i++ { - _ = buf.WriteByte(value[i]) - lastCh = 1 - } - _ = buf.WriteByte(suffix) - lastCh = 2 - } - } -} diff --git a/schemas/quote_test.go b/schemas/quote_test.go index 24739377..c7990f92 100644 --- a/schemas/quote_test.go +++ b/schemas/quote_test.go @@ -11,54 +11,125 @@ import ( "github.com/stretchr/testify/assert" ) -func TestQuoteTo(t *testing.T) { - var quoter = Quoter{"[", "]"} +func TestAlwaysQuoteTo(t *testing.T) { + var ( + quoter = Quoter{'[', ']', AlwaysReserve} + kases = []struct { + expected string + value string + }{ + {"[mytable]", "mytable"}, + {"[mytable]", "`mytable`"}, + {"[mytable]", `[mytable]`}, + {`["mytable"]`, `"mytable"`}, + {"[myschema].[mytable]", "myschema.mytable"}, + {"[myschema].[mytable]", "`myschema`.mytable"}, + {"[myschema].[mytable]", "myschema.`mytable`"}, + {"[myschema].[mytable]", "`myschema`.`mytable`"}, + {"[myschema].[mytable]", `[myschema].mytable`}, + {"[myschema].[mytable]", `myschema.[mytable]`}, + {"[myschema].[mytable]", `[myschema].[mytable]`}, + {`["myschema].[mytable"]`, `"myschema.mytable"`}, + {"[message_user] AS [sender]", "`message_user` AS `sender`"}, + {"[myschema].[mytable] AS [table]", "myschema.mytable AS table"}, + } + ) - test := func(t *testing.T, expected string, value string) { - buf := &strings.Builder{} - quoter.QuoteTo(buf, value) - assert.EqualValues(t, expected, buf.String()) + for _, v := range kases { + t.Run(v.value, func(t *testing.T) { + buf := &strings.Builder{} + quoter.QuoteTo(buf, v.value) + assert.EqualValues(t, v.expected, buf.String()) + }) } +} - test(t, "[mytable]", "mytable") - test(t, "[mytable]", "`mytable`") - test(t, "[mytable]", `[mytable]`) +func TestReversedQuoteTo(t *testing.T) { + var ( + quoter = Quoter{'[', ']', func(s string) bool { + if s == "mytable" { + return true + } + return false + }} + kases = []struct { + expected string + value string + }{ + {"[mytable]", "mytable"}, + {"[mytable]", "`mytable`"}, + {"[mytable]", `[mytable]`}, + {`"mytable"`, `"mytable"`}, + {"myschema.[mytable]", "myschema.mytable"}, + {"myschema.[mytable]", "`myschema`.mytable"}, + {"myschema.[mytable]", "myschema.`mytable`"}, + {"myschema.[mytable]", "`myschema`.`mytable`"}, + {"myschema.[mytable]", `[myschema].mytable`}, + {"myschema.[mytable]", `myschema.[mytable]`}, + {"myschema.[mytable]", `[myschema].[mytable]`}, + {`"myschema.mytable"`, `"myschema.mytable"`}, + {"message_user AS sender", "`message_user` AS `sender`"}, + {"myschema.[mytable] AS table", "myschema.mytable AS table"}, + } + ) - test(t, `["mytable"]`, `"mytable"`) + for _, v := range kases { + t.Run(v.value, func(t *testing.T) { + buf := &strings.Builder{} + quoter.QuoteTo(buf, v.value) + assert.EqualValues(t, v.expected, buf.String()) + }) + } +} - test(t, "[myschema].[mytable]", "myschema.mytable") - test(t, "[myschema].[mytable]", "`myschema`.mytable") - test(t, "[myschema].[mytable]", "myschema.`mytable`") - test(t, "[myschema].[mytable]", "`myschema`.`mytable`") - test(t, "[myschema].[mytable]", `[myschema].mytable`) - test(t, "[myschema].[mytable]", `myschema.[mytable]`) - test(t, "[myschema].[mytable]", `[myschema].[mytable]`) +func TestNoQuoteTo(t *testing.T) { + var ( + quoter = Quoter{'[', ']', AlwaysNoReserve} + kases = []struct { + expected string + value string + }{ + {"mytable", "mytable"}, + {"mytable", "`mytable`"}, + {"mytable", `[mytable]`}, + {`"mytable"`, `"mytable"`}, + {"myschema.mytable", "myschema.mytable"}, + {"myschema.mytable", "`myschema`.mytable"}, + {"myschema.mytable", "myschema.`mytable`"}, + {"myschema.mytable", "`myschema`.`mytable`"}, + {"myschema.mytable", `[myschema].mytable`}, + {"myschema.mytable", `myschema.[mytable]`}, + {"myschema.mytable", `[myschema].[mytable]`}, + {`"myschema.mytable"`, `"myschema.mytable"`}, + {"message_user AS sender", "`message_user` AS `sender`"}, + {"myschema.mytable AS table", "myschema.mytable AS table"}, + } + ) - test(t, `["myschema].[mytable"]`, `"myschema.mytable"`) - - test(t, "[message_user] AS [sender]", "`message_user` AS `sender`") - - assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ",")) - - buf := &strings.Builder{} - quoter = Quoter{"", ""} - quoter.QuoteTo(buf, "noquote") - assert.EqualValues(t, "noquote", buf.String()) + for _, v := range kases { + t.Run(v.value, func(t *testing.T) { + buf := &strings.Builder{} + quoter.QuoteTo(buf, v.value) + assert.EqualValues(t, v.expected, buf.String()) + }) + } } func TestJoin(t *testing.T) { cols := []string{"f1", "f2", "f3"} - quoter := Quoter{"[", "]"} + quoter := Quoter{'[', ']', AlwaysReserve} + + assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ",")) assert.EqualValues(t, "[f1], [f2], [f3]", quoter.Join(cols, ", ")) - quoter = Quoter{"", ""} + quoter.IsReserved = AlwaysNoReserve assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", ")) } func TestStrings(t *testing.T) { cols := []string{"f1", "f2", "t3.f3"} - quoter := Quoter{"[", "]"} + quoter := Quoter{'[', ']', AlwaysReserve} quotedCols := quoter.Strings(cols) assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols) @@ -72,6 +143,6 @@ func TestTrim(t *testing.T) { for src, dst := range kases { assert.EqualValues(t, src, CommonQuoter.Trim(src)) - assert.EqualValues(t, dst, Quoter{"[", "]"}.Trim(src)) + assert.EqualValues(t, dst, Quoter{'[', ']', AlwaysReserve}.Trim(src)) } } diff --git a/types_null_test.go b/types_null_test.go index 1d5d005e..665849ca 100644 --- a/types_null_test.go +++ b/types_null_test.go @@ -22,7 +22,7 @@ type NullType struct { Age sql.NullInt64 Height sql.NullFloat64 IsMan sql.NullBool `xorm:"null"` - CustomStruct CustomStruct `xorm:"valchar(64) null"` + CustomStruct CustomStruct `xorm:"varchar(64) null"` } type CustomStruct struct { @@ -58,14 +58,12 @@ func (m CustomStruct) Value() (driver.Value, error) { func TestCreateNullStructTable(t *testing.T) { assert.NoError(t, prepareEngine()) - err := testEngine.CreateTables(new(NullType)) assert.NoError(t, err) } func TestDropNullStructTable(t *testing.T) { assert.NoError(t, prepareEngine()) - err := testEngine.DropTables(new(NullType)) assert.NoError(t, err) } @@ -78,7 +76,7 @@ func TestNullStructInsert(t *testing.T) { item := new(NullType) _, err := testEngine.Insert(item) assert.NoError(t, err) - assert.EqualValues(t, item.Id, 1) + assert.EqualValues(t, 1, item.Id) } if true { @@ -90,12 +88,11 @@ func TestNullStructInsert(t *testing.T) { } _, err := testEngine.Insert(&item) assert.NoError(t, err) - assert.EqualValues(t, item.Id, 2) + assert.EqualValues(t, 2, item.Id) } if true { items := []NullType{} - for i := 0; i < 5; i++ { item := NullType{ Name: sql.NullString{String: "haolei_" + fmt.Sprint(i+1), Valid: true}, @@ -152,7 +149,7 @@ func TestNullStructUpdate(t *testing.T) { affected, err := testEngine.ID(2).Cols("age", "height", "is_man").Update(item) assert.NoError(t, err) - assert.EqualValues(t, affected, 1) + assert.EqualValues(t, 1, affected) } if true { // 测试In update @@ -160,7 +157,7 @@ func TestNullStructUpdate(t *testing.T) { item.Age = sql.NullInt64{Int64: 23, Valid: true} affected, err := testEngine.In("id", 3, 4).Cols("age", "height", "is_man").Update(item) assert.NoError(t, err) - assert.EqualValues(t, affected, 2) + assert.EqualValues(t, 2, affected) } if true { // 测试where @@ -183,9 +180,7 @@ func TestNullStructUpdate(t *testing.T) { _, err := testEngine.AllCols().ID(6).Update(item) assert.NoError(t, err) - fmt.Println(item) } - } func TestNullStructFind(t *testing.T) { @@ -274,9 +269,8 @@ func TestNullStructCount(t *testing.T) { if true { item := new(NullType) - total, err := testEngine.Where("age IS NOT NULL").Count(item) + _, err := testEngine.Where("age IS NOT NULL").Count(item) assert.NoError(t, err) - fmt.Println(total) } } @@ -292,7 +286,6 @@ func TestNullStructRows(t *testing.T) { for rows.Next() { err = rows.Scan(item) assert.NoError(t, err) - fmt.Println(item) } } diff --git a/xorm_test.go b/xorm_test.go index c1f38757..c9a1a74b 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -18,6 +18,7 @@ import ( _ "github.com/mattn/go-sqlite3" _ "github.com/ziutek/mymysql/godrv" "xorm.io/xorm/caches" + "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" "xorm.io/xorm/schemas" @@ -38,6 +39,7 @@ var ( schema = flag.String("schema", "", "specify the schema") ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb") ingoreUpdateLimit = flag.Bool("ignore_update_limit", false, "ignore update limit if implementation difference, only for cockroach") + quotePolicyStr = flag.String("quote", "always", "quote could be always, none, reversed") tableMapper names.Mapper colMapper names.Mapper ) @@ -131,6 +133,14 @@ func createEngine(dbType, connStr string) error { testEngine.SetMapper(names.LintGonicMapper) } } + + if *quotePolicyStr == "none" { + testEngine.SetQuotePolicy(dialects.QuotePolicyNone) + } else if *quotePolicyStr == "reserved" { + testEngine.SetQuotePolicy(dialects.QuotePolicyReserved) + } else { + testEngine.SetQuotePolicy(dialects.QuotePolicyAlways) + } } tableMapper = testEngine.GetTableMapper()