Improve statement #1549

Merged
lunny merged 2 commits from lunny/statement into master 2020-02-27 00:34:18 +00:00
9 changed files with 89 additions and 72 deletions

View File

@ -130,7 +130,7 @@ func NewSimpleLogger3(out io.Writer, prefix string, flag int, l LogLevel) *Simpl
// Error implement ILogger
func (s *SimpleLogger) Error(v ...interface{}) {
if s.level <= LOG_ERR {
s.ERR.Output(2, fmt.Sprint(v...))
s.ERR.Output(2, fmt.Sprintln(v...))
}
return
}
@ -146,7 +146,7 @@ func (s *SimpleLogger) Errorf(format string, v ...interface{}) {
// Debug implement ILogger
func (s *SimpleLogger) Debug(v ...interface{}) {
if s.level <= LOG_DEBUG {
s.DEBUG.Output(2, fmt.Sprint(v...))
s.DEBUG.Output(2, fmt.Sprintln(v...))
}
return
}
@ -162,7 +162,7 @@ func (s *SimpleLogger) Debugf(format string, v ...interface{}) {
// Info implement ILogger
func (s *SimpleLogger) Info(v ...interface{}) {
if s.level <= LOG_INFO {
s.INFO.Output(2, fmt.Sprint(v...))
s.INFO.Output(2, fmt.Sprintln(v...))
}
return
}
@ -178,7 +178,7 @@ func (s *SimpleLogger) Infof(format string, v ...interface{}) {
// Warn implement ILogger
func (s *SimpleLogger) Warn(v ...interface{}) {
if s.level <= LOG_WARNING {
s.WARN.Output(2, fmt.Sprint(v...))
s.WARN.Output(2, fmt.Sprintln(v...))
}
return
}

View File

@ -93,17 +93,25 @@ func (q Quoter) Join(a []string, sep string) string {
if i > 0 {
b.WriteString(sep)
}
if q[0] != "" {
if q[0] != "" && s != "*" {
b.WriteString(q[0])
}
b.WriteString(strings.TrimSpace(s))
if q[1] != "" {
if q[1] != "" && s != "*" {
b.WriteString(q[1])
}
}
return b.String()
}
func (q Quoter) Strings(s []string) []string {
var res = make([]string, 0, len(s))
for _, a := range s {
res = append(res, q.Quote(a))
}
return res
}
func (q Quoter) QuoteTo(buf *strings.Builder, value string) {
if q.IsEmpty() {
buf.WriteString(value)

View File

@ -55,3 +55,11 @@ func TestJoin(t *testing.T) {
quoter = Quoter{"", ""}
assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", "))
}
func TestStrings(t *testing.T) {
cols := []string{"f1", "f2", "t3.f3"}
quoter := Quoter{"[", "]"}
quotedCols := quoter.Strings(cols)
assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols)
}

View File

@ -72,7 +72,8 @@ func (session *Session) Clone() *Session {
// Init reset the session as the init status.
func (session *Session) Init() {
session.statement.Init()
session.statement.Reset()
session.statement.dialect = session.engine.dialect
session.statement.Engine = session.engine
session.showSQL = session.engine.showSQL
session.isAutoCommit = true
@ -128,7 +129,7 @@ func (session *Session) IsClosed() bool {
func (session *Session) resetStatement() {
if session.autoResetStatement {
session.statement.Init()
session.statement.Reset()
}
}

View File

@ -34,7 +34,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
return false, ErrTableNotFound
}
tableName = session.statement.Engine.Quote(tableName)
tableName = session.statement.quote(tableName)
if len(session.statement.JoinStr) > 0 {
joinStr = session.statement.JoinStr
}

View File

@ -282,7 +282,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
tableName := session.statement.TableName()
cacher := session.engine.getCacher(tableName)
session.engine.logger.Debug("[cacheGet] find sql:", newsql, args)
session.engine.logger.Debug("[cache] Get SQL:", newsql, args)
table := session.statement.RefTable
ids, err := caches.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
@ -318,19 +318,19 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
}
ids = []schemas.PK{pk}
session.engine.logger.Debug("[cacheGet] cache ids:", newsql, ids)
session.engine.logger.Debug("[cache] cache ids:", newsql, ids)
err = caches.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil {
return false, err
}
} else {
session.engine.logger.Debug("[cacheGet] cache hit sql:", newsql, ids)
session.engine.logger.Debug("[cache] cache hit:", newsql, ids)
}
if len(ids) > 0 {
structValue := reflect.Indirect(reflect.ValueOf(bean))
id := ids[0]
session.engine.logger.Debug("[cacheGet] get bean:", tableName, id)
session.engine.logger.Debug("[cache] get bean:", tableName, id)
sid, err := id.ToString()
if err != nil {
return false, err
@ -343,10 +343,10 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return has, err
}
session.engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean)
session.engine.logger.Debug("[cache] cache bean:", tableName, id, cacheBean)
cacher.PutBean(tableName, sid, cacheBean)
} else {
session.engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean)
session.engine.logger.Debug("[cache] cache hit:", tableName, id, cacheBean)
has = true
}
structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean)))

View File

@ -341,9 +341,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var top string
if st.LimitN != nil {
limitValue := *st.LimitN
if st.Engine.dialect.DBType() == schemas.MYSQL {
if st.dialect.DBType() == schemas.MYSQL {
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
} else if st.Engine.dialect.DBType() == schemas.SQLITE {
} else if st.dialect.DBType() == schemas.SQLITE {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
@ -354,7 +354,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
} else if st.Engine.dialect.DBType() == schemas.POSTGRES {
} else if st.dialect.DBType() == schemas.POSTGRES {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
@ -366,8 +366,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
} else if st.Engine.dialect.DBType() == schemas.MSSQL {
if st.OrderStr != "" && st.Engine.dialect.DBType() == schemas.MSSQL &&
} else if st.dialect.DBType() == schemas.MSSQL {
if st.OrderStr != "" && st.dialect.DBType() == schemas.MSSQL &&
table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],

View File

@ -20,6 +20,7 @@ import (
// Statement save all the sql info for executing SQL
type Statement struct {
RefTable *schemas.Table
dialect dialects.Dialect
Engine *Engine
Start int
LimitN *int
@ -32,7 +33,6 @@ type Statement struct {
ColumnStr string
selectStr string
useAllCols bool
OmitStr string
AltTableName string
tableName string
RawSQL string
@ -63,8 +63,20 @@ type Statement struct {
lastError error
}
func newStatement(dialect dialects.Dialect) *Statement {
statement := &Statement{
dialect: dialect,
}
statement.Reset()
return statement
}
func (statement *Statement) omitStr() string {
return statement.dialect.Quoter().Join(statement.omitColumnMap, " ,")
}
// Init reset all the statement's fields
func (statement *Statement) Init() {
func (statement *Statement) Reset() {
statement.RefTable = nil
statement.Start = 0
statement.LimitN = nil
@ -75,7 +87,6 @@ func (statement *Statement) Init() {
statement.GroupByStr = ""
statement.HavingStr = ""
statement.ColumnStr = ""
statement.OmitStr = ""
statement.columnMap = columnMap{}
statement.omitColumnMap = columnMap{}
statement.AltTableName = ""
@ -144,6 +155,10 @@ func (statement *Statement) Where(query interface{}, args ...interface{}) *State
return statement.And(query, args...)
}
func (statement *Statement) quote(s string) string {
return statement.dialect.Quoter().Quote(s)
}
// And add Where & and statement
func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
switch query.(type) {
@ -154,7 +169,7 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme
queryMap := query.(map[string]interface{})
newMap := make(map[string]interface{})
for k, v := range queryMap {
newMap[statement.Engine.Quote(k)] = v
newMap[statement.quote(k)] = v
}
statement.cond = statement.cond.And(builder.Eq(newMap))
case builder.Cond:
@ -197,14 +212,14 @@ func (statement *Statement) Or(query interface{}, args ...interface{}) *Statemen
// In generate "Where column IN (?) " statement
func (statement *Statement) In(column string, args ...interface{}) *Statement {
in := builder.In(statement.Engine.Quote(column), args...)
in := builder.In(statement.quote(column), args...)
statement.cond = statement.cond.And(in)
return statement
}
// NotIn generate "Where column NOT IN (?) " statement
func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
notIn := builder.NotIn(statement.Engine.Quote(column), args...)
notIn := builder.NotIn(statement.quote(column), args...)
statement.cond = statement.cond.And(notIn)
return statement
}
@ -341,7 +356,7 @@ func (statement *Statement) buildUpdates(bean interface{},
if fieldValue.IsNil() {
if includeNil {
args = append(args, nil)
colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
colNames = append(colNames, fmt.Sprintf("%v=?", statement.quote(col.Name)))
}
continue
} else if !fieldValue.IsValid() {
@ -485,10 +500,10 @@ func (statement *Statement) buildUpdates(bean interface{},
APPEND:
args = append(args, val)
if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
if col.IsPrimaryKey {
continue
}
colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name)))
}
return colNames, args
@ -504,9 +519,9 @@ func (statement *Statement) colName(col *schemas.Column, tableName string) strin
if len(statement.TableAlias) > 0 {
nm = statement.TableAlias
}
return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
return statement.quote(nm) + "." + statement.quote(col.Name)
}
return statement.Engine.Quote(col.Name)
return statement.quote(col.Name)
}
// TableName return current tableName
@ -572,24 +587,6 @@ func (statement *Statement) SetExpr(column string, expression interface{}) *Stat
return statement
}
func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
newColumns := make([]string, 0)
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
for _, col := range columns {
newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...)))
}
return newColumns
}
func (statement *Statement) colmap2NewColsWithQuote() []string {
newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
copy(newColumns, statement.columnMap)
for i := 0; i < len(statement.columnMap); i++ {
newColumns[i] = statement.Engine.Quote(newColumns[i])
}
return newColumns
}
// Distinct generates "DISTINCT col1, col2 " statement
func (statement *Statement) Distinct(columns ...string) *Statement {
statement.IsDistinct = true
@ -616,10 +613,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
statement.columnMap.add(nc)
}
newColumns := statement.colmap2NewColsWithQuote()
statement.ColumnStr = strings.Join(newColumns, ", ")
statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.dialect.Quoter().Quote("*"), "*", -1)
statement.ColumnStr = statement.dialect.Quoter().Join(statement.columnMap, ", ")
return statement
}
@ -654,7 +648,6 @@ func (statement *Statement) Omit(columns ...string) {
for _, nc := range newColumns {
statement.omitColumnMap = append(statement.omitColumnMap, nc)
}
statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
}
// Nullable Update use only: update columns to null when value is nullable and zero-value
@ -695,8 +688,13 @@ func (statement *Statement) Desc(colNames ...string) *Statement {
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, statement.OrderStr, ", ")
}
newColNames := statement.col2NewColsWithQuote(colNames...)
fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " DESC")
}
statement.OrderStr = buf.String()
return statement
}
@ -707,8 +705,13 @@ func (statement *Statement) Asc(colNames ...string) *Statement {
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, statement.OrderStr, ", ")
}
newColNames := statement.col2NewColsWithQuote(colNames...)
fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
for i, col := range colNames {
if i > 0 {
fmt.Fprint(&buf, ", ")
}
statement.dialect.Quoter().QuoteTo(&buf, col)
fmt.Fprint(&buf, " ASC")
}
statement.OrderStr = buf.String()
return statement
}

View File

@ -166,15 +166,17 @@ func (TestType) TableName() string {
func createTestStatement() *Statement {
if engine, ok := testEngine.(*Engine); ok {
statement := &Statement{}
statement.Init()
statement.Reset()
statement.Engine = engine
statement.dialect = engine.dialect
statement.setRefValue(reflect.ValueOf(TestType{}))
return statement
} else if eg, ok := testEngine.(*EngineGroup); ok {
statement := &Statement{}
statement.Init()
statement.Reset()
statement.Engine = eg.Engine
statement.dialect = eg.Engine.dialect
statement.setRefValue(reflect.ValueOf(TestType{}))
return statement
@ -232,17 +234,12 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) {
})
assert.NoError(t, err)
record := assertGetRecord()
record.OnlyFromDBField = "test"
testEngine.Update(record)
assertGetRecord()
_, err = testEngine.ID(1).Update(&TestOnlyFromDBField{
OnlyToDBField: "b",
OnlyFromDBField: "test",
})
assert.NoError(t, err)
assertGetRecord()
}
func TestCol2NewColsWithQuote(t *testing.T) {
cols := []string{"f1", "f2", "t3.f3"}
statement := createTestStatement()
quotedCols := statement.col2NewColsWithQuote(cols...)
assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols)
}