Move maptype to tag parser #1561

Merged
lunny merged 1 commits from lunny/move_maptype into master 2 years ago
  1. 54
      engine.go
  2. 8
      engine_cond.go
  3. 2
      session.go
  4. 6
      session_convert.go
  5. 2
      session_find.go
  6. 14
      statement.go
  7. 33
      tags/parser.go
  8. 2
      tags/tag.go
  9. 59
      tags_test.go
  10. 13
      xorm.go

54
engine.go

@ -9,7 +9,6 @@ import ( @@ -9,7 +9,6 @@ import (
"bytes"
"context"
"database/sql"
"encoding/gob"
"errors"
"fmt"
"io"
@ -17,7 +16,6 @@ import ( @@ -17,7 +16,6 @@ import (
"reflect"
"strconv"
"strings"
"sync"
"time"
"xorm.io/builder"
@ -37,10 +35,6 @@ type Engine struct { @@ -37,10 +35,6 @@ type Engine struct {
db *core.DB
dialect dialects.Dialect
Tables map[reflect.Type]*schemas.Table
mutex *sync.RWMutex
showSQL bool
showExecTime bool
@ -753,43 +747,6 @@ func (engine *Engine) Having(conditions string) *Session { @@ -753,43 +747,6 @@ func (engine *Engine) Having(conditions string) *Session {
return session.Having(conditions)
}
// UnMapType removes the database mapper of a type
func (engine *Engine) UnMapType(t reflect.Type) {
engine.mutex.Lock()
defer engine.mutex.Unlock()
delete(engine.Tables, t)
}
func (engine *Engine) autoMapType(v reflect.Value) (*schemas.Table, error) {
t := v.Type()
engine.mutex.Lock()
defer engine.mutex.Unlock()
table, ok := engine.Tables[t]
if !ok {
var err error
table, err = engine.tagParser.MapType(v)
if err != nil {
return nil, err
}
engine.Tables[t] = table
if engine.GetDefaultCacher() != nil {
if v.CanAddr() {
engine.GobRegister(v.Addr().Interface())
} else {
engine.GobRegister(v.Interface())
}
}
}
return table, nil
}
// GobRegister register one struct to gob for cache use
func (engine *Engine) GobRegister(v interface{}) *Engine {
gob.Register(v)
return engine
}
// Table table struct
type Table struct {
*schemas.Table
@ -804,7 +761,7 @@ func (t *Table) IsValid() bool { @@ -804,7 +761,7 @@ func (t *Table) IsValid() bool {
// TableInfo get table info according to bean's content
func (engine *Engine) TableInfo(bean interface{}) *Table {
v := rValue(bean)
tb, err := engine.autoMapType(v)
tb, err := engine.tagParser.MapType(v)
if err != nil {
engine.logger.Error(err)
}
@ -842,7 +799,7 @@ func (engine *Engine) IDOfV(rv reflect.Value) schemas.PK { @@ -842,7 +799,7 @@ func (engine *Engine) IDOfV(rv reflect.Value) schemas.PK {
func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) {
v := reflect.Indirect(rv)
table, err := engine.autoMapType(v)
table, err := engine.tagParser.MapType(v)
if err != nil {
return nil, err
}
@ -938,6 +895,11 @@ func (engine *Engine) ClearCache(beans ...interface{}) error { @@ -938,6 +895,11 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
return nil
}
// UnMapType remove table from tables cache
func (engine *Engine) UnMapType(t reflect.Type) {
engine.tagParser.ClearTable(t)
}
// Sync the new struct changes to database, this method will automatically add
// table, column, index, unique. but will not delete or change anything.
// If you change some field, you should change the database manually.
@ -948,7 +910,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { @@ -948,7 +910,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
for _, bean := range beans {
v := rValue(bean)
tableNameNoSchema := engine.TableName(bean)
table, err := engine.autoMapType(v)
table, err := engine.tagParser.MapType(v)
if err != nil {
return err
}

8
engine_cond.go

@ -165,8 +165,10 @@ func (engine *Engine) buildConds(table *schemas.Table, bean interface{}, @@ -165,8 +165,10 @@ func (engine *Engine) buildConds(table *schemas.Table, bean interface{},
val = bytes
}
} else {
engine.autoMapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok {
table, err := engine.tagParser.MapType(fieldValue)
if err != nil {
val = fieldValue.Interface()
} else {
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues
@ -180,8 +182,6 @@ func (engine *Engine) buildConds(table *schemas.Table, bean interface{}, @@ -180,8 +182,6 @@ func (engine *Engine) buildConds(table *schemas.Table, bean interface{},
//TODO: how to handler?
return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys)
}
} else {
val = fieldValue.Interface()
}
}
}

2
session.go

@ -690,7 +690,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b @@ -690,7 +690,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
}
}
} else if session.statement.UseCascade {
table, err := session.engine.autoMapType(*fieldValue)
table, err := session.engine.tagParser.MapType(*fieldValue)
if err != nil {
return nil, err
}

6
session_convert.go

@ -209,7 +209,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val @@ -209,7 +209,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
v = x
fieldValue.Set(reflect.ValueOf(v).Convert(fieldType))
} else if session.statement.UseCascade {
table, err := session.engine.autoMapType(*fieldValue)
table, err := session.engine.tagParser.MapType(*fieldValue)
if err != nil {
return err
}
@ -492,7 +492,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val @@ -492,7 +492,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
default:
if session.statement.UseCascade {
structInter := reflect.New(fieldType.Elem())
table, err := session.engine.autoMapType(structInter.Elem())
table, err := session.engine.tagParser.MapType(structInter.Elem())
if err != nil {
return err
}
@ -603,7 +603,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. @@ -603,7 +603,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
return v.Value()
}
fieldTable, err := session.engine.autoMapType(fieldValue)
fieldTable, err := session.engine.tagParser.MapType(fieldValue)
if err != nil {
return nil, err
}

2
session_find.go

@ -275,7 +275,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect @@ -275,7 +275,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
if elemType.Kind() == reflect.Struct {
var newValue = newElemFunc(fields)
dataStruct := rValue(newValue.Interface())
tb, err := session.engine.autoMapType(dataStruct)
tb, err := session.engine.tagParser.MapType(dataStruct)
if err != nil {
return err
}

14
statement.go

@ -225,7 +225,7 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement @@ -225,7 +225,7 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement
func (statement *Statement) setRefValue(v reflect.Value) error {
var err error
statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
statement.RefTable, err = statement.Engine.tagParser.MapType(reflect.Indirect(v))
if err != nil {
return err
}
@ -235,7 +235,7 @@ func (statement *Statement) setRefValue(v reflect.Value) error { @@ -235,7 +235,7 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
func (statement *Statement) setRefBean(bean interface{}) error {
var err error
statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
statement.RefTable, err = statement.Engine.tagParser.MapType(rValue(bean))
if err != nil {
return err
}
@ -414,8 +414,10 @@ func (statement *Statement) buildUpdates(bean interface{}, @@ -414,8 +414,10 @@ func (statement *Statement) buildUpdates(bean interface{},
val, _ = nulType.Value()
} else {
if !col.SQLType.IsJson() {
engine.autoMapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok {
table, err := engine.tagParser.MapType(fieldValue)
if err != nil {
val = fieldValue.Interface()
} else {
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues
@ -428,8 +430,6 @@ func (statement *Statement) buildUpdates(bean interface{}, @@ -428,8 +430,6 @@ func (statement *Statement) buildUpdates(bean interface{},
// TODO: how to handler?
panic("not supported")
}
} else {
val = fieldValue.Interface()
}
} else {
// Blank struct could not be as update data
@ -723,7 +723,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { @@ -723,7 +723,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
t := v.Type()
if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
statement.RefTable, err = statement.Engine.tagParser.MapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement

33
tags/parser.go

@ -5,10 +5,12 @@ @@ -5,10 +5,12 @@
package tags
import (
"encoding/gob"
"errors"
"fmt"
"reflect"
"strings"
"sync"
"time"
"xorm.io/xorm/caches"
@ -25,6 +27,7 @@ type Parser struct { @@ -25,6 +27,7 @@ type Parser struct {
TableMapper names.Mapper
handlers map[string]Handler
cacherMgr *caches.Manager
tableCache sync.Map // map[reflect.Type]*schemas.Table
}
func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnMapper names.Mapper, cacherMgr *caches.Manager) *Parser {
@ -51,6 +54,36 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index @@ -51,6 +54,36 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index
}
func (parser *Parser) MapType(v reflect.Value) (*schemas.Table, error) {
t := v.Type()
tableI, ok := parser.tableCache.Load(t)
if ok {
return tableI.(*schemas.Table), nil
}
table, err := parser.mapType(v)
if err != nil {
return nil, err
}
parser.tableCache.Store(t, table)
if parser.cacherMgr.GetDefaultCacher() != nil {
if v.CanAddr() {
gob.Register(v.Addr().Interface())
} else {
gob.Register(v.Interface())
}
}
return table, nil
}
// ClearTable removes the database mapper of a type from the cache
func (parser *Parser) ClearTable(t reflect.Type) {
parser.tableCache.Delete(t)
}
func (parser *Parser) mapType(v reflect.Value) (*schemas.Table, error) {
t := v.Type()
table := schemas.NewEmptyTable()
table.Type = t

2
tags/tag.go

@ -280,7 +280,7 @@ func ExtendsTagHandler(ctx *Context) error { @@ -280,7 +280,7 @@ func ExtendsTagHandler(ctx *Context) error {
isPtr = true
fallthrough
case reflect.Struct:
parentTable, err := ctx.parser.MapType(fieldValue)
parentTable, err := ctx.parser.mapType(fieldValue)
if err != nil {
return err
}

59
tags_test.go

@ -92,12 +92,8 @@ func TestExtends(t *testing.T) { @@ -92,12 +92,8 @@ func TestExtends(t *testing.T) {
tu9 := &tempUser4{}
_, err = testEngine.Get(tu9)
assert.NoError(t, err)
if tu9.TempUser2.TempUser.Username != tu8.TempUser2.TempUser.Username || tu9.TempUser2.Departname != tu8.TempUser2.Departname {
err = errors.New(fmt.Sprintln("not equal for", tu8, tu9))
t.Error(err)
panic(err)
}
assert.EqualValues(t, tu8.TempUser2.TempUser.Username, tu9.TempUser2.TempUser.Username)
assert.EqualValues(t, tu8.TempUser2.Departname, tu9.TempUser2.Departname)
tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}}
_, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10)
@ -117,17 +113,10 @@ func TestExtends(t *testing.T) { @@ -117,17 +113,10 @@ func TestExtends(t *testing.T) {
_, err = testEngine.Get(tu5)
assert.NoError(t, err)
if tu5.Temp == nil {
err = errors.New("error get data extends")
t.Error(err)
panic(err)
}
if tu5.Temp.Id != 1 || tu5.Temp.Username != "extends" ||
tu5.Departname != "dev depart" {
err = errors.New("error get data extends")
t.Error(err)
panic(err)
}
assert.NotNil(t, tu5.Temp)
assert.EqualValues(t, 1, tu5.Temp.Id)
assert.EqualValues(t, "extends", tu5.Temp.Username)
assert.EqualValues(t, "dev depart", tu5.Departname)
tu6 := &tempUser3{&tempUser{0, "extends update"}, ""}
_, err = testEngine.ID(tu5.Temp.Id).Update(tu6)
@ -162,47 +151,25 @@ func TestExtends(t *testing.T) { @@ -162,47 +151,25 @@ func TestExtends(t *testing.T) {
qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid))
b, err := testEngine.SQL(sql).NoCascade().Get(&info)
assert.NoError(t, err)
if !b {
err = errors.New("should has lest one record")
t.Error(err)
panic(err)
}
fmt.Println(info)
if info.Userinfo.Uid == 0 || info.Userdetail.Id == 0 {
err = errors.New("all of the id should has value")
t.Error(err)
panic(err)
}
assert.True(t, b, "should has lest one record")
assert.True(t, info.Userinfo.Uid > 0, "all of the id should has value")
assert.True(t, info.Userdetail.Id > 0, "all of the id should has value")
fmt.Println("----join--info2")
var info2 UserAndDetail
b, err = testEngine.Table(&Userinfo{}).
Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)).
NoCascade().Get(&info2)
if err != nil {
t.Error(err)
panic(err)
}
if !b {
err = errors.New("should has lest one record")
t.Error(err)
panic(err)
}
if info2.Userinfo.Uid == 0 || info2.Userdetail.Id == 0 {
err = errors.New("all of the id should has value")
t.Error(err)
panic(err)
}
fmt.Println(info2)
assert.NoError(t, err)
assert.True(t, b)
assert.True(t, info2.Userinfo.Uid > 0, "all of the id should has value")
assert.True(t, info2.Userdetail.Id > 0, "all of the id should has value")
fmt.Println("----join--infos2")
var infos2 = make([]UserAndDetail, 0)
err = testEngine.Table(&Userinfo{}).
Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)).
NoCascade().
Find(&infos2)
assert.NoError(t, err)
fmt.Println(infos2)
}
type MessageBase struct {

13
xorm.go

@ -10,9 +10,7 @@ import ( @@ -10,9 +10,7 @@ import (
"context"
"fmt"
"os"
"reflect"
"runtime"
"sync"
"time"
"xorm.io/xorm/caches"
@ -61,14 +59,17 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { @@ -61,14 +59,17 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
return nil, err
}
cacherMgr := caches.NewManager()
mapper := names.NewCacheMapper(new(names.SnakeMapper))
tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr)
engine := &Engine{
db: db,
dialect: dialect,
Tables: make(map[reflect.Type]*schemas.Table),
mutex: &sync.RWMutex{},
TZLocation: time.Local,
defaultContext: context.Background(),
cacherMgr: caches.NewManager(),
cacherMgr: cacherMgr,
tagParser: tagParser,
}
if uri.DBType == schemas.SQLITE {
@ -80,8 +81,6 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { @@ -80,8 +81,6 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
logger := log.NewSimpleLogger(os.Stdout)
logger.SetLevel(log.LOG_INFO)
engine.SetLogger(logger)
mapper := names.NewCacheMapper(new(names.SnakeMapper))
engine.tagParser = tags.NewParser("xorm", dialect, mapper, mapper, engine.cacherMgr)
runtime.SetFinalizer(engine, close)

Loading…
Cancel
Save