Fix break session sql enable feature #1566

Merged
lunny merged 1 commits from lunny/fix_session_sql_log into master 2020-03-01 03:05:34 +00:00
8 changed files with 89 additions and 68 deletions

View File

@ -83,7 +83,7 @@ type DB struct {
Mapper names.Mapper Mapper names.Mapper
reflectCache map[reflect.Type]*cacheStruct reflectCache map[reflect.Type]*cacheStruct
reflectCacheMutex sync.RWMutex reflectCacheMutex sync.RWMutex
Logger log.SQLLogger Logger log.ContextLogger
} }
// Open opens a database // Open opens a database
@ -108,6 +108,19 @@ func FromDB(db *sql.DB) *DB {
} }
} }
// NeedLogSQL returns true if need to log SQL
func (db *DB) NeedLogSQL(ctx context.Context) bool {
if db.Logger == nil {
return false
}
v := ctx.Value("__xorm_show_sql")
if showSQL, ok := v.(bool); ok {
return showSQL
}
return db.Logger.IsShowSQL()
}
func (db *DB) reflectNew(typ reflect.Type) reflect.Value { func (db *DB) reflectNew(typ reflect.Type) reflect.Value {
db.reflectCacheMutex.Lock() db.reflectCacheMutex.Lock()
defer db.reflectCacheMutex.Unlock() defer db.reflectCacheMutex.Unlock()
@ -124,7 +137,8 @@ func (db *DB) reflectNew(typ reflect.Type) reflect.Value {
// QueryContext overwrites sql.DB.QueryContext // QueryContext overwrites sql.DB.QueryContext
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
start := time.Now() start := time.Now()
if db.Logger != nil { showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{ db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: query, SQL: query,
@ -132,7 +146,7 @@ func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{
}) })
} }
rows, err := db.DB.QueryContext(ctx, query, args...) rows, err := db.DB.QueryContext(ctx, query, args...)
if db.Logger != nil { if showSQL {
db.Logger.AfterSQL(log.LogContext{ db.Logger.AfterSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: query, SQL: query,
@ -246,7 +260,8 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
start := time.Now() start := time.Now()
if db.Logger != nil { showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{ db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: query, SQL: query,
@ -254,7 +269,7 @@ func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}
}) })
} }
res, err := db.DB.ExecContext(ctx, query, args...) res, err := db.DB.ExecContext(ctx, query, args...)
if db.Logger != nil { if showSQL {
db.Logger.AfterSQL(log.LogContext{ db.Logger.AfterSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: query, SQL: query,

View File

@ -32,14 +32,15 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
}) })
start := time.Now() start := time.Now()
if db.Logger != nil { showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{ db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: "PREPARE", SQL: "PREPARE",
}) })
} }
stmt, err := db.DB.PrepareContext(ctx, query) stmt, err := db.DB.PrepareContext(ctx, query)
if db.Logger != nil { if showSQL {
db.Logger.AfterSQL(log.LogContext{ db.Logger.AfterSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: "PREPARE", SQL: "PREPARE",
@ -94,7 +95,8 @@ func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
start := time.Now() start := time.Now()
if s.db.Logger != nil { showSQL := s.db.NeedLogSQL(ctx)
if showSQL {
s.db.Logger.BeforeSQL(log.LogContext{ s.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: s.query, SQL: s.query,
@ -102,7 +104,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result
}) })
} }
res, err := s.Stmt.ExecContext(ctx, args) res, err := s.Stmt.ExecContext(ctx, args)
if s.db.Logger != nil { if showSQL {
s.db.Logger.AfterSQL(log.LogContext{ s.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: s.query, SQL: s.query,
@ -116,7 +118,8 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
start := time.Now() start := time.Now()
if s.db.Logger != nil { showSQL := s.db.NeedLogSQL(ctx)
if showSQL {
s.db.Logger.BeforeSQL(log.LogContext{ s.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: s.query, SQL: s.query,
@ -124,7 +127,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
}) })
} }
rows, err := s.Stmt.QueryContext(ctx, args...) rows, err := s.Stmt.QueryContext(ctx, args...)
if s.db.Logger != nil { if showSQL {
s.db.Logger.AfterSQL(log.LogContext{ s.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: s.query, SQL: s.query,

View File

@ -19,14 +19,15 @@ type Tx struct {
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
start := time.Now() start := time.Now()
if db.Logger != nil { showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{ db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: "BEGIN TRANSACTION", SQL: "BEGIN TRANSACTION",
}) })
} }
tx, err := db.DB.BeginTx(ctx, opts) tx, err := db.DB.BeginTx(ctx, opts)
if db.Logger != nil { if showSQL {
db.Logger.AfterSQL(log.LogContext{ db.Logger.AfterSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: "BEGIN TRANSACTION", SQL: "BEGIN TRANSACTION",
@ -54,14 +55,15 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
}) })
start := time.Now() start := time.Now()
if tx.db.Logger != nil { showSQL := tx.db.NeedLogSQL(ctx)
if showSQL {
tx.db.Logger.BeforeSQL(log.LogContext{ tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: "PREPARE", SQL: "PREPARE",
}) })
} }
stmt, err := tx.Tx.PrepareContext(ctx, query) stmt, err := tx.Tx.PrepareContext(ctx, query)
if tx.db.Logger != nil { if showSQL {
tx.db.Logger.AfterSQL(log.LogContext{ tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: "PREPARE", SQL: "PREPARE",
@ -110,7 +112,8 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
start := time.Now() start := time.Now()
if tx.db.Logger != nil { showSQL := tx.db.NeedLogSQL(ctx)
if showSQL {
tx.db.Logger.BeforeSQL(log.LogContext{ tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: query, SQL: query,
@ -118,7 +121,7 @@ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}
}) })
} }
res, err := tx.Tx.ExecContext(ctx, query, args...) res, err := tx.Tx.ExecContext(ctx, query, args...)
if tx.db.Logger != nil { if showSQL {
tx.db.Logger.AfterSQL(log.LogContext{ tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: query, SQL: query,
@ -136,7 +139,8 @@ func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
start := time.Now() start := time.Now()
if tx.db.Logger != nil { showSQL := tx.db.NeedLogSQL(ctx)
if showSQL {
tx.db.Logger.BeforeSQL(log.LogContext{ tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: query, SQL: query,
@ -144,7 +148,7 @@ func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{
}) })
} }
rows, err := tx.Tx.QueryContext(ctx, query, args...) rows, err := tx.Tx.QueryContext(ctx, query, args...)
if tx.db.Logger != nil { if showSQL {
tx.db.Logger.AfterSQL(log.LogContext{ tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx, Ctx: ctx,
SQL: query, SQL: query,

View File

@ -61,11 +61,7 @@ func (engine *Engine) BufferSize(size int) *Session {
// ShowSQL show SQL statement or not on logger if log level is great than INFO // ShowSQL show SQL statement or not on logger if log level is great than INFO
func (engine *Engine) ShowSQL(show ...bool) { func (engine *Engine) ShowSQL(show ...bool) {
engine.logger.ShowSQL(show...) engine.logger.ShowSQL(show...)
if engine.logger.IsShowSQL() { engine.db.Logger = engine.logger
engine.db.Logger = engine.logger
} else {
engine.db.Logger = &log.DiscardSQLLogger{}
}
} }
// Logger return the logger interface // Logger return the logger interface
@ -83,11 +79,7 @@ func (engine *Engine) SetLogger(logger interface{}) {
realLogger = t realLogger = t
} }
engine.logger = realLogger engine.logger = realLogger
if realLogger.IsShowSQL() { engine.db.Logger = realLogger
engine.db.Logger = realLogger
} else {
engine.db.Logger = &log.DiscardSQLLogger{}
}
} }
// SetLogLevel sets the logger level // SetLogLevel sets the logger level

View File

@ -19,17 +19,10 @@ type LogContext struct {
} }
type SQLLogger interface { type SQLLogger interface {
BeforeSQL(context LogContext) BeforeSQL(context LogContext) // only invoked when IsShowSQL is true
AfterSQL(context LogContext) AfterSQL(context LogContext) // only invoked when IsShowSQL is true
} }
type DiscardSQLLogger struct{}
var _ SQLLogger = &DiscardSQLLogger{}
func (DiscardSQLLogger) BeforeSQL(LogContext) {}
func (DiscardSQLLogger) AfterSQL(LogContext) {}
// ContextLogger represents a logger interface with context // ContextLogger represents a logger interface with context
type ContextLogger interface { type ContextLogger interface {
SQLLogger SQLLogger
@ -64,10 +57,6 @@ func NewLoggerAdapter(logger Logger) ContextLogger {
func (l *LoggerAdapter) BeforeSQL(ctx LogContext) {} func (l *LoggerAdapter) BeforeSQL(ctx LogContext) {}
func (l *LoggerAdapter) AfterSQL(ctx LogContext) { func (l *LoggerAdapter) AfterSQL(ctx LogContext) {
if !l.logger.IsShowSQL() {
return
}
if ctx.ExecuteTime > 0 { if ctx.ExecuteTime > 0 {
l.logger.Infof("[SQL] %v %v - %v", ctx.SQL, ctx.Args, ctx.ExecuteTime) l.logger.Infof("[SQL] %v %v - %v", ctx.SQL, ctx.Args, ctx.ExecuteTime)
} else { } else {

View File

@ -58,8 +58,6 @@ type Session struct {
prepareStmt bool prepareStmt bool
stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr))
// !evalphobia! stored the last executed query on this session
//beforeSQLExec func(string, ...interface{})
lastSQL string lastSQL string
lastSQLArgs []interface{} lastSQLArgs []interface{}
showSQL bool showSQL bool
@ -82,7 +80,6 @@ func (session *Session) Init() {
session.engine.DatabaseTZ, session.engine.DatabaseTZ,
) )
//session.showSQL = session.engine.showSQL
session.isAutoCommit = true session.isAutoCommit = true
session.isCommitedOrRollbacked = false session.isCommitedOrRollbacked = false
session.isAutoClose = false session.isAutoClose = false
@ -241,11 +238,11 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session {
// MustLogSQL means record SQL or not and don't follow engine's setting // MustLogSQL means record SQL or not and don't follow engine's setting
func (session *Session) MustLogSQL(log ...bool) *Session { func (session *Session) MustLogSQL(log ...bool) *Session {
var showSQL = true
if len(log) > 0 { if len(log) > 0 {
session.showSQL = log[0] showSQL = log[0]
} else {
session.showSQL = true
} }
session.ctx = context.WithValue(session.ctx, "__xorm_show_sql", showSQL)
return session return session
} }

View File

@ -43,3 +43,14 @@ func TestNullFloatStruct(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestMustLogSQL(t *testing.T) {
assert.NoError(t, prepareEngine())
testEngine.ShowSQL(false)
defer testEngine.ShowSQL(true)
assertSync(t, new(Userinfo))
_, err := testEngine.Table("userinfo").MustLogSQL(true).Get(new(Userinfo))
assert.NoError(t, err)
}

View File

@ -34,17 +34,22 @@ func (session *Session) Rollback() error {
session.isAutoCommit = true session.isAutoCommit = true
start := time.Now() start := time.Now()
session.engine.logger.BeforeSQL(log.LogContext{ needSQL := session.engine.db.NeedLogSQL(session.ctx)
Ctx: session.ctx, if needSQL {
SQL: "ROLL BACK", session.engine.logger.BeforeSQL(log.LogContext{
}) Ctx: session.ctx,
SQL: "ROLL BACK",
})
}
err := session.tx.Rollback() err := session.tx.Rollback()
session.engine.logger.AfterSQL(log.LogContext{ if needSQL {
Ctx: session.ctx, session.engine.logger.AfterSQL(log.LogContext{
SQL: "ROLL BACK", Ctx: session.ctx,
ExecuteTime: time.Now().Sub(start), SQL: "ROLL BACK",
Err: err, ExecuteTime: time.Now().Sub(start),
}) Err: err,
})
}
return err return err
} }
return nil return nil
@ -58,17 +63,22 @@ func (session *Session) Commit() error {
session.isAutoCommit = true session.isAutoCommit = true
start := time.Now() start := time.Now()
session.engine.logger.BeforeSQL(log.LogContext{ needSQL := session.engine.db.NeedLogSQL(session.ctx)
Ctx: session.ctx, if needSQL {
SQL: "COMMIT", session.engine.logger.BeforeSQL(log.LogContext{
}) Ctx: session.ctx,
SQL: "COMMIT",
})
}
err := session.tx.Commit() err := session.tx.Commit()
session.engine.logger.AfterSQL(log.LogContext{ if needSQL {
Ctx: session.ctx, session.engine.logger.AfterSQL(log.LogContext{
SQL: "COMMIT", Ctx: session.ctx,
ExecuteTime: time.Now().Sub(start), SQL: "COMMIT",
Err: err, ExecuteTime: time.Now().Sub(start),
}) Err: err,
})
}
if err != nil { if err != nil {
return err return err