Add Hook #1644

Merged
lunny merged 2 commits from yxlimo/xorm:master into master 2020-04-09 06:03:50 +00:00
10 changed files with 324 additions and 174 deletions

75
contexts/hook.go Normal file
View File

@ -0,0 +1,75 @@
// Copyright 2020 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package contexts
import (
"context"
"database/sql"
"time"
)
// ContextHook represents a hook context
type ContextHook struct {
start time.Time
Ctx context.Context
SQL string // log content or SQL
Args []interface{} // if it's a SQL, it's the arguments
Result sql.Result
ExecuteTime time.Duration
Err error // SQL executed error
}
// NewContextHook return context for hook
func NewContextHook(ctx context.Context, sql string, args []interface{}) *ContextHook {
return &ContextHook{
start: time.Now(),
Ctx: ctx,
SQL: sql,
Args: args,
}
}
func (c *ContextHook) End(ctx context.Context, result sql.Result, err error) {
c.Ctx = ctx
c.Result = result
c.Err = err
c.ExecuteTime = time.Now().Sub(c.start)
}
type Hook interface {
BeforeProcess(c *ContextHook) (context.Context, error)
AfterProcess(c *ContextHook) error
}
type Hooks struct {
hooks []Hook
}
func (h *Hooks) AddHook(hooks ...Hook) {
h.hooks = append(h.hooks, hooks...)
}
func (h *Hooks) BeforeProcess(c *ContextHook) (context.Context, error) {
ctx := c.Ctx
for _, h := range h.hooks {
var err error
ctx, err = h.BeforeProcess(c)
if err != nil {
return nil, err
}
}
return ctx, nil
}
func (h *Hooks) AfterProcess(c *ContextHook) error {
firstErr := c.Err
for _, h := range h.hooks {
err := h.AfterProcess(c)
if err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}

140
contexts/hook_test.go Normal file
View File

@ -0,0 +1,140 @@
package contexts
import (
"context"
"errors"
"testing"
)
type testHook struct {
before func(c *ContextHook) (context.Context, error)
after func(c *ContextHook) error
}
func (h *testHook) BeforeProcess(c *ContextHook) (context.Context, error) {
if h.before != nil {
return h.before(c)
}
return c.Ctx, nil
}
func (h *testHook) AfterProcess(c *ContextHook) error {
if h.after != nil {
return h.after(c)
}
return c.Err
}
var _ Hook = &testHook{}
func TestBeforeProcess(t *testing.T) {
expectErr := errors.New("before error")
tests := []struct {
msg string
hooks []Hook
expect error
}{
{
msg: "first hook return err",
hooks: []Hook{
&testHook{
before: func(c *ContextHook) (ctx context.Context, err error) {
return c.Ctx, expectErr
},
},
&testHook{
before: func(c *ContextHook) (ctx context.Context, err error) {
return c.Ctx, nil
},
},
},
expect: expectErr,
},
{
msg: "second hook return err",
hooks: []Hook{
&testHook{
before: func(c *ContextHook) (ctx context.Context, err error) {
return c.Ctx, nil
},
},
&testHook{
before: func(c *ContextHook) (ctx context.Context, err error) {
return c.Ctx, expectErr
},
},
},
expect: expectErr,
},
}
for _, tt := range tests {
t.Run(tt.msg, func(t *testing.T) {
hooks := Hooks{}
hooks.AddHook(tt.hooks...)
_, err := hooks.BeforeProcess(&ContextHook{
Ctx: context.Background(),
})
if err != tt.expect {
t.Errorf("got %v, expect %v", err, tt.expect)
}
})
}
}
func TestAfterProcess(t *testing.T) {
expectErr := errors.New("expect err")
tests := []struct {
msg string
ctx *ContextHook
hooks []Hook
expect error
}{
{
msg: "context has err",
ctx: &ContextHook{
Ctx: context.Background(),
Err: expectErr,
},
hooks: []Hook{
&testHook{
after: func(c *ContextHook) error {
return errors.New("hook err")
},
},
},
expect: expectErr,
},
{
msg: "last hook has err",
ctx: &ContextHook{
Ctx: context.Background(),
Err: nil,
},
hooks: []Hook{
&testHook{
after: func(c *ContextHook) error {
return nil
},
},
&testHook{
after: func(c *ContextHook) error {
return expectErr
},
},
},
expect: expectErr,
},
}
for _, tt := range tests {
t.Run(tt.msg, func(t *testing.T) {
hooks := Hooks{}
hooks.AddHook(tt.hooks...)
err := hooks.AfterProcess(tt.ctx)
if err != tt.expect {
t.Errorf("got %v, expect %v", err, tt.expect)
}
})
}
}

View File

@ -12,8 +12,8 @@ import (
"reflect"
"regexp"
"sync"
"time"
"xorm.io/xorm/contexts"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
)
@ -88,6 +88,7 @@ type DB struct {
reflectCache map[reflect.Type]*cacheStruct
reflectCacheMutex sync.RWMutex
Logger log.ContextLogger
hooks contexts.Hooks
}
// Open opens a database
@ -140,26 +141,14 @@ func (db *DB) reflectNew(typ reflect.Type) reflect.Value {
// QueryContext overwrites sql.DB.QueryContext
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
start := time.Now()
showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
hookCtx := contexts.NewContextHook(ctx, query, args)
ctx, err := db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
rows, err := db.DB.QueryContext(ctx, query, args...)
if showSQL {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
hookCtx.End(ctx, nil, err)
if err := db.afterProcess(hookCtx); err != nil {
if rows != nil {
rows.Close()
}
@ -239,7 +228,7 @@ var (
re = regexp.MustCompile(`[?](\w+)`)
)
// ExecMapContext exec map with context.Context
// ExecMapContext exec map with context.ContextHook
// insert into (name) values (?)
// insert into (name) values (?name)
func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) {
@ -263,28 +252,42 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{
}
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
start := time.Now()
showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
hookCtx := contexts.NewContextHook(ctx, query, args)
ctx, err := db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
res, err := db.DB.ExecContext(ctx, query, args...)
if showSQL {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
hookCtx.End(ctx, res, err)
if err := db.afterProcess(hookCtx); err != nil {
return nil, err
}
return res, err
return res, nil
}
func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
return db.ExecStructContext(context.Background(), query, st)
}
func (db *DB) beforeProcess(c *contexts.ContextHook) (context.Context, error) {
if db.NeedLogSQL(c.Ctx) {
db.Logger.BeforeSQL(log.LogContext(*c))
}
ctx, err := db.hooks.BeforeProcess(c)
if err != nil {
return nil, err
}
return ctx, nil
}
func (db *DB) afterProcess(c *contexts.ContextHook) error {
err := db.hooks.AfterProcess(c)
if db.NeedLogSQL(c.Ctx) {
db.Logger.AfterSQL(log.LogContext(*c))
}
return err
}
func (db *DB) AddHook(h ...contexts.Hook) {
db.hooks.AddHook(h...)
}

View File

@ -9,9 +9,8 @@ import (
"database/sql"
"errors"
"reflect"
"time"
"xorm.io/xorm/log"
"xorm.io/xorm/contexts"
)
// Stmt reprents a stmt objects
@ -30,28 +29,16 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
i++
return "?"
})
start := time.Now()
showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
})
}
stmt, err := db.DB.PrepareContext(ctx, query)
if showSQL {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
hookCtx := contexts.NewContextHook(ctx, "PREPARE", nil)
ctx, err := db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
stmt, err := db.DB.PrepareContext(ctx, query)
hookCtx.End(ctx, nil, err)
if err := db.afterProcess(hookCtx); err != nil {
return nil, err
}
return &Stmt{stmt, db, names, query}, nil
}
@ -94,49 +81,28 @@ func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
}
func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
start := time.Now()
showSQL := s.db.NeedLogSQL(ctx)
if showSQL {
s.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
})
hookCtx := contexts.NewContextHook(ctx, s.query, args)
ctx, err := s.db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
res, err := s.Stmt.ExecContext(ctx, args)
if showSQL {
s.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
hookCtx.End(ctx, res, err)
if err := s.db.afterProcess(hookCtx); err != nil {
return nil, err
}
return res, err
return res, nil
}
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
start := time.Now()
showSQL := s.db.NeedLogSQL(ctx)
if showSQL {
s.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
})
hookCtx := contexts.NewContextHook(ctx, s.query, args)
ctx, err := s.db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
rows, err := s.Stmt.QueryContext(ctx, args...)
if showSQL {
s.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
hookCtx.End(ctx, nil, err)
if err := s.db.afterProcess(hookCtx); err != nil {
return nil, err
}
return &Rows{rows, s.db}, nil

View File

@ -7,9 +7,8 @@ package core
import (
"context"
"database/sql"
"time"
"xorm.io/xorm/log"
"xorm.io/xorm/contexts"
)
var (
@ -23,24 +22,14 @@ type Tx struct {
}
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
start := time.Now()
showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: "BEGIN TRANSACTION",
})
hookCtx := contexts.NewContextHook(ctx, "BEGIN TRANSACTION", nil)
ctx, err := db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
tx, err := db.DB.BeginTx(ctx, opts)
if showSQL {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: "BEGIN TRANSACTION",
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
hookCtx.End(ctx, nil, err)
if err := db.afterProcess(hookCtx); err != nil {
return nil, err
}
return &Tx{tx, db}, nil
@ -58,25 +47,14 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
i++
return "?"
})
start := time.Now()
showSQL := tx.db.NeedLogSQL(ctx)
if showSQL {
tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
})
hookCtx := contexts.NewContextHook(ctx, "PREPARE", nil)
ctx, err := tx.db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
stmt, err := tx.Tx.PrepareContext(ctx, query)
if showSQL {
tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
hookCtx.End(ctx, nil, err)
if err := tx.db.afterProcess(hookCtx); err != nil {
return nil, err
}
return &Stmt{stmt, tx.db, names, query}, nil
@ -116,24 +94,15 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{
}
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
start := time.Now()
showSQL := tx.db.NeedLogSQL(ctx)
if showSQL {
tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
hookCtx := contexts.NewContextHook(ctx, query, args)
ctx, err := tx.db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
res, err := tx.Tx.ExecContext(ctx, query, args...)
if showSQL {
tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
hookCtx.End(ctx, res, err)
if err := tx.db.afterProcess(hookCtx); err != nil {
return nil, err
}
return res, err
}
@ -143,26 +112,14 @@ func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
}
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
start := time.Now()
showSQL := tx.db.NeedLogSQL(ctx)
if showSQL {
tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
hookCtx := contexts.NewContextHook(ctx, query, args)
ctx, err := tx.db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
rows, err := tx.Tx.QueryContext(ctx, query, args...)
if showSQL {
tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
hookCtx.End(ctx, nil, err)
if err := tx.db.afterProcess(hookCtx); err != nil {
if rows != nil {
rows.Close()
}

View File

@ -18,6 +18,7 @@ import (
"time"
"xorm.io/xorm/caches"
"xorm.io/xorm/contexts"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/utils"
@ -1287,6 +1288,10 @@ func (engine *Engine) SetSchema(schema string) {
engine.dialect.URI().SetSchema(schema)
}
func (engine *Engine) AddHook(hook contexts.Hook) {
engine.db.AddHook(hook)
}
// Unscoped always disable struct tag "deleted"
func (engine *Engine) Unscoped() *Session {
session := engine.NewSession()
@ -1298,7 +1303,7 @@ func (engine *Engine) tbNameWithSchema(v string) string {
return dialects.TableNameWithSchema(engine.dialect, v)
}
// Context creates a session with the context
// ContextHook creates a session with the context
func (engine *Engine) Context(ctx context.Context) *Session {
session := engine.NewSession()
session.isAutoClose = true

View File

@ -9,6 +9,7 @@ import (
"time"
"xorm.io/xorm/caches"
"xorm.io/xorm/contexts"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
@ -78,7 +79,7 @@ func (eg *EngineGroup) Close() error {
return nil
}
// Context returned a group session
// ContextHook returned a group session
func (eg *EngineGroup) Context(ctx context.Context) *Session {
sess := eg.NewSession()
sess.isAutoClose = true
@ -143,6 +144,13 @@ func (eg *EngineGroup) SetLogger(logger interface{}) {
}
}
func (eg *EngineGroup) AddHook(hook contexts.Hook) {
eg.Engine.AddHook(hook)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].AddHook(hook)
}
}
// SetLogLevel sets the logger level
func (eg *EngineGroup) SetLogLevel(level log.LogLevel) {
eg.Engine.SetLogLevel(level)

View File

@ -11,6 +11,7 @@ import (
"time"
"xorm.io/xorm/caches"
"xorm.io/xorm/contexts"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
@ -111,6 +112,7 @@ type EngineInterface interface {
SetTableMapper(names.Mapper)
SetTZDatabase(tz *time.Location)
SetTZLocation(tz *time.Location)
AddHook(hook contexts.Hook)
ShowSQL(show ...bool)
Sync(...interface{}) error
Sync2(...interface{}) error

View File

@ -5,19 +5,13 @@
package log
import (
"context"
"fmt"
"time"
"xorm.io/xorm/contexts"
)
// LogContext represents a log context
type LogContext struct {
Ctx context.Context
SQL string // log content or SQL
Args []interface{} // if it's a SQL, it's the arguments
ExecuteTime time.Duration
Err error // SQL executed error
}
type LogContext contexts.ContextHook
// SQLLogger represents an interface to log SQL
type SQLLogger interface {

View File

@ -887,7 +887,7 @@ func (session *Session) incrVersionFieldValue(fieldValue *reflect.Value) {
}
}
// Context sets the context on this session
// ContextHook sets the context on this session
func (session *Session) Context(ctx context.Context) *Session {
session.ctx = ctx
return session