From 808ff05591ad0fc0159714975f37bfaee2bf60bf Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 21 Feb 2020 11:28:26 +0800 Subject: [PATCH 1/4] Fix join table name quote bug --- session_find_test.go | 10 +++++++++- statement.go | 5 +++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/session_find_test.go b/session_find_test.go index 991fadf2..71116472 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -790,8 +790,12 @@ func TestFindJoin(t *testing.T) { DeviceId int64 } + type Order struct { + Id int64 + } + assert.NoError(t, prepareEngine()) - assertSync(t, new(SceneItem), new(DeviceUserPrivrels)) + assertSync(t, new(SceneItem), new(DeviceUserPrivrels), new(Order)) var scenes []SceneItem err := testEngine.Join("LEFT OUTER", "device_user_privrels", "device_user_privrels.device_id=scene_item.device_id"). @@ -802,6 +806,10 @@ func TestFindJoin(t *testing.T) { err = testEngine.Join("LEFT OUTER", new(DeviceUserPrivrels), "device_user_privrels.device_id=scene_item.device_id"). Where("scene_item.type=?", 3).Or("device_user_privrels.user_id=?", 339).Find(&scenes) assert.NoError(t, err) + + scenes = make([]SceneItem, 0) + err = testEngine.Join("INNER", "order", "`scene_item`.device_id=`order`.id").Find(&scenes) + assert.NoError(t, err) } func TestJoinFindLimit(t *testing.T) { diff --git a/statement.go b/statement.go index 87cab7cc..671a699e 100644 --- a/statement.go +++ b/statement.go @@ -765,6 +765,11 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition statement.joinArgs = append(statement.joinArgs, subQueryArgs...) default: tbName := statement.Engine.TableName(tablename, true) + if !isSubQuery(tbName) { + var buf strings.Builder + statement.Engine.QuoteTo(&buf, tbName) + tbName = buf.String() + } fmt.Fprintf(&buf, "%s ON %v", tbName, condition) } -- 2.40.1 From 381a45d38af5f796a40784ac7ab031c24644ce2a Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 24 Feb 2020 21:36:09 +0800 Subject: [PATCH 2/4] Add new Quoter object to handle quote --- dialects/dialect.go | 52 ++++------- dialects/filter.go | 45 +--------- dialects/mssql.go | 8 +- dialects/mysql.go | 12 +-- dialects/oracle.go | 13 ++- dialects/postgres.go | 10 +-- dialects/sqlite3.go | 9 +- engine.go | 54 ++---------- schemas/quote.go | 110 ++++++++++++++++++++++++ engine_test.go => schemas/quote_test.go | 10 ++- session_cond_test.go | 14 +-- session_find_test.go | 8 +- session_insert_test.go | 20 ++--- session_tx_test.go | 4 +- session_update_test.go | 8 +- statement.go | 2 +- statement_exprparam.go | 3 +- statement_quote.go | 19 ---- 18 files changed, 198 insertions(+), 203 deletions(-) create mode 100644 schemas/quote.go rename engine_test.go => schemas/quote_test.go (84%) delete mode 100644 statement_quote.go diff --git a/dialects/dialect.go b/dialects/dialect.go index ac9ff301..3ed867f4 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -45,11 +45,8 @@ type Dialect interface { DataSourceName() string IsReserved(string) bool - Quote(string) string + Quoter() schemas.Quoter - AndStr() string - OrStr() string - EqStr() string RollBackStr() string AutoIncrStr() string @@ -101,7 +98,7 @@ type Base struct { // String generate column description string according dialect func String(d Dialect, col *schemas.Column) string { - sql := d.Quote(col.Name) + " " + sql := d.Quoter().Quote(col.Name) + " " sql += d.SQLType(col) + " " @@ -129,7 +126,7 @@ func String(d Dialect, col *schemas.Column) string { // StringNoPk generate column description string according dialect without primary keys func StringNoPk(d Dialect, col *schemas.Column) string { - sql := d.Quote(col.Name) + " " + sql := d.Quoter().Quote(col.Name) + " " sql += d.SQLType(col) + " " @@ -186,18 +183,6 @@ func (b *Base) DataSourceName() string { return b.dataSourceName } -func (b *Base) AndStr() string { - return "AND" -} - -func (b *Base) OrStr() string { - return "OR" -} - -func (b *Base) EqStr() string { - return "=" -} - func (db *Base) RollBackStr() string { return "ROLL BACK" } @@ -207,7 +192,7 @@ func (db *Base) SupportDropIfExists() bool { } func (db *Base) DropTableSQL(tableName string) string { - quote := db.dialect.Quote + quote := db.dialect.Quoter().Quote return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)) } @@ -226,14 +211,15 @@ func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) { } func (db *Base) IsColumnExist(tableName, colName string) (bool, error) { + quote := db.dialect.Quoter().Quote query := fmt.Sprintf( "SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?", - db.dialect.Quote("COLUMN_NAME"), - db.dialect.Quote("INFORMATION_SCHEMA"), - db.dialect.Quote("COLUMNS"), - db.dialect.Quote("TABLE_SCHEMA"), - db.dialect.Quote("TABLE_NAME"), - db.dialect.Quote("COLUMN_NAME"), + quote("COLUMN_NAME"), + quote("INFORMATION_SCHEMA"), + quote("COLUMNS"), + quote("TABLE_SCHEMA"), + quote("TABLE_NAME"), + quote("COLUMN_NAME"), ) return db.HasRecords(query, db.uri.DBName, tableName, colName) } @@ -263,8 +249,7 @@ func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, cha }*/ func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { - quotes := db.dialect.Quote("") - quote := db.dialect.Quote + quoter := db.dialect.Quoter() var unique string var idxName string if index.Type == schemas.UniqueType { @@ -272,12 +257,12 @@ func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { } idxName = index.XName(tableName) return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique, - quote(idxName), quote(tableName), - quote(strings.Join(index.Cols, fmt.Sprintf("%c,%c", quotes[1], quotes[0])))) + quoter.Quote(idxName), quoter.Quote(tableName), + quoter.Quote(strings.Join(index.Cols, quoter.ReverseQuote(",")))) } func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { - quote := db.dialect.Quote + quote := db.dialect.Quoter().Quote var name string if index.IsRegular { name = index.XName(tableName) @@ -298,11 +283,10 @@ func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, char tableName = table.Name } - sql += b.dialect.Quote(tableName) + quoter := b.dialect.Quoter() + sql += quoter.Quote(tableName) sql += " (" - quotes := b.dialect.Quote("") - if len(table.ColumnsSeq()) > 0 { pkList := table.PrimaryKeys @@ -322,7 +306,7 @@ func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, char if len(pkList) > 1 { sql += "PRIMARY KEY ( " - sql += b.dialect.Quote(strings.Join(pkList, fmt.Sprintf("%c,%c", quotes[1], quotes[0]))) + sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(","))) sql += " ), " } diff --git a/dialects/filter.go b/dialects/filter.go index f7bad1a9..15044e1f 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -21,49 +21,8 @@ type QuoteFilter struct { } func (s *QuoteFilter) Do(sql string, dialect Dialect, table *schemas.Table) string { - dummy := dialect.Quote("") - if len(dummy) != 2 { - return sql - } - prefix, suffix := dummy[0], dummy[1] - raw := []byte(sql) - for i, cnt := 0, 0; i < len(raw); i = i + 1 { - if raw[i] == '`' { - if cnt%2 == 0 { - raw[i] = prefix - } else { - raw[i] = suffix - } - cnt++ - } - } - return string(raw) -} - -// IdFilter filter SQL replace (id) to primary key column name -type IdFilter struct { -} - -type Quoter struct { - dialect Dialect -} - -func NewQuoter(dialect Dialect) *Quoter { - return &Quoter{dialect} -} - -func (q *Quoter) Quote(content string) string { - return q.dialect.Quote(content) -} - -func (i *IdFilter) Do(sql string, dialect Dialect, table *schemas.Table) string { - quoter := NewQuoter(dialect) - if table != nil && len(table.PrimaryKeys) == 1 { - sql = strings.Replace(sql, " `(id)` ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) - sql = strings.Replace(sql, " "+quoter.Quote("(id)")+" ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) - return strings.Replace(sql, " (id) ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) - } - return sql + quoter := dialect.Quoter() + return quoter.Quote(sql) } // SeqFilter filter SQL replace ?, ? ... to $1, $2 ... diff --git a/dialects/mssql.go b/dialects/mssql.go index 5a91e9ef..d473d975 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -286,8 +286,8 @@ func (db *mssql) IsReserved(name string) bool { return ok } -func (db *mssql) Quote(name string) string { - return "[" + name + "]" +func (db *mssql) Quoter() schemas.Quoter { + return schemas.Quoter{"[", "]"} } func (db *mssql) SupportEngine() bool { @@ -503,7 +503,7 @@ func (db *mssql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch sql = "IF NOT EXISTS (SELECT [name] FROM sys.tables WHERE [name] = '" + tableName + "' ) CREATE TABLE " - sql += db.Quote(tableName) + " (" + sql += db.Quoter().Quote(tableName) + " (" pkList := table.PrimaryKeys @@ -534,7 +534,7 @@ func (db *mssql) ForUpdateSQL(query string) string { } func (db *mssql) Filters() []Filter { - return []Filter{&IdFilter{}, &QuoteFilter{}} + return []Filter{&QuoteFilter{}} } type odbcDriver struct { diff --git a/dialects/mysql.go b/dialects/mysql.go index 82d7b5f4..32dc25b7 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -275,8 +275,8 @@ func (db *mysql) IsReserved(name string) bool { return ok } -func (db *mysql) Quote(name string) string { - return "`" + name + "`" +func (db *mysql) Quoter() schemas.Quoter { + return schemas.Quoter{"`", "`"} } func (db *mysql) SupportEngine() bool { @@ -512,9 +512,9 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch tableName = table.Name } - quotes := db.Quote("") + quoter := db.Quoter() - sql += db.Quote(tableName) + sql += quoter.Quote(tableName) sql += " (" if len(table.ColumnsSeq()) > 0 { @@ -536,7 +536,7 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch if len(pkList) > 1 { sql += "PRIMARY KEY ( " - sql += db.Quote(strings.Join(pkList, fmt.Sprintf("%c,%c", quotes[1], quotes[0]))) + sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(","))) sql += " ), " } @@ -562,7 +562,7 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch } func (db *mysql) Filters() []Filter { - return []Filter{&IdFilter{}} + return []Filter{} } type mymysqlDriver struct { diff --git a/dialects/oracle.go b/dialects/oracle.go index 2f903331..bf9ee2af 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -552,8 +552,8 @@ func (db *oracle) IsReserved(name string) bool { return ok } -func (db *oracle) Quote(name string) string { - return "[" + name + "]" +func (db *oracle) Quoter() schemas.Quoter { + return schemas.Quoter{"[", "]"} } func (db *oracle) SupportEngine() bool { @@ -582,7 +582,8 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c tableName = table.Name } - sql += db.Quote(tableName) + " (" + quoter := db.Quoter() + sql += quoter.Quote(tableName) + " (" pkList := table.PrimaryKeys @@ -597,11 +598,9 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c sql += ", " } - quotes := db.Quote("") - if len(pkList) > 0 { sql += "PRIMARY KEY ( " - sql += db.Quote(strings.Join(pkList, fmt.Sprintf("%c,%c", quotes[1], quotes[0]))) + sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(","))) sql += " ), " } @@ -849,7 +848,7 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*schemas.Index, error } func (db *oracle) Filters() []Filter { - return []Filter{&QuoteFilter{}, &SeqFilter{Prefix: ":", Start: 1}, &IdFilter{}} + return []Filter{&QuoteFilter{}, &SeqFilter{Prefix: ":", Start: 1}} } type goracleDriver struct { diff --git a/dialects/postgres.go b/dialects/postgres.go index e4f4b89b..f161fdfa 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -859,9 +859,8 @@ func (db *postgres) IsReserved(name string) bool { return ok } -func (db *postgres) Quote(name string) string { - name = strings.Replace(name, ".", `"."`, -1) - return "\"" + name + "\"" +func (db *postgres) Quoter() schemas.Quoter { + return schemas.Quoter{`"`, `"`} } func (db *postgres) AutoIncrStr() string { @@ -911,7 +910,6 @@ func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) strin } func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string { - quote := db.Quote idxName := index.Name tableParts := strings.Split(strings.Replace(tableName, `"`, "", -1), ".") @@ -928,7 +926,7 @@ func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string if db.uri.Schema != "" { idxName = db.uri.Schema + "." + idxName } - return fmt.Sprintf("DROP INDEX %v", quote(idxName)) + return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { @@ -1161,7 +1159,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*schemas.Index, err } func (db *postgres) Filters() []Filter { - return []Filter{&IdFilter{}, &QuoteFilter{}, &SeqFilter{Prefix: "$", Start: 1}} + return []Filter{&QuoteFilter{}, &SeqFilter{Prefix: "$", Start: 1}} } type pqDriver struct { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index b7ff2147..0fd80b73 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -199,8 +199,8 @@ func (db *sqlite3) IsReserved(name string) bool { return ok } -func (db *sqlite3) Quote(name string) string { - return "`" + name + "`" +func (db *sqlite3) Quoter() schemas.Quoter { + return schemas.Quoter{"`", "`"} } func (db *sqlite3) AutoIncrStr() string { @@ -231,7 +231,6 @@ func (db *sqlite3) TableCheckSQL(tableName string) (string, []interface{}) { func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { // var unique string - quote := db.Quote idxName := index.Name if !strings.HasPrefix(idxName, "UQE_") && @@ -242,7 +241,7 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) } } - return fmt.Sprintf("DROP INDEX %v", quote(idxName)) + return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } func (db *sqlite3) ForUpdateSQL(query string) string { @@ -478,7 +477,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*schemas.Index, erro } func (db *sqlite3) Filters() []Filter { - return []Filter{&IdFilter{}} + return []Filter{} } type sqlite3Driver struct { diff --git a/engine.go b/engine.go index c0136bba..e505ac1f 100644 --- a/engine.go +++ b/engine.go @@ -222,53 +222,13 @@ func (engine *Engine) QuoteTo(buf *strings.Builder, value string) { if value == "" { return } - - quoteTo(buf, engine.dialect.Quote(""), value) -} - -func quoteTo(buf *strings.Builder, quotePair string, value string) { - if len(quotePair) < 2 { // no quote - _, _ = buf.WriteString(value) - return - } - - prefix, suffix := quotePair[0], quotePair[1] - - i := 0 - for i < len(value) { - // start of a token; might be already quoted - if value[i] == '.' { - _ = buf.WriteByte('.') - i++ - } else if value[i] == prefix || value[i] == '`' { - // Has quotes; skip/normalize `name` to prefix+name+sufix - var ch byte - if value[i] == prefix { - ch = suffix - } else { - ch = '`' - } - i++ - _ = buf.WriteByte(prefix) - for ; i < len(value) && value[i] != ch; i++ { - _ = buf.WriteByte(value[i]) - } - _ = buf.WriteByte(suffix) - i++ - } else { - // Requires quotes - _ = buf.WriteByte(prefix) - for ; i < len(value) && value[i] != '.'; i++ { - _ = buf.WriteByte(value[i]) - } - _ = buf.WriteByte(suffix) - } - } + engine.dialect.Quoter().QuoteTo(buf, value) } +/* func (engine *Engine) quote(sql string) string { return engine.dialect.Quote(sql) -} +}*/ // SqlType will be deprecated, please use SQLType instead // @@ -530,8 +490,8 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dia } cols := table.ColumnsSeq() - colNames := engine.dialect.Quote(strings.Join(cols, engine.dialect.Quote(", "))) - destColNames := dialect.Quote(strings.Join(cols, dialect.Quote(", "))) + colNames := engine.dialect.Quoter().Join(cols, ", ") + destColNames := dialect.Quoter().Join(cols, ", ") rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.Quote(table.Name)) if err != nil { @@ -546,7 +506,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dia return err } - _, err = io.WriteString(w, "INSERT INTO "+dialect.Quote(table.Name)+" ("+destColNames+") VALUES (") + _, err = io.WriteString(w, "INSERT INTO "+dialect.Quoter().Quote(table.Name)+" ("+destColNames+") VALUES (") if err != nil { return err } @@ -617,7 +577,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dia // FIXME: Hack for postgres if string(dialect.DBType()) == schemas.POSTGRES && table.AutoIncrColumn() != nil { - _, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quote(table.Name)+"), 1), false);\n") + _, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quoter().Quote(table.Name)+"), 1), false);\n") if err != nil { return err } diff --git a/schemas/quote.go b/schemas/quote.go new file mode 100644 index 00000000..5230cec8 --- /dev/null +++ b/schemas/quote.go @@ -0,0 +1,110 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "strings" +) + +// Quoter represents two quote characters +type Quoter [2]string + +// CommonQuoter represetns a common quoter +var CommonQuoter = Quoter{"`", "`"} + +func (q Quoter) IsEmpty() bool { + return q[0] == "" && q[1] == "" +} + +func (q Quoter) Quote(s string) string { + var buf strings.Builder + q.QuoteTo(&buf, s) + return buf.String() +} + +func (q Quoter) ReverseQuote(s string) string { + reverseQuoter := Quoter{q[1], q[0]} + return reverseQuoter.Quote(s) +} + +// Trim removes quotes from s +func (q Quoter) Trim(s string) string { + if len(s) < 2 { + return s + } + + if s[0:1] == q[0] { + s = s[1:] + } + if len(s) > 0 && s[len(s)-1:] == q[0] { + return s[:len(s)-1] + } + return s +} + +func (q Quoter) Join(s []string, splitter string) string { + return q.Quote(strings.Join(s, q.ReverseQuote(splitter))) +} + +func (q Quoter) QuoteTo(buf *strings.Builder, value string) { + if q.IsEmpty() { + buf.WriteString(value) + return + } + + prefix, suffix := q[0][0], q[1][0] + lastCh := 0 // 0 prefix, 1 char, 2 suffix + i := 0 + for i < len(value) { + // start of a token; might be already quoted + if value[i] == '.' { + _ = buf.WriteByte('.') + lastCh = 1 + i++ + } else if value[i] == prefix || value[i] == '`' { + // Has quotes; skip/normalize `name` to prefix+name+sufix + var ch byte + if value[i] == prefix { + ch = suffix + } else { + ch = '`' + } + _ = buf.WriteByte(prefix) + i++ + lastCh = 0 + for ; i < len(value) && value[i] != ch && value[i] != ' '; i++ { + _ = buf.WriteByte(value[i]) + lastCh = 1 + } + _ = buf.WriteByte(suffix) + lastCh = 2 + i++ + } else if value[i] == ' ' { + if lastCh != 2 { + _ = buf.WriteByte(suffix) + lastCh = 2 + } + + // a AS b or a b + for ; i < len(value); i++ { + if value[i] != ' ' && value[i-1] == ' ' && (len(value) > i+1 && !strings.EqualFold(value[i:i+2], "AS")) { + break + } + + _ = buf.WriteByte(value[i]) + lastCh = 1 + } + } else { + // Requires quotes + _ = buf.WriteByte(prefix) + for ; i < len(value) && value[i] != '.' && value[i] != ' '; i++ { + _ = buf.WriteByte(value[i]) + lastCh = 1 + } + _ = buf.WriteByte(suffix) + lastCh = 2 + } + } +} diff --git a/engine_test.go b/schemas/quote_test.go similarity index 84% rename from engine_test.go rename to schemas/quote_test.go index 50522f5f..f89c6258 100644 --- a/engine_test.go +++ b/schemas/quote_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package schemas import ( "strings" @@ -12,10 +12,11 @@ import ( ) func TestQuoteTo(t *testing.T) { + var quoter = Quoter{"[", "]"} test := func(t *testing.T, expected string, value string) { buf := &strings.Builder{} - quoteTo(buf, "[]", value) + quoter.QuoteTo(buf, value) assert.EqualValues(t, expected, buf.String()) } @@ -35,7 +36,10 @@ func TestQuoteTo(t *testing.T) { test(t, `["myschema].[mytable"]`, `"myschema.mytable"`) + test(t, "[message_user] AS [sender]", "`message_user` AS `sender`") + buf := &strings.Builder{} - quoteTo(buf, "", "noquote") + quoter = Quoter{"", ""} + quoter.QuoteTo(buf, "noquote") assert.EqualValues(t, "noquote", buf.String()) } diff --git a/session_cond_test.go b/session_cond_test.go index 865890d0..6c9ab960 100644 --- a/session_cond_test.go +++ b/session_cond_test.go @@ -137,13 +137,13 @@ func TestIn(t *testing.T) { idsStr = idsStr[:len(idsStr)-1] users := make([]Userinfo, 0) - err = testEngine.In("(id)", ids[0], ids[1], ids[2]).Find(&users) + err = testEngine.In("uid", ids[0], ids[1], ids[2]).Find(&users) assert.NoError(t, err) fmt.Println(users) assert.EqualValues(t, 3, len(users)) users = make([]Userinfo, 0) - err = testEngine.In("(id)", ids).Find(&users) + err = testEngine.In("uid", ids).Find(&users) assert.NoError(t, err) fmt.Println(users) assert.EqualValues(t, 3, len(users)) @@ -161,7 +161,7 @@ func TestIn(t *testing.T) { idsInterface = append(idsInterface, id) } - err = testEngine.Where(department+" = ?", "dev").In("(id)", idsInterface...).Find(&users) + err = testEngine.Where(department+" = ?", "dev").In("uid", idsInterface...).Find(&users) assert.NoError(t, err) fmt.Println(users) assert.EqualValues(t, 3, len(users)) @@ -175,11 +175,11 @@ func TestIn(t *testing.T) { dev := testEngine.GetColumnMapper().Obj2Table("Dev") - err = testEngine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users) + err = testEngine.In("uid", 1).In("uid", 2).In(department, dev).Find(&users) assert.NoError(t, err) fmt.Println(users) - cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev-"}) + cnt, err = testEngine.In("uid", ids[0]).Update(&Userinfo{Departname: "dev-"}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) @@ -189,11 +189,11 @@ func TestIn(t *testing.T) { assert.True(t, has) assert.EqualValues(t, "dev-", user.Departname) - cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev"}) + cnt, err = testEngine.In("uid", ids[0]).Update(&Userinfo{Departname: "dev"}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - cnt, err = testEngine.In("(id)", ids[1]).Delete(&Userinfo{}) + cnt, err = testEngine.In("uid", ids[1]).Delete(&Userinfo{}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) } diff --git a/session_find_test.go b/session_find_test.go index 71116472..94b6b153 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -77,14 +77,14 @@ func TestWhere(t *testing.T) { assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.Where("(id) > ?", 2).Find(&users) + err := testEngine.Where("uid > ?", 2).Find(&users) if err != nil { t.Error(err) panic(err) } fmt.Println(users) - err = testEngine.Where("(id) > ?", 2).And("(id) < ?", 10).Find(&users) + err = testEngine.Where("uid > ?", 2).And("uid < ?", 10).Find(&users) if err != nil { t.Error(err) panic(err) @@ -312,12 +312,12 @@ func TestOrderSameMapper(t *testing.T) { assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.OrderBy("(id) desc").Find(&users) + err := testEngine.OrderBy("Uid desc").Find(&users) assert.NoError(t, err) fmt.Println(users) users2 := make([]Userinfo, 0) - err = testEngine.Asc("(id)", "Username").Desc("Height").Find(&users2) + err = testEngine.Asc("Uid", "Username").Desc("Height").Find(&users2) assert.NoError(t, err) fmt.Println(users2) } diff --git a/session_insert_test.go b/session_insert_test.go index 6190d18a..72b89e09 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -201,7 +201,7 @@ func TestInsertDefault(t *testing.T) { _, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("Status")).Insert(&di2) assert.NoError(t, err) - has, err := testEngine.Desc("(id)").Get(di) + has, err := testEngine.Desc("id").Get(di) assert.NoError(t, err) if !has { err = errors.New("error with no data") @@ -247,7 +247,7 @@ func TestInsertDefault2(t *testing.T) { t.Error(err) } - has, err := testEngine.Desc("(id)").Get(di) + has, err := testEngine.Desc("id").Get(di) if err != nil { t.Error(err) } @@ -257,7 +257,7 @@ func TestInsertDefault2(t *testing.T) { panic(err) } - has, err = testEngine.NoAutoCondition().Desc("(id)").Get(&di2) + has, err = testEngine.NoAutoCondition().Desc("id").Get(&di2) if err != nil { t.Error(err) } @@ -330,7 +330,7 @@ func TestInsertCreated(t *testing.T) { t.Fatal(err) } - has, err := testEngine.Desc("(id)").Get(di) + has, err := testEngine.Desc("id").Get(di) if err != nil { t.Fatal(err) } @@ -352,7 +352,7 @@ func TestInsertCreated(t *testing.T) { if err != nil { t.Fatal(err) } - has, err = testEngine.Desc("(id)").Get(di2) + has, err = testEngine.Desc("id").Get(di2) if err != nil { t.Fatal(err) } @@ -374,7 +374,7 @@ func TestInsertCreated(t *testing.T) { if err != nil { t.Fatal(err) } - has, err = testEngine.Desc("(id)").Get(di3) + has, err = testEngine.Desc("id").Get(di3) if err != nil { t.Fatal(err) } @@ -396,7 +396,7 @@ func TestInsertCreated(t *testing.T) { if err != nil { t.Fatal(err) } - has, err = testEngine.Desc("(id)").Get(di4) + has, err = testEngine.Desc("id").Get(di4) if err != nil { t.Fatal(err) } @@ -418,7 +418,7 @@ func TestInsertCreated(t *testing.T) { if err != nil { t.Fatal(err) } - has, err = testEngine.Desc("(id)").Get(di5) + has, err = testEngine.Desc("id").Get(di5) if err != nil { t.Fatal(err) } @@ -442,7 +442,7 @@ func TestInsertCreated(t *testing.T) { t.Fatal(err) } - has, err = testEngine.Desc("(id)").Get(di6) + has, err = testEngine.Desc("id").Get(di6) if err != nil { t.Fatal(err) } @@ -517,7 +517,7 @@ func TestCreatedJsonTime(t *testing.T) { if err != nil { t.Fatal(err) } - has, err := testEngine.Desc("(id)").Get(di5) + has, err := testEngine.Desc("id").Get(di5) if err != nil { t.Fatal(err) } diff --git a/session_tx_test.go b/session_tx_test.go index da3f0f04..ae83fe30 100644 --- a/session_tx_test.go +++ b/session_tx_test.go @@ -39,7 +39,7 @@ func TestTransaction(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "yyy"} - _, err = session.Where("(id) = ?", 0).Update(&user2) + _, err = session.Where("uid = ?", 0).Update(&user2) assert.NoError(t, err) _, err = session.Delete(&user2) @@ -119,7 +119,7 @@ func TestCombineTransactionSameMapper(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "zzz"} - _, err = session.Where("(id) = ?", 0).Update(&user2) + _, err = session.Where("uid = ?", 0).Update(&user2) assert.NoError(t, err) _, err = session.Exec("delete from "+testEngine.TableName("`Userinfo`", true)+" where `Username` = ?", user2.Username) diff --git a/session_update_test.go b/session_update_test.go index cb79bad0..2d310aa1 100644 --- a/session_update_test.go +++ b/session_update_test.go @@ -137,7 +137,7 @@ func TestForUpdate(t *testing.T) { // use lock fList := make([]ForUpdate, 0) session1.ForUpdate() - session1.Where("(id) = ?", 1) + session1.Where("id = ?", 1) err = session1.Find(&fList) switch { case err != nil: @@ -158,7 +158,7 @@ func TestForUpdate(t *testing.T) { wg.Add(1) go func() { f2 := new(ForUpdate) - session2.Where("(id) = ?", 1).ForUpdate() + session2.Where("id = ?", 1).ForUpdate() has, err := session2.Get(f2) // wait release lock switch { case err != nil: @@ -175,7 +175,7 @@ func TestForUpdate(t *testing.T) { wg.Add(1) go func() { f3 := new(ForUpdate) - session3.Where("(id) = ?", 1) + session3.Where("id = ?", 1) has, err := session3.Get(f3) // wait release lock switch { case err != nil: @@ -193,7 +193,7 @@ func TestForUpdate(t *testing.T) { f := new(ForUpdate) f.Name = "updated by session1" - session1.Where("(id) = ?", 1) + session1.Where("id = ?", 1) session1.Update(f) // release lock diff --git a/statement.go b/statement.go index 671a699e..26c5bd1b 100644 --- a/statement.go +++ b/statement.go @@ -618,7 +618,7 @@ func (statement *Statement) Cols(columns ...string) *Statement { newColumns := statement.colmap2NewColsWithQuote() statement.ColumnStr = strings.Join(newColumns, ", ") - statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) + statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.dialect.Quoter().Quote("*"), "*", -1) return statement } diff --git a/statement_exprparam.go b/statement_exprparam.go index fc62e36f..3231f86a 100644 --- a/statement_exprparam.go +++ b/statement_exprparam.go @@ -9,6 +9,7 @@ import ( "strings" "xorm.io/builder" + "xorm.io/xorm/schemas" ) type ErrUnsupportedExprType struct { @@ -40,7 +41,7 @@ func (exprs *exprParams) addParam(colName string, arg interface{}) { func (exprs *exprParams) isColExist(colName string) bool { for _, name := range exprs.colNames { - if strings.EqualFold(trimQuote(name), trimQuote(colName)) { + if strings.EqualFold(schemas.CommonQuoter.Trim(name), schemas.CommonQuoter.Trim(colName)) { return true } } diff --git a/statement_quote.go b/statement_quote.go deleted file mode 100644 index e22e0d14..00000000 --- a/statement_quote.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2019 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -func trimQuote(s string) string { - if len(s) == 0 { - return s - } - - if s[0] == '`' { - s = s[1:] - } - if len(s) > 0 && s[len(s)-1] == '`' { - return s[:len(s)-1] - } - return s -} -- 2.40.1 From dd8dc7dd2c08ff6adb4cc018301d4f7d6f2f1d60 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 24 Feb 2020 23:06:35 +0800 Subject: [PATCH 3/4] Fix test --- session_cond_test.go | 14 +++++++------- session_find_test.go | 8 ++++---- session_tx_test.go | 15 ++++++--------- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/session_cond_test.go b/session_cond_test.go index 6c9ab960..30b9f778 100644 --- a/session_cond_test.go +++ b/session_cond_test.go @@ -137,13 +137,13 @@ func TestIn(t *testing.T) { idsStr = idsStr[:len(idsStr)-1] users := make([]Userinfo, 0) - err = testEngine.In("uid", ids[0], ids[1], ids[2]).Find(&users) + err = testEngine.In("id", ids[0], ids[1], ids[2]).Find(&users) assert.NoError(t, err) fmt.Println(users) assert.EqualValues(t, 3, len(users)) users = make([]Userinfo, 0) - err = testEngine.In("uid", ids).Find(&users) + err = testEngine.In("id", ids).Find(&users) assert.NoError(t, err) fmt.Println(users) assert.EqualValues(t, 3, len(users)) @@ -161,7 +161,7 @@ func TestIn(t *testing.T) { idsInterface = append(idsInterface, id) } - err = testEngine.Where(department+" = ?", "dev").In("uid", idsInterface...).Find(&users) + err = testEngine.Where(department+" = ?", "dev").In("id", idsInterface...).Find(&users) assert.NoError(t, err) fmt.Println(users) assert.EqualValues(t, 3, len(users)) @@ -175,11 +175,11 @@ func TestIn(t *testing.T) { dev := testEngine.GetColumnMapper().Obj2Table("Dev") - err = testEngine.In("uid", 1).In("uid", 2).In(department, dev).Find(&users) + err = testEngine.In("id", 1).In("id", 2).In(department, dev).Find(&users) assert.NoError(t, err) fmt.Println(users) - cnt, err = testEngine.In("uid", ids[0]).Update(&Userinfo{Departname: "dev-"}) + cnt, err = testEngine.In("id", ids[0]).Update(&Userinfo{Departname: "dev-"}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) @@ -189,11 +189,11 @@ func TestIn(t *testing.T) { assert.True(t, has) assert.EqualValues(t, "dev-", user.Departname) - cnt, err = testEngine.In("uid", ids[0]).Update(&Userinfo{Departname: "dev"}) + cnt, err = testEngine.In("id", ids[0]).Update(&Userinfo{Departname: "dev"}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - cnt, err = testEngine.In("uid", ids[1]).Delete(&Userinfo{}) + cnt, err = testEngine.In("id", ids[1]).Delete(&Userinfo{}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) } diff --git a/session_find_test.go b/session_find_test.go index 94b6b153..8df3bc84 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -77,14 +77,14 @@ func TestWhere(t *testing.T) { assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.Where("uid > ?", 2).Find(&users) + err := testEngine.Where("id > ?", 2).Find(&users) if err != nil { t.Error(err) panic(err) } fmt.Println(users) - err = testEngine.Where("uid > ?", 2).And("uid < ?", 10).Find(&users) + err = testEngine.Where("id > ?", 2).And("id < ?", 10).Find(&users) if err != nil { t.Error(err) panic(err) @@ -312,12 +312,12 @@ func TestOrderSameMapper(t *testing.T) { assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.OrderBy("Uid desc").Find(&users) + err := testEngine.OrderBy("id desc").Find(&users) assert.NoError(t, err) fmt.Println(users) users2 := make([]Userinfo, 0) - err = testEngine.Asc("Uid", "Username").Desc("Height").Find(&users2) + err = testEngine.Asc("id", "Username").Desc("Height").Find(&users2) assert.NoError(t, err) fmt.Println(users2) } diff --git a/session_tx_test.go b/session_tx_test.go index ae83fe30..1e3dcabf 100644 --- a/session_tx_test.go +++ b/session_tx_test.go @@ -17,15 +17,12 @@ func TestTransaction(t *testing.T) { assert.NoError(t, prepareEngine()) assertSync(t, new(Userinfo)) - counter := func() { - total, err := testEngine.Count(&Userinfo{}) - if err != nil { - t.Error(err) - } - fmt.Printf("----now total %v records\n", total) + counter := func(t *testing.T) { + _, err := testEngine.Count(&Userinfo{}) + assert.NoError(t, err) } - counter() + counter(t) //defer counter() session := testEngine.NewSession() @@ -39,7 +36,7 @@ func TestTransaction(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "yyy"} - _, err = session.Where("uid = ?", 0).Update(&user2) + _, err = session.Where("id = ?", 0).Update(&user2) assert.NoError(t, err) _, err = session.Delete(&user2) @@ -119,7 +116,7 @@ func TestCombineTransactionSameMapper(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "zzz"} - _, err = session.Where("uid = ?", 0).Update(&user2) + _, err = session.Where("id = ?", 0).Update(&user2) assert.NoError(t, err) _, err = session.Exec("delete from "+testEngine.TableName("`Userinfo`", true)+" where `Username` = ?", user2.Username) -- 2.40.1 From 499098db4ec5de532003aea5e73f8feaec83a3e5 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 24 Feb 2020 23:57:03 +0800 Subject: [PATCH 4/4] Fix test --- dialects/filter.go | 19 ++++++++++++++- engine.go | 5 +--- schemas/quote.go | 57 ++++++++++++++++++++++++++++++++++++++++++- schemas/quote_test.go | 2 ++ 4 files changed, 77 insertions(+), 6 deletions(-) diff --git a/dialects/filter.go b/dialects/filter.go index 15044e1f..4795edb7 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -22,7 +22,24 @@ type QuoteFilter struct { func (s *QuoteFilter) Do(sql string, dialect Dialect, table *schemas.Table) string { quoter := dialect.Quoter() - return quoter.Quote(sql) + if quoter.IsEmpty() { + return sql + } + + prefix, suffix := quoter[0][0], quoter[1][0] + raw := []byte(sql) + for i, cnt := 0, 0; i < len(raw); i = i + 1 { + if raw[i] == '`' { + if cnt%2 == 0 { + raw[i] = prefix + } else { + raw[i] = suffix + } + cnt++ + } + } + return string(raw) + } // SeqFilter filter SQL replace ?, ? ... to $1, $2 ... diff --git a/engine.go b/engine.go index e505ac1f..b60234e0 100644 --- a/engine.go +++ b/engine.go @@ -193,10 +193,7 @@ func (engine *Engine) SupportInsertMany() bool { func (engine *Engine) quoteColumns(columnStr string) string { columns := strings.Split(columnStr, ",") - for i := 0; i < len(columns); i++ { - columns[i] = engine.Quote(strings.TrimSpace(columns[i])) - } - return strings.Join(columns, ",") + return engine.dialect.Quoter().Join(columns, ",") } // Quote Use QuoteStr quote the string sql diff --git a/schemas/quote.go b/schemas/quote.go index 5230cec8..e3571e34 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -5,6 +5,7 @@ package schemas import ( + "fmt" "strings" ) @@ -24,6 +25,37 @@ func (q Quoter) Quote(s string) string { return buf.String() } +func (q Quoter) Replace(sql string, newQuoter Quoter) string { + if q.IsEmpty() { + return sql + } + + if newQuoter.IsEmpty() { + var buf strings.Builder + for i := 0; i < len(sql); i = i + 1 { + if sql[i] != q[0][0] && sql[i] != q[1][0] { + _ = buf.WriteByte(sql[i]) + } + } + return buf.String() + } + + prefix, suffix := newQuoter[0][0], newQuoter[1][0] + var buf strings.Builder + for i, cnt := 0, 0; i < len(sql); i = i + 1 { + if cnt == 0 && sql[i] == q[0][0] { + _ = buf.WriteByte(prefix) + cnt = 1 + } else if cnt == 1 && sql[i] == q[1][0] { + _ = buf.WriteByte(suffix) + cnt = 0 + } else { + _ = buf.WriteByte(sql[i]) + } + } + return buf.String() +} + func (q Quoter) ReverseQuote(s string) string { reverseQuoter := Quoter{q[1], q[0]} return reverseQuoter.Quote(s) @@ -44,8 +76,31 @@ func (q Quoter) Trim(s string) string { return s } +func TrimSpaceJoin(a []string, sep string) string { + switch len(a) { + case 0: + return "" + case 1: + return a[0] + } + n := len(sep) * (len(a) - 1) + for i := 0; i < len(a); i++ { + n += len(a[i]) + } + + var b strings.Builder + b.Grow(n) + b.WriteString(strings.TrimSpace(a[0])) + for _, s := range a[1:] { + b.WriteString(sep) + b.WriteString(strings.TrimSpace(s)) + } + return b.String() +} + func (q Quoter) Join(s []string, splitter string) string { - return q.Quote(strings.Join(s, q.ReverseQuote(splitter))) + //return fmt.Sprintf("%s%s%s", q[0], TrimSpaceJoin(s, fmt.Sprintf("%s%s%s", q[1], splitter, q[0])), q[1]) + return q.Quote(TrimSpaceJoin(s, fmt.Sprintf("%s%s%s", q[1], splitter, q[0]))) } func (q Quoter) QuoteTo(buf *strings.Builder, value string) { diff --git a/schemas/quote_test.go b/schemas/quote_test.go index f89c6258..af773c8b 100644 --- a/schemas/quote_test.go +++ b/schemas/quote_test.go @@ -38,6 +38,8 @@ func TestQuoteTo(t *testing.T) { test(t, "[message_user] AS [sender]", "`message_user` AS `sender`") + assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ",")) + buf := &strings.Builder{} quoter = Quoter{"", ""} quoter.QuoteTo(buf, "noquote") -- 2.40.1