Improve some codes #1551

Merged
lunny merged 1 commits from lunny/improve into master 2020-02-27 01:30:10 +00:00
8 changed files with 18 additions and 50 deletions

View File

@ -200,17 +200,3 @@ func sliceEq(left, right []string) bool {
func indexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
}
func eraseAny(value string, strToErase ...string) string {
if len(strToErase) == 0 {
return value
}
var replaceSeq []string
for _, s := range strToErase {
replaceSeq = append(replaceSeq, s, "")
}
replacer := strings.NewReplacer(replaceSeq...)
return replacer.Replace(value)
}

View File

@ -1,18 +0,0 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestEraseAny(t *testing.T) {
raw := "SELECT * FROM `table`.[table_name]"
assert.EqualValues(t, raw, eraseAny(raw))
assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`"))
assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]"))
}

View File

@ -69,7 +69,7 @@ func (q Quoter) Trim(s string) string {
if s[0:1] == q[0] {
s = s[1:]
}
if len(s) > 0 && s[len(s)-1:] == q[0] {
if len(s) > 0 && s[len(s)-1:] == q[1] {
return s[:len(s)-1]
}
return s

View File

@ -63,3 +63,9 @@ func TestStrings(t *testing.T) {
quotedCols := quoter.Strings(cols)
assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols)
}
func TestTrim(t *testing.T) {
raw := "[table_name]"
assert.EqualValues(t, raw, CommonQuoter.Trim(raw))
assert.EqualValues(t, "table_name", Quoter{"[", "]"}.Trim(raw))
}

View File

@ -135,7 +135,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
return ErrTableNotFound
}
var columnStr = session.statement.ColumnStr
var columnStr = session.statement.columnStr()
if len(session.statement.selectStr) > 0 {
columnStr = session.statement.selectStr
} else {

View File

@ -29,7 +29,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa
return "", nil, ErrTableNotFound
}
var columnStr = session.statement.ColumnStr
var columnStr = session.statement.columnStr()
if len(session.statement.selectStr) > 0 {
columnStr = session.statement.selectStr
} else {

View File

@ -103,14 +103,8 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
sps := strings.SplitN(kv, "=", 2)
sps2 := strings.Split(sps[0], ".")
colName := sps2[len(sps2)-1]
// treat quote prefix, suffix and '`' as quotes
quotes := append(strings.Split(session.engine.Quote(""), ""), "`")
if strings.ContainsAny(colName, strings.Join(quotes, "")) {
colName = strings.TrimSpace(eraseAny(colName, quotes...))
} else {
session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
return ErrCacheFailed
}
colName = session.engine.dialect.Quoter().Trim(colName)
colName = schemas.CommonQuoter.Trim(colName)
if col := table.GetColumn(colName); col != nil {
fieldValue, err := col.ValueOf(bean)
@ -182,7 +176,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, ErrTableNotFound
}
if session.statement.ColumnStr == "" {
if session.statement.columnStr() == "" {
colNames, args = session.statement.buildUpdates(bean, false, false,
false, false, true)
} else {

View File

@ -30,7 +30,6 @@ type Statement struct {
joinArgs []interface{}
GroupByStr string
HavingStr string
ColumnStr string
selectStr string
useAllCols bool
AltTableName string
@ -86,7 +85,6 @@ func (statement *Statement) Reset() {
statement.joinArgs = make([]interface{}, 0)
statement.GroupByStr = ""
statement.HavingStr = ""
statement.ColumnStr = ""
statement.columnMap = columnMap{}
statement.omitColumnMap = columnMap{}
statement.AltTableName = ""
@ -612,11 +610,13 @@ func (statement *Statement) Cols(columns ...string) *Statement {
for _, nc := range cols {
statement.columnMap.add(nc)
}
statement.ColumnStr = statement.dialect.Quoter().Join(statement.columnMap, ", ")
return statement
}
func (statement *Statement) columnStr() string {
return statement.Engine.dialect.Quoter().Join(statement.columnMap, ", ")
}
// AllCols update use only: update all columns
func (statement *Statement) AllCols() *Statement {
statement.useAllCols = true
@ -955,7 +955,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
statement.setRefBean(bean)
}
var columnStr = statement.ColumnStr
var columnStr = statement.columnStr()
if len(statement.selectStr) > 0 {
columnStr = statement.selectStr
} else {
@ -1020,7 +1020,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
var selectSQL = statement.selectStr
if len(selectSQL) <= 0 {
if statement.IsDistinct {
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.columnStr())
} else {
selectSQL = "count(*)"
}