Improve statement #1549

Merged
lunny merged 2 commits from lunny/statement into master 2020-02-27 00:34:18 +00:00
9 changed files with 89 additions and 72 deletions

View File

@ -130,7 +130,7 @@ func NewSimpleLogger3(out io.Writer, prefix string, flag int, l LogLevel) *Simpl
// Error implement ILogger // Error implement ILogger
func (s *SimpleLogger) Error(v ...interface{}) { func (s *SimpleLogger) Error(v ...interface{}) {
if s.level <= LOG_ERR { if s.level <= LOG_ERR {
s.ERR.Output(2, fmt.Sprint(v...)) s.ERR.Output(2, fmt.Sprintln(v...))
} }
return return
} }
@ -146,7 +146,7 @@ func (s *SimpleLogger) Errorf(format string, v ...interface{}) {
// Debug implement ILogger // Debug implement ILogger
func (s *SimpleLogger) Debug(v ...interface{}) { func (s *SimpleLogger) Debug(v ...interface{}) {
if s.level <= LOG_DEBUG { if s.level <= LOG_DEBUG {
s.DEBUG.Output(2, fmt.Sprint(v...)) s.DEBUG.Output(2, fmt.Sprintln(v...))
} }
return return
} }
@ -162,7 +162,7 @@ func (s *SimpleLogger) Debugf(format string, v ...interface{}) {
// Info implement ILogger // Info implement ILogger
func (s *SimpleLogger) Info(v ...interface{}) { func (s *SimpleLogger) Info(v ...interface{}) {
if s.level <= LOG_INFO { if s.level <= LOG_INFO {
s.INFO.Output(2, fmt.Sprint(v...)) s.INFO.Output(2, fmt.Sprintln(v...))
} }
return return
} }
@ -178,7 +178,7 @@ func (s *SimpleLogger) Infof(format string, v ...interface{}) {
// Warn implement ILogger // Warn implement ILogger
func (s *SimpleLogger) Warn(v ...interface{}) { func (s *SimpleLogger) Warn(v ...interface{}) {
if s.level <= LOG_WARNING { if s.level <= LOG_WARNING {
s.WARN.Output(2, fmt.Sprint(v...)) s.WARN.Output(2, fmt.Sprintln(v...))
} }
return return
} }

View File

@ -93,17 +93,25 @@ func (q Quoter) Join(a []string, sep string) string {
if i > 0 { if i > 0 {
b.WriteString(sep) b.WriteString(sep)
} }
if q[0] != "" { if q[0] != "" && s != "*" {
b.WriteString(q[0]) b.WriteString(q[0])
} }
b.WriteString(strings.TrimSpace(s)) b.WriteString(strings.TrimSpace(s))
if q[1] != "" { if q[1] != "" && s != "*" {
b.WriteString(q[1]) b.WriteString(q[1])
} }
} }
return b.String() return b.String()
} }
func (q Quoter) Strings(s []string) []string {
var res = make([]string, 0, len(s))
for _, a := range s {
res = append(res, q.Quote(a))
}
return res
}
func (q Quoter) QuoteTo(buf *strings.Builder, value string) { func (q Quoter) QuoteTo(buf *strings.Builder, value string) {
if q.IsEmpty() { if q.IsEmpty() {
buf.WriteString(value) buf.WriteString(value)

View File

@ -55,3 +55,11 @@ func TestJoin(t *testing.T) {
quoter = Quoter{"", ""} quoter = Quoter{"", ""}
assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", ")) assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", "))
} }
func TestStrings(t *testing.T) {
cols := []string{"f1", "f2", "t3.f3"}
quoter := Quoter{"[", "]"}
quotedCols := quoter.Strings(cols)
assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols)
}

View File

@ -72,7 +72,8 @@ func (session *Session) Clone() *Session {
// Init reset the session as the init status. // Init reset the session as the init status.
func (session *Session) Init() { func (session *Session) Init() {
session.statement.Init() session.statement.Reset()
session.statement.dialect = session.engine.dialect
session.statement.Engine = session.engine session.statement.Engine = session.engine
session.showSQL = session.engine.showSQL session.showSQL = session.engine.showSQL
session.isAutoCommit = true session.isAutoCommit = true
@ -128,7 +129,7 @@ func (session *Session) IsClosed() bool {
func (session *Session) resetStatement() { func (session *Session) resetStatement() {
if session.autoResetStatement { if session.autoResetStatement {
session.statement.Init() session.statement.Reset()
} }
} }

View File

@ -34,7 +34,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
return false, ErrTableNotFound return false, ErrTableNotFound
} }
tableName = session.statement.Engine.Quote(tableName) tableName = session.statement.quote(tableName)
if len(session.statement.JoinStr) > 0 { if len(session.statement.JoinStr) > 0 {
joinStr = session.statement.JoinStr joinStr = session.statement.JoinStr
} }

View File

@ -282,7 +282,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
tableName := session.statement.TableName() tableName := session.statement.TableName()
cacher := session.engine.getCacher(tableName) cacher := session.engine.getCacher(tableName)
session.engine.logger.Debug("[cacheGet] find sql:", newsql, args) session.engine.logger.Debug("[cache] Get SQL:", newsql, args)
table := session.statement.RefTable table := session.statement.RefTable
ids, err := caches.GetCacheSql(cacher, tableName, newsql, args) ids, err := caches.GetCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
@ -318,19 +318,19 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
} }
ids = []schemas.PK{pk} ids = []schemas.PK{pk}
session.engine.logger.Debug("[cacheGet] cache ids:", newsql, ids) session.engine.logger.Debug("[cache] cache ids:", newsql, ids)
err = caches.PutCacheSql(cacher, ids, tableName, newsql, args) err = caches.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil { if err != nil {
return false, err return false, err
} }
} else { } else {
session.engine.logger.Debug("[cacheGet] cache hit sql:", newsql, ids) session.engine.logger.Debug("[cache] cache hit:", newsql, ids)
} }
if len(ids) > 0 { if len(ids) > 0 {
structValue := reflect.Indirect(reflect.ValueOf(bean)) structValue := reflect.Indirect(reflect.ValueOf(bean))
id := ids[0] id := ids[0]
session.engine.logger.Debug("[cacheGet] get bean:", tableName, id) session.engine.logger.Debug("[cache] get bean:", tableName, id)
sid, err := id.ToString() sid, err := id.ToString()
if err != nil { if err != nil {
return false, err return false, err
@ -343,10 +343,10 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return has, err return has, err
} }
session.engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean) session.engine.logger.Debug("[cache] cache bean:", tableName, id, cacheBean)
cacher.PutBean(tableName, sid, cacheBean) cacher.PutBean(tableName, sid, cacheBean)
} else { } else {
session.engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean) session.engine.logger.Debug("[cache] cache hit:", tableName, id, cacheBean)
has = true has = true
} }
structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean))) structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean)))

View File

@ -341,9 +341,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var top string var top string
if st.LimitN != nil { if st.LimitN != nil {
limitValue := *st.LimitN limitValue := *st.LimitN
if st.Engine.dialect.DBType() == schemas.MYSQL { if st.dialect.DBType() == schemas.MYSQL {
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
} else if st.Engine.dialect.DBType() == schemas.SQLITE { } else if st.dialect.DBType() == schemas.SQLITE {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), tempCondSQL), condArgs...))
@ -354,7 +354,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 { if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL condSQL = "WHERE " + condSQL
} }
} else if st.Engine.dialect.DBType() == schemas.POSTGRES { } else if st.dialect.DBType() == schemas.POSTGRES {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), tempCondSQL), condArgs...))
@ -366,8 +366,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 { if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL condSQL = "WHERE " + condSQL
} }
} else if st.Engine.dialect.DBType() == schemas.MSSQL { } else if st.dialect.DBType() == schemas.MSSQL {
if st.OrderStr != "" && st.Engine.dialect.DBType() == schemas.MSSQL && if st.OrderStr != "" && st.dialect.DBType() == schemas.MSSQL &&
table != nil && len(table.PrimaryKeys) == 1 { table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],

View File

@ -20,6 +20,7 @@ import (
// Statement save all the sql info for executing SQL // Statement save all the sql info for executing SQL
type Statement struct { type Statement struct {
RefTable *schemas.Table RefTable *schemas.Table
dialect dialects.Dialect
Engine *Engine Engine *Engine
Start int Start int
LimitN *int LimitN *int
@ -32,7 +33,6 @@ type Statement struct {
ColumnStr string ColumnStr string
selectStr string selectStr string
useAllCols bool useAllCols bool
OmitStr string
AltTableName string AltTableName string
tableName string tableName string
RawSQL string RawSQL string
@ -63,8 +63,20 @@ type Statement struct {
lastError error lastError error
} }
func newStatement(dialect dialects.Dialect) *Statement {
statement := &Statement{
dialect: dialect,
}
statement.Reset()
return statement
}
func (statement *Statement) omitStr() string {
return statement.dialect.Quoter().Join(statement.omitColumnMap, " ,")
}
// Init reset all the statement's fields // Init reset all the statement's fields
func (statement *Statement) Init() { func (statement *Statement) Reset() {
statement.RefTable = nil statement.RefTable = nil
statement.Start = 0 statement.Start = 0
statement.LimitN = nil statement.LimitN = nil
@ -75,7 +87,6 @@ func (statement *Statement) Init() {
statement.GroupByStr = "" statement.GroupByStr = ""
statement.HavingStr = "" statement.HavingStr = ""
statement.ColumnStr = "" statement.ColumnStr = ""
statement.OmitStr = ""
statement.columnMap = columnMap{} statement.columnMap = columnMap{}
statement.omitColumnMap = columnMap{} statement.omitColumnMap = columnMap{}
statement.AltTableName = "" statement.AltTableName = ""
@ -144,6 +155,10 @@ func (statement *Statement) Where(query interface{}, args ...interface{}) *State
return statement.And(query, args...) return statement.And(query, args...)
} }
func (statement *Statement) quote(s string) string {
return statement.dialect.Quoter().Quote(s)
}
// And add Where & and statement // And add Where & and statement
func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
switch query.(type) { switch query.(type) {
@ -154,7 +169,7 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme
queryMap := query.(map[string]interface{}) queryMap := query.(map[string]interface{})
newMap := make(map[string]interface{}) newMap := make(map[string]interface{})
for k, v := range queryMap { for k, v := range queryMap {
newMap[statement.Engine.Quote(k)] = v newMap[statement.quote(k)] = v
} }
statement.cond = statement.cond.And(builder.Eq(newMap)) statement.cond = statement.cond.And(builder.Eq(newMap))
case builder.Cond: case builder.Cond:
@ -197,14 +212,14 @@ func (statement *Statement) Or(query interface{}, args ...interface{}) *Statemen
// In generate "Where column IN (?) " statement // In generate "Where column IN (?) " statement
func (statement *Statement) In(column string, args ...interface{}) *Statement { func (statement *Statement) In(column string, args ...interface{}) *Statement {
in := builder.In(statement.Engine.Quote(column), args...) in := builder.In(statement.quote(column), args...)
statement.cond = statement.cond.And(in) statement.cond = statement.cond.And(in)
return statement return statement
} }
// NotIn generate "Where column NOT IN (?) " statement // NotIn generate "Where column NOT IN (?) " statement
func (statement *Statement) NotIn(column string, args ...interface{}) *Statement { func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
notIn := builder.NotIn(statement.Engine.Quote(column), args...) notIn := builder.NotIn(statement.quote(column), args...)
statement.cond = statement.cond.And(notIn) statement.cond = statement.cond.And(notIn)
return statement return statement
} }
@ -341,7 +356,7 @@ func (statement *Statement) buildUpdates(bean interface{},
if fieldValue.IsNil() { if fieldValue.IsNil() {
if includeNil { if includeNil {
args = append(args, nil) args = append(args, nil)
colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) colNames = append(colNames, fmt.Sprintf("%v=?", statement.quote(col.Name)))
} }
continue continue
} else if !fieldValue.IsValid() { } else if !fieldValue.IsValid() {
@ -485,10 +500,10 @@ func (statement *Statement) buildUpdates(bean interface{},
APPEND: APPEND:
args = append(args, val) args = append(args, val)
if col.IsPrimaryKey && engine.dialect.DBType() == "ql" { if col.IsPrimaryKey {
continue continue
} }
colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name)))
} }
return colNames, args return colNames, args
@ -504,9 +519,9 @@ func (statement *Statement) colName(col *schemas.Column, tableName string) strin
if len(statement.TableAlias) > 0 { if len(statement.TableAlias) > 0 {
nm = statement.TableAlias nm = statement.TableAlias
} }
return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name) return statement.quote(nm) + "." + statement.quote(col.Name)
} }
return statement.Engine.Quote(col.Name) return statement.quote(col.Name)
} }
// TableName return current tableName // TableName return current tableName
@ -572,24 +587,6 @@ func (statement *Statement) SetExpr(column string, expression interface{}) *Stat
return statement return statement
} }
func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
newColumns := make([]string, 0)
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
for _, col := range columns {
newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...)))
}
return newColumns
}
func (statement *Statement) colmap2NewColsWithQuote() []string {
newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
copy(newColumns, statement.columnMap)
for i := 0; i < len(statement.columnMap); i++ {
newColumns[i] = statement.Engine.Quote(newColumns[i])
}
return newColumns
}
// Distinct generates "DISTINCT col1, col2 " statement // Distinct generates "DISTINCT col1, col2 " statement
func (statement *Statement) Distinct(columns ...string) *Statement { func (statement *Statement) Distinct(columns ...string) *Statement {
statement.IsDistinct = true statement.IsDistinct = true
@ -616,10 +613,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
statement.columnMap.add(nc) statement.columnMap.add(nc)
} }
newColumns := statement.colmap2NewColsWithQuote() statement.ColumnStr = statement.dialect.Quoter().Join(statement.columnMap, ", ")
statement.ColumnStr = strings.Join(newColumns, ", ")
statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.dialect.Quoter().Quote("*"), "*", -1)
return statement return statement
} }
@ -654,7 +648,6 @@ func (statement *Statement) Omit(columns ...string) {
for _, nc := range newColumns { for _, nc := range newColumns {
statement.omitColumnMap = append(statement.omitColumnMap, nc) statement.omitColumnMap = append(statement.omitColumnMap, nc)
} }
statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
} }
// Nullable Update use only: update columns to null when value is nullable and zero-value // Nullable Update use only: update columns to null when value is nullable and zero-value
@ -695,8 +688,13 @@ func (statement *Statement) Desc(colNames ...string) *Statement {
if len(statement.OrderStr) > 0 { if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, statement.OrderStr, ", ") fmt.Fprint(&buf, statement.OrderStr, ", ")
} }
newColNames := statement.col2NewColsWithQuote(colNames...) for i, col := range colNames {
fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, ")) if i > 0 {
fmt.Fprint(&buf, ", ")
}
statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " DESC")
}
statement.OrderStr = buf.String() statement.OrderStr = buf.String()
return statement return statement
} }
@ -707,8 +705,13 @@ func (statement *Statement) Asc(colNames ...string) *Statement {
if len(statement.OrderStr) > 0 { if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, statement.OrderStr, ", ") fmt.Fprint(&buf, statement.OrderStr, ", ")
} }
newColNames := statement.col2NewColsWithQuote(colNames...) for i, col := range colNames {
fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, ")) if i > 0 {
fmt.Fprint(&buf, ", ")
}
statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " ASC")
}
statement.OrderStr = buf.String() statement.OrderStr = buf.String()
return statement return statement
} }

View File

@ -166,15 +166,17 @@ func (TestType) TableName() string {
func createTestStatement() *Statement { func createTestStatement() *Statement {
if engine, ok := testEngine.(*Engine); ok { if engine, ok := testEngine.(*Engine); ok {
statement := &Statement{} statement := &Statement{}
statement.Init() statement.Reset()
statement.Engine = engine statement.Engine = engine
statement.dialect = engine.dialect
statement.setRefValue(reflect.ValueOf(TestType{})) statement.setRefValue(reflect.ValueOf(TestType{}))
return statement return statement
} else if eg, ok := testEngine.(*EngineGroup); ok { } else if eg, ok := testEngine.(*EngineGroup); ok {
statement := &Statement{} statement := &Statement{}
statement.Init() statement.Reset()
statement.Engine = eg.Engine statement.Engine = eg.Engine
statement.dialect = eg.Engine.dialect
statement.setRefValue(reflect.ValueOf(TestType{})) statement.setRefValue(reflect.ValueOf(TestType{}))
return statement return statement
@ -232,17 +234,12 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
record := assertGetRecord() assertGetRecord()
record.OnlyFromDBField = "test"
testEngine.Update(record) _, err = testEngine.ID(1).Update(&TestOnlyFromDBField{
OnlyToDBField: "b",
OnlyFromDBField: "test",
})
assert.NoError(t, err)
assertGetRecord() assertGetRecord()
} }
func TestCol2NewColsWithQuote(t *testing.T) {
cols := []string{"f1", "f2", "t3.f3"}
statement := createTestStatement()
quotedCols := statement.col2NewColsWithQuote(cols...)
assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols)
}