From f735b38bd477bdad15c17488326fd422d956bcd5 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 27 Feb 2020 08:33:29 +0800 Subject: [PATCH] Improve code --- helpers.go | 14 -------------- helpers_test.go | 18 ------------------ schemas/quote.go | 2 +- schemas/quote_test.go | 6 ++++++ session_find.go | 2 +- session_query.go | 2 +- session_update.go | 12 +++--------- statement.go | 12 ++++++------ 8 files changed, 18 insertions(+), 50 deletions(-) delete mode 100644 helpers_test.go diff --git a/helpers.go b/helpers.go index 75393ae3..32505c2c 100644 --- a/helpers.go +++ b/helpers.go @@ -200,17 +200,3 @@ func sliceEq(left, right []string) bool { func indexName(tableName, idxName string) string { return fmt.Sprintf("IDX_%v_%v", tableName, idxName) } - -func eraseAny(value string, strToErase ...string) string { - if len(strToErase) == 0 { - return value - } - var replaceSeq []string - for _, s := range strToErase { - replaceSeq = append(replaceSeq, s, "") - } - - replacer := strings.NewReplacer(replaceSeq...) - - return replacer.Replace(value) -} diff --git a/helpers_test.go b/helpers_test.go deleted file mode 100644 index fc9ece27..00000000 --- a/helpers_test.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestEraseAny(t *testing.T) { - raw := "SELECT * FROM `table`.[table_name]" - assert.EqualValues(t, raw, eraseAny(raw)) - assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`")) - assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]")) -} diff --git a/schemas/quote.go b/schemas/quote.go index 21327eb0..d10a5dc8 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -69,7 +69,7 @@ func (q Quoter) Trim(s string) string { if s[0:1] == q[0] { s = s[1:] } - if len(s) > 0 && s[len(s)-1:] == q[0] { + if len(s) > 0 && s[len(s)-1:] == q[1] { return s[:len(s)-1] } return s diff --git a/schemas/quote_test.go b/schemas/quote_test.go index 0c87d3a8..174d1a0d 100644 --- a/schemas/quote_test.go +++ b/schemas/quote_test.go @@ -63,3 +63,9 @@ func TestStrings(t *testing.T) { quotedCols := quoter.Strings(cols) assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols) } + +func TestTrim(t *testing.T) { + raw := "[table_name]" + assert.EqualValues(t, raw, CommonQuoter.Trim(raw)) + assert.EqualValues(t, "table_name", Quoter{"[", "]"}.Trim(raw)) +} diff --git a/session_find.go b/session_find.go index 566e83dd..251691b1 100644 --- a/session_find.go +++ b/session_find.go @@ -135,7 +135,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) return ErrTableNotFound } - var columnStr = session.statement.ColumnStr + var columnStr = session.statement.columnStr() if len(session.statement.selectStr) > 0 { columnStr = session.statement.selectStr } else { diff --git a/session_query.go b/session_query.go index afed4bcb..1783e154 100644 --- a/session_query.go +++ b/session_query.go @@ -29,7 +29,7 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa return "", nil, ErrTableNotFound } - var columnStr = session.statement.ColumnStr + var columnStr = session.statement.columnStr() if len(session.statement.selectStr) > 0 { columnStr = session.statement.selectStr } else { diff --git a/session_update.go b/session_update.go index 74b180d5..c95b23c5 100644 --- a/session_update.go +++ b/session_update.go @@ -103,14 +103,8 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri sps := strings.SplitN(kv, "=", 2) sps2 := strings.Split(sps[0], ".") colName := sps2[len(sps2)-1] - // treat quote prefix, suffix and '`' as quotes - quotes := append(strings.Split(session.engine.Quote(""), ""), "`") - if strings.ContainsAny(colName, strings.Join(quotes, "")) { - colName = strings.TrimSpace(eraseAny(colName, quotes...)) - } else { - session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName) - return ErrCacheFailed - } + colName = session.engine.dialect.Quoter().Trim(colName) + colName = schemas.CommonQuoter.Trim(colName) if col := table.GetColumn(colName); col != nil { fieldValue, err := col.ValueOf(bean) @@ -182,7 +176,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, ErrTableNotFound } - if session.statement.ColumnStr == "" { + if session.statement.columnStr() == "" { colNames, args = session.statement.buildUpdates(bean, false, false, false, false, true) } else { diff --git a/statement.go b/statement.go index fd6b3962..651ce175 100644 --- a/statement.go +++ b/statement.go @@ -30,7 +30,6 @@ type Statement struct { joinArgs []interface{} GroupByStr string HavingStr string - ColumnStr string selectStr string useAllCols bool AltTableName string @@ -86,7 +85,6 @@ func (statement *Statement) Reset() { statement.joinArgs = make([]interface{}, 0) statement.GroupByStr = "" statement.HavingStr = "" - statement.ColumnStr = "" statement.columnMap = columnMap{} statement.omitColumnMap = columnMap{} statement.AltTableName = "" @@ -612,11 +610,13 @@ func (statement *Statement) Cols(columns ...string) *Statement { for _, nc := range cols { statement.columnMap.add(nc) } - - statement.ColumnStr = statement.dialect.Quoter().Join(statement.columnMap, ", ") return statement } +func (statement *Statement) columnStr() string { + return statement.Engine.dialect.Quoter().Join(statement.columnMap, ", ") +} + // AllCols update use only: update all columns func (statement *Statement) AllCols() *Statement { statement.useAllCols = true @@ -955,7 +955,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, statement.setRefBean(bean) } - var columnStr = statement.ColumnStr + var columnStr = statement.columnStr() if len(statement.selectStr) > 0 { columnStr = statement.selectStr } else { @@ -1020,7 +1020,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa var selectSQL = statement.selectStr if len(selectSQL) <= 0 { if statement.IsDistinct { - selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr) + selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.columnStr()) } else { selectSQL = "count(*)" } -- 2.40.1