Browse Source

remove QuoteStr() usage (#1360)

tags/v0.7.5
helong zhang Lunny Xiao <xiaolunwen@gmail.com> 4 months ago
parent
commit
18b32486cf
8 changed files with 105 additions and 63 deletions
  1. +24
    -16
      engine.go
  2. +22
    -1
      helpers.go
  3. +21
    -1
      helpers_test.go
  4. +8
    -16
      session_insert.go
  5. +9
    -8
      session_update.go
  6. +1
    -1
      session_update_test.go
  7. +10
    -19
      statement.go
  8. +10
    -1
      statement_test.go

+ 24
- 16
engine.go View File

@@ -177,6 +177,7 @@ func (engine *Engine) SupportInsertMany() bool {

// QuoteStr Engine's database use which character as quote.
// mysql, sqlite use ` and postgres use "
// Deprecated, use Quote() instead
func (engine *Engine) QuoteStr() string {
return engine.dialect.QuoteStr()
}
@@ -196,13 +197,10 @@ func (engine *Engine) Quote(value string) string {
return value
}

if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' {
return value
}

value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1)
buf := builder.StringBuilder{}
engine.QuoteTo(&buf, value)

return engine.dialect.QuoteStr() + value + engine.dialect.QuoteStr()
return buf.String()
}

// QuoteTo quotes string and writes into the buffer
@@ -216,20 +214,30 @@ func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) {
return
}

if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' {
buf.WriteString(value)
quotePair := engine.dialect.Quote("")

if value[0] == '`' || len(quotePair) < 2 || value[0] == quotePair[0] { // no quote
_, _ = buf.WriteString(value)
return
} else {
prefix, suffix := quotePair[0], quotePair[1]

_ = buf.WriteByte(prefix)
for i := 0; i < len(value); i++ {
if value[i] == '.' {
_ = buf.WriteByte(suffix)
_ = buf.WriteByte('.')
_ = buf.WriteByte(prefix)
} else {
_ = buf.WriteByte(value[i])
}
}
_ = buf.WriteByte(suffix)
}

value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1)

buf.WriteString(engine.dialect.QuoteStr())
buf.WriteString(value)
buf.WriteString(engine.dialect.QuoteStr())
}

func (engine *Engine) quote(sql string) string {
return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr()
return engine.dialect.Quote(sql)
}

// SqlType will be deprecated, please use SQLType instead
@@ -1581,7 +1589,7 @@ func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{
func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) {
switch sqlTypeName {
case core.Time:
s := t.Format("2006-01-02 15:04:05") //time.RFC3339
s := t.Format("2006-01-02 15:04:05") // time.RFC3339
v = s[11:19]
case core.Date:
v = t.Format("2006-01-02")


+ 22
- 1
helpers.go View File

@@ -281,7 +281,7 @@ func rValue(bean interface{}) reflect.Value {

func rType(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
//return reflect.TypeOf(sliceValue.Interface())
// return reflect.TypeOf(sliceValue.Interface())
return sliceValue.Type()
}

@@ -309,3 +309,24 @@ func sliceEq(left, right []string) bool {
func indexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
}

func eraseAny(value string, strToErase ...string) string {
if len(strToErase) == 0 {
return value
}
var replaceSeq []string
for _, s := range strToErase {
replaceSeq = append(replaceSeq, s, "")
}

replacer := strings.NewReplacer(replaceSeq...)

return replacer.Replace(value)
}

func quoteColumns(cols []string, quoteFunc func(string) string, sep string) string {
for i := range cols {
cols[i] = quoteFunc(cols[i])
}
return strings.Join(cols, sep+" ")
}

+ 21
- 1
helpers_test.go View File

@@ -4,7 +4,11 @@

package xorm

import "testing"
import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestSplitTag(t *testing.T) {
var cases = []struct {
@@ -24,3 +28,19 @@ func TestSplitTag(t *testing.T) {
}
}
}

func TestEraseAny(t *testing.T) {
raw := "SELECT * FROM `table`.[table_name]"
assert.EqualValues(t, raw, eraseAny(raw))
assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`"))
assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]"))
}

func TestQuoteColumns(t *testing.T) {
cols := []string{"f1", "f2", "f3"}
quoteFunc := func(value string) string {
return "[" + value + "]"
}

assert.EqualValues(t, "[f1], [f2], [f3]", quoteColumns(cols, quoteFunc, ","))
}

+ 8
- 16
session_insert.go View File

@@ -242,23 +242,17 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error

var sql string
if session.engine.dialect.DBType() == core.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr())
sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL",
quoteColumns(colNames, session.engine.Quote, ","))
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
quoteColumns(colNames, session.engine.Quote, ","),
strings.Join(colMultiPlaces, temp))
} else {
sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
quoteColumns(colNames, session.engine.Quote, ","),
strings.Join(colMultiPlaces, "),("))
}
res, err := session.exec(sql, args...)
@@ -378,11 +372,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
}
if len(colPlaces) > 0 {
sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v)%s VALUES (%v)",
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.Quote(", ")),
session.engine.QuoteStr(),
quoteColumns(colNames, session.engine.Quote, ","),
output,
colPlaces)
} else {


+ 9
- 8
session_update.go View File

@@ -96,14 +96,15 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
return ErrCacheFailed
}
kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")

for idx, kv := range kvs {
sps := strings.SplitN(kv, "=", 2)
sps2 := strings.Split(sps[0], ".")
colName := sps2[len(sps2)-1]
if strings.Contains(colName, "`") {
colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1))
} else if strings.Contains(colName, session.engine.QuoteStr()) {
colName = strings.TrimSpace(strings.Replace(colName, session.engine.QuoteStr(), "", -1))
// treat quote prefix, suffix and '`' as quotes
quotes := append(strings.Split(session.engine.Quote(""), ""), "`")
if strings.ContainsAny(colName, strings.Join(quotes, "")) {
colName = strings.TrimSpace(eraseAny(colName, quotes...))
} else {
session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
return ErrCacheFailed
@@ -221,19 +222,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}

//for update action to like "column = column + ?"
// for update action to like "column = column + ?"
incColumns := session.statement.getInc()
for _, v := range incColumns {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?")
args = append(args, v.arg)
}
//for update action to like "column = column - ?"
// for update action to like "column = column - ?"
decColumns := session.statement.getDec()
for _, v := range decColumns {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?")
args = append(args, v.arg)
}
//for update action to like "column = expression"
// for update action to like "column = expression"
exprColumns := session.statement.getExpr()
for _, v := range exprColumns {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr)
@@ -382,7 +383,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}

if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
//session.cacheUpdate(table, tableName, sqlStr, args...)
// session.cacheUpdate(table, tableName, sqlStr, args...)
session.engine.logger.Debug("[cacheUpdate] clear table ", tableName)
cacher.ClearIds(tableName)
cacher.ClearBeans(tableName)


+ 1
- 1
session_update_test.go View File

@@ -11,8 +11,8 @@ import (
"testing"
"time"

"xorm.io/core"
"github.com/stretchr/testify/assert"
"xorm.io/core"
)

func TestUpdateMap(t *testing.T) {


+ 10
- 19
statement.go View File

@@ -6,7 +6,6 @@ package xorm

import (
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strings"
@@ -398,7 +397,7 @@ func (statement *Statement) buildUpdates(bean interface{},
continue
}
} else {
//TODO: how to handler?
// TODO: how to handler?
panic("not supported")
}
} else {
@@ -579,21 +578,9 @@ func (statement *Statement) getExpr() map[string]exprParam {

func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
newColumns := make([]string, 0)
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
fields := strings.Split(strings.TrimSpace(c), ".")
if len(fields) == 1 {
newColumns = append(newColumns, statement.Engine.quote(fields[0]))
} else if len(fields) == 2 {
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
statement.Engine.quote(fields[1]))
} else {
panic(errors.New("unwanted colnames"))
}
}
newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...)))
}
return newColumns
}
@@ -764,7 +751,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
return statement
}
tbs := strings.Split(tp.TableName(), ".")
var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")

var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder:
@@ -774,7 +763,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
return statement
}
tbs := strings.Split(tp.TableName(), ".")
var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")

var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
@@ -1246,7 +1237,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {

var whereStr = sqls[1]

//TODO: for postgres only, if any other database?
// TODO: for postgres only, if any other database?
var paraStr string
if statement.Engine.dialect.DBType() == core.POSTGRES {
paraStr = "$"


+ 10
- 1
statement_test.go View File

@@ -9,8 +9,8 @@ import (
"strings"
"testing"

"xorm.io/core"
"github.com/stretchr/testify/assert"
"xorm.io/core"
)

var colStrTests = []struct {
@@ -237,3 +237,12 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) {
testEngine.Update(record)
assertGetRecord()
}

func TestCol2NewColsWithQuote(t *testing.T) {
cols := []string{"f1", "f2", "t3.f3"}

statement := createTestStatement()

quotedCols := statement.col2NewColsWithQuote(cols...)
assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols)
}

Loading…
Cancel
Save