diff --git a/engine.go b/engine.go index 50b0958c..1bf42d15 100644 --- a/engine.go +++ b/engine.go @@ -9,7 +9,6 @@ import ( "bytes" "context" "database/sql" - "encoding/gob" "errors" "fmt" "io" @@ -17,7 +16,6 @@ import ( "reflect" "strconv" "strings" - "sync" "time" "xorm.io/builder" @@ -37,10 +35,6 @@ type Engine struct { db *core.DB dialect dialects.Dialect - Tables map[reflect.Type]*schemas.Table - - mutex *sync.RWMutex - showSQL bool showExecTime bool @@ -753,43 +747,6 @@ func (engine *Engine) Having(conditions string) *Session { return session.Having(conditions) } -// UnMapType removes the database mapper of a type -func (engine *Engine) UnMapType(t reflect.Type) { - engine.mutex.Lock() - defer engine.mutex.Unlock() - delete(engine.Tables, t) -} - -func (engine *Engine) autoMapType(v reflect.Value) (*schemas.Table, error) { - t := v.Type() - engine.mutex.Lock() - defer engine.mutex.Unlock() - table, ok := engine.Tables[t] - if !ok { - var err error - table, err = engine.tagParser.MapType(v) - if err != nil { - return nil, err - } - - engine.Tables[t] = table - if engine.GetDefaultCacher() != nil { - if v.CanAddr() { - engine.GobRegister(v.Addr().Interface()) - } else { - engine.GobRegister(v.Interface()) - } - } - } - return table, nil -} - -// GobRegister register one struct to gob for cache use -func (engine *Engine) GobRegister(v interface{}) *Engine { - gob.Register(v) - return engine -} - // Table table struct type Table struct { *schemas.Table @@ -804,7 +761,7 @@ func (t *Table) IsValid() bool { // TableInfo get table info according to bean's content func (engine *Engine) TableInfo(bean interface{}) *Table { v := rValue(bean) - tb, err := engine.autoMapType(v) + tb, err := engine.tagParser.MapType(v) if err != nil { engine.logger.Error(err) } @@ -842,7 +799,7 @@ func (engine *Engine) IDOfV(rv reflect.Value) schemas.PK { func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) { v := reflect.Indirect(rv) - table, err := engine.autoMapType(v) + table, err := engine.tagParser.MapType(v) if err != nil { return nil, err } @@ -938,6 +895,11 @@ func (engine *Engine) ClearCache(beans ...interface{}) error { return nil } +// UnMapType remove table from tables cache +func (engine *Engine) UnMapType(t reflect.Type) { + engine.tagParser.ClearTable(t) +} + // Sync the new struct changes to database, this method will automatically add // table, column, index, unique. but will not delete or change anything. // If you change some field, you should change the database manually. @@ -948,7 +910,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { for _, bean := range beans { v := rValue(bean) tableNameNoSchema := engine.TableName(bean) - table, err := engine.autoMapType(v) + table, err := engine.tagParser.MapType(v) if err != nil { return err } diff --git a/engine_cond.go b/engine_cond.go index 00bfd59d..e757df11 100644 --- a/engine_cond.go +++ b/engine_cond.go @@ -165,8 +165,10 @@ func (engine *Engine) buildConds(table *schemas.Table, bean interface{}, val = bytes } } else { - engine.autoMapType(fieldValue) - if table, ok := engine.Tables[fieldValue.Type()]; ok { + table, err := engine.tagParser.MapType(fieldValue) + if err != nil { + val = fieldValue.Interface() + } else { if len(table.PrimaryKeys) == 1 { pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) // fix non-int pk issues @@ -180,8 +182,6 @@ func (engine *Engine) buildConds(table *schemas.Table, bean interface{}, //TODO: how to handler? return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) } - } else { - val = fieldValue.Interface() } } } diff --git a/session.go b/session.go index d4d9f78a..0b0f56c0 100644 --- a/session.go +++ b/session.go @@ -690,7 +690,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } } } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) + table, err := session.engine.tagParser.MapType(*fieldValue) if err != nil { return nil, err } diff --git a/session_convert.go b/session_convert.go index 04436ec6..e7eabecc 100644 --- a/session_convert.go +++ b/session_convert.go @@ -209,7 +209,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val v = x fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) + table, err := session.engine.tagParser.MapType(*fieldValue) if err != nil { return err } @@ -492,7 +492,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val default: if session.statement.UseCascade { structInter := reflect.New(fieldType.Elem()) - table, err := session.engine.autoMapType(structInter.Elem()) + table, err := session.engine.tagParser.MapType(structInter.Elem()) if err != nil { return err } @@ -603,7 +603,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. return v.Value() } - fieldTable, err := session.engine.autoMapType(fieldValue) + fieldTable, err := session.engine.tagParser.MapType(fieldValue) if err != nil { return nil, err } diff --git a/session_find.go b/session_find.go index 492f19e6..6903c1b9 100644 --- a/session_find.go +++ b/session_find.go @@ -275,7 +275,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect if elemType.Kind() == reflect.Struct { var newValue = newElemFunc(fields) dataStruct := rValue(newValue.Interface()) - tb, err := session.engine.autoMapType(dataStruct) + tb, err := session.engine.tagParser.MapType(dataStruct) if err != nil { return err } diff --git a/statement.go b/statement.go index b1593621..3a823d82 100644 --- a/statement.go +++ b/statement.go @@ -225,7 +225,7 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement func (statement *Statement) setRefValue(v reflect.Value) error { var err error - statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v)) + statement.RefTable, err = statement.Engine.tagParser.MapType(reflect.Indirect(v)) if err != nil { return err } @@ -235,7 +235,7 @@ func (statement *Statement) setRefValue(v reflect.Value) error { func (statement *Statement) setRefBean(bean interface{}) error { var err error - statement.RefTable, err = statement.Engine.autoMapType(rValue(bean)) + statement.RefTable, err = statement.Engine.tagParser.MapType(rValue(bean)) if err != nil { return err } @@ -414,8 +414,10 @@ func (statement *Statement) buildUpdates(bean interface{}, val, _ = nulType.Value() } else { if !col.SQLType.IsJson() { - engine.autoMapType(fieldValue) - if table, ok := engine.Tables[fieldValue.Type()]; ok { + table, err := engine.tagParser.MapType(fieldValue) + if err != nil { + val = fieldValue.Interface() + } else { if len(table.PrimaryKeys) == 1 { pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) // fix non-int pk issues @@ -428,8 +430,6 @@ func (statement *Statement) buildUpdates(bean interface{}, // TODO: how to handler? panic("not supported") } - } else { - val = fieldValue.Interface() } } else { // Blank struct could not be as update data @@ -723,7 +723,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { t := v.Type() if t.Kind() == reflect.Struct { var err error - statement.RefTable, err = statement.Engine.autoMapType(v) + statement.RefTable, err = statement.Engine.tagParser.MapType(v) if err != nil { statement.Engine.logger.Error(err) return statement diff --git a/tags/parser.go b/tags/parser.go index 15dcaa30..5c94c55b 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -5,10 +5,12 @@ package tags import ( + "encoding/gob" "errors" "fmt" "reflect" "strings" + "sync" "time" "xorm.io/xorm/caches" @@ -25,6 +27,7 @@ type Parser struct { TableMapper names.Mapper handlers map[string]Handler cacherMgr *caches.Manager + tableCache sync.Map // map[reflect.Type]*schemas.Table } func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnMapper names.Mapper, cacherMgr *caches.Manager) *Parser { @@ -51,6 +54,36 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index } func (parser *Parser) MapType(v reflect.Value) (*schemas.Table, error) { + t := v.Type() + tableI, ok := parser.tableCache.Load(t) + if ok { + return tableI.(*schemas.Table), nil + } + + table, err := parser.mapType(v) + if err != nil { + return nil, err + } + + parser.tableCache.Store(t, table) + + if parser.cacherMgr.GetDefaultCacher() != nil { + if v.CanAddr() { + gob.Register(v.Addr().Interface()) + } else { + gob.Register(v.Interface()) + } + } + + return table, nil +} + +// ClearTable removes the database mapper of a type from the cache +func (parser *Parser) ClearTable(t reflect.Type) { + parser.tableCache.Delete(t) +} + +func (parser *Parser) mapType(v reflect.Value) (*schemas.Table, error) { t := v.Type() table := schemas.NewEmptyTable() table.Type = t diff --git a/tags/tag.go b/tags/tag.go index 3222615a..a043ed77 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -280,7 +280,7 @@ func ExtendsTagHandler(ctx *Context) error { isPtr = true fallthrough case reflect.Struct: - parentTable, err := ctx.parser.MapType(fieldValue) + parentTable, err := ctx.parser.mapType(fieldValue) if err != nil { return err } diff --git a/tags_test.go b/tags_test.go index 2d90948b..b8a43670 100644 --- a/tags_test.go +++ b/tags_test.go @@ -92,12 +92,8 @@ func TestExtends(t *testing.T) { tu9 := &tempUser4{} _, err = testEngine.Get(tu9) assert.NoError(t, err) - - if tu9.TempUser2.TempUser.Username != tu8.TempUser2.TempUser.Username || tu9.TempUser2.Departname != tu8.TempUser2.Departname { - err = errors.New(fmt.Sprintln("not equal for", tu8, tu9)) - t.Error(err) - panic(err) - } + assert.EqualValues(t, tu8.TempUser2.TempUser.Username, tu9.TempUser2.TempUser.Username) + assert.EqualValues(t, tu8.TempUser2.Departname, tu9.TempUser2.Departname) tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}} _, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10) @@ -117,17 +113,10 @@ func TestExtends(t *testing.T) { _, err = testEngine.Get(tu5) assert.NoError(t, err) - if tu5.Temp == nil { - err = errors.New("error get data extends") - t.Error(err) - panic(err) - } - if tu5.Temp.Id != 1 || tu5.Temp.Username != "extends" || - tu5.Departname != "dev depart" { - err = errors.New("error get data extends") - t.Error(err) - panic(err) - } + assert.NotNil(t, tu5.Temp) + assert.EqualValues(t, 1, tu5.Temp.Id) + assert.EqualValues(t, "extends", tu5.Temp.Username) + assert.EqualValues(t, "dev depart", tu5.Departname) tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} _, err = testEngine.ID(tu5.Temp.Id).Update(tu6) @@ -162,47 +151,25 @@ func TestExtends(t *testing.T) { qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) b, err := testEngine.SQL(sql).NoCascade().Get(&info) assert.NoError(t, err) - if !b { - err = errors.New("should has lest one record") - t.Error(err) - panic(err) - } - fmt.Println(info) - if info.Userinfo.Uid == 0 || info.Userdetail.Id == 0 { - err = errors.New("all of the id should has value") - t.Error(err) - panic(err) - } + assert.True(t, b, "should has lest one record") + assert.True(t, info.Userinfo.Uid > 0, "all of the id should has value") + assert.True(t, info.Userdetail.Id > 0, "all of the id should has value") - fmt.Println("----join--info2") var info2 UserAndDetail b, err = testEngine.Table(&Userinfo{}). Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). NoCascade().Get(&info2) - if err != nil { - t.Error(err) - panic(err) - } - if !b { - err = errors.New("should has lest one record") - t.Error(err) - panic(err) - } - if info2.Userinfo.Uid == 0 || info2.Userdetail.Id == 0 { - err = errors.New("all of the id should has value") - t.Error(err) - panic(err) - } - fmt.Println(info2) + assert.NoError(t, err) + assert.True(t, b) + assert.True(t, info2.Userinfo.Uid > 0, "all of the id should has value") + assert.True(t, info2.Userdetail.Id > 0, "all of the id should has value") - fmt.Println("----join--infos2") var infos2 = make([]UserAndDetail, 0) err = testEngine.Table(&Userinfo{}). Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). NoCascade(). Find(&infos2) assert.NoError(t, err) - fmt.Println(infos2) } type MessageBase struct { diff --git a/xorm.go b/xorm.go index f3230aa1..2946b7c9 100644 --- a/xorm.go +++ b/xorm.go @@ -10,9 +10,7 @@ import ( "context" "fmt" "os" - "reflect" "runtime" - "sync" "time" "xorm.io/xorm/caches" @@ -61,14 +59,17 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { return nil, err } + cacherMgr := caches.NewManager() + mapper := names.NewCacheMapper(new(names.SnakeMapper)) + tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr) + engine := &Engine{ db: db, dialect: dialect, - Tables: make(map[reflect.Type]*schemas.Table), - mutex: &sync.RWMutex{}, TZLocation: time.Local, defaultContext: context.Background(), - cacherMgr: caches.NewManager(), + cacherMgr: cacherMgr, + tagParser: tagParser, } if uri.DBType == schemas.SQLITE { @@ -80,8 +81,6 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { logger := log.NewSimpleLogger(os.Stdout) logger.SetLevel(log.LOG_INFO) engine.SetLogger(logger) - mapper := names.NewCacheMapper(new(names.SnakeMapper)) - engine.tagParser = tags.NewParser("xorm", dialect, mapper, mapper, engine.cacherMgr) runtime.SetFinalizer(engine, close)