Improve some codes #1551
14
helpers.go
14
helpers.go
|
@ -200,17 +200,3 @@ func sliceEq(left, right []string) bool {
|
|||
func indexName(tableName, idxName string) string {
|
||||
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
|
||||
}
|
||||
|
||||
func eraseAny(value string, strToErase ...string) string {
|
||||
if len(strToErase) == 0 {
|
||||
return value
|
||||
}
|
||||
var replaceSeq []string
|
||||
for _, s := range strToErase {
|
||||
replaceSeq = append(replaceSeq, s, "")
|
||||
}
|
||||
|
||||
replacer := strings.NewReplacer(replaceSeq...)
|
||||
|
||||
return replacer.Replace(value)
|
||||
}
|
||||
|
|
|
@ -1,18 +0,0 @@
|
|||
// Copyright 2017 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package xorm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEraseAny(t *testing.T) {
|
||||
raw := "SELECT * FROM `table`.[table_name]"
|
||||
assert.EqualValues(t, raw, eraseAny(raw))
|
||||
assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`"))
|
||||
assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]"))
|
||||
}
|
|
@ -69,7 +69,7 @@ func (q Quoter) Trim(s string) string {
|
|||
if s[0:1] == q[0] {
|
||||
s = s[1:]
|
||||
}
|
||||
if len(s) > 0 && s[len(s)-1:] == q[0] {
|
||||
if len(s) > 0 && s[len(s)-1:] == q[1] {
|
||||
return s[:len(s)-1]
|
||||
}
|
||||
return s
|
||||
|
|
|
@ -63,3 +63,9 @@ func TestStrings(t *testing.T) {
|
|||
quotedCols := quoter.Strings(cols)
|
||||
assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols)
|
||||
}
|
||||
|
||||
func TestTrim(t *testing.T) {
|
||||
raw := "[table_name]"
|
||||
assert.EqualValues(t, raw, CommonQuoter.Trim(raw))
|
||||
assert.EqualValues(t, "table_name", Quoter{"[", "]"}.Trim(raw))
|
||||
}
|
||||
|
|
|
@ -135,7 +135,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
|
|||
return ErrTableNotFound
|
||||
}
|
||||
|
||||
var columnStr = session.statement.ColumnStr
|
||||
var columnStr = session.statement.columnStr()
|
||||
if len(session.statement.selectStr) > 0 {
|
||||
columnStr = session.statement.selectStr
|
||||
} else {
|
||||
|
|
|
@ -29,7 +29,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa
|
|||
return "", nil, ErrTableNotFound
|
||||
}
|
||||
|
||||
var columnStr = session.statement.ColumnStr
|
||||
var columnStr = session.statement.columnStr()
|
||||
if len(session.statement.selectStr) > 0 {
|
||||
columnStr = session.statement.selectStr
|
||||
} else {
|
||||
|
|
|
@ -103,14 +103,8 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
|
|||
sps := strings.SplitN(kv, "=", 2)
|
||||
sps2 := strings.Split(sps[0], ".")
|
||||
colName := sps2[len(sps2)-1]
|
||||
// treat quote prefix, suffix and '`' as quotes
|
||||
quotes := append(strings.Split(session.engine.Quote(""), ""), "`")
|
||||
if strings.ContainsAny(colName, strings.Join(quotes, "")) {
|
||||
colName = strings.TrimSpace(eraseAny(colName, quotes...))
|
||||
} else {
|
||||
session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
|
||||
return ErrCacheFailed
|
||||
}
|
||||
colName = session.engine.dialect.Quoter().Trim(colName)
|
||||
colName = schemas.CommonQuoter.Trim(colName)
|
||||
|
||||
if col := table.GetColumn(colName); col != nil {
|
||||
fieldValue, err := col.ValueOf(bean)
|
||||
|
@ -182,7 +176,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
return 0, ErrTableNotFound
|
||||
}
|
||||
|
||||
if session.statement.ColumnStr == "" {
|
||||
if session.statement.columnStr() == "" {
|
||||
colNames, args = session.statement.buildUpdates(bean, false, false,
|
||||
false, false, true)
|
||||
} else {
|
||||
|
|
12
statement.go
12
statement.go
|
@ -30,7 +30,6 @@ type Statement struct {
|
|||
joinArgs []interface{}
|
||||
GroupByStr string
|
||||
HavingStr string
|
||||
ColumnStr string
|
||||
selectStr string
|
||||
useAllCols bool
|
||||
AltTableName string
|
||||
|
@ -86,7 +85,6 @@ func (statement *Statement) Reset() {
|
|||
statement.joinArgs = make([]interface{}, 0)
|
||||
statement.GroupByStr = ""
|
||||
statement.HavingStr = ""
|
||||
statement.ColumnStr = ""
|
||||
statement.columnMap = columnMap{}
|
||||
statement.omitColumnMap = columnMap{}
|
||||
statement.AltTableName = ""
|
||||
|
@ -612,11 +610,13 @@ func (statement *Statement) Cols(columns ...string) *Statement {
|
|||
for _, nc := range cols {
|
||||
statement.columnMap.add(nc)
|
||||
}
|
||||
|
||||
statement.ColumnStr = statement.dialect.Quoter().Join(statement.columnMap, ", ")
|
||||
return statement
|
||||
}
|
||||
|
||||
func (statement *Statement) columnStr() string {
|
||||
return statement.Engine.dialect.Quoter().Join(statement.columnMap, ", ")
|
||||
}
|
||||
|
||||
// AllCols update use only: update all columns
|
||||
func (statement *Statement) AllCols() *Statement {
|
||||
statement.useAllCols = true
|
||||
|
@ -955,7 +955,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
|
|||
statement.setRefBean(bean)
|
||||
}
|
||||
|
||||
var columnStr = statement.ColumnStr
|
||||
var columnStr = statement.columnStr()
|
||||
if len(statement.selectStr) > 0 {
|
||||
columnStr = statement.selectStr
|
||||
} else {
|
||||
|
@ -1020,7 +1020,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
|
|||
var selectSQL = statement.selectStr
|
||||
if len(selectSQL) <= 0 {
|
||||
if statement.IsDistinct {
|
||||
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
|
||||
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.columnStr())
|
||||
} else {
|
||||
selectSQL = "count(*)"
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user