diff --git a/core/db.go b/core/db.go index 592ccf18..9aa771ba 100644 --- a/core/db.go +++ b/core/db.go @@ -83,7 +83,7 @@ type DB struct { Mapper names.Mapper reflectCache map[reflect.Type]*cacheStruct reflectCacheMutex sync.RWMutex - Logger log.SQLLogger + Logger log.ContextLogger } // Open opens a database @@ -108,6 +108,19 @@ func FromDB(db *sql.DB) *DB { } } +// NeedLogSQL returns true if need to log SQL +func (db *DB) NeedLogSQL(ctx context.Context) bool { + if db.Logger == nil { + return false + } + + v := ctx.Value("__xorm_show_sql") + if showSQL, ok := v.(bool); ok { + return showSQL + } + return db.Logger.IsShowSQL() +} + func (db *DB) reflectNew(typ reflect.Type) reflect.Value { db.reflectCacheMutex.Lock() defer db.reflectCacheMutex.Unlock() @@ -124,7 +137,8 @@ func (db *DB) reflectNew(typ reflect.Type) reflect.Value { // QueryContext overwrites sql.DB.QueryContext func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { start := time.Now() - if db.Logger != nil { + showSQL := db.NeedLogSQL(ctx) + if showSQL { db.Logger.BeforeSQL(log.LogContext{ Ctx: ctx, SQL: query, @@ -132,7 +146,7 @@ func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{ }) } rows, err := db.DB.QueryContext(ctx, query, args...) - if db.Logger != nil { + if showSQL { db.Logger.AfterSQL(log.LogContext{ Ctx: ctx, SQL: query, @@ -246,7 +260,8 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{ func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { start := time.Now() - if db.Logger != nil { + showSQL := db.NeedLogSQL(ctx) + if showSQL { db.Logger.BeforeSQL(log.LogContext{ Ctx: ctx, SQL: query, @@ -254,7 +269,7 @@ func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{} }) } res, err := db.DB.ExecContext(ctx, query, args...) - if db.Logger != nil { + if showSQL { db.Logger.AfterSQL(log.LogContext{ Ctx: ctx, SQL: query, diff --git a/core/stmt.go b/core/stmt.go index d3c46977..9d5954bd 100644 --- a/core/stmt.go +++ b/core/stmt.go @@ -32,14 +32,15 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { }) start := time.Now() - if db.Logger != nil { + showSQL := db.NeedLogSQL(ctx) + if showSQL { db.Logger.BeforeSQL(log.LogContext{ Ctx: ctx, SQL: "PREPARE", }) } stmt, err := db.DB.PrepareContext(ctx, query) - if db.Logger != nil { + if showSQL { db.Logger.AfterSQL(log.LogContext{ Ctx: ctx, SQL: "PREPARE", @@ -94,7 +95,8 @@ func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { start := time.Now() - if s.db.Logger != nil { + showSQL := s.db.NeedLogSQL(ctx) + if showSQL { s.db.Logger.BeforeSQL(log.LogContext{ Ctx: ctx, SQL: s.query, @@ -102,7 +104,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result }) } res, err := s.Stmt.ExecContext(ctx, args) - if s.db.Logger != nil { + if showSQL { s.db.Logger.AfterSQL(log.LogContext{ Ctx: ctx, SQL: s.query, @@ -116,7 +118,8 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { start := time.Now() - if s.db.Logger != nil { + showSQL := s.db.NeedLogSQL(ctx) + if showSQL { s.db.Logger.BeforeSQL(log.LogContext{ Ctx: ctx, SQL: s.query, @@ -124,7 +127,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er }) } rows, err := s.Stmt.QueryContext(ctx, args...) - if s.db.Logger != nil { + if showSQL { s.db.Logger.AfterSQL(log.LogContext{ Ctx: ctx, SQL: s.query, diff --git a/core/tx.go b/core/tx.go index 10022efc..07713267 100644 --- a/core/tx.go +++ b/core/tx.go @@ -19,14 +19,15 @@ type Tx struct { func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { start := time.Now() - if db.Logger != nil { + showSQL := db.NeedLogSQL(ctx) + if showSQL { db.Logger.BeforeSQL(log.LogContext{ Ctx: ctx, SQL: "BEGIN TRANSACTION", }) } tx, err := db.DB.BeginTx(ctx, opts) - if db.Logger != nil { + if showSQL { db.Logger.AfterSQL(log.LogContext{ Ctx: ctx, SQL: "BEGIN TRANSACTION", @@ -54,14 +55,15 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { }) start := time.Now() - if tx.db.Logger != nil { + showSQL := tx.db.NeedLogSQL(ctx) + if showSQL { tx.db.Logger.BeforeSQL(log.LogContext{ Ctx: ctx, SQL: "PREPARE", }) } stmt, err := tx.Tx.PrepareContext(ctx, query) - if tx.db.Logger != nil { + if showSQL { tx.db.Logger.AfterSQL(log.LogContext{ Ctx: ctx, SQL: "PREPARE", @@ -110,7 +112,8 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { start := time.Now() - if tx.db.Logger != nil { + showSQL := tx.db.NeedLogSQL(ctx) + if showSQL { tx.db.Logger.BeforeSQL(log.LogContext{ Ctx: ctx, SQL: query, @@ -118,7 +121,7 @@ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{} }) } res, err := tx.Tx.ExecContext(ctx, query, args...) - if tx.db.Logger != nil { + if showSQL { tx.db.Logger.AfterSQL(log.LogContext{ Ctx: ctx, SQL: query, @@ -136,7 +139,8 @@ func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { start := time.Now() - if tx.db.Logger != nil { + showSQL := tx.db.NeedLogSQL(ctx) + if showSQL { tx.db.Logger.BeforeSQL(log.LogContext{ Ctx: ctx, SQL: query, @@ -144,7 +148,7 @@ func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{ }) } rows, err := tx.Tx.QueryContext(ctx, query, args...) - if tx.db.Logger != nil { + if showSQL { tx.db.Logger.AfterSQL(log.LogContext{ Ctx: ctx, SQL: query, diff --git a/engine.go b/engine.go index b34f0716..421e89e0 100644 --- a/engine.go +++ b/engine.go @@ -61,11 +61,7 @@ func (engine *Engine) BufferSize(size int) *Session { // ShowSQL show SQL statement or not on logger if log level is great than INFO func (engine *Engine) ShowSQL(show ...bool) { engine.logger.ShowSQL(show...) - if engine.logger.IsShowSQL() { - engine.db.Logger = engine.logger - } else { - engine.db.Logger = &log.DiscardSQLLogger{} - } + engine.db.Logger = engine.logger } // Logger return the logger interface @@ -83,11 +79,7 @@ func (engine *Engine) SetLogger(logger interface{}) { realLogger = t } engine.logger = realLogger - if realLogger.IsShowSQL() { - engine.db.Logger = realLogger - } else { - engine.db.Logger = &log.DiscardSQLLogger{} - } + engine.db.Logger = realLogger } // SetLogLevel sets the logger level diff --git a/log/logger_context.go b/log/logger_context.go index b05f1c52..f80091f3 100644 --- a/log/logger_context.go +++ b/log/logger_context.go @@ -19,17 +19,10 @@ type LogContext struct { } type SQLLogger interface { - BeforeSQL(context LogContext) - AfterSQL(context LogContext) + BeforeSQL(context LogContext) // only invoked when IsShowSQL is true + AfterSQL(context LogContext) // only invoked when IsShowSQL is true } -type DiscardSQLLogger struct{} - -var _ SQLLogger = &DiscardSQLLogger{} - -func (DiscardSQLLogger) BeforeSQL(LogContext) {} -func (DiscardSQLLogger) AfterSQL(LogContext) {} - // ContextLogger represents a logger interface with context type ContextLogger interface { SQLLogger @@ -64,10 +57,6 @@ func NewLoggerAdapter(logger Logger) ContextLogger { func (l *LoggerAdapter) BeforeSQL(ctx LogContext) {} func (l *LoggerAdapter) AfterSQL(ctx LogContext) { - if !l.logger.IsShowSQL() { - return - } - if ctx.ExecuteTime > 0 { l.logger.Infof("[SQL] %v %v - %v", ctx.SQL, ctx.Args, ctx.ExecuteTime) } else { diff --git a/session.go b/session.go index 287465ca..4cab103f 100644 --- a/session.go +++ b/session.go @@ -58,8 +58,6 @@ type Session struct { prepareStmt bool stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) - // !evalphobia! stored the last executed query on this session - //beforeSQLExec func(string, ...interface{}) lastSQL string lastSQLArgs []interface{} showSQL bool @@ -82,7 +80,6 @@ func (session *Session) Init() { session.engine.DatabaseTZ, ) - //session.showSQL = session.engine.showSQL session.isAutoCommit = true session.isCommitedOrRollbacked = false session.isAutoClose = false @@ -241,11 +238,11 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session { // MustLogSQL means record SQL or not and don't follow engine's setting func (session *Session) MustLogSQL(log ...bool) *Session { + var showSQL = true if len(log) > 0 { - session.showSQL = log[0] - } else { - session.showSQL = true + showSQL = log[0] } + session.ctx = context.WithValue(session.ctx, "__xorm_show_sql", showSQL) return session } diff --git a/session_test.go b/session_test.go index 343f9baa..968842c3 100644 --- a/session_test.go +++ b/session_test.go @@ -43,3 +43,14 @@ func TestNullFloatStruct(t *testing.T) { }) assert.NoError(t, err) } + +func TestMustLogSQL(t *testing.T) { + assert.NoError(t, prepareEngine()) + testEngine.ShowSQL(false) + defer testEngine.ShowSQL(true) + + assertSync(t, new(Userinfo)) + + _, err := testEngine.Table("userinfo").MustLogSQL(true).Get(new(Userinfo)) + assert.NoError(t, err) +} diff --git a/session_tx.go b/session_tx.go index 7a4861c6..489489f3 100644 --- a/session_tx.go +++ b/session_tx.go @@ -34,17 +34,22 @@ func (session *Session) Rollback() error { session.isAutoCommit = true start := time.Now() - session.engine.logger.BeforeSQL(log.LogContext{ - Ctx: session.ctx, - SQL: "ROLL BACK", - }) + needSQL := session.engine.db.NeedLogSQL(session.ctx) + if needSQL { + session.engine.logger.BeforeSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "ROLL BACK", + }) + } err := session.tx.Rollback() - session.engine.logger.AfterSQL(log.LogContext{ - Ctx: session.ctx, - SQL: "ROLL BACK", - ExecuteTime: time.Now().Sub(start), - Err: err, - }) + if needSQL { + session.engine.logger.AfterSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "ROLL BACK", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } return err } return nil @@ -58,17 +63,22 @@ func (session *Session) Commit() error { session.isAutoCommit = true start := time.Now() - session.engine.logger.BeforeSQL(log.LogContext{ - Ctx: session.ctx, - SQL: "COMMIT", - }) + needSQL := session.engine.db.NeedLogSQL(session.ctx) + if needSQL { + session.engine.logger.BeforeSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "COMMIT", + }) + } err := session.tx.Commit() - session.engine.logger.AfterSQL(log.LogContext{ - Ctx: session.ctx, - SQL: "COMMIT", - ExecuteTime: time.Now().Sub(start), - Err: err, - }) + if needSQL { + session.engine.logger.AfterSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "COMMIT", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } if err != nil { return err