Browse Source

Add Hook (#1644)

move hook to standalone package

add hook for engine

Co-authored-by: yuxiao.lu <yuxiao.lu@liulishuo.com>
Reviewed-on: #1644
Reviewed-by: Lunny Xiao <xiaolunwen@gmail.com>
tags/godror
limo.creed Lunny Xiao <xiaolunwen@gmail.com> 1 month ago
parent
commit
34dc7f8791
10 changed files with 323 additions and 173 deletions
  1. +75
    -0
      contexts/hook.go
  2. +140
    -0
      contexts/hook_test.go
  3. +40
    -37
      core/db.go
  4. +21
    -55
      core/stmt.go
  5. +26
    -69
      core/tx.go
  6. +6
    -1
      engine.go
  7. +9
    -1
      engine_group.go
  8. +2
    -0
      interface.go
  9. +3
    -9
      log/logger_context.go
  10. +1
    -1
      session.go

+ 75
- 0
contexts/hook.go View File

@@ -0,0 +1,75 @@
// Copyright 2020 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package contexts

import (
"context"
"database/sql"
"time"
)

// ContextHook represents a hook context
type ContextHook struct {
start time.Time
Ctx context.Context
SQL string // log content or SQL
Args []interface{} // if it's a SQL, it's the arguments
Result sql.Result
ExecuteTime time.Duration
Err error // SQL executed error
}

// NewContextHook return context for hook
func NewContextHook(ctx context.Context, sql string, args []interface{}) *ContextHook {
return &ContextHook{
start: time.Now(),
Ctx: ctx,
SQL: sql,
Args: args,
}
}

func (c *ContextHook) End(ctx context.Context, result sql.Result, err error) {
c.Ctx = ctx
c.Result = result
c.Err = err
c.ExecuteTime = time.Now().Sub(c.start)
}

type Hook interface {
BeforeProcess(c *ContextHook) (context.Context, error)
AfterProcess(c *ContextHook) error
}

type Hooks struct {
hooks []Hook
}

func (h *Hooks) AddHook(hooks ...Hook) {
h.hooks = append(h.hooks, hooks...)
}

func (h *Hooks) BeforeProcess(c *ContextHook) (context.Context, error) {
ctx := c.Ctx
for _, h := range h.hooks {
var err error
ctx, err = h.BeforeProcess(c)
if err != nil {
return nil, err
}
}
return ctx, nil
}

func (h *Hooks) AfterProcess(c *ContextHook) error {
firstErr := c.Err
for _, h := range h.hooks {
err := h.AfterProcess(c)
if err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}

+ 140
- 0
contexts/hook_test.go View File

@@ -0,0 +1,140 @@
package contexts

import (
"context"
"errors"
"testing"
)

type testHook struct {
before func(c *ContextHook) (context.Context, error)
after func(c *ContextHook) error
}

func (h *testHook) BeforeProcess(c *ContextHook) (context.Context, error) {
if h.before != nil {
return h.before(c)
}
return c.Ctx, nil
}

func (h *testHook) AfterProcess(c *ContextHook) error {
if h.after != nil {
return h.after(c)
}
return c.Err
}

var _ Hook = &testHook{}

func TestBeforeProcess(t *testing.T) {
expectErr := errors.New("before error")
tests := []struct {
msg string
hooks []Hook
expect error
}{
{
msg: "first hook return err",
hooks: []Hook{
&testHook{
before: func(c *ContextHook) (ctx context.Context, err error) {
return c.Ctx, expectErr
},
},
&testHook{
before: func(c *ContextHook) (ctx context.Context, err error) {
return c.Ctx, nil
},
},
},
expect: expectErr,
},
{
msg: "second hook return err",
hooks: []Hook{
&testHook{
before: func(c *ContextHook) (ctx context.Context, err error) {
return c.Ctx, nil
},
},
&testHook{
before: func(c *ContextHook) (ctx context.Context, err error) {
return c.Ctx, expectErr
},
},
},
expect: expectErr,
},
}

for _, tt := range tests {
t.Run(tt.msg, func(t *testing.T) {
hooks := Hooks{}
hooks.AddHook(tt.hooks...)
_, err := hooks.BeforeProcess(&ContextHook{
Ctx: context.Background(),
})
if err != tt.expect {
t.Errorf("got %v, expect %v", err, tt.expect)
}
})
}
}

func TestAfterProcess(t *testing.T) {
expectErr := errors.New("expect err")
tests := []struct {
msg string
ctx *ContextHook
hooks []Hook
expect error
}{
{
msg: "context has err",
ctx: &ContextHook{
Ctx: context.Background(),
Err: expectErr,
},
hooks: []Hook{
&testHook{
after: func(c *ContextHook) error {
return errors.New("hook err")
},
},
},
expect: expectErr,
},
{
msg: "last hook has err",
ctx: &ContextHook{
Ctx: context.Background(),
Err: nil,
},
hooks: []Hook{
&testHook{
after: func(c *ContextHook) error {
return nil
},
},
&testHook{
after: func(c *ContextHook) error {
return expectErr
},
},
},
expect: expectErr,
},
}

for _, tt := range tests {
t.Run(tt.msg, func(t *testing.T) {
hooks := Hooks{}
hooks.AddHook(tt.hooks...)
err := hooks.AfterProcess(tt.ctx)
if err != tt.expect {
t.Errorf("got %v, expect %v", err, tt.expect)
}
})
}
}

+ 40
- 37
core/db.go View File

@@ -12,8 +12,8 @@ import (
"reflect"
"regexp"
"sync"
"time"

"xorm.io/xorm/contexts"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
)
@@ -88,6 +88,7 @@ type DB struct {
reflectCache map[reflect.Type]*cacheStruct
reflectCacheMutex sync.RWMutex
Logger log.ContextLogger
hooks contexts.Hooks
}

// Open opens a database
@@ -140,26 +141,14 @@ func (db *DB) reflectNew(typ reflect.Type) reflect.Value {

// QueryContext overwrites sql.DB.QueryContext
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
start := time.Now()
showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
hookCtx := contexts.NewContextHook(ctx, query, args)
ctx, err := db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
rows, err := db.DB.QueryContext(ctx, query, args...)
if showSQL {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
hookCtx.End(ctx, nil, err)
if err := db.afterProcess(hookCtx); err != nil {
if rows != nil {
rows.Close()
}
@@ -239,7 +228,7 @@ var (
re = regexp.MustCompile(`[?](\w+)`)
)

// ExecMapContext exec map with context.Context
// ExecMapContext exec map with context.ContextHook
// insert into (name) values (?)
// insert into (name) values (?name)
func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) {
@@ -263,28 +252,42 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{
}

func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
start := time.Now()
showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
hookCtx := contexts.NewContextHook(ctx, query, args)
ctx, err := db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
res, err := db.DB.ExecContext(ctx, query, args...)
if showSQL {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
hookCtx.End(ctx, res, err)
if err := db.afterProcess(hookCtx); err != nil {
return nil, err
}
return res, err
return res, nil
}

func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
return db.ExecStructContext(context.Background(), query, st)
}

func (db *DB) beforeProcess(c *contexts.ContextHook) (context.Context, error) {
if db.NeedLogSQL(c.Ctx) {
db.Logger.BeforeSQL(log.LogContext(*c))
}
ctx, err := db.hooks.BeforeProcess(c)
if err != nil {
return nil, err
}
return ctx, nil
}

func (db *DB) afterProcess(c *contexts.ContextHook) error {
err := db.hooks.AfterProcess(c)
if db.NeedLogSQL(c.Ctx) {
db.Logger.AfterSQL(log.LogContext(*c))
}
return err
}

func (db *DB) AddHook(h ...contexts.Hook) {
db.hooks.AddHook(h...)
}

+ 21
- 55
core/stmt.go View File

@@ -9,9 +9,8 @@ import (
"database/sql"
"errors"
"reflect"
"time"

"xorm.io/xorm/log"
"xorm.io/xorm/contexts"
)

// Stmt reprents a stmt objects
@@ -30,28 +29,16 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
i++
return "?"
})

start := time.Now()
showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
})
hookCtx := contexts.NewContextHook(ctx, "PREPARE", nil)
ctx, err := db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
stmt, err := db.DB.PrepareContext(ctx, query)
if showSQL {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
hookCtx.End(ctx, nil, err)
if err := db.afterProcess(hookCtx); err != nil {
return nil, err
}

return &Stmt{stmt, db, names, query}, nil
}

@@ -94,49 +81,28 @@ func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
}

func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
start := time.Now()
showSQL := s.db.NeedLogSQL(ctx)
if showSQL {
s.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
})
hookCtx := contexts.NewContextHook(ctx, s.query, args)
ctx, err := s.db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
res, err := s.Stmt.ExecContext(ctx, args)
if showSQL {
s.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
hookCtx.End(ctx, res, err)
if err := s.db.afterProcess(hookCtx); err != nil {
return nil, err
}
return res, err
return res, nil
}

func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
start := time.Now()
showSQL := s.db.NeedLogSQL(ctx)
if showSQL {
s.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
})
hookCtx := contexts.NewContextHook(ctx, s.query, args)
ctx, err := s.db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
rows, err := s.Stmt.QueryContext(ctx, args...)
if showSQL {
s.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
hookCtx.End(ctx, nil, err)
if err := s.db.afterProcess(hookCtx); err != nil {
return nil, err
}
return &Rows{rows, s.db}, nil


+ 26
- 69
core/tx.go View File

@@ -7,9 +7,8 @@ package core
import (
"context"
"database/sql"
"time"

"xorm.io/xorm/log"
"xorm.io/xorm/contexts"
)

var (
@@ -23,24 +22,14 @@ type Tx struct {
}

func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
start := time.Now()
showSQL := db.NeedLogSQL(ctx)
if showSQL {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: "BEGIN TRANSACTION",
})
hookCtx := contexts.NewContextHook(ctx, "BEGIN TRANSACTION", nil)
ctx, err := db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
tx, err := db.DB.BeginTx(ctx, opts)
if showSQL {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: "BEGIN TRANSACTION",
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
hookCtx.End(ctx, nil, err)
if err := db.afterProcess(hookCtx); err != nil {
return nil, err
}
return &Tx{tx, db}, nil
@@ -58,25 +47,14 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
i++
return "?"
})

start := time.Now()
showSQL := tx.db.NeedLogSQL(ctx)
if showSQL {
tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
})
hookCtx := contexts.NewContextHook(ctx, "PREPARE", nil)
ctx, err := tx.db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
stmt, err := tx.Tx.PrepareContext(ctx, query)
if showSQL {
tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
hookCtx.End(ctx, nil, err)
if err := tx.db.afterProcess(hookCtx); err != nil {
return nil, err
}
return &Stmt{stmt, tx.db, names, query}, nil
@@ -116,24 +94,15 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{
}

func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
start := time.Now()
showSQL := tx.db.NeedLogSQL(ctx)
if showSQL {
tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
hookCtx := contexts.NewContextHook(ctx, query, args)
ctx, err := tx.db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
res, err := tx.Tx.ExecContext(ctx, query, args...)
if showSQL {
tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
hookCtx.End(ctx, res, err)
if err := tx.db.afterProcess(hookCtx); err != nil {
return nil, err
}
return res, err
}
@@ -143,26 +112,14 @@ func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
}

func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
start := time.Now()
showSQL := tx.db.NeedLogSQL(ctx)
if showSQL {
tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
hookCtx := contexts.NewContextHook(ctx, query, args)
ctx, err := tx.db.beforeProcess(hookCtx)
if err != nil {
return nil, err
}
rows, err := tx.Tx.QueryContext(ctx, query, args...)
if showSQL {
tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
hookCtx.End(ctx, nil, err)
if err := tx.db.afterProcess(hookCtx); err != nil {
if rows != nil {
rows.Close()
}


+ 6
- 1
engine.go View File

@@ -18,6 +18,7 @@ import (
"time"

"xorm.io/xorm/caches"
"xorm.io/xorm/contexts"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/utils"
@@ -1287,6 +1288,10 @@ func (engine *Engine) SetSchema(schema string) {
engine.dialect.URI().SetSchema(schema)
}

func (engine *Engine) AddHook(hook contexts.Hook) {
engine.db.AddHook(hook)
}

// Unscoped always disable struct tag "deleted"
func (engine *Engine) Unscoped() *Session {
session := engine.NewSession()
@@ -1298,7 +1303,7 @@ func (engine *Engine) tbNameWithSchema(v string) string {
return dialects.TableNameWithSchema(engine.dialect, v)
}

// Context creates a session with the context
// ContextHook creates a session with the context
func (engine *Engine) Context(ctx context.Context) *Session {
session := engine.NewSession()
session.isAutoClose = true


+ 9
- 1
engine_group.go View File

@@ -9,6 +9,7 @@ import (
"time"

"xorm.io/xorm/caches"
"xorm.io/xorm/contexts"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
@@ -78,7 +79,7 @@ func (eg *EngineGroup) Close() error {
return nil
}

// Context returned a group session
// ContextHook returned a group session
func (eg *EngineGroup) Context(ctx context.Context) *Session {
sess := eg.NewSession()
sess.isAutoClose = true
@@ -143,6 +144,13 @@ func (eg *EngineGroup) SetLogger(logger interface{}) {
}
}

func (eg *EngineGroup) AddHook(hook contexts.Hook) {
eg.Engine.AddHook(hook)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].AddHook(hook)
}
}

// SetLogLevel sets the logger level
func (eg *EngineGroup) SetLogLevel(level log.LogLevel) {
eg.Engine.SetLogLevel(level)


+ 2
- 0
interface.go View File

@@ -11,6 +11,7 @@ import (
"time"

"xorm.io/xorm/caches"
"xorm.io/xorm/contexts"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
@@ -111,6 +112,7 @@ type EngineInterface interface {
SetTableMapper(names.Mapper)
SetTZDatabase(tz *time.Location)
SetTZLocation(tz *time.Location)
AddHook(hook contexts.Hook)
ShowSQL(show ...bool)
Sync(...interface{}) error
Sync2(...interface{}) error


+ 3
- 9
log/logger_context.go View File

@@ -5,19 +5,13 @@
package log

import (
"context"
"fmt"
"time"

"xorm.io/xorm/contexts"
)

// LogContext represents a log context
type LogContext struct {
Ctx context.Context
SQL string // log content or SQL
Args []interface{} // if it's a SQL, it's the arguments
ExecuteTime time.Duration
Err error // SQL executed error
}
type LogContext contexts.ContextHook

// SQLLogger represents an interface to log SQL
type SQLLogger interface {


+ 1
- 1
session.go View File

@@ -887,7 +887,7 @@ func (session *Session) incrVersionFieldValue(fieldValue *reflect.Value) {
}
}

// Context sets the context on this session
// ContextHook sets the context on this session
func (session *Session) Context(ctx context.Context) *Session {
session.ctx = ctx
return session


Loading…
Cancel
Save