Browse Source

Use a new ContextLogger interface to implement logger (#1557)

Fix bug

Add log track on prepare & tx

Some improvements

remove unused codes

refactor logger

Fix bug

log context

add ContextLogger interface

Reviewed-on: #1557
tags/v1.0.0
Lunny Xiao 4 weeks ago
parent
commit
41388c2f56
33 changed files with 609 additions and 355 deletions
  1. +44
    -2
      core/db.go
  2. +62
    -3
      core/stmt.go
  3. +79
    -8
      core/tx.go
  4. +8
    -38
      dialects/dialect.go
  5. +0
    -3
      dialects/mssql.go
  6. +10
    -9
      dialects/mysql.go
  7. +0
    -4
      dialects/oracle.go
  8. +0
    -6
      dialects/postgres.go
  9. +2
    -10
      dialects/sqlite3.go
  10. +44
    -70
      engine.go
  11. +1
    -9
      engine_group.go
  12. +3
    -4
      interface.go
  13. +9
    -9
      internal/statements/statement.go
  14. +1
    -1
      internal/statements/update.go
  15. +108
    -0
      log/logger_context.go
  16. +3
    -5
      schemas/table.go
  17. +7
    -5
      schemas/type.go
  18. +7
    -18
      session.go
  19. +4
    -14
      session_convert.go
  20. +2
    -2
      session_delete.go
  21. +14
    -10
      session_find.go
  22. +7
    -7
      session_get.go
  23. +7
    -7
      session_insert.go
  24. +4
    -41
      session_raw.go
  25. +1
    -1
      session_schema.go
  26. +67
    -33
      session_tx.go
  27. +10
    -10
      session_update.go
  28. +57
    -22
      tags/parser.go
  29. +44
    -0
      tags/parser_test.go
  30. +1
    -1
      tags/tag.go
  31. +1
    -1
      tags_test.go
  32. +1
    -1
      xorm.go
  33. +1
    -1
      xorm_test.go

+ 44
- 2
core/db.go View File

@@ -12,7 +12,9 @@ import (
"reflect"
"regexp"
"sync"
"time"

"xorm.io/xorm/log"
"xorm.io/xorm/names"
)

@@ -81,6 +83,7 @@ type DB struct {
Mapper names.Mapper
reflectCache map[reflect.Type]*cacheStruct
reflectCacheMutex sync.RWMutex
Logger log.SQLLogger
}

// Open opens a database
@@ -120,7 +123,24 @@ func (db *DB) reflectNew(typ reflect.Type) reflect.Value {

// QueryContext overwrites sql.DB.QueryContext
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
start := time.Now()
if db.Logger != nil {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
}
rows, err := db.DB.QueryContext(ctx, query, args...)
if db.Logger != nil {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
if rows != nil {
rows.Close()
@@ -209,7 +229,7 @@ func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{})
if err != nil {
return nil, err
}
return db.DB.ExecContext(ctx, query, args...)
return db.ExecContext(ctx, query, args...)
}

func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) {
@@ -221,7 +241,29 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{
if err != nil {
return nil, err
}
return db.DB.ExecContext(ctx, query, args...)
return db.ExecContext(ctx, query, args...)
}

func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
start := time.Now()
if db.Logger != nil {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
}
res, err := db.DB.ExecContext(ctx, query, args...)
if db.Logger != nil {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
return res, err
}

func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {


+ 62
- 3
core/stmt.go View File

@@ -9,6 +9,9 @@ import (
"database/sql"
"errors"
"reflect"
"time"

"xorm.io/xorm/log"
)

// Stmt reprents a stmt objects
@@ -16,6 +19,7 @@ type Stmt struct {
*sql.Stmt
db *DB
names map[string]int
query string
}

func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
@@ -27,11 +31,27 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
return "?"
})

start := time.Now()
if db.Logger != nil {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
})
}
stmt, err := db.DB.PrepareContext(ctx, query)
if db.Logger != nil {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
return nil, err
}
return &Stmt{stmt, db, names}, nil

return &Stmt{stmt, db, names, query}, nil
}

func (db *DB) Prepare(query string) (*Stmt, error) {
@@ -48,7 +68,7 @@ func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result,
for k, i := range s.names {
args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
}
return s.Stmt.ExecContext(ctx, args...)
return s.ExecContext(ctx, args...)
}

func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
@@ -65,15 +85,54 @@ func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Resul
for k, i := range s.names {
args[i] = vv.Elem().FieldByName(k).Interface()
}
return s.Stmt.ExecContext(ctx, args...)
return s.ExecContext(ctx, args...)
}

func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
return s.ExecStructContext(context.Background(), st)
}

func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
start := time.Now()
if s.db.Logger != nil {
s.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
})
}
res, err := s.Stmt.ExecContext(ctx, args)
if s.db.Logger != nil {
s.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
return res, err
}

func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
start := time.Now()
if s.db.Logger != nil {
s.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
})
}
rows, err := s.Stmt.QueryContext(ctx, args...)
if s.db.Logger != nil {
s.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: s.query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
return nil, err
}


+ 79
- 8
core/tx.go View File

@@ -7,6 +7,9 @@ package core
import (
"context"
"database/sql"
"time"

"xorm.io/xorm/log"
)

type Tx struct {
@@ -15,7 +18,22 @@ type Tx struct {
}

func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
start := time.Now()
if db.Logger != nil {
db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: "BEGIN TRANSACTION",
})
}
tx, err := db.DB.BeginTx(ctx, opts)
if db.Logger != nil {
db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: "BEGIN TRANSACTION",
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
return nil, err
}
@@ -23,11 +41,7 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
}

func (db *DB) Begin() (*Tx, error) {
tx, err := db.DB.Begin()
if err != nil {
return nil, err
}
return &Tx{tx, db}, nil
return db.BeginTx(context.Background(), nil)
}

func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
@@ -39,11 +53,26 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
return "?"
})

start := time.Now()
if tx.db.Logger != nil {
tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
})
}
stmt, err := tx.Tx.PrepareContext(ctx, query)
if tx.db.Logger != nil {
tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: "PREPARE",
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
return nil, err
}
return &Stmt{stmt, tx.db, names}, nil
return &Stmt{stmt, tx.db, names, query}, nil
}

func (tx *Tx) Prepare(query string) (*Stmt, error) {
@@ -64,7 +93,7 @@ func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{})
if err != nil {
return nil, err
}
return tx.Tx.ExecContext(ctx, query, args...)
return tx.ExecContext(ctx, query, args...)
}

func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) {
@@ -76,7 +105,29 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{
if err != nil {
return nil, err
}
return tx.Tx.ExecContext(ctx, query, args...)
return tx.ExecContext(ctx, query, args...)
}

func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
start := time.Now()
if tx.db.Logger != nil {
tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
}
res, err := tx.Tx.ExecContext(ctx, query, args...)
if tx.db.Logger != nil {
tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
return res, err
}

func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
@@ -84,8 +135,28 @@ func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
}

func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
start := time.Now()
if tx.db.Logger != nil {
tx.db.Logger.BeforeSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
})
}
rows, err := tx.Tx.QueryContext(ctx, query, args...)
if tx.db.Logger != nil {
tx.db.Logger.AfterSQL(log.LogContext{
Ctx: ctx,
SQL: query,
Args: args,
ExecuteTime: time.Now().Sub(start),
Err: err,
})
}
if err != nil {
if rows != nil {
rows.Close()
}
return nil, err
}
return &Rows{rows, tx.db}, nil


+ 8
- 38
dialects/dialect.go View File

@@ -11,14 +11,11 @@ import (
"time"

"xorm.io/xorm/core"
"xorm.io/xorm/log"
"xorm.io/xorm/schemas"
)

type DBType string

type URI struct {
DBType DBType
DBType schemas.DBType
Proto string
Host string
Port string
@@ -32,13 +29,12 @@ type URI struct {
Schema string
}

// a dialect is a driver's wrapper
// Dialect represents a kind of database
type Dialect interface {
SetLogger(logger log.Logger)
Init(*core.DB, *URI, string, string) error
URI() *URI
DB() *core.DB
DBType() DBType
DBType() schemas.DBType
SQLType(*schemas.Column) string
FormatBytes(b []byte) string
DefaultSchema() string
@@ -49,7 +45,6 @@ type Dialect interface {
IsReserved(string) bool
Quoter() schemas.Quoter

RollBackStr() string
AutoIncrStr() string

SupportInsertMany() bool
@@ -92,7 +87,6 @@ type Base struct {
dialect Dialect
driverName string
dataSourceName string
logger log.Logger
uri *URI
}

@@ -100,10 +94,6 @@ func (b *Base) DB() *core.DB {
return b.db
}

func (b *Base) SetLogger(logger log.Logger) {
b.logger = logger
}

func (b *Base) DefaultSchema() string {
return ""
}
@@ -118,7 +108,7 @@ func (b *Base) URI() *URI {
return b.uri
}

func (b *Base) DBType() DBType {
func (b *Base) DBType() schemas.DBType {
return b.uri.DBType
}

@@ -187,10 +177,6 @@ func (b *Base) DataSourceName() string {
return b.dataSourceName
}

func (db *Base) RollBackStr() string {
return "ROLL BACK"
}

func (db *Base) SupportDropIfExists() bool {
return true
}
@@ -201,7 +187,6 @@ func (db *Base) DropTableSQL(tableName string) string {
}

func (db *Base) HasRecords(ctx context.Context, query string, args ...interface{}) (bool, error) {
db.LogSQL(query, args)
rows, err := db.DB().QueryContext(ctx, query, args...)
if err != nil {
return false, err
@@ -229,13 +214,8 @@ func (db *Base) IsColumnExist(ctx context.Context, tableName, colName string) (b
}

func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
quoter := db.dialect.Quoter()
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName),
return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName),
db.String(col))
if db.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
return sql
}

func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
@@ -323,16 +303,6 @@ func (b *Base) ForUpdateSQL(query string) string {
return query + " FOR UPDATE"
}

func (b *Base) LogSQL(sql string, args []interface{}) {
if b.logger != nil && b.logger.IsShowSQL() {
if len(args) > 0 {
b.logger.Infof("[SQL] %v %v", sql, args)
} else {
b.logger.Infof("[SQL] %v", sql)
}
}
}

func (b *Base) SetParams(params map[string]string) {
}

@@ -341,7 +311,7 @@ var (
)

// RegisterDialect register database dialect
func RegisterDialect(dbName DBType, dialectFunc func() Dialect) {
func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) {
if dialectFunc == nil {
panic("core: Register dialect is nil")
}
@@ -349,7 +319,7 @@ func RegisterDialect(dbName DBType, dialectFunc func() Dialect) {
}

// QueryDialect query if registered database dialect
func QueryDialect(dbName DBType) Dialect {
func QueryDialect(dbName schemas.DBType) Dialect {
if d, ok := dialects[strings.ToLower(string(dbName))]; ok {
return d()
}
@@ -358,7 +328,7 @@ func QueryDialect(dbName DBType) Dialect {

func regDrvsNDialects() bool {
providedDrvsNDialects := map[string]struct {
dbType DBType
dbType schemas.DBType
getDriver func() Driver
getDialect func() Dialect
}{


+ 0
- 3
dialects/mssql.go View File

@@ -351,7 +351,6 @@ func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, ma
LEFT OUTER JOIN
sys.indexes i ON ic.object_id = i.object_id AND ic.index_id = i.index_id
where a.object_id=object_id('` + tableName + `')`
db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {
@@ -411,7 +410,6 @@ func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, ma
func (db *mssql) GetTables(ctx context.Context) ([]*schemas.Table, error) {
args := []interface{}{}
s := `select name from sysobjects where xtype ='U'`
db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {
@@ -446,7 +444,6 @@ INNER JOIN SYS.COLUMNS C ON IXS.OBJECT_ID=C.OBJECT_ID
AND IXCS.COLUMN_ID=C.COLUMN_ID
WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
`
db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {


+ 10
- 9
dialects/mysql.go View File

@@ -303,23 +303,26 @@ func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{}
return sql, args
}

/*func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{db.DbName, tableName, colName}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
return sql, args
}*/

func (db *mysql) TableCheckSQL(tableName string) (string, []interface{}) {
args := []interface{}{db.uri.DBName, tableName}
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
return sql, args
}

func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
quoter := db.dialect.Quoter()
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName),
db.String(col))
if len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
return sql
}

func (db *mysql) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{db.uri.DBName, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {
@@ -430,7 +433,6 @@ func (db *mysql) GetTables(ctx context.Context) ([]*schemas.Table, error) {
args := []interface{}{db.uri.DBName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " +
"`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')"
db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {
@@ -459,7 +461,6 @@ func (db *mysql) GetTables(ctx context.Context) ([]*schemas.Table, error) {
func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{db.uri.DBName, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {


+ 0
- 4
dialects/oracle.go View File

@@ -635,7 +635,6 @@ func (db *oracle) IsColumnExist(ctx context.Context, tableName, colName string)
args := []interface{}{tableName, colName}
query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" +
" AND column_name = :2"
db.LogSQL(query, args)

rows, err := db.DB().QueryContext(ctx, query, args...)
if err != nil {
@@ -653,7 +652,6 @@ func (db *oracle) GetColumns(ctx context.Context, tableName string) ([]string, m
args := []interface{}{tableName}
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {
@@ -750,7 +748,6 @@ func (db *oracle) GetColumns(ctx context.Context, tableName string) ([]string, m
func (db *oracle) GetTables(ctx context.Context) ([]*schemas.Table, error) {
args := []interface{}{}
s := "SELECT table_name FROM user_tables"
db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {
@@ -775,7 +772,6 @@ func (db *oracle) GetIndexes(ctx context.Context, tableName string) (map[string]
args := []interface{}{tableName}
s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " +
"WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1"
db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {


+ 0
- 6
dialects/postgres.go View File

@@ -943,7 +943,6 @@ func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2"
}
db.LogSQL(query, args)

rows, err := db.DB().QueryContext(ctx, query, args...)
if err != nil {
@@ -975,8 +974,6 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
}
s = fmt.Sprintf(s, f)

db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {
return nil, nil, err
@@ -1077,8 +1074,6 @@ func (db *postgres) GetTables(ctx context.Context) ([]*schemas.Table, error) {
s = s + " WHERE schemaname = $1"
}

db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {
return nil, err
@@ -1117,7 +1112,6 @@ func (db *postgres) GetIndexes(ctx context.Context, tableName string) (map[strin
args = append(args, db.uri.Schema)
s = s + " AND schemaname=$2"
}
db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {


+ 2
- 10
dialects/sqlite3.go View File

@@ -249,16 +249,10 @@ func (db *sqlite3) ForUpdateSQL(query string) string {
return query
}

/*func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName}
sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))"
return sql, args
}*/

func (db *sqlite3) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) {
args := []interface{}{tableName}
query := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))"
db.LogSQL(query, args)

rows, err := db.DB().QueryContext(ctx, query, args...)
if err != nil {
return false, err
@@ -336,7 +330,7 @@ func parseString(colStr string) (*schemas.Column, error) {
func (db *sqlite3) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
db.LogSQL(s, args)
rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {
return nil, nil, err
@@ -393,7 +387,6 @@ func (db *sqlite3) GetColumns(ctx context.Context, tableName string) ([]string,
func (db *sqlite3) GetTables(ctx context.Context) ([]*schemas.Table, error) {
args := []interface{}{}
s := "SELECT name FROM sqlite_master WHERE type='table'"
db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {
@@ -419,7 +412,6 @@ func (db *sqlite3) GetTables(ctx context.Context) ([]*schemas.Table, error) {
func (db *sqlite3) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
db.LogSQL(s, args)

rows, err := db.DB().QueryContext(ctx, s, args...)
if err != nil {


+ 44
- 70
engine.go View File

@@ -31,22 +31,16 @@ import (
// Engine is the major struct of xorm, it means a database manager.
// Commonly, an application only need one engine
type Engine struct {
db *core.DB
dialect dialects.Dialect

showSQL bool
showExecTime bool
cacherMgr *caches.Manager
db *core.DB
defaultContext context.Context
dialect dialects.Dialect
engineGroup *EngineGroup
logger log.ContextLogger
tagParser *tags.Parser

logger log.Logger
TZLocation *time.Location // The timezone of the application
DatabaseTZ *time.Location // The timezone of the database

engineGroup *EngineGroup

defaultContext context.Context

tagParser *tags.Parser
cacherMgr *caches.Manager
}

func (engine *Engine) SetCacher(tableName string, cacher caches.Cacher) {
@@ -67,32 +61,33 @@ func (engine *Engine) BufferSize(size int) *Session {
// ShowSQL show SQL statement or not on logger if log level is great than INFO
func (engine *Engine) ShowSQL(show ...bool) {
engine.logger.ShowSQL(show...)
if len(show) == 0 {
engine.showSQL = true
if engine.logger.IsShowSQL() {
engine.db.Logger = engine.logger
} else {
engine.showSQL = show[0]
}
}

// ShowExecTime show SQL statement and execute time or not on logger if log level is great than INFO
func (engine *Engine) ShowExecTime(show ...bool) {
if len(show) == 0 {
engine.showExecTime = true
} else {
engine.showExecTime = show[0]
engine.db.Logger = &log.DiscardSQLLogger{}
}
}

// Logger return the logger interface
func (engine *Engine) Logger() log.Logger {
func (engine *Engine) Logger() log.ContextLogger {
return engine.logger
}

// SetLogger set the new logger
func (engine *Engine) SetLogger(logger log.Logger) {
engine.logger = logger
engine.showSQL = logger.IsShowSQL()
engine.dialect.SetLogger(logger)
func (engine *Engine) SetLogger(logger interface{}) {
var realLogger log.ContextLogger
switch t := logger.(type) {
case log.Logger:
realLogger = log.NewLoggerAdapter(t)
case log.ContextLogger:
realLogger = t
}
engine.logger = realLogger
if realLogger.IsShowSQL() {
engine.db.Logger = realLogger
} else {
engine.db.Logger = &log.DiscardSQLLogger{}
}
}

// SetLogLevel sets the logger level
@@ -123,12 +118,12 @@ func (engine *Engine) SetMapper(mapper names.Mapper) {

// SetTableMapper set the table name mapping rule
func (engine *Engine) SetTableMapper(mapper names.Mapper) {
engine.tagParser.TableMapper = mapper
engine.tagParser.SetTableMapper(mapper)
}

// SetColumnMapper set the column name mapping rule
func (engine *Engine) SetColumnMapper(mapper names.Mapper) {
engine.tagParser.ColumnMapper = mapper
engine.tagParser.SetColumnMapper(mapper)
}

// SupportInsertMany If engine's database support batch insert records like
@@ -255,17 +250,6 @@ func (engine *Engine) Ping() error {
return session.Ping()
}

// logSQL save sql
func (engine *Engine) logSQL(sqlStr string, sqlArgs ...interface{}) {
if engine.showSQL && !engine.showExecTime {
if len(sqlArgs) > 0 {
engine.logger.Infof("[SQL] %v %#v", sqlStr, sqlArgs)
} else {
engine.logger.Infof("[SQL] %v", sqlStr)
}
}
}

// SQL method let's you manually write raw SQL and operate
// For example:
//
@@ -336,7 +320,7 @@ func (engine *Engine) DBMetas() ([]*schemas.Table, error) {
}

// DumpAllToFile dump database all table structs and data to a file
func (engine *Engine) DumpAllToFile(fp string, tp ...dialects.DBType) error {
func (engine *Engine) DumpAllToFile(fp string, tp ...schemas.DBType) error {
f, err := os.Create(fp)
if err != nil {
return err
@@ -346,7 +330,7 @@ func (engine *Engine) DumpAllToFile(fp string, tp ...dialects.DBType) error {
}

// DumpAll dump database all table structs and data to w
func (engine *Engine) DumpAll(w io.Writer, tp ...dialects.DBType) error {
func (engine *Engine) DumpAll(w io.Writer, tp ...schemas.DBType) error {
tables, err := engine.DBMetas()
if err != nil {
return err
@@ -355,7 +339,7 @@ func (engine *Engine) DumpAll(w io.Writer, tp ...dialects.DBType) error {
}

// DumpTablesToFile dump specified tables to SQL file.
func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ...dialects.DBType) error {
func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ...schemas.DBType) error {
f, err := os.Create(fp)
if err != nil {
return err
@@ -365,12 +349,12 @@ func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ..
}

// DumpTables dump specify tables to io.Writer
func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...dialects.DBType) error {
func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error {
return engine.dumpTables(tables, w, tp...)
}

// dumpTables dump database all table structs and data to w with specify db type
func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dialects.DBType) error {
func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error {
var dialect dialects.Dialect
var distDBName string
if len(tp) == 0 {
@@ -496,7 +480,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dia
}

// FIXME: Hack for postgres
if string(dialect.DBType()) == schemas.POSTGRES && table.AutoIncrColumn() != nil {
if dialect.DBType() == schemas.POSTGRES && table.AutoIncrColumn() != nil {
_, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quoter().Quote(table.Name)+"), 1), false);\n")
if err != nil {
return err
@@ -739,13 +723,9 @@ func (t *Table) IsValid() bool {
}

// TableInfo get table info according to bean's content
func (engine *Engine) TableInfo(bean interface{}) (*Table, error) {
func (engine *Engine) TableInfo(bean interface{}) (*schemas.Table, error) {
v := utils.ReflectValue(bean)
tb, err := engine.tagParser.MapType(v)
if err != nil {
return nil, err
}
return &Table{tb, dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)}, nil
return engine.tagParser.ParseWithCache(v)
}

// IsTableEmpty if a table has any reocrd
@@ -763,7 +743,7 @@ func (engine *Engine) IsTableExist(beanOrTableName interface{}) (bool, error) {
}

// IDOf get id from one struct
func (engine *Engine) IDOf(bean interface{}) schemas.PK {
func (engine *Engine) IDOf(bean interface{}) (schemas.PK, error) {
return engine.IDOfV(reflect.ValueOf(bean))
}

@@ -773,18 +753,13 @@ func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string
}

// IDOfV get id from one value of struct
func (engine *Engine) IDOfV(rv reflect.Value) schemas.PK {
pk, err := engine.idOfV(rv)
if err != nil {
engine.logger.Error(err)
return nil
}
return pk
func (engine *Engine) IDOfV(rv reflect.Value) (schemas.PK, error) {
return engine.idOfV(rv)
}

func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) {
v := reflect.Indirect(rv)
table, err := engine.tagParser.MapType(v)
table, err := engine.tagParser.ParseWithCache(v)
if err != nil {
return nil, err
}
@@ -882,7 +857,7 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {

// UnMapType remove table from tables cache
func (engine *Engine) UnMapType(t reflect.Type) {
engine.tagParser.ClearTable(t)
engine.tagParser.ClearCacheTable(t)
}

// Sync the new struct changes to database, this method will automatically add
@@ -895,7 +870,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
for _, bean := range beans {
v := utils.ReflectValue(bean)
tableNameNoSchema := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)
table, err := engine.tagParser.MapType(v)
table, err := engine.tagParser.ParseWithCache(v)
if err != nil {
return err
}
@@ -1216,8 +1191,7 @@ func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) {
for scanner.Scan() {
query := strings.Trim(scanner.Text(), " \t\n\r")
if len(query) > 0 {
engine.logSQL(query)
result, err := engine.DB().Exec(query)
result, err := engine.DB().ExecContext(engine.defaultContext, query)
results = append(results, result)
if err != nil {
return nil, err
@@ -1244,12 +1218,12 @@ func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interfa

// GetColumnMapper returns the column name mapper
func (engine *Engine) GetColumnMapper() names.Mapper {
return engine.tagParser.ColumnMapper
return engine.tagParser.GetColumnMapper()
}

// GetTableMapper returns the table name mapper
func (engine *Engine) GetTableMapper() names.Mapper {
return engine.tagParser.TableMapper
return engine.tagParser.GetTableMapper()
}

// GetTZLocation returns time zone of the application


+ 1
- 9
engine_group.go View File

@@ -135,7 +135,7 @@ func (eg *EngineGroup) SetDefaultCacher(cacher caches.Cacher) {
}

// SetLogger set the new logger
func (eg *EngineGroup) SetLogger(logger log.Logger) {
func (eg *EngineGroup) SetLogger(logger interface{}) {
eg.Engine.SetLogger(logger)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetLogger(logger)
@@ -188,14 +188,6 @@ func (eg *EngineGroup) SetTableMapper(mapper names.Mapper) {
}
}

// ShowExecTime show SQL statement and execute time or not on logger if log level is great than INFO
func (eg *EngineGroup) ShowExecTime(show ...bool) {
eg.Engine.ShowExecTime(show...)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].ShowExecTime(show...)
}
}

// ShowSQL show SQL statement or not on logger if log level is great than INFO
func (eg *EngineGroup) ShowSQL(show ...bool) {
eg.Engine.ShowSQL(show...)


+ 3
- 4
interface.go View File

@@ -83,7 +83,7 @@ type EngineInterface interface {
DBMetas() ([]*schemas.Table, error)
Dialect() dialects.Dialect
DropTables(...interface{}) error
DumpAllToFile(fp string, tp ...dialects.DBType) error
DumpAllToFile(fp string, tp ...schemas.DBType) error
GetCacher(string) caches.Cacher
GetColumnMapper() names.Mapper
GetDefaultCacher() caches.Cacher
@@ -98,7 +98,7 @@ type EngineInterface interface {
SetConnMaxLifetime(time.Duration)
SetColumnMapper(names.Mapper)
SetDefaultCacher(caches.Cacher)
SetLogger(logger log.Logger)
SetLogger(logger interface{})
SetLogLevel(log.LogLevel)
SetMapper(names.Mapper)
SetMaxOpenConns(int)
@@ -107,12 +107,11 @@ type EngineInterface interface {
SetTableMapper(names.Mapper)
SetTZDatabase(tz *time.Location)
SetTZLocation(tz *time.Location)
ShowExecTime(...bool)
ShowSQL(show ...bool)
Sync(...interface{}) error
Sync2(...interface{}) error
StoreEngine(storeEngine string) *Session
TableInfo(bean interface{}) (*Table, error)
TableInfo(bean interface{}) (*schemas.Table, error)
TableName(interface{}, ...bool) string
UnMapType(reflect.Type)
}


+ 9
- 9
internal/statements/statement.go View File

@@ -253,11 +253,11 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement

func (statement *Statement) SetRefValue(v reflect.Value) error {
var err error
statement.RefTable, err = statement.tagParser.MapType(reflect.Indirect(v))
statement.RefTable, err = statement.tagParser.ParseWithCache(reflect.Indirect(v))
if err != nil {
return err
}
statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, v, true)
statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), v, true)
return nil
}

@@ -267,11 +267,11 @@ func rValue(bean interface{}) reflect.Value {

func (statement *Statement) SetRefBean(bean interface{}) error {
var err error
statement.RefTable, err = statement.tagParser.MapType(rValue(bean))
statement.RefTable, err = statement.tagParser.ParseWithCache(rValue(bean))
if err != nil {
return err
}
statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, bean, true)
statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), bean, true)
return nil
}

@@ -507,13 +507,13 @@ func (statement *Statement) SetTable(tableNameOrBean interface{}) error {
t := v.Type()
if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.tagParser.MapType(v)
statement.RefTable, err = statement.tagParser.ParseWithCache(v)
if err != nil {
return err
}
}

statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, tableNameOrBean, true)
statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tableNameOrBean, true)
return nil
}

@@ -554,7 +554,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, tablename, true)
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
if !utils.IsSubQuery(tbName) {
var buf strings.Builder
statement.dialect.Quoter().QuoteTo(&buf, tbName)
@@ -689,7 +689,7 @@ func (statement *Statement) GenDelIndexSQL() []string {
} else if index.Type == schemas.IndexType {
rIdxName = utils.IndexName(idxPrefixName, idxName)
}
sql := fmt.Sprintf("DROP INDEX %v", statement.quote(dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, rIdxName, true)))
sql := fmt.Sprintf("DROP INDEX %v", statement.quote(dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), rIdxName, true)))
if statement.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.quote(tbName))
}
@@ -844,7 +844,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
val = bytes
}
} else {
table, err := statement.tagParser.MapType(fieldValue)
table, err := statement.tagParser.ParseWithCache(fieldValue)
if err != nil {
val = fieldValue.Interface()
} else {


+ 1
- 1
internal/statements/update.go View File

@@ -187,7 +187,7 @@ func (statement *Statement) BuildUpdates(bean interface{},
val, _ = nulType.Value()
} else {
if !col.SQLType.IsJson() {
table, err := statement.tagParser.MapType(fieldValue)
table, err := statement.tagParser.ParseWithCache(fieldValue)
if err != nil {
val = fieldValue.Interface()
} else {


+ 108
- 0
log/logger_context.go View File

@@ -0,0 +1,108 @@
// Copyright 2020 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package log

import (
"context"
"time"
)

// LogContext represents a log context
type LogContext struct {
Ctx context.Context
SQL string // log content or SQL
Args []interface{} // if it's a SQL, it's the arguments
ExecuteTime time.Duration
Err error // SQL executed error
}

type SQLLogger interface {
BeforeSQL(context LogContext)
AfterSQL(context LogContext)
}

type DiscardSQLLogger struct{}

var _ SQLLogger = &DiscardSQLLogger{}

func (DiscardSQLLogger) BeforeSQL(LogContext) {}
func (DiscardSQLLogger) AfterSQL(LogContext) {}

// ContextLogger represents a logger interface with context
type ContextLogger interface {
SQLLogger

Debugf(format string, v ...interface{})
Errorf(format string, v ...interface{})
Infof(format string, v ...interface{})
Warnf(format string, v ...interface{})

Level() LogLevel
SetLevel(l LogLevel)

ShowSQL(show ...bool)
IsShowSQL() bool
}

var (
_ ContextLogger = &LoggerAdapter{}
)

// LoggerAdapter wraps a Logger interafce as LoggerContext interface
type LoggerAdapter struct {
logger Logger
}

func NewLoggerAdapter(logger Logger) ContextLogger {
return &LoggerAdapter{
logger: logger,
}
}

func (l *LoggerAdapter) BeforeSQL(ctx LogContext) {}

func (l *LoggerAdapter) AfterSQL(ctx LogContext) {
if !l.logger.IsShowSQL() {
return
}

if ctx.ExecuteTime > 0 {
l.logger.Infof("[SQL] %v %v - %v", ctx.SQL, ctx.Args, ctx.ExecuteTime)
} else {
l.logger.Infof("[SQL] %v %v", ctx.SQL, ctx.Args)
}
}

func (l *LoggerAdapter) Debugf(format string, v ...interface{}) {
l.logger.Debugf(format, v...)
}

func (l *LoggerAdapter) Errorf(format string, v ...interface{}) {
l.logger.Errorf(format, v...)
}

func (l *LoggerAdapter) Infof(format string, v ...interface{}) {
l.logger.Infof(format, v...)
}

func (l *LoggerAdapter) Warnf(format string, v ...interface{}) {
l.logger.Warnf(format, v...)
}

func (l *LoggerAdapter) Level() LogLevel {
return l.logger.Level()
}

func (l *LoggerAdapter) SetLevel(lv LogLevel) {
l.logger.SetLevel(lv)
}

func (l *LoggerAdapter) ShowSQL(show ...bool) {
l.logger.ShowSQL(show...)
}

func (l *LoggerAdapter) IsShowSQL() bool {
return l.logger.IsShowSQL()
}

+ 3
- 5
schemas/table.go View File

@@ -7,7 +7,6 @@ package schemas
import (
"reflect"
"strings"
//"xorm.io/xorm/cache"
)

// Table represents a database table
@@ -24,10 +23,9 @@ type Table struct {
Updated string
Deleted string
Version string
//Cacher caches.Cacher
StoreEngine string
Charset string
Comment string
StoreEngine string
Charset string
Comment string
}

func NewEmptyTable() *Table {


+ 7
- 5
schemas/type.go View File

@@ -11,12 +11,14 @@ import (
"time"
)

type DBType string

const (
POSTGRES = "postgres"
SQLITE = "sqlite3"
MYSQL = "mysql"
MSSQL = "mssql"
ORACLE = "oracle"
POSTGRES DBType = "postgres"
SQLITE DBType = "sqlite3"
MYSQL DBType = "mysql"
MSSQL DBType = "mssql"
ORACLE DBType = "oracle"
)

// SQLType represents SQL types


+ 7
- 18
session.go View File

@@ -82,7 +82,7 @@ func (session *Session) Init() {
session.engine.DatabaseTZ,
)

session.showSQL = session.engine.showSQL
//session.showSQL = session.engine.showSQL
session.isAutoCommit = true
session.isCommitedOrRollbacked = false
session.isAutoClose = false
@@ -165,7 +165,7 @@ func (session *Session) After(closures func(interface{})) *Session {
// Table can input a string or pointer to struct for special a table to operate.
func (session *Session) Table(tableNameOrBean interface{}) *Session {
if err := session.statement.SetTable(tableNameOrBean); err != nil {
session.engine.logger.Error(err)
session.statement.LastError = err
}
return session
}
@@ -447,7 +447,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
fieldValue, err := session.getField(dataStruct, key, table, idx)
if err != nil {
if !strings.Contains(err.Error(), "is not valid") {
session.engine.logger.Warn(err)
session.engine.logger.Warnf("%v", err)
}
continue
}
@@ -650,7 +650,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
hasAssigned = true
t, err := session.byte2Time(col, d)
if err != nil {
session.engine.logger.Error("byte2Time error:", err.Error())
session.engine.logger.Errorf("byte2Time error: %v", err)
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
@@ -659,7 +659,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
hasAssigned = true
t, err := session.str2Time(col, d)
if err != nil {
session.engine.logger.Error("byte2Time error:", err.Error())
session.engine.logger.Errorf("byte2Time error: %v", err)
hasAssigned = false
} else {
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
@@ -672,7 +672,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
// !<winxxp>! 增加支持sql.Scanner接口的结构,如sql.NullString
hasAssigned = true
if err := nulVal.Scan(vv.Interface()); err != nil {
session.engine.logger.Error("sql.Sanner error:", err.Error())
session.engine.logger.Errorf("sql.Sanner error: %v", err)
hasAssigned = false
}
} else if col.SQLType.IsJson() {
@@ -698,7 +698,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
}
}
} else if session.statement.UseCascade {
table, err := session.engine.tagParser.MapType(*fieldValue)
table, err := session.engine.tagParser.ParseWithCache(*fieldValue)
if err != nil {
return nil, err
}
@@ -865,17 +865,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
func (session *Session) saveLastSQL(sql string, args ...interface{}) {
session.lastSQL = sql
session.lastSQLArgs = args
session.logSQL(sql, args...)
}

func (session *Session) logSQL(sqlStr string, sqlArgs ...interface{}) {
if session.showSQL && !session.engine.showExecTime {
if len(sqlArgs) > 0 {
session.engine.logger.Infof("[SQL] %v %#v", sqlStr, sqlArgs)
} else {
session.engine.logger.Infof("[SQL] %v", sqlStr)
}
}
}

// LastSQL returns last query information


+ 4
- 14
session_convert.go View File

@@ -111,7 +111,6 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
if len(data) > 0 {
err := json.DefaultJSONHandler.Unmarshal(data, x.Interface())
if err != nil {
session.engine.logger.Error(err)
return err
}
fieldValue.Set(x.Elem())
@@ -125,7 +124,6 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
if len(data) > 0 {
err := json.DefaultJSONHandler.Unmarshal(data, x.Interface())
if err != nil {
session.engine.logger.Error(err)
return err
}
fieldValue.Set(x.Elem())
@@ -138,7 +136,6 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
if len(data) > 0 {
err := json.DefaultJSONHandler.Unmarshal(data, x.Interface())
if err != nil {
session.engine.logger.Error(err)
return err
}
fieldValue.Set(x.Elem())
@@ -210,7 +207,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
v = x
fieldValue.Set(reflect.ValueOf(v).Convert(fieldType))
} else if session.statement.UseCascade {
table, err := session.engine.tagParser.MapType(*fieldValue)
table, err := session.engine.tagParser.ParseWithCache(*fieldValue)
if err != nil {
return err
}
@@ -267,7 +264,6 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
if len(data) > 0 {
err := json.DefaultJSONHandler.Unmarshal(data, &x)
if err != nil {
session.engine.logger.Error(err)
return err
}
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
@@ -278,7 +274,6 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
if len(data) > 0 {
err := json.DefaultJSONHandler.Unmarshal(data, &x)
if err != nil {
session.engine.logger.Error(err)
return err
}
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
@@ -493,7 +488,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
default:
if session.statement.UseCascade {
structInter := reflect.New(fieldType.Elem())
table, err := session.engine.tagParser.MapType(structInter.Elem())
table, err := session.engine.tagParser.ParseWithCache(structInter.Elem())
if err != nil {
return err
}
@@ -570,7 +565,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
if fieldValue.IsNil() {
return nil, nil
} else if !fieldValue.IsValid() {
session.engine.logger.Warn("the field[", col.FieldName, "] is invalid")
session.engine.logger.Warnf("the field [%s] is invalid", col.FieldName)
return nil, nil
} else {
// !nashtsai! deference pointer type to instance type
@@ -604,7 +599,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
return v.Value()
}

fieldTable, err := session.engine.tagParser.MapType(fieldValue)
fieldTable, err := session.engine.tagParser.ParseWithCache(fieldValue)
if err != nil {
return nil, err
}
@@ -618,14 +613,12 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
if col.SQLType.IsText() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
session.engine.logger.Error(err)
return 0, err
}
return string(bytes), nil
} else if col.SQLType.IsBlob() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
session.engine.logger.Error(err)
return 0, err
}
return bytes, nil
@@ -634,7 +627,6 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
case reflect.Complex64, reflect.Complex128:
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
session.engine.logger.Error(err)
return 0, err
}
return string(bytes), nil
@@ -646,7 +638,6 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
if col.SQLType.IsText() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
session.engine.logger.Error(err)
return 0, err
}
return string(bytes), nil
@@ -659,7 +650,6 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
} else {
bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
session.engine.logger.Error(err)
return 0, err
}
}


+ 2
- 2
session_delete.go View File

@@ -62,14 +62,14 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri
}

for _, id := range ids {
session.engine.logger.Debug("[cacheDelete] delete cache obj:", tableName, id)
session.engine.logger.Debugf("[cache] delete cache obj: %v, %v", tableName, id)
sid, err := id.ToString()
if err != nil {
return err
}
cacher.DelBean(tableName, sid)
}
session.engine.logger.Debug("[cacheDelete] clear cache table:", tableName)
session.engine.logger.Debugf("[cache] clear cache table: %v", tableName)
cacher.ClearIds(tableName)
return nil
}


+ 14
- 10
session_find.go View File

@@ -141,7 +141,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
return err
}
err = nil // !nashtsai! reset err to nil for ErrCacheFailed
session.engine.logger.Warn("Cache Find Failed")
session.engine.logger.Warnf("Cache Find Failed")
}
}

@@ -225,7 +225,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
if elemType.Kind() == reflect.Struct {
var newValue = newElemFunc(fields)
dataStruct := utils.ReflectValue(newValue.Interface())
tb, err := session.engine.tagParser.MapType(dataStruct)
tb, err := session.engine.tagParser.ParseWithCache(dataStruct)
if err != nil {
return err
}
@@ -307,7 +307,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
for rows.Next() {
i++
if i > 500 {
session.engine.logger.Debug("[cacheFind] ids length > 500, no cache")
session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache")
return ErrCacheFailed
}
var res = make([]string, len(table.PrimaryKeys))
@@ -326,13 +326,13 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
ids = append(ids, pk)
}

session.engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, sqlStr, newsql, args)
session.engine.logger.Debugf("[cache] cache sql: %v, %v, %v, %v, %v", ids, tableName, sqlStr, newsql, args)
err = caches.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil {
return err
}
} else {
session.engine.logger.Debug("[cacheFind] cache hit sql:", tableName, sqlStr, newsql, args)
session.engine.logger.Debugf("[cache] cache hit sql: %v, %v, %v, %v", tableName, sqlStr, newsql, args)
}

sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
@@ -365,16 +365,20 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
ides = append(ides, id)
ididxes[sid] = idx
} else {
session.engine.logger.Debug("[cacheFind] cache hit bean:", tableName, id, bean)
session.engine.logger.Debugf("[cache] cache hit bean: %v, %v, %v", tableName, id, bean)

pk, err := session.engine.IDOf(bean)
if err != nil {
return err
}

pk := session.engine.IDOf(bean)
xid, err := pk.ToString()
if err != nil {
return err
}

if sid != xid {
session.engine.logger.Error("[cacheFind] error cache", xid, sid, bean)
session.engine.logger.Errorf("[cache] error cache: %v, %v, %v", xid, sid, bean)
return ErrCacheFailed
}
temps[idx] = bean
@@ -424,7 +428,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in

bean := rv.Interface()
temps[ididxes[sid]] = bean
session.engine.logger.Debug("[cacheFind] cache bean:", tableName, id, bean, temps)
session.engine.logger.Debugf("[cache] cache bean: %v, %v, %v, %v", tableName, id, bean, temps)
cacher.PutBean(tableName, sid, bean)
}
}
@@ -432,7 +436,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
for j := 0; j < len(temps); j++ {
bean := temps[j]
if bean == nil {
session.engine.logger.Warn("[cacheFind] cache no hit:", tableName, ids[j], temps)
session.engine.logger.Warnf("[cache] cache no hit: %v, %v, %v", tableName, ids[j], temps)
// return errors.New("cache error") // !nashtsai! no need to return error, but continue instead
continue
}


+ 7
- 7
session_get.go View File

@@ -79,7 +79,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
if context != nil {
res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args))
if res != nil {
session.engine.logger.Debug("hit context cache", sqlStr)
session.engine.logger.Debugf("hit context cache: %s", sqlStr)

structValue := reflect.Indirect(reflect.ValueOf(bean))
structValue.Set(reflect.Indirect(reflect.ValueOf(res)))
@@ -283,7 +283,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
tableName := session.statement.TableName()
cacher := session.engine.cacherMgr.GetCacher(tableName)

session.engine.logger.Debug("[cache] Get SQL:", newsql, args)
session.engine.logger.Debugf("[cache] Get SQL: %s, %v", newsql, args)
table := session.statement.RefTable
ids, err := caches.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
@@ -319,19 +319,19 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
}

ids = []schemas.PK{pk}
session.engine.logger.Debug("[cache] cache ids:", newsql, ids)
session.engine.logger.Debugf("[cache] cache ids: %s, %v", newsql, ids)
err = caches.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil {
return false, err
}
} else {
session.engine.logger.Debug("[cache] cache hit:", newsql, ids)
session.engine.logger.Debugf("[cache] cache hit: %s, %v", newsql, ids)
}

if len(ids) > 0 {
structValue := reflect.Indirect(reflect.ValueOf(bean))
id := ids[0]
session.engine.logger.Debug("[cache] get bean:", tableName, id)
session.engine.logger.Debugf("[cache] get bean: %s, %v", tableName, id)
sid, err := id.ToString()
if err != nil {
return false, err
@@ -344,10 +344,10 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return has, err
}

session.engine.logger.Debug("[cache] cache bean:", tableName, id, cacheBean)
session.engine.logger.Debugf("[cache] cache bean: %s, %v, %v", tableName, id, cacheBean)
cacher.PutBean(tableName, sid, cacheBean)
} else {
session.engine.logger.Debug("[cache] cache hit:", tableName, id, cacheBean)
session.engine.logger.Debugf("[cache] cache hit: %s, %v, %v", tableName, id, cacheBean)
has = true
}
structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean)))


+ 7
- 7
session_insert.go View File

@@ -485,7 +485,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Error(err)
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
@@ -503,7 +503,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {

aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Error(err)
session.engine.logger.Errorf("%v", err)
}

if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
@@ -526,7 +526,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Error(err)
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
@@ -544,7 +544,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {

aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Error(err)
session.engine.logger.Errorf("%v", err)
}

if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
@@ -567,7 +567,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Error(err)
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
@@ -585,7 +585,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {

aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Error(err)
session.engine.logger.Errorf("%v", err)
}

if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
@@ -617,7 +617,7 @@ func (session *Session) cacheInsert(table string) error {
if cacher == nil {
return nil
}
session.engine.logger.Debug("[cache] clear sql:", table)
session.engine.logger.Debugf("[cache] clear sql: %v", table)
cacher.ClearIds(table)
return nil
}


+ 4
- 41
session_raw.go View File

@@ -7,7 +7,6 @@ package xorm
import (
"database/sql"
"reflect"
"time"

"xorm.io/xorm/core"
"xorm.io/xorm/internal/statements"
@@ -27,27 +26,8 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row

session.queryPreprocess(&sqlStr, args...)

if session.showSQL {
session.lastSQL = sqlStr
session.lastSQLArgs = args
if session.engine.showExecTime {
b4ExecTime := time.Now()
defer func() {
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration)
} else {
session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
}()
} else {
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args)
} else {
session.engine.logger.Infof("[SQL] %v", sqlStr)
}
}
}
session.lastSQL = sqlStr
session.lastSQLArgs = args

if session.isAutoCommit {
var db *core.DB
@@ -156,25 +136,8 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er

session.queryPreprocess(&sqlStr, args...)

if session.engine.showSQL {
if session.engine.showExecTime {
b4ExecTime := time.Now()
defer func() {
execDuration := time.Since(b4ExecTime)
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration)
} else {
session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration)
}
}()
} else {
if len(args) > 0 {
session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args)
} else {
session.engine.logger.Infof("[SQL] %v", sqlStr)
}
}
}
session.lastSQL = sqlStr
session.lastSQLArgs = args

if !session.isAutoCommit {
return session.tx.ExecContext(session.ctx, sqlStr, args...)


+ 1
- 1
session_schema.go View File

@@ -242,7 +242,7 @@ func (session *Session) Sync2(beans ...interface{}) error {

for _, bean := range beans {
v := utils.ReflectValue(bean)
table, err := engine.tagParser.MapType(v)
table, err := engine.tagParser.ParseWithCache(v)
if err != nil {
return err
}


+ 67
- 33
session_tx.go View File

@@ -4,6 +4,12 @@

package xorm

import (
"time"

"xorm.io/xorm/log"
)

// Begin a transaction
func (session *Session) Begin() error {
if session.isAutoCommit {
@@ -14,6 +20,7 @@ func (session *Session) Begin() error {
session.isAutoCommit = false
session.isCommitedOrRollbacked = false
session.tx = tx

session.saveLastSQL("BEGIN TRANSACTION")
}
return nil
@@ -22,10 +29,23 @@ func (session *Session) Begin() error {
// Rollback When using transaction, you can rollback if any error
func (session *Session) Rollback() error {
if !session.isAutoCommit && !session.isCommitedOrRollbacked {
session.saveLastSQL(session.engine.dialect.RollBackStr())
session.saveLastSQL("ROLL BACK")
session.isCommitedOrRollbacked = true
session.isAutoCommit = true
return session.tx.Rollback()

start := time.Now()
session.engine.logger.BeforeSQL(log.LogContext{
Ctx: session.ctx,
SQL: "ROLL BACK",
})
err := session.tx.Rollback()
session.engine.logger.AfterSQL(log.LogContext{
Ctx: session.ctx,
SQL: "ROLL BACK",
ExecuteTime: time.Now().Sub(start),
Err: err,
})
return err
}
return nil
}
@@ -36,48 +56,62 @@ func (session *Session) Commit() error {
session.saveLastSQL("COMMIT")
session.isCommitedOrRollbacked = true
session.isAutoCommit = true
var err error
if err = session.tx.Commit(); err == nil {
// handle processors after tx committed
closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) {
if closuresPtr != nil {
for _, closure := range *closuresPtr {
closure(bean)
}

start := time.Now()
session.engine.logger.BeforeSQL(log.LogContext{
Ctx: session.ctx,
SQL: "COMMIT",
})
err := session.tx.Commit()
session.engine.logger.AfterSQL(log.LogContext{
Ctx: session.ctx,
SQL: "COMMIT",
ExecuteTime: time.Now().Sub(start),
Err: err,
})

if err != nil {
return err
}

// handle processors after tx committed
closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) {
if closuresPtr != nil {
for _, closure := range *closuresPtr {
closure(bean)
}
}
}

for bean, closuresPtr := range session.afterInsertBeans {
closureCallFunc(closuresPtr, bean)
for bean, closuresPtr := range session.afterInsertBeans {
closureCallFunc(closuresPtr, bean)

if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
processor.AfterInsert()
}
if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
processor.AfterInsert()
}
for bean, closuresPtr := range session.afterUpdateBeans {
closureCallFunc(closuresPtr, bean)
}
for bean, closuresPtr := range session.afterUpdateBeans {
closureCallFunc(closuresPtr, bean)

if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok {
processor.AfterUpdate()
}
if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok {
processor.AfterUpdate()
}
for bean, closuresPtr := range session.afterDeleteBeans {
closureCallFunc(closuresPtr, bean)
}
for bean, closuresPtr := range session.afterDeleteBeans {
closureCallFunc(closuresPtr, bean)

if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok {
processor.AfterDelete()
}
if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok {
processor.AfterDelete()
}
cleanUpFunc := func(slices *map[interface{}]*[]func(interface{})) {
if len(*slices) > 0 {
*slices = make(map[interface{}]*[]func(interface{}), 0)
}
}
cleanUpFunc := func(slices *map[interface{}]*[]func(interface{})) {
if len(*slices) > 0 {
*slices = make(map[interface{}]*[]func(interface{}), 0)
}
cleanUpFunc(&session.afterInsertBeans)
cleanUpFunc(&session.afterUpdateBeans)
cleanUpFunc(&session.afterDeleteBeans)
}
return err
cleanUpFunc(&session.afterInsertBeans)
cleanUpFunc(&session.afterUpdateBeans)
cleanUpFunc(&session.afterDeleteBeans)
}
return nil
}

+ 10
- 10
session_update.go View File

@@ -30,7 +30,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
for _, filter := range session.engine.dialect.Filters() {
newsql = filter.Do(newsql)
}
session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)
session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql)