Move maptype to tag parser #1561

Merged
lunny merged 1 commits from lunny/move_maptype into master 2020-02-28 02:00:27 +00:00
10 changed files with 77 additions and 116 deletions

View File

@ -9,7 +9,6 @@ import (
"bytes" "bytes"
"context" "context"
"database/sql" "database/sql"
"encoding/gob"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -17,7 +16,6 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"xorm.io/builder" "xorm.io/builder"
@ -37,10 +35,6 @@ type Engine struct {
db *core.DB db *core.DB
dialect dialects.Dialect dialect dialects.Dialect
Tables map[reflect.Type]*schemas.Table
mutex *sync.RWMutex
showSQL bool showSQL bool
showExecTime bool showExecTime bool
@ -753,43 +747,6 @@ func (engine *Engine) Having(conditions string) *Session {
return session.Having(conditions) return session.Having(conditions)
} }
// UnMapType removes the database mapper of a type
func (engine *Engine) UnMapType(t reflect.Type) {
engine.mutex.Lock()
defer engine.mutex.Unlock()
delete(engine.Tables, t)
}
func (engine *Engine) autoMapType(v reflect.Value) (*schemas.Table, error) {
t := v.Type()
engine.mutex.Lock()
defer engine.mutex.Unlock()
table, ok := engine.Tables[t]
if !ok {
var err error
table, err = engine.tagParser.MapType(v)
if err != nil {
return nil, err
}
engine.Tables[t] = table
if engine.GetDefaultCacher() != nil {
if v.CanAddr() {
engine.GobRegister(v.Addr().Interface())
} else {
engine.GobRegister(v.Interface())
}
}
}
return table, nil
}
// GobRegister register one struct to gob for cache use
func (engine *Engine) GobRegister(v interface{}) *Engine {
gob.Register(v)
return engine
}
// Table table struct // Table table struct
type Table struct { type Table struct {
*schemas.Table *schemas.Table
@ -804,7 +761,7 @@ func (t *Table) IsValid() bool {
// TableInfo get table info according to bean's content // TableInfo get table info according to bean's content
func (engine *Engine) TableInfo(bean interface{}) *Table { func (engine *Engine) TableInfo(bean interface{}) *Table {
v := rValue(bean) v := rValue(bean)
tb, err := engine.autoMapType(v) tb, err := engine.tagParser.MapType(v)
if err != nil { if err != nil {
engine.logger.Error(err) engine.logger.Error(err)
} }
@ -842,7 +799,7 @@ func (engine *Engine) IDOfV(rv reflect.Value) schemas.PK {
func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) { func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) {
v := reflect.Indirect(rv) v := reflect.Indirect(rv)
table, err := engine.autoMapType(v) table, err := engine.tagParser.MapType(v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -938,6 +895,11 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
return nil return nil
} }
// UnMapType remove table from tables cache
func (engine *Engine) UnMapType(t reflect.Type) {
engine.tagParser.ClearTable(t)
}
// Sync the new struct changes to database, this method will automatically add // Sync the new struct changes to database, this method will automatically add
// table, column, index, unique. but will not delete or change anything. // table, column, index, unique. but will not delete or change anything.
// If you change some field, you should change the database manually. // If you change some field, you should change the database manually.
@ -948,7 +910,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
for _, bean := range beans { for _, bean := range beans {
v := rValue(bean) v := rValue(bean)
tableNameNoSchema := engine.TableName(bean) tableNameNoSchema := engine.TableName(bean)
table, err := engine.autoMapType(v) table, err := engine.tagParser.MapType(v)
if err != nil { if err != nil {
return err return err
} }

View File

@ -165,8 +165,10 @@ func (engine *Engine) buildConds(table *schemas.Table, bean interface{},
val = bytes val = bytes
} }
} else { } else {
engine.autoMapType(fieldValue) table, err := engine.tagParser.MapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok { if err != nil {
val = fieldValue.Interface()
} else {
if len(table.PrimaryKeys) == 1 { if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues // fix non-int pk issues
@ -180,8 +182,6 @@ func (engine *Engine) buildConds(table *schemas.Table, bean interface{},
//TODO: how to handler? //TODO: how to handler?
return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys)
} }
} else {
val = fieldValue.Interface()
} }
} }
} }

View File

@ -690,7 +690,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
} }
} }
} else if session.statement.UseCascade { } else if session.statement.UseCascade {
table, err := session.engine.autoMapType(*fieldValue) table, err := session.engine.tagParser.MapType(*fieldValue)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -209,7 +209,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
v = x v = x
fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(v).Convert(fieldType))
} else if session.statement.UseCascade { } else if session.statement.UseCascade {
table, err := session.engine.autoMapType(*fieldValue) table, err := session.engine.tagParser.MapType(*fieldValue)
if err != nil { if err != nil {
return err return err
} }
@ -492,7 +492,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
default: default:
if session.statement.UseCascade { if session.statement.UseCascade {
structInter := reflect.New(fieldType.Elem()) structInter := reflect.New(fieldType.Elem())
table, err := session.engine.autoMapType(structInter.Elem()) table, err := session.engine.tagParser.MapType(structInter.Elem())
if err != nil { if err != nil {
return err return err
} }
@ -603,7 +603,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
return v.Value() return v.Value()
} }
fieldTable, err := session.engine.autoMapType(fieldValue) fieldTable, err := session.engine.tagParser.MapType(fieldValue)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -275,7 +275,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
var newValue = newElemFunc(fields) var newValue = newElemFunc(fields)
dataStruct := rValue(newValue.Interface()) dataStruct := rValue(newValue.Interface())
tb, err := session.engine.autoMapType(dataStruct) tb, err := session.engine.tagParser.MapType(dataStruct)
if err != nil { if err != nil {
return err return err
} }

View File

@ -225,7 +225,7 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement
func (statement *Statement) setRefValue(v reflect.Value) error { func (statement *Statement) setRefValue(v reflect.Value) error {
var err error var err error
statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v)) statement.RefTable, err = statement.Engine.tagParser.MapType(reflect.Indirect(v))
if err != nil { if err != nil {
return err return err
} }
@ -235,7 +235,7 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
func (statement *Statement) setRefBean(bean interface{}) error { func (statement *Statement) setRefBean(bean interface{}) error {
var err error var err error
statement.RefTable, err = statement.Engine.autoMapType(rValue(bean)) statement.RefTable, err = statement.Engine.tagParser.MapType(rValue(bean))
if err != nil { if err != nil {
return err return err
} }
@ -414,8 +414,10 @@ func (statement *Statement) buildUpdates(bean interface{},
val, _ = nulType.Value() val, _ = nulType.Value()
} else { } else {
if !col.SQLType.IsJson() { if !col.SQLType.IsJson() {
engine.autoMapType(fieldValue) table, err := engine.tagParser.MapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok { if err != nil {
val = fieldValue.Interface()
} else {
if len(table.PrimaryKeys) == 1 { if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues // fix non-int pk issues
@ -428,8 +430,6 @@ func (statement *Statement) buildUpdates(bean interface{},
// TODO: how to handler? // TODO: how to handler?
panic("not supported") panic("not supported")
} }
} else {
val = fieldValue.Interface()
} }
} else { } else {
// Blank struct could not be as update data // Blank struct could not be as update data
@ -723,7 +723,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
t := v.Type() t := v.Type()
if t.Kind() == reflect.Struct { if t.Kind() == reflect.Struct {
var err error var err error
statement.RefTable, err = statement.Engine.autoMapType(v) statement.RefTable, err = statement.Engine.tagParser.MapType(v)
if err != nil { if err != nil {
statement.Engine.logger.Error(err) statement.Engine.logger.Error(err)
return statement return statement

View File

@ -5,10 +5,12 @@
package tags package tags
import ( import (
"encoding/gob"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"sync"
"time" "time"
"xorm.io/xorm/caches" "xorm.io/xorm/caches"
@ -25,6 +27,7 @@ type Parser struct {
TableMapper names.Mapper TableMapper names.Mapper
handlers map[string]Handler handlers map[string]Handler
cacherMgr *caches.Manager cacherMgr *caches.Manager
tableCache sync.Map // map[reflect.Type]*schemas.Table
} }
func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnMapper names.Mapper, cacherMgr *caches.Manager) *Parser { func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnMapper names.Mapper, cacherMgr *caches.Manager) *Parser {
@ -51,6 +54,36 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index
} }
func (parser *Parser) MapType(v reflect.Value) (*schemas.Table, error) { func (parser *Parser) MapType(v reflect.Value) (*schemas.Table, error) {
t := v.Type()
tableI, ok := parser.tableCache.Load(t)
if ok {
return tableI.(*schemas.Table), nil
}
table, err := parser.mapType(v)
if err != nil {
return nil, err
}
parser.tableCache.Store(t, table)
if parser.cacherMgr.GetDefaultCacher() != nil {
if v.CanAddr() {
gob.Register(v.Addr().Interface())
} else {
gob.Register(v.Interface())
}
}
return table, nil
}
// ClearTable removes the database mapper of a type from the cache
func (parser *Parser) ClearTable(t reflect.Type) {
parser.tableCache.Delete(t)
}
func (parser *Parser) mapType(v reflect.Value) (*schemas.Table, error) {
t := v.Type() t := v.Type()
table := schemas.NewEmptyTable() table := schemas.NewEmptyTable()
table.Type = t table.Type = t

View File

@ -280,7 +280,7 @@ func ExtendsTagHandler(ctx *Context) error {
isPtr = true isPtr = true
fallthrough fallthrough
case reflect.Struct: case reflect.Struct:
parentTable, err := ctx.parser.MapType(fieldValue) parentTable, err := ctx.parser.mapType(fieldValue)
if err != nil { if err != nil {
return err return err
} }

View File

@ -92,12 +92,8 @@ func TestExtends(t *testing.T) {
tu9 := &tempUser4{} tu9 := &tempUser4{}
_, err = testEngine.Get(tu9) _, err = testEngine.Get(tu9)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, tu8.TempUser2.TempUser.Username, tu9.TempUser2.TempUser.Username)
if tu9.TempUser2.TempUser.Username != tu8.TempUser2.TempUser.Username || tu9.TempUser2.Departname != tu8.TempUser2.Departname { assert.EqualValues(t, tu8.TempUser2.Departname, tu9.TempUser2.Departname)
err = errors.New(fmt.Sprintln("not equal for", tu8, tu9))
t.Error(err)
panic(err)
}
tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}} tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}}
_, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10) _, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10)
@ -117,17 +113,10 @@ func TestExtends(t *testing.T) {
_, err = testEngine.Get(tu5) _, err = testEngine.Get(tu5)
assert.NoError(t, err) assert.NoError(t, err)
if tu5.Temp == nil { assert.NotNil(t, tu5.Temp)
err = errors.New("error get data extends") assert.EqualValues(t, 1, tu5.Temp.Id)
t.Error(err) assert.EqualValues(t, "extends", tu5.Temp.Username)
panic(err) assert.EqualValues(t, "dev depart", tu5.Departname)
}
if tu5.Temp.Id != 1 || tu5.Temp.Username != "extends" ||
tu5.Departname != "dev depart" {
err = errors.New("error get data extends")
t.Error(err)
panic(err)
}
tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} tu6 := &tempUser3{&tempUser{0, "extends update"}, ""}
_, err = testEngine.ID(tu5.Temp.Id).Update(tu6) _, err = testEngine.ID(tu5.Temp.Id).Update(tu6)
@ -162,47 +151,25 @@ func TestExtends(t *testing.T) {
qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid))
b, err := testEngine.SQL(sql).NoCascade().Get(&info) b, err := testEngine.SQL(sql).NoCascade().Get(&info)
assert.NoError(t, err) assert.NoError(t, err)
if !b { assert.True(t, b, "should has lest one record")
err = errors.New("should has lest one record") assert.True(t, info.Userinfo.Uid > 0, "all of the id should has value")
t.Error(err) assert.True(t, info.Userdetail.Id > 0, "all of the id should has value")
panic(err)
}
fmt.Println(info)
if info.Userinfo.Uid == 0 || info.Userdetail.Id == 0 {
err = errors.New("all of the id should has value")
t.Error(err)
panic(err)
}
fmt.Println("----join--info2")
var info2 UserAndDetail var info2 UserAndDetail
b, err = testEngine.Table(&Userinfo{}). b, err = testEngine.Table(&Userinfo{}).
Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)).
NoCascade().Get(&info2) NoCascade().Get(&info2)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.True(t, b)
panic(err) assert.True(t, info2.Userinfo.Uid > 0, "all of the id should has value")
} assert.True(t, info2.Userdetail.Id > 0, "all of the id should has value")
if !b {
err = errors.New("should has lest one record")
t.Error(err)
panic(err)
}
if info2.Userinfo.Uid == 0 || info2.Userdetail.Id == 0 {
err = errors.New("all of the id should has value")
t.Error(err)
panic(err)
}
fmt.Println(info2)
fmt.Println("----join--infos2")
var infos2 = make([]UserAndDetail, 0) var infos2 = make([]UserAndDetail, 0)
err = testEngine.Table(&Userinfo{}). err = testEngine.Table(&Userinfo{}).
Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)).
NoCascade(). NoCascade().
Find(&infos2) Find(&infos2)
assert.NoError(t, err) assert.NoError(t, err)
fmt.Println(infos2)
} }
type MessageBase struct { type MessageBase struct {

13
xorm.go
View File

@ -10,9 +10,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"reflect"
"runtime" "runtime"
"sync"
"time" "time"
"xorm.io/xorm/caches" "xorm.io/xorm/caches"
@ -61,14 +59,17 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
return nil, err return nil, err
} }
cacherMgr := caches.NewManager()
mapper := names.NewCacheMapper(new(names.SnakeMapper))
tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr)
engine := &Engine{ engine := &Engine{
db: db, db: db,
dialect: dialect, dialect: dialect,
Tables: make(map[reflect.Type]*schemas.Table),
mutex: &sync.RWMutex{},
TZLocation: time.Local, TZLocation: time.Local,
defaultContext: context.Background(), defaultContext: context.Background(),
cacherMgr: caches.NewManager(), cacherMgr: cacherMgr,
tagParser: tagParser,
} }
if uri.DBType == schemas.SQLITE { if uri.DBType == schemas.SQLITE {
@ -80,8 +81,6 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
logger := log.NewSimpleLogger(os.Stdout) logger := log.NewSimpleLogger(os.Stdout)
logger.SetLevel(log.LOG_INFO) logger.SetLevel(log.LOG_INFO)
engine.SetLogger(logger) engine.SetLogger(logger)
mapper := names.NewCacheMapper(new(names.SnakeMapper))
engine.tagParser = tags.NewParser("xorm", dialect, mapper, mapper, engine.cacherMgr)
runtime.SetFinalizer(engine, close) runtime.SetFinalizer(engine, close)