Improve statement #1549
|
@ -130,7 +130,7 @@ func NewSimpleLogger3(out io.Writer, prefix string, flag int, l LogLevel) *Simpl
|
|||
// Error implement ILogger
|
||||
func (s *SimpleLogger) Error(v ...interface{}) {
|
||||
if s.level <= LOG_ERR {
|
||||
s.ERR.Output(2, fmt.Sprint(v...))
|
||||
s.ERR.Output(2, fmt.Sprintln(v...))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func (s *SimpleLogger) Errorf(format string, v ...interface{}) {
|
|||
// Debug implement ILogger
|
||||
func (s *SimpleLogger) Debug(v ...interface{}) {
|
||||
if s.level <= LOG_DEBUG {
|
||||
s.DEBUG.Output(2, fmt.Sprint(v...))
|
||||
s.DEBUG.Output(2, fmt.Sprintln(v...))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -162,7 +162,7 @@ func (s *SimpleLogger) Debugf(format string, v ...interface{}) {
|
|||
// Info implement ILogger
|
||||
func (s *SimpleLogger) Info(v ...interface{}) {
|
||||
if s.level <= LOG_INFO {
|
||||
s.INFO.Output(2, fmt.Sprint(v...))
|
||||
s.INFO.Output(2, fmt.Sprintln(v...))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -178,7 +178,7 @@ func (s *SimpleLogger) Infof(format string, v ...interface{}) {
|
|||
// Warn implement ILogger
|
||||
func (s *SimpleLogger) Warn(v ...interface{}) {
|
||||
if s.level <= LOG_WARNING {
|
||||
s.WARN.Output(2, fmt.Sprint(v...))
|
||||
s.WARN.Output(2, fmt.Sprintln(v...))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -93,17 +93,25 @@ func (q Quoter) Join(a []string, sep string) string {
|
|||
if i > 0 {
|
||||
b.WriteString(sep)
|
||||
}
|
||||
if q[0] != "" {
|
||||
if q[0] != "" && s != "*" {
|
||||
b.WriteString(q[0])
|
||||
}
|
||||
b.WriteString(strings.TrimSpace(s))
|
||||
if q[1] != "" {
|
||||
if q[1] != "" && s != "*" {
|
||||
b.WriteString(q[1])
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (q Quoter) Strings(s []string) []string {
|
||||
var res = make([]string, 0, len(s))
|
||||
for _, a := range s {
|
||||
res = append(res, q.Quote(a))
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (q Quoter) QuoteTo(buf *strings.Builder, value string) {
|
||||
if q.IsEmpty() {
|
||||
buf.WriteString(value)
|
||||
|
|
|
@ -55,3 +55,11 @@ func TestJoin(t *testing.T) {
|
|||
quoter = Quoter{"", ""}
|
||||
assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", "))
|
||||
}
|
||||
|
||||
func TestStrings(t *testing.T) {
|
||||
cols := []string{"f1", "f2", "t3.f3"}
|
||||
quoter := Quoter{"[", "]"}
|
||||
|
||||
quotedCols := quoter.Strings(cols)
|
||||
assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols)
|
||||
}
|
||||
|
|
|
@ -72,7 +72,8 @@ func (session *Session) Clone() *Session {
|
|||
|
||||
// Init reset the session as the init status.
|
||||
func (session *Session) Init() {
|
||||
session.statement.Init()
|
||||
session.statement.Reset()
|
||||
session.statement.dialect = session.engine.dialect
|
||||
session.statement.Engine = session.engine
|
||||
session.showSQL = session.engine.showSQL
|
||||
session.isAutoCommit = true
|
||||
|
@ -128,7 +129,7 @@ func (session *Session) IsClosed() bool {
|
|||
|
||||
func (session *Session) resetStatement() {
|
||||
if session.autoResetStatement {
|
||||
session.statement.Init()
|
||||
session.statement.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
|
|||
return false, ErrTableNotFound
|
||||
}
|
||||
|
||||
tableName = session.statement.Engine.Quote(tableName)
|
||||
tableName = session.statement.quote(tableName)
|
||||
if len(session.statement.JoinStr) > 0 {
|
||||
joinStr = session.statement.JoinStr
|
||||
}
|
||||
|
|
|
@ -282,7 +282,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
|
|||
tableName := session.statement.TableName()
|
||||
cacher := session.engine.getCacher(tableName)
|
||||
|
||||
session.engine.logger.Debug("[cacheGet] find sql:", newsql, args)
|
||||
session.engine.logger.Debug("[cache] Get SQL:", newsql, args)
|
||||
table := session.statement.RefTable
|
||||
ids, err := caches.GetCacheSql(cacher, tableName, newsql, args)
|
||||
if err != nil {
|
||||
|
@ -318,19 +318,19 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
|
|||
}
|
||||
|
||||
ids = []schemas.PK{pk}
|
||||
session.engine.logger.Debug("[cacheGet] cache ids:", newsql, ids)
|
||||
session.engine.logger.Debug("[cache] cache ids:", newsql, ids)
|
||||
err = caches.PutCacheSql(cacher, ids, tableName, newsql, args)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
} else {
|
||||
session.engine.logger.Debug("[cacheGet] cache hit sql:", newsql, ids)
|
||||
session.engine.logger.Debug("[cache] cache hit:", newsql, ids)
|
||||
}
|
||||
|
||||
if len(ids) > 0 {
|
||||
structValue := reflect.Indirect(reflect.ValueOf(bean))
|
||||
id := ids[0]
|
||||
session.engine.logger.Debug("[cacheGet] get bean:", tableName, id)
|
||||
session.engine.logger.Debug("[cache] get bean:", tableName, id)
|
||||
sid, err := id.ToString()
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
@ -343,10 +343,10 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
|
|||
return has, err
|
||||
}
|
||||
|
||||
session.engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean)
|
||||
session.engine.logger.Debug("[cache] cache bean:", tableName, id, cacheBean)
|
||||
cacher.PutBean(tableName, sid, cacheBean)
|
||||
} else {
|
||||
session.engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean)
|
||||
session.engine.logger.Debug("[cache] cache hit:", tableName, id, cacheBean)
|
||||
has = true
|
||||
}
|
||||
structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean)))
|
||||
|
|
|
@ -341,9 +341,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
var top string
|
||||
if st.LimitN != nil {
|
||||
limitValue := *st.LimitN
|
||||
if st.Engine.dialect.DBType() == schemas.MYSQL {
|
||||
if st.dialect.DBType() == schemas.MYSQL {
|
||||
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
|
||||
} else if st.Engine.dialect.DBType() == schemas.SQLITE {
|
||||
} else if st.dialect.DBType() == schemas.SQLITE {
|
||||
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
|
||||
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
|
||||
session.engine.Quote(tableName), tempCondSQL), condArgs...))
|
||||
|
@ -354,7 +354,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
if len(condSQL) > 0 {
|
||||
condSQL = "WHERE " + condSQL
|
||||
}
|
||||
} else if st.Engine.dialect.DBType() == schemas.POSTGRES {
|
||||
} else if st.dialect.DBType() == schemas.POSTGRES {
|
||||
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
|
||||
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
|
||||
session.engine.Quote(tableName), tempCondSQL), condArgs...))
|
||||
|
@ -366,8 +366,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
if len(condSQL) > 0 {
|
||||
condSQL = "WHERE " + condSQL
|
||||
}
|
||||
} else if st.Engine.dialect.DBType() == schemas.MSSQL {
|
||||
if st.OrderStr != "" && st.Engine.dialect.DBType() == schemas.MSSQL &&
|
||||
} else if st.dialect.DBType() == schemas.MSSQL {
|
||||
if st.OrderStr != "" && st.dialect.DBType() == schemas.MSSQL &&
|
||||
table != nil && len(table.PrimaryKeys) == 1 {
|
||||
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
|
||||
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
|
||||
|
|
79
statement.go
79
statement.go
|
@ -20,6 +20,7 @@ import (
|
|||
// Statement save all the sql info for executing SQL
|
||||
type Statement struct {
|
||||
RefTable *schemas.Table
|
||||
dialect dialects.Dialect
|
||||
Engine *Engine
|
||||
Start int
|
||||
LimitN *int
|
||||
|
@ -32,7 +33,6 @@ type Statement struct {
|
|||
ColumnStr string
|
||||
selectStr string
|
||||
useAllCols bool
|
||||
OmitStr string
|
||||
AltTableName string
|
||||
tableName string
|
||||
RawSQL string
|
||||
|
@ -63,8 +63,20 @@ type Statement struct {
|
|||
lastError error
|
||||
}
|
||||
|
||||
func newStatement(dialect dialects.Dialect) *Statement {
|
||||
statement := &Statement{
|
||||
dialect: dialect,
|
||||
}
|
||||
statement.Reset()
|
||||
return statement
|
||||
}
|
||||
|
||||
func (statement *Statement) omitStr() string {
|
||||
return statement.dialect.Quoter().Join(statement.omitColumnMap, " ,")
|
||||
}
|
||||
|
||||
// Init reset all the statement's fields
|
||||
func (statement *Statement) Init() {
|
||||
func (statement *Statement) Reset() {
|
||||
statement.RefTable = nil
|
||||
statement.Start = 0
|
||||
statement.LimitN = nil
|
||||
|
@ -75,7 +87,6 @@ func (statement *Statement) Init() {
|
|||
statement.GroupByStr = ""
|
||||
statement.HavingStr = ""
|
||||
statement.ColumnStr = ""
|
||||
statement.OmitStr = ""
|
||||
statement.columnMap = columnMap{}
|
||||
statement.omitColumnMap = columnMap{}
|
||||
statement.AltTableName = ""
|
||||
|
@ -144,6 +155,10 @@ func (statement *Statement) Where(query interface{}, args ...interface{}) *State
|
|||
return statement.And(query, args...)
|
||||
}
|
||||
|
||||
func (statement *Statement) quote(s string) string {
|
||||
return statement.dialect.Quoter().Quote(s)
|
||||
}
|
||||
|
||||
// And add Where & and statement
|
||||
func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
|
||||
switch query.(type) {
|
||||
|
@ -154,7 +169,7 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme
|
|||
queryMap := query.(map[string]interface{})
|
||||
newMap := make(map[string]interface{})
|
||||
for k, v := range queryMap {
|
||||
newMap[statement.Engine.Quote(k)] = v
|
||||
newMap[statement.quote(k)] = v
|
||||
}
|
||||
statement.cond = statement.cond.And(builder.Eq(newMap))
|
||||
case builder.Cond:
|
||||
|
@ -197,14 +212,14 @@ func (statement *Statement) Or(query interface{}, args ...interface{}) *Statemen
|
|||
|
||||
// In generate "Where column IN (?) " statement
|
||||
func (statement *Statement) In(column string, args ...interface{}) *Statement {
|
||||
in := builder.In(statement.Engine.Quote(column), args...)
|
||||
in := builder.In(statement.quote(column), args...)
|
||||
statement.cond = statement.cond.And(in)
|
||||
return statement
|
||||
}
|
||||
|
||||
// NotIn generate "Where column NOT IN (?) " statement
|
||||
func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
|
||||
notIn := builder.NotIn(statement.Engine.Quote(column), args...)
|
||||
notIn := builder.NotIn(statement.quote(column), args...)
|
||||
statement.cond = statement.cond.And(notIn)
|
||||
return statement
|
||||
}
|
||||
|
@ -341,7 +356,7 @@ func (statement *Statement) buildUpdates(bean interface{},
|
|||
if fieldValue.IsNil() {
|
||||
if includeNil {
|
||||
args = append(args, nil)
|
||||
colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
|
||||
colNames = append(colNames, fmt.Sprintf("%v=?", statement.quote(col.Name)))
|
||||
}
|
||||
continue
|
||||
} else if !fieldValue.IsValid() {
|
||||
|
@ -485,10 +500,10 @@ func (statement *Statement) buildUpdates(bean interface{},
|
|||
|
||||
APPEND:
|
||||
args = append(args, val)
|
||||
if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
|
||||
if col.IsPrimaryKey {
|
||||
continue
|
||||
}
|
||||
colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
|
||||
colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name)))
|
||||
}
|
||||
|
||||
return colNames, args
|
||||
|
@ -504,9 +519,9 @@ func (statement *Statement) colName(col *schemas.Column, tableName string) strin
|
|||
if len(statement.TableAlias) > 0 {
|
||||
nm = statement.TableAlias
|
||||
}
|
||||
return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
|
||||
return statement.quote(nm) + "." + statement.quote(col.Name)
|
||||
}
|
||||
return statement.Engine.Quote(col.Name)
|
||||
return statement.quote(col.Name)
|
||||
}
|
||||
|
||||
// TableName return current tableName
|
||||
|
@ -572,24 +587,6 @@ func (statement *Statement) SetExpr(column string, expression interface{}) *Stat
|
|||
return statement
|
||||
}
|
||||
|
||||
func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
|
||||
newColumns := make([]string, 0)
|
||||
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
|
||||
for _, col := range columns {
|
||||
newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...)))
|
||||
}
|
||||
return newColumns
|
||||
}
|
||||
|
||||
func (statement *Statement) colmap2NewColsWithQuote() []string {
|
||||
newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
|
||||
copy(newColumns, statement.columnMap)
|
||||
for i := 0; i < len(statement.columnMap); i++ {
|
||||
newColumns[i] = statement.Engine.Quote(newColumns[i])
|
||||
}
|
||||
return newColumns
|
||||
}
|
||||
|
||||
// Distinct generates "DISTINCT col1, col2 " statement
|
||||
func (statement *Statement) Distinct(columns ...string) *Statement {
|
||||
statement.IsDistinct = true
|
||||
|
@ -616,10 +613,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
|
|||
statement.columnMap.add(nc)
|
||||
}
|
||||
|
||||
newColumns := statement.colmap2NewColsWithQuote()
|
||||
|
||||
statement.ColumnStr = strings.Join(newColumns, ", ")
|
||||
statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.dialect.Quoter().Quote("*"), "*", -1)
|
||||
statement.ColumnStr = statement.dialect.Quoter().Join(statement.columnMap, ", ")
|
||||
return statement
|
||||
}
|
||||
|
||||
|
@ -654,7 +648,6 @@ func (statement *Statement) Omit(columns ...string) {
|
|||
for _, nc := range newColumns {
|
||||
statement.omitColumnMap = append(statement.omitColumnMap, nc)
|
||||
}
|
||||
statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
|
||||
}
|
||||
|
||||
// Nullable Update use only: update columns to null when value is nullable and zero-value
|
||||
|
@ -695,8 +688,13 @@ func (statement *Statement) Desc(colNames ...string) *Statement {
|
|||
if len(statement.OrderStr) > 0 {
|
||||
fmt.Fprint(&buf, statement.OrderStr, ", ")
|
||||
}
|
||||
newColNames := statement.col2NewColsWithQuote(colNames...)
|
||||
fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
|
||||
for i, col := range colNames {
|
||||
if i > 0 {
|
||||
fmt.Fprint(&buf, ", ")
|
||||
}
|
||||
statement.dialect.Quoter().QuoteTo(&buf, col)
|
||||
fmt.Fprint(&buf, " DESC")
|
||||
}
|
||||
statement.OrderStr = buf.String()
|
||||
return statement
|
||||
}
|
||||
|
@ -707,8 +705,13 @@ func (statement *Statement) Asc(colNames ...string) *Statement {
|
|||
if len(statement.OrderStr) > 0 {
|
||||
fmt.Fprint(&buf, statement.OrderStr, ", ")
|
||||
}
|
||||
newColNames := statement.col2NewColsWithQuote(colNames...)
|
||||
fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
|
||||
for i, col := range colNames {
|
||||
if i > 0 {
|
||||
fmt.Fprint(&buf, ", ")
|
||||
}
|
||||
statement.dialect.Quoter().QuoteTo(&buf, col)
|
||||
fmt.Fprint(&buf, " ASC")
|
||||
}
|
||||
statement.OrderStr = buf.String()
|
||||
return statement
|
||||
}
|
||||
|
|
|
@ -166,15 +166,17 @@ func (TestType) TableName() string {
|
|||
func createTestStatement() *Statement {
|
||||
if engine, ok := testEngine.(*Engine); ok {
|
||||
statement := &Statement{}
|
||||
statement.Init()
|
||||
statement.Reset()
|
||||
statement.Engine = engine
|
||||
statement.dialect = engine.dialect
|
||||
statement.setRefValue(reflect.ValueOf(TestType{}))
|
||||
|
||||
return statement
|
||||
} else if eg, ok := testEngine.(*EngineGroup); ok {
|
||||
statement := &Statement{}
|
||||
statement.Init()
|
||||
statement.Reset()
|
||||
statement.Engine = eg.Engine
|
||||
statement.dialect = eg.Engine.dialect
|
||||
statement.setRefValue(reflect.ValueOf(TestType{}))
|
||||
|
||||
return statement
|
||||
|
@ -232,17 +234,12 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) {
|
|||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
record := assertGetRecord()
|
||||
record.OnlyFromDBField = "test"
|
||||
testEngine.Update(record)
|
||||
assertGetRecord()
|
||||
|
||||
_, err = testEngine.ID(1).Update(&TestOnlyFromDBField{
|
||||
OnlyToDBField: "b",
|
||||
OnlyFromDBField: "test",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assertGetRecord()
|
||||
}
|
||||
|
||||
func TestCol2NewColsWithQuote(t *testing.T) {
|
||||
cols := []string{"f1", "f2", "t3.f3"}
|
||||
|
||||
statement := createTestStatement()
|
||||
|
||||
quotedCols := statement.col2NewColsWithQuote(cols...)
|
||||
assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user