From 489d996183048905233096ff37b2628d6aa89cad Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 26 Feb 2020 23:00:12 +0800 Subject: [PATCH 1/2] Improve statement --- schemas/quote.go | 12 +++++-- schemas/quote_test.go | 8 +++++ session.go | 5 +-- session_exist.go | 2 +- session_update.go | 10 +++--- statement.go | 79 +++++++++++++++++++++++-------------------- statement_test.go | 15 +++----- 7 files changed, 74 insertions(+), 57 deletions(-) diff --git a/schemas/quote.go b/schemas/quote.go index 5dac6d27..21327eb0 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -93,17 +93,25 @@ func (q Quoter) Join(a []string, sep string) string { if i > 0 { b.WriteString(sep) } - if q[0] != "" { + if q[0] != "" && s != "*" { b.WriteString(q[0]) } b.WriteString(strings.TrimSpace(s)) - if q[1] != "" { + if q[1] != "" && s != "*" { b.WriteString(q[1]) } } return b.String() } +func (q Quoter) Strings(s []string) []string { + var res = make([]string, 0, len(s)) + for _, a := range s { + res = append(res, q.Quote(a)) + } + return res +} + func (q Quoter) QuoteTo(buf *strings.Builder, value string) { if q.IsEmpty() { buf.WriteString(value) diff --git a/schemas/quote_test.go b/schemas/quote_test.go index 5eea05d3..0c87d3a8 100644 --- a/schemas/quote_test.go +++ b/schemas/quote_test.go @@ -55,3 +55,11 @@ func TestJoin(t *testing.T) { quoter = Quoter{"", ""} assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", ")) } + +func TestStrings(t *testing.T) { + cols := []string{"f1", "f2", "t3.f3"} + quoter := Quoter{"[", "]"} + + quotedCols := quoter.Strings(cols) + assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols) +} diff --git a/session.go b/session.go index 3e31d3d7..703aa873 100644 --- a/session.go +++ b/session.go @@ -72,8 +72,9 @@ func (session *Session) Clone() *Session { // Init reset the session as the init status. func (session *Session) Init() { - session.statement.Init() + session.statement.dialect = session.engine.dialect session.statement.Engine = session.engine + session.statement.Reset() session.showSQL = session.engine.showSQL session.isAutoCommit = true session.isCommitedOrRollbacked = false @@ -128,7 +129,7 @@ func (session *Session) IsClosed() bool { func (session *Session) resetStatement() { if session.autoResetStatement { - session.statement.Init() + session.statement.Reset() } } diff --git a/session_exist.go b/session_exist.go index 153bb219..d5b0c1d8 100644 --- a/session_exist.go +++ b/session_exist.go @@ -34,7 +34,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { return false, ErrTableNotFound } - tableName = session.statement.Engine.Quote(tableName) + tableName = session.statement.quote(tableName) if len(session.statement.JoinStr) > 0 { joinStr = session.statement.JoinStr } diff --git a/session_update.go b/session_update.go index 427d452d..74b180d5 100644 --- a/session_update.go +++ b/session_update.go @@ -341,9 +341,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var top string if st.LimitN != nil { limitValue := *st.LimitN - if st.Engine.dialect.DBType() == schemas.MYSQL { + if st.dialect.DBType() == schemas.MYSQL { condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) - } else if st.Engine.dialect.DBType() == schemas.SQLITE { + } else if st.dialect.DBType() == schemas.SQLITE { tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) @@ -354,7 +354,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if st.Engine.dialect.DBType() == schemas.POSTGRES { + } else if st.dialect.DBType() == schemas.POSTGRES { tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) @@ -366,8 +366,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if st.Engine.dialect.DBType() == schemas.MSSQL { - if st.OrderStr != "" && st.Engine.dialect.DBType() == schemas.MSSQL && + } else if st.dialect.DBType() == schemas.MSSQL { + if st.OrderStr != "" && st.dialect.DBType() == schemas.MSSQL && table != nil && len(table.PrimaryKeys) == 1 { cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], diff --git a/statement.go b/statement.go index 9dc5bf52..78e252b9 100644 --- a/statement.go +++ b/statement.go @@ -20,6 +20,7 @@ import ( // Statement save all the sql info for executing SQL type Statement struct { RefTable *schemas.Table + dialect dialects.Dialect Engine *Engine Start int LimitN *int @@ -32,7 +33,6 @@ type Statement struct { ColumnStr string selectStr string useAllCols bool - OmitStr string AltTableName string tableName string RawSQL string @@ -63,8 +63,20 @@ type Statement struct { lastError error } +func newStatement(dialect dialects.Dialect) *Statement { + statement := &Statement{ + dialect: dialect, + } + statement.Reset() + return statement +} + +func (statement *Statement) omitStr() string { + return statement.dialect.Quoter().Join(statement.omitColumnMap, " ,") +} + // Init reset all the statement's fields -func (statement *Statement) Init() { +func (statement *Statement) Reset() { statement.RefTable = nil statement.Start = 0 statement.LimitN = nil @@ -75,7 +87,6 @@ func (statement *Statement) Init() { statement.GroupByStr = "" statement.HavingStr = "" statement.ColumnStr = "" - statement.OmitStr = "" statement.columnMap = columnMap{} statement.omitColumnMap = columnMap{} statement.AltTableName = "" @@ -144,6 +155,10 @@ func (statement *Statement) Where(query interface{}, args ...interface{}) *State return statement.And(query, args...) } +func (statement *Statement) quote(s string) string { + return statement.dialect.Quoter().Quote(s) +} + // And add Where & and statement func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { switch query.(type) { @@ -154,7 +169,7 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme queryMap := query.(map[string]interface{}) newMap := make(map[string]interface{}) for k, v := range queryMap { - newMap[statement.Engine.Quote(k)] = v + newMap[statement.quote(k)] = v } statement.cond = statement.cond.And(builder.Eq(newMap)) case builder.Cond: @@ -197,14 +212,14 @@ func (statement *Statement) Or(query interface{}, args ...interface{}) *Statemen // In generate "Where column IN (?) " statement func (statement *Statement) In(column string, args ...interface{}) *Statement { - in := builder.In(statement.Engine.Quote(column), args...) + in := builder.In(statement.quote(column), args...) statement.cond = statement.cond.And(in) return statement } // NotIn generate "Where column NOT IN (?) " statement func (statement *Statement) NotIn(column string, args ...interface{}) *Statement { - notIn := builder.NotIn(statement.Engine.Quote(column), args...) + notIn := builder.NotIn(statement.quote(column), args...) statement.cond = statement.cond.And(notIn) return statement } @@ -341,7 +356,7 @@ func (statement *Statement) buildUpdates(bean interface{}, if fieldValue.IsNil() { if includeNil { args = append(args, nil) - colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) + colNames = append(colNames, fmt.Sprintf("%v=?", statement.quote(col.Name))) } continue } else if !fieldValue.IsValid() { @@ -485,10 +500,10 @@ func (statement *Statement) buildUpdates(bean interface{}, APPEND: args = append(args, val) - if col.IsPrimaryKey && engine.dialect.DBType() == "ql" { + if col.IsPrimaryKey { continue } - colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) + colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name))) } return colNames, args @@ -504,9 +519,9 @@ func (statement *Statement) colName(col *schemas.Column, tableName string) strin if len(statement.TableAlias) > 0 { nm = statement.TableAlias } - return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name) + return statement.quote(nm) + "." + statement.quote(col.Name) } - return statement.Engine.Quote(col.Name) + return statement.quote(col.Name) } // TableName return current tableName @@ -572,22 +587,8 @@ func (statement *Statement) SetExpr(column string, expression interface{}) *Stat return statement } -func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { - newColumns := make([]string, 0) - quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") - for _, col := range columns { - newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...))) - } - return newColumns -} - -func (statement *Statement) colmap2NewColsWithQuote() []string { - newColumns := make([]string, len(statement.columnMap), len(statement.columnMap)) - copy(newColumns, statement.columnMap) - for i := 0; i < len(statement.columnMap); i++ { - newColumns[i] = statement.Engine.Quote(newColumns[i]) - } - return newColumns +func (statement *Statement) quoteColumnMap() []string { + return statement.dialect.Quoter().Strings(statement.columnMap) } // Distinct generates "DISTINCT col1, col2 " statement @@ -616,10 +617,7 @@ func (statement *Statement) Cols(columns ...string) *Statement { statement.columnMap.add(nc) } - newColumns := statement.colmap2NewColsWithQuote() - - statement.ColumnStr = strings.Join(newColumns, ", ") - statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.dialect.Quoter().Quote("*"), "*", -1) + statement.ColumnStr = statement.dialect.Quoter().Join(statement.columnMap, ", ") return statement } @@ -654,7 +652,6 @@ func (statement *Statement) Omit(columns ...string) { for _, nc := range newColumns { statement.omitColumnMap = append(statement.omitColumnMap, nc) } - statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) } // Nullable Update use only: update columns to null when value is nullable and zero-value @@ -695,8 +692,13 @@ func (statement *Statement) Desc(colNames ...string) *Statement { if len(statement.OrderStr) > 0 { fmt.Fprint(&buf, statement.OrderStr, ", ") } - newColNames := statement.col2NewColsWithQuote(colNames...) - fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, ")) + for i, col := range colNames { + if i > 0 { + fmt.Fprint(&buf, ", ") + } + statement.dialect.Quoter().QuoteTo(&buf, col) + fmt.Fprint(&buf, " DESC") + } statement.OrderStr = buf.String() return statement } @@ -707,8 +709,13 @@ func (statement *Statement) Asc(colNames ...string) *Statement { if len(statement.OrderStr) > 0 { fmt.Fprint(&buf, statement.OrderStr, ", ") } - newColNames := statement.col2NewColsWithQuote(colNames...) - fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, ")) + for i, col := range colNames { + if i > 0 { + fmt.Fprint(&buf, ", ") + } + statement.dialect.Quoter().QuoteTo(&buf, col) + fmt.Fprint(&buf, " ASC") + } statement.OrderStr = buf.String() return statement } diff --git a/statement_test.go b/statement_test.go index bc5fc5dd..f048e1e8 100644 --- a/statement_test.go +++ b/statement_test.go @@ -166,15 +166,17 @@ func (TestType) TableName() string { func createTestStatement() *Statement { if engine, ok := testEngine.(*Engine); ok { statement := &Statement{} - statement.Init() + statement.Reset() statement.Engine = engine + statement.dialect = engine.dialect statement.setRefValue(reflect.ValueOf(TestType{})) return statement } else if eg, ok := testEngine.(*EngineGroup); ok { statement := &Statement{} - statement.Init() + statement.Reset() statement.Engine = eg.Engine + statement.dialect = eg.Engine.dialect statement.setRefValue(reflect.ValueOf(TestType{})) return statement @@ -237,12 +239,3 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) { testEngine.Update(record) assertGetRecord() } - -func TestCol2NewColsWithQuote(t *testing.T) { - cols := []string{"f1", "f2", "t3.f3"} - - statement := createTestStatement() - - quotedCols := statement.col2NewColsWithQuote(cols...) - assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols) -} -- 2.40.1 From fe31f0c9065c769164e2c4502ea4dcd908838701 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 27 Feb 2020 08:09:29 +0800 Subject: [PATCH 2/2] Fix cache bug --- log/logger.go | 8 ++++---- session.go | 2 +- session_get.go | 12 ++++++------ statement.go | 4 ---- statement_test.go | 10 +++++++--- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/log/logger.go b/log/logger.go index b5ab9019..eeb63693 100644 --- a/log/logger.go +++ b/log/logger.go @@ -130,7 +130,7 @@ func NewSimpleLogger3(out io.Writer, prefix string, flag int, l LogLevel) *Simpl // Error implement ILogger func (s *SimpleLogger) Error(v ...interface{}) { if s.level <= LOG_ERR { - s.ERR.Output(2, fmt.Sprint(v...)) + s.ERR.Output(2, fmt.Sprintln(v...)) } return } @@ -146,7 +146,7 @@ func (s *SimpleLogger) Errorf(format string, v ...interface{}) { // Debug implement ILogger func (s *SimpleLogger) Debug(v ...interface{}) { if s.level <= LOG_DEBUG { - s.DEBUG.Output(2, fmt.Sprint(v...)) + s.DEBUG.Output(2, fmt.Sprintln(v...)) } return } @@ -162,7 +162,7 @@ func (s *SimpleLogger) Debugf(format string, v ...interface{}) { // Info implement ILogger func (s *SimpleLogger) Info(v ...interface{}) { if s.level <= LOG_INFO { - s.INFO.Output(2, fmt.Sprint(v...)) + s.INFO.Output(2, fmt.Sprintln(v...)) } return } @@ -178,7 +178,7 @@ func (s *SimpleLogger) Infof(format string, v ...interface{}) { // Warn implement ILogger func (s *SimpleLogger) Warn(v ...interface{}) { if s.level <= LOG_WARNING { - s.WARN.Output(2, fmt.Sprint(v...)) + s.WARN.Output(2, fmt.Sprintln(v...)) } return } diff --git a/session.go b/session.go index 703aa873..8c692879 100644 --- a/session.go +++ b/session.go @@ -72,9 +72,9 @@ func (session *Session) Clone() *Session { // Init reset the session as the init status. func (session *Session) Init() { + session.statement.Reset() session.statement.dialect = session.engine.dialect session.statement.Engine = session.engine - session.statement.Reset() session.showSQL = session.engine.showSQL session.isAutoCommit = true session.isCommitedOrRollbacked = false diff --git a/session_get.go b/session_get.go index 376ac2c1..c42361a8 100644 --- a/session_get.go +++ b/session_get.go @@ -282,7 +282,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf tableName := session.statement.TableName() cacher := session.engine.getCacher(tableName) - session.engine.logger.Debug("[cacheGet] find sql:", newsql, args) + session.engine.logger.Debug("[cache] Get SQL:", newsql, args) table := session.statement.RefTable ids, err := caches.GetCacheSql(cacher, tableName, newsql, args) if err != nil { @@ -318,19 +318,19 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } ids = []schemas.PK{pk} - session.engine.logger.Debug("[cacheGet] cache ids:", newsql, ids) + session.engine.logger.Debug("[cache] cache ids:", newsql, ids) err = caches.PutCacheSql(cacher, ids, tableName, newsql, args) if err != nil { return false, err } } else { - session.engine.logger.Debug("[cacheGet] cache hit sql:", newsql, ids) + session.engine.logger.Debug("[cache] cache hit:", newsql, ids) } if len(ids) > 0 { structValue := reflect.Indirect(reflect.ValueOf(bean)) id := ids[0] - session.engine.logger.Debug("[cacheGet] get bean:", tableName, id) + session.engine.logger.Debug("[cache] get bean:", tableName, id) sid, err := id.ToString() if err != nil { return false, err @@ -343,10 +343,10 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf return has, err } - session.engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean) + session.engine.logger.Debug("[cache] cache bean:", tableName, id, cacheBean) cacher.PutBean(tableName, sid, cacheBean) } else { - session.engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean) + session.engine.logger.Debug("[cache] cache hit:", tableName, id, cacheBean) has = true } structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean))) diff --git a/statement.go b/statement.go index 78e252b9..fd6b3962 100644 --- a/statement.go +++ b/statement.go @@ -587,10 +587,6 @@ func (statement *Statement) SetExpr(column string, expression interface{}) *Stat return statement } -func (statement *Statement) quoteColumnMap() []string { - return statement.dialect.Quoter().Strings(statement.columnMap) -} - // Distinct generates "DISTINCT col1, col2 " statement func (statement *Statement) Distinct(columns ...string) *Statement { statement.IsDistinct = true diff --git a/statement_test.go b/statement_test.go index f048e1e8..6e5564b0 100644 --- a/statement_test.go +++ b/statement_test.go @@ -234,8 +234,12 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) { }) assert.NoError(t, err) - record := assertGetRecord() - record.OnlyFromDBField = "test" - testEngine.Update(record) + assertGetRecord() + + _, err = testEngine.ID(1).Update(&TestOnlyFromDBField{ + OnlyToDBField: "b", + OnlyFromDBField: "test", + }) + assert.NoError(t, err) assertGetRecord() } -- 2.40.1