From 7e4dc9cc57bcba704259219264ca9d40c75efaa7 Mon Sep 17 00:00:00 2001 From: "yuxiao.lu" Date: Thu, 2 Apr 2020 15:58:04 +0800 Subject: [PATCH 1/2] add hook for engine --- core/db.go | 99 +++++++++++++++++++++++++++---------------- core/interface.go | 7 +++ core/stmt.go | 86 +++++++++++++++---------------------- core/tx.go | 81 +++++++++++++++-------------------- engine.go | 4 ++ engine_group.go | 8 ++++ interface.go | 2 + log/logger_context.go | 2 + 8 files changed, 154 insertions(+), 135 deletions(-) diff --git a/core/db.go b/core/db.go index 671d1dc2..931b062a 100644 --- a/core/db.go +++ b/core/db.go @@ -88,6 +88,7 @@ type DB struct { reflectCache map[reflect.Type]*cacheStruct reflectCacheMutex sync.RWMutex Logger log.ContextLogger + hooks []Hook } // Open opens a database @@ -139,27 +140,22 @@ func (db *DB) reflectNew(typ reflect.Type) reflect.Value { } // QueryContext overwrites sql.DB.QueryContext -func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { +func (db *DB) QueryContext(parentCtx context.Context, query string, args ...interface{}) (*Rows, error) { + logCtx := log.LogContext{ + Ctx: parentCtx, + SQL: query, + Args: args, + } start := time.Now() - showSQL := db.NeedLogSQL(ctx) - if showSQL { - db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - }) - } - rows, err := db.DB.QueryContext(ctx, query, args...) - if showSQL { - db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } + ctx, err := db.beforeProcess(logCtx) if err != nil { + return nil, err + } + logCtx.Ctx = ctx + rows, err := db.DB.QueryContext(ctx, query, args...) + logCtx.ExecuteTime = time.Now().Sub(start) + logCtx.Err = err + if err := db.afterProcess(logCtx); err != nil { if rows != nil { rows.Close() } @@ -262,29 +258,60 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{ return db.ExecContext(ctx, query, args...) } -func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { +func (db *DB) ExecContext(parentCtx context.Context, query string, args ...interface{}) (sql.Result, error) { + logCtx := log.LogContext{ + Ctx: parentCtx, + SQL: query, + Args: args, + } start := time.Now() - showSQL := db.NeedLogSQL(ctx) - if showSQL { - db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - }) + ctx, err := db.beforeProcess(logCtx) + if err != nil { + return nil, err } + logCtx.Ctx = ctx res, err := db.DB.ExecContext(ctx, query, args...) - if showSQL { - db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) + logCtx.Err = err + logCtx.ExecuteTime = time.Now().Sub(start) + if err := db.afterProcess(logCtx); err != nil { + return nil, err } - return res, err + return res, nil } func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { return db.ExecStructContext(context.Background(), query, st) } + +func (db *DB) beforeProcess(logCtx log.LogContext) (context.Context, error) { + ctx := logCtx.Ctx + if db.NeedLogSQL(ctx) { + db.Logger.BeforeSQL(logCtx) + } + for _, h := range db.hooks { + var err error + ctx, err = h.BeforeProcess(ctx, logCtx.SQL, logCtx.Args...) + if err != nil { + return nil, err + } + } + return ctx, nil +} + +func (db *DB) afterProcess(logCtx log.LogContext) error { + firstErr := logCtx.Err + for _, h := range db.hooks { + err := h.AfterProcess(&logCtx) + if err != nil && firstErr == nil { + firstErr = err + } + } + if db.NeedLogSQL(logCtx.Ctx) { + db.Logger.AfterSQL(logCtx) + } + return firstErr +} + +func (db *DB) AddHook(hook Hook) { + db.hooks = append(db.hooks, hook) +} diff --git a/core/interface.go b/core/interface.go index a5c8e4e2..73ca2d2b 100644 --- a/core/interface.go +++ b/core/interface.go @@ -3,8 +3,15 @@ package core import ( "context" "database/sql" + + "xorm.io/xorm/log" ) +type Hook interface { + BeforeProcess(ctx context.Context, query string, args ...interface{}) (context.Context, error) + AfterProcess(logContext *log.LogContext) error +} + // Queryer represents an interface to query a SQL to get data from database type Queryer interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) diff --git a/core/stmt.go b/core/stmt.go index ebf2af73..754a8f89 100644 --- a/core/stmt.go +++ b/core/stmt.go @@ -30,28 +30,21 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { i++ return "?" }) - + logCtx := log.LogContext{ + Ctx: ctx, + SQL: "PREPARE", + } start := time.Now() - showSQL := db.NeedLogSQL(ctx) - if showSQL { - db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - }) - } - stmt, err := db.DB.PrepareContext(ctx, query) - if showSQL { - db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } + ctx, err := db.beforeProcess(logCtx) if err != nil { return nil, err } - + stmt, err := db.DB.PrepareContext(ctx, query) + logCtx.Err = err + logCtx.ExecuteTime = time.Now().Sub(start) + if err := db.afterProcess(logCtx); err != nil { + return nil, err + } return &Stmt{stmt, db, names, query}, nil } @@ -94,49 +87,40 @@ func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { } func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + logCtx := log.LogContext{ + Ctx: ctx, + SQL: s.query, + Args: args, + } start := time.Now() - showSQL := s.db.NeedLogSQL(ctx) - if showSQL { - s.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - }) + ctx, err := s.db.beforeProcess(logCtx) + if err != nil { + return nil, err } res, err := s.Stmt.ExecContext(ctx, args) - if showSQL { - s.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) + logCtx.ExecuteTime = time.Now().Sub(start) + logCtx.Err = err + if err := s.db.afterProcess(logCtx); err != nil { + return nil, err } - return res, err + return res, nil } func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { + logCtx := log.LogContext{ + Ctx: ctx, + SQL: s.query, + Args: args, + } start := time.Now() - showSQL := s.db.NeedLogSQL(ctx) - if showSQL { - s.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - }) + ctx, err := s.db.beforeProcess(logCtx) + if err != nil { + return nil, err } rows, err := s.Stmt.QueryContext(ctx, args...) - if showSQL { - s.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } - if err != nil { + logCtx.ExecuteTime = time.Now().Sub(start) + logCtx.Err = err + if err := s.db.afterProcess(logCtx); err != nil { return nil, err } return &Rows{rows, s.db}, nil diff --git a/core/tx.go b/core/tx.go index 99a8097d..0b491659 100644 --- a/core/tx.go +++ b/core/tx.go @@ -58,25 +58,19 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { i++ return "?" }) - + logCtx := log.LogContext{ + Ctx: ctx, + SQL: "PREPARE", + } start := time.Now() - showSQL := tx.db.NeedLogSQL(ctx) - if showSQL { - tx.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - }) + ctx, err := tx.db.beforeProcess(logCtx) + if err != nil { + return nil, err } stmt, err := tx.Tx.PrepareContext(ctx, query) - if showSQL { - tx.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } - if err != nil { + logCtx.Err = err + logCtx.ExecuteTime = time.Now().Sub(start) + if err := tx.db.afterProcess(logCtx); err != nil { return nil, err } return &Stmt{stmt, tx.db, names, query}, nil @@ -117,23 +111,20 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { start := time.Now() - showSQL := tx.db.NeedLogSQL(ctx) - if showSQL { - tx.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - }) + logCtx := log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + } + ctx, err := tx.db.beforeProcess(logCtx) + if err != nil { + return nil, err } res, err := tx.Tx.ExecContext(ctx, query, args...) - if showSQL { - tx.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) + logCtx.ExecuteTime = time.Now().Sub(start) + logCtx.Err = err + if err := tx.db.afterProcess(logCtx); err != nil { + return nil, err } return res, err } @@ -143,26 +134,20 @@ func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { } func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + logCtx := log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + } start := time.Now() - showSQL := tx.db.NeedLogSQL(ctx) - if showSQL { - tx.db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - }) + ctx, err := tx.db.beforeProcess(logCtx) + if err != nil { + return nil, err } rows, err := tx.Tx.QueryContext(ctx, query, args...) - if showSQL { - tx.db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } - if err != nil { + logCtx.Err = err + logCtx.ExecuteTime = time.Now().Sub(start) + if err := tx.db.afterProcess(logCtx); err != nil { if rows != nil { rows.Close() } diff --git a/engine.go b/engine.go index d99e15db..7d196249 100644 --- a/engine.go +++ b/engine.go @@ -1287,6 +1287,10 @@ func (engine *Engine) SetSchema(schema string) { engine.dialect.URI().SetSchema(schema) } +func (engine *Engine) AddHook(hook core.Hook) { + engine.db.AddHook(hook) +} + // Unscoped always disable struct tag "deleted" func (engine *Engine) Unscoped() *Session { session := engine.NewSession() diff --git a/engine_group.go b/engine_group.go index 02a57ab4..a4a1f6a6 100644 --- a/engine_group.go +++ b/engine_group.go @@ -9,6 +9,7 @@ import ( "time" "xorm.io/xorm/caches" + "xorm.io/xorm/core" "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" @@ -143,6 +144,13 @@ func (eg *EngineGroup) SetLogger(logger interface{}) { } } +func (eg *EngineGroup) AddHook(hook core.Hook) { + eg.Engine.AddHook(hook) + for i := 0; i < len(eg.slaves); i++ { + eg.slaves[i].AddHook(hook) + } +} + // SetLogLevel sets the logger level func (eg *EngineGroup) SetLogLevel(level log.LogLevel) { eg.Engine.SetLogLevel(level) diff --git a/interface.go b/interface.go index 262a2cfe..5f8c4b2e 100644 --- a/interface.go +++ b/interface.go @@ -11,6 +11,7 @@ import ( "time" "xorm.io/xorm/caches" + "xorm.io/xorm/core" "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" @@ -111,6 +112,7 @@ type EngineInterface interface { SetTableMapper(names.Mapper) SetTZDatabase(tz *time.Location) SetTZLocation(tz *time.Location) + AddHook(hook core.Hook) ShowSQL(show ...bool) Sync(...interface{}) error Sync2(...interface{}) error diff --git a/log/logger_context.go b/log/logger_context.go index faed26d6..715f86c6 100644 --- a/log/logger_context.go +++ b/log/logger_context.go @@ -6,6 +6,7 @@ package log import ( "context" + "database/sql" "fmt" "time" ) @@ -15,6 +16,7 @@ type LogContext struct { Ctx context.Context SQL string // log content or SQL Args []interface{} // if it's a SQL, it's the arguments + Result sql.Result ExecuteTime time.Duration Err error // SQL executed error } -- 2.40.1 From bb8148bf290f9cab845312572dda0bf0cd2f669b Mon Sep 17 00:00:00 2001 From: "yuxiao.lu" Date: Fri, 3 Apr 2020 16:37:18 +0800 Subject: [PATCH 2/2] move hook to standalone package --- contexts/hook.go | 75 ++++++++++++++++++++++ contexts/hook_test.go | 140 ++++++++++++++++++++++++++++++++++++++++++ core/db.go | 76 ++++++++--------------- core/interface.go | 7 --- core/stmt.go | 44 ++++--------- core/tx.go | 66 ++++++-------------- engine.go | 5 +- engine_group.go | 6 +- interface.go | 4 +- log/logger_context.go | 14 +---- session.go | 2 +- 11 files changed, 285 insertions(+), 154 deletions(-) create mode 100644 contexts/hook.go create mode 100644 contexts/hook_test.go diff --git a/contexts/hook.go b/contexts/hook.go new file mode 100644 index 00000000..71ad8e87 --- /dev/null +++ b/contexts/hook.go @@ -0,0 +1,75 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package contexts + +import ( + "context" + "database/sql" + "time" +) + +// ContextHook represents a hook context +type ContextHook struct { + start time.Time + Ctx context.Context + SQL string // log content or SQL + Args []interface{} // if it's a SQL, it's the arguments + Result sql.Result + ExecuteTime time.Duration + Err error // SQL executed error +} + +// NewContextHook return context for hook +func NewContextHook(ctx context.Context, sql string, args []interface{}) *ContextHook { + return &ContextHook{ + start: time.Now(), + Ctx: ctx, + SQL: sql, + Args: args, + } +} + +func (c *ContextHook) End(ctx context.Context, result sql.Result, err error) { + c.Ctx = ctx + c.Result = result + c.Err = err + c.ExecuteTime = time.Now().Sub(c.start) +} + +type Hook interface { + BeforeProcess(c *ContextHook) (context.Context, error) + AfterProcess(c *ContextHook) error +} + +type Hooks struct { + hooks []Hook +} + +func (h *Hooks) AddHook(hooks ...Hook) { + h.hooks = append(h.hooks, hooks...) +} + +func (h *Hooks) BeforeProcess(c *ContextHook) (context.Context, error) { + ctx := c.Ctx + for _, h := range h.hooks { + var err error + ctx, err = h.BeforeProcess(c) + if err != nil { + return nil, err + } + } + return ctx, nil +} + +func (h *Hooks) AfterProcess(c *ContextHook) error { + firstErr := c.Err + for _, h := range h.hooks { + err := h.AfterProcess(c) + if err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} diff --git a/contexts/hook_test.go b/contexts/hook_test.go new file mode 100644 index 00000000..96c54e33 --- /dev/null +++ b/contexts/hook_test.go @@ -0,0 +1,140 @@ +package contexts + +import ( + "context" + "errors" + "testing" +) + +type testHook struct { + before func(c *ContextHook) (context.Context, error) + after func(c *ContextHook) error +} + +func (h *testHook) BeforeProcess(c *ContextHook) (context.Context, error) { + if h.before != nil { + return h.before(c) + } + return c.Ctx, nil +} + +func (h *testHook) AfterProcess(c *ContextHook) error { + if h.after != nil { + return h.after(c) + } + return c.Err +} + +var _ Hook = &testHook{} + +func TestBeforeProcess(t *testing.T) { + expectErr := errors.New("before error") + tests := []struct { + msg string + hooks []Hook + expect error + }{ + { + msg: "first hook return err", + hooks: []Hook{ + &testHook{ + before: func(c *ContextHook) (ctx context.Context, err error) { + return c.Ctx, expectErr + }, + }, + &testHook{ + before: func(c *ContextHook) (ctx context.Context, err error) { + return c.Ctx, nil + }, + }, + }, + expect: expectErr, + }, + { + msg: "second hook return err", + hooks: []Hook{ + &testHook{ + before: func(c *ContextHook) (ctx context.Context, err error) { + return c.Ctx, nil + }, + }, + &testHook{ + before: func(c *ContextHook) (ctx context.Context, err error) { + return c.Ctx, expectErr + }, + }, + }, + expect: expectErr, + }, + } + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + hooks := Hooks{} + hooks.AddHook(tt.hooks...) + _, err := hooks.BeforeProcess(&ContextHook{ + Ctx: context.Background(), + }) + if err != tt.expect { + t.Errorf("got %v, expect %v", err, tt.expect) + } + }) + } +} + +func TestAfterProcess(t *testing.T) { + expectErr := errors.New("expect err") + tests := []struct { + msg string + ctx *ContextHook + hooks []Hook + expect error + }{ + { + msg: "context has err", + ctx: &ContextHook{ + Ctx: context.Background(), + Err: expectErr, + }, + hooks: []Hook{ + &testHook{ + after: func(c *ContextHook) error { + return errors.New("hook err") + }, + }, + }, + expect: expectErr, + }, + { + msg: "last hook has err", + ctx: &ContextHook{ + Ctx: context.Background(), + Err: nil, + }, + hooks: []Hook{ + &testHook{ + after: func(c *ContextHook) error { + return nil + }, + }, + &testHook{ + after: func(c *ContextHook) error { + return expectErr + }, + }, + }, + expect: expectErr, + }, + } + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + hooks := Hooks{} + hooks.AddHook(tt.hooks...) + err := hooks.AfterProcess(tt.ctx) + if err != tt.expect { + t.Errorf("got %v, expect %v", err, tt.expect) + } + }) + } +} diff --git a/core/db.go b/core/db.go index 931b062a..50c64c6f 100644 --- a/core/db.go +++ b/core/db.go @@ -12,8 +12,8 @@ import ( "reflect" "regexp" "sync" - "time" + "xorm.io/xorm/contexts" "xorm.io/xorm/log" "xorm.io/xorm/names" ) @@ -88,7 +88,7 @@ type DB struct { reflectCache map[reflect.Type]*cacheStruct reflectCacheMutex sync.RWMutex Logger log.ContextLogger - hooks []Hook + hooks contexts.Hooks } // Open opens a database @@ -140,22 +140,15 @@ func (db *DB) reflectNew(typ reflect.Type) reflect.Value { } // QueryContext overwrites sql.DB.QueryContext -func (db *DB) QueryContext(parentCtx context.Context, query string, args ...interface{}) (*Rows, error) { - logCtx := log.LogContext{ - Ctx: parentCtx, - SQL: query, - Args: args, - } - start := time.Now() - ctx, err := db.beforeProcess(logCtx) +func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + hookCtx := contexts.NewContextHook(ctx, query, args) + ctx, err := db.beforeProcess(hookCtx) if err != nil { return nil, err } - logCtx.Ctx = ctx rows, err := db.DB.QueryContext(ctx, query, args...) - logCtx.ExecuteTime = time.Now().Sub(start) - logCtx.Err = err - if err := db.afterProcess(logCtx); err != nil { + hookCtx.End(ctx, nil, err) + if err := db.afterProcess(hookCtx); err != nil { if rows != nil { rows.Close() } @@ -235,7 +228,7 @@ var ( re = regexp.MustCompile(`[?](\w+)`) ) -// ExecMapContext exec map with context.Context +// ExecMapContext exec map with context.ContextHook // insert into (name) values (?) // insert into (name) values (?name) func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) { @@ -258,22 +251,15 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{ return db.ExecContext(ctx, query, args...) } -func (db *DB) ExecContext(parentCtx context.Context, query string, args ...interface{}) (sql.Result, error) { - logCtx := log.LogContext{ - Ctx: parentCtx, - SQL: query, - Args: args, - } - start := time.Now() - ctx, err := db.beforeProcess(logCtx) +func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + hookCtx := contexts.NewContextHook(ctx, query, args) + ctx, err := db.beforeProcess(hookCtx) if err != nil { return nil, err } - logCtx.Ctx = ctx res, err := db.DB.ExecContext(ctx, query, args...) - logCtx.Err = err - logCtx.ExecuteTime = time.Now().Sub(start) - if err := db.afterProcess(logCtx); err != nil { + hookCtx.End(ctx, res, err) + if err := db.afterProcess(hookCtx); err != nil { return nil, err } return res, nil @@ -283,35 +269,25 @@ func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { return db.ExecStructContext(context.Background(), query, st) } -func (db *DB) beforeProcess(logCtx log.LogContext) (context.Context, error) { - ctx := logCtx.Ctx - if db.NeedLogSQL(ctx) { - db.Logger.BeforeSQL(logCtx) +func (db *DB) beforeProcess(c *contexts.ContextHook) (context.Context, error) { + if db.NeedLogSQL(c.Ctx) { + db.Logger.BeforeSQL(log.LogContext(*c)) } - for _, h := range db.hooks { - var err error - ctx, err = h.BeforeProcess(ctx, logCtx.SQL, logCtx.Args...) - if err != nil { - return nil, err - } + ctx, err := db.hooks.BeforeProcess(c) + if err != nil { + return nil, err } return ctx, nil } -func (db *DB) afterProcess(logCtx log.LogContext) error { - firstErr := logCtx.Err - for _, h := range db.hooks { - err := h.AfterProcess(&logCtx) - if err != nil && firstErr == nil { - firstErr = err - } +func (db *DB) afterProcess(c *contexts.ContextHook) error { + err := db.hooks.AfterProcess(c) + if db.NeedLogSQL(c.Ctx) { + db.Logger.AfterSQL(log.LogContext(*c)) } - if db.NeedLogSQL(logCtx.Ctx) { - db.Logger.AfterSQL(logCtx) - } - return firstErr + return err } -func (db *DB) AddHook(hook Hook) { - db.hooks = append(db.hooks, hook) +func (db *DB) AddHook(h ...contexts.Hook) { + db.hooks.AddHook(h...) } diff --git a/core/interface.go b/core/interface.go index 73ca2d2b..a5c8e4e2 100644 --- a/core/interface.go +++ b/core/interface.go @@ -3,15 +3,8 @@ package core import ( "context" "database/sql" - - "xorm.io/xorm/log" ) -type Hook interface { - BeforeProcess(ctx context.Context, query string, args ...interface{}) (context.Context, error) - AfterProcess(logContext *log.LogContext) error -} - // Queryer represents an interface to query a SQL to get data from database type Queryer interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) diff --git a/core/stmt.go b/core/stmt.go index 754a8f89..d46ac9c6 100644 --- a/core/stmt.go +++ b/core/stmt.go @@ -9,9 +9,8 @@ import ( "database/sql" "errors" "reflect" - "time" - "xorm.io/xorm/log" + "xorm.io/xorm/contexts" ) // Stmt reprents a stmt objects @@ -30,19 +29,14 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { i++ return "?" }) - logCtx := log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - } - start := time.Now() - ctx, err := db.beforeProcess(logCtx) + hookCtx := contexts.NewContextHook(ctx, "PREPARE", nil) + ctx, err := db.beforeProcess(hookCtx) if err != nil { return nil, err } stmt, err := db.DB.PrepareContext(ctx, query) - logCtx.Err = err - logCtx.ExecuteTime = time.Now().Sub(start) - if err := db.afterProcess(logCtx); err != nil { + hookCtx.End(ctx, nil, err) + if err := db.afterProcess(hookCtx); err != nil { return nil, err } return &Stmt{stmt, db, names, query}, nil @@ -87,40 +81,28 @@ func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { } func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { - logCtx := log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - } - start := time.Now() - ctx, err := s.db.beforeProcess(logCtx) + hookCtx := contexts.NewContextHook(ctx, s.query, args) + ctx, err := s.db.beforeProcess(hookCtx) if err != nil { return nil, err } res, err := s.Stmt.ExecContext(ctx, args) - logCtx.ExecuteTime = time.Now().Sub(start) - logCtx.Err = err - if err := s.db.afterProcess(logCtx); err != nil { + hookCtx.End(ctx, res, err) + if err := s.db.afterProcess(hookCtx); err != nil { return nil, err } return res, nil } func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { - logCtx := log.LogContext{ - Ctx: ctx, - SQL: s.query, - Args: args, - } - start := time.Now() - ctx, err := s.db.beforeProcess(logCtx) + hookCtx := contexts.NewContextHook(ctx, s.query, args) + ctx, err := s.db.beforeProcess(hookCtx) if err != nil { return nil, err } rows, err := s.Stmt.QueryContext(ctx, args...) - logCtx.ExecuteTime = time.Now().Sub(start) - logCtx.Err = err - if err := s.db.afterProcess(logCtx); err != nil { + hookCtx.End(ctx, nil, err) + if err := s.db.afterProcess(hookCtx); err != nil { return nil, err } return &Rows{rows, s.db}, nil diff --git a/core/tx.go b/core/tx.go index 0b491659..9b2988af 100644 --- a/core/tx.go +++ b/core/tx.go @@ -7,9 +7,8 @@ package core import ( "context" "database/sql" - "time" - "xorm.io/xorm/log" + "xorm.io/xorm/contexts" ) var ( @@ -23,24 +22,14 @@ type Tx struct { } func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { - start := time.Now() - showSQL := db.NeedLogSQL(ctx) - if showSQL { - db.Logger.BeforeSQL(log.LogContext{ - Ctx: ctx, - SQL: "BEGIN TRANSACTION", - }) + hookCtx := contexts.NewContextHook(ctx, "BEGIN TRANSACTION", nil) + ctx, err := db.beforeProcess(hookCtx) + if err != nil { + return nil, err } tx, err := db.DB.BeginTx(ctx, opts) - if showSQL { - db.Logger.AfterSQL(log.LogContext{ - Ctx: ctx, - SQL: "BEGIN TRANSACTION", - ExecuteTime: time.Now().Sub(start), - Err: err, - }) - } - if err != nil { + hookCtx.End(ctx, nil, err) + if err := db.afterProcess(hookCtx); err != nil { return nil, err } return &Tx{tx, db}, nil @@ -58,19 +47,14 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { i++ return "?" }) - logCtx := log.LogContext{ - Ctx: ctx, - SQL: "PREPARE", - } - start := time.Now() - ctx, err := tx.db.beforeProcess(logCtx) + hookCtx := contexts.NewContextHook(ctx, "PREPARE", nil) + ctx, err := tx.db.beforeProcess(hookCtx) if err != nil { return nil, err } stmt, err := tx.Tx.PrepareContext(ctx, query) - logCtx.Err = err - logCtx.ExecuteTime = time.Now().Sub(start) - if err := tx.db.afterProcess(logCtx); err != nil { + hookCtx.End(ctx, nil, err) + if err := tx.db.afterProcess(hookCtx); err != nil { return nil, err } return &Stmt{stmt, tx.db, names, query}, nil @@ -110,20 +94,14 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{ } func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - start := time.Now() - logCtx := log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - } - ctx, err := tx.db.beforeProcess(logCtx) + hookCtx := contexts.NewContextHook(ctx, query, args) + ctx, err := tx.db.beforeProcess(hookCtx) if err != nil { return nil, err } res, err := tx.Tx.ExecContext(ctx, query, args...) - logCtx.ExecuteTime = time.Now().Sub(start) - logCtx.Err = err - if err := tx.db.afterProcess(logCtx); err != nil { + hookCtx.End(ctx, res, err) + if err := tx.db.afterProcess(hookCtx); err != nil { return nil, err } return res, err @@ -134,20 +112,14 @@ func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { } func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { - logCtx := log.LogContext{ - Ctx: ctx, - SQL: query, - Args: args, - } - start := time.Now() - ctx, err := tx.db.beforeProcess(logCtx) + hookCtx := contexts.NewContextHook(ctx, query, args) + ctx, err := tx.db.beforeProcess(hookCtx) if err != nil { return nil, err } rows, err := tx.Tx.QueryContext(ctx, query, args...) - logCtx.Err = err - logCtx.ExecuteTime = time.Now().Sub(start) - if err := tx.db.afterProcess(logCtx); err != nil { + hookCtx.End(ctx, nil, err) + if err := tx.db.afterProcess(hookCtx); err != nil { if rows != nil { rows.Close() } diff --git a/engine.go b/engine.go index 7d196249..7399f41a 100644 --- a/engine.go +++ b/engine.go @@ -18,6 +18,7 @@ import ( "time" "xorm.io/xorm/caches" + "xorm.io/xorm/contexts" "xorm.io/xorm/core" "xorm.io/xorm/dialects" "xorm.io/xorm/internal/utils" @@ -1287,7 +1288,7 @@ func (engine *Engine) SetSchema(schema string) { engine.dialect.URI().SetSchema(schema) } -func (engine *Engine) AddHook(hook core.Hook) { +func (engine *Engine) AddHook(hook contexts.Hook) { engine.db.AddHook(hook) } @@ -1302,7 +1303,7 @@ func (engine *Engine) tbNameWithSchema(v string) string { return dialects.TableNameWithSchema(engine.dialect, v) } -// Context creates a session with the context +// ContextHook creates a session with the context func (engine *Engine) Context(ctx context.Context) *Session { session := engine.NewSession() session.isAutoClose = true diff --git a/engine_group.go b/engine_group.go index a4a1f6a6..cdd9dd44 100644 --- a/engine_group.go +++ b/engine_group.go @@ -9,7 +9,7 @@ import ( "time" "xorm.io/xorm/caches" - "xorm.io/xorm/core" + "xorm.io/xorm/contexts" "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" @@ -79,7 +79,7 @@ func (eg *EngineGroup) Close() error { return nil } -// Context returned a group session +// ContextHook returned a group session func (eg *EngineGroup) Context(ctx context.Context) *Session { sess := eg.NewSession() sess.isAutoClose = true @@ -144,7 +144,7 @@ func (eg *EngineGroup) SetLogger(logger interface{}) { } } -func (eg *EngineGroup) AddHook(hook core.Hook) { +func (eg *EngineGroup) AddHook(hook contexts.Hook) { eg.Engine.AddHook(hook) for i := 0; i < len(eg.slaves); i++ { eg.slaves[i].AddHook(hook) diff --git a/interface.go b/interface.go index 5f8c4b2e..6aac4ae8 100644 --- a/interface.go +++ b/interface.go @@ -11,7 +11,7 @@ import ( "time" "xorm.io/xorm/caches" - "xorm.io/xorm/core" + "xorm.io/xorm/contexts" "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" @@ -112,7 +112,7 @@ type EngineInterface interface { SetTableMapper(names.Mapper) SetTZDatabase(tz *time.Location) SetTZLocation(tz *time.Location) - AddHook(hook core.Hook) + AddHook(hook contexts.Hook) ShowSQL(show ...bool) Sync(...interface{}) error Sync2(...interface{}) error diff --git a/log/logger_context.go b/log/logger_context.go index 715f86c6..6b7252ef 100644 --- a/log/logger_context.go +++ b/log/logger_context.go @@ -5,21 +5,13 @@ package log import ( - "context" - "database/sql" "fmt" - "time" + + "xorm.io/xorm/contexts" ) // LogContext represents a log context -type LogContext struct { - Ctx context.Context - SQL string // log content or SQL - Args []interface{} // if it's a SQL, it's the arguments - Result sql.Result - ExecuteTime time.Duration - Err error // SQL executed error -} +type LogContext contexts.ContextHook // SQLLogger represents an interface to log SQL type SQLLogger interface { diff --git a/session.go b/session.go index 9f47d9b4..761b1415 100644 --- a/session.go +++ b/session.go @@ -887,7 +887,7 @@ func (session *Session) incrVersionFieldValue(fieldValue *reflect.Value) { } } -// Context sets the context on this session +// ContextHook sets the context on this session func (session *Session) Context(ctx context.Context) *Session { session.ctx = ctx return session -- 2.40.1