Implement one associate belongs to #2006

Closed
lunny wants to merge 11 commits from lunny/belongs_to2 into master
11 changed files with 746 additions and 59 deletions

View File

@ -703,6 +703,13 @@ func (engine *Engine) AllCols() *Session {
return session.AllCols()
}
// Load loads associated fields from database
func (engine *Engine) Load(beanOrSlices interface{}, cols ...string) error {
session := engine.NewSession()
session.isAutoClose = true
return session.Load(beanOrSlices, cols...)
}
// MustCols specify some columns must use even if they are empty
func (engine *Engine) MustCols(columns ...string) *Session {
session := engine.NewSession()

View File

@ -0,0 +1,232 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package integrations
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestBelongsTo_Get(t *testing.T) {
assert.NoError(t, PrepareEngine())
type Face1 struct {
Id int64
Name string
}
type Nose1 struct {
Id int64
Face Face1 `xorm:"belongs_to"`
}
err := testEngine.Sync2(new(Nose1), new(Face1))
assert.NoError(t, err)
var face = Face1{
Name: "face1",
}
_, err = testEngine.Insert(&face)
assert.NoError(t, err)
var cfgFace Face1
has, err := testEngine.Get(&cfgFace)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, face, cfgFace)
var nose = Nose1{Face: face}
_, err = testEngine.Insert(&nose)
assert.NoError(t, err)
var cfgNose Nose1
has, err = testEngine.Get(&cfgNose)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, nose.Id, cfgNose.Id)
assert.Equal(t, nose.Face.Id, cfgNose.Face.Id)
assert.Equal(t, "", cfgNose.Face.Name)
err = testEngine.Load(&cfgNose)
assert.NoError(t, err)
assert.Equal(t, nose.Id, cfgNose.Id)
assert.Equal(t, nose.Face.Id, cfgNose.Face.Id)
assert.Equal(t, "face1", cfgNose.Face.Name)
var cfgNose2 Nose1
has, err = testEngine.Cascade().Get(&cfgNose2)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, nose.Id, cfgNose2.Id)
assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id)
assert.Equal(t, "face1", cfgNose2.Face.Name)
}
func TestBelongsTo_GetPtr(t *testing.T) {
assert.NoError(t, PrepareEngine())
type Face2 struct {
Id int64
Name string
}
type Nose2 struct {
Id int64
Face *Face2 `xorm:"belongs_to"`
}
err := testEngine.Sync2(new(Nose2), new(Face2))
assert.NoError(t, err)
var face = Face2{
Name: "face1",
}
_, err = testEngine.Insert(&face)
assert.NoError(t, err)
var cfgFace Face2
has, err := testEngine.Get(&cfgFace)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, face, cfgFace)
var nose = Nose2{Face: &face}
_, err = testEngine.Insert(&nose)
assert.NoError(t, err)
var cfgNose Nose2
has, err = testEngine.Get(&cfgNose)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, nose.Id, cfgNose.Id)
assert.Equal(t, nose.Face.Id, cfgNose.Face.Id)
err = testEngine.Load(&cfgNose)
assert.NoError(t, err)
assert.Equal(t, nose.Id, cfgNose.Id)
assert.Equal(t, nose.Face.Id, cfgNose.Face.Id)
assert.Equal(t, "face1", cfgNose.Face.Name)
var cfgNose2 Nose2
has, err = testEngine.Cascade().Get(&cfgNose2)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, nose.Id, cfgNose2.Id)
assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id)
assert.Equal(t, "face1", cfgNose2.Face.Name)
}
func TestBelongsTo_Find(t *testing.T) {
assert.NoError(t, PrepareEngine())
type Face3 struct {
Id int64
Name string
}
type Nose3 struct {
Id int64
Face Face3 `xorm:"belongs_to"`
}
err := testEngine.Sync2(new(Nose3), new(Face3))
assert.NoError(t, err)
var face1 = Face3{
Name: "face1",
}
var face2 = Face3{
Name: "face2",
}
_, err = testEngine.Insert(&face1, &face2)
assert.NoError(t, err)
var noses = []Nose3{
{Face: face1},
{Face: face2},
}
_, err = testEngine.Insert(&noses)
assert.NoError(t, err)
var noses1 []Nose3
err = testEngine.Find(&noses1)
assert.NoError(t, err)
assert.Equal(t, 2, len(noses1))
assert.Equal(t, face1.Id, noses1[0].Face.Id)
assert.Equal(t, face2.Id, noses1[1].Face.Id)
assert.Equal(t, "", noses1[0].Face.Name)
assert.Equal(t, "", noses1[1].Face.Name)
var noses2 []Nose3
err = testEngine.Cascade().Find(&noses2)
assert.NoError(t, err)
assert.Equal(t, 2, len(noses2))
assert.Equal(t, face1.Id, noses2[0].Face.Id)
assert.Equal(t, face2.Id, noses2[1].Face.Id)
assert.Equal(t, "face1", noses2[0].Face.Name)
assert.Equal(t, "face2", noses2[1].Face.Name)
err = testEngine.Load(noses1, "face_id")
assert.NoError(t, err)
assert.Equal(t, "face1", noses1[0].Face.Name)
assert.Equal(t, "face2", noses1[1].Face.Name)
}
func TestBelongsTo_FindPtr(t *testing.T) {
assert.NoError(t, PrepareEngine())
type Face4 struct {
Id int64
Name string
}
type Nose4 struct {
Id int64
Face *Face4 `xorm:"belongs_to"`
}
err := testEngine.Sync2(new(Nose4), new(Face4))
assert.NoError(t, err)
var face1 = Face4{
Name: "face1",
}
var face2 = Face4{
Name: "face2",
}
_, err = testEngine.Insert(&face1, &face2)
assert.NoError(t, err)
var noses = []Nose4{
{Face: &face1},
{Face: &face2},
}
_, err = testEngine.Insert(&noses)
assert.NoError(t, err)
var noses1 []Nose4
err = testEngine.Find(&noses1)
assert.NoError(t, err)
assert.Equal(t, 2, len(noses1))
assert.Equal(t, face1.Id, noses1[0].Face.Id)
assert.Equal(t, face2.Id, noses1[1].Face.Id)
assert.Equal(t, "", noses1[0].Face.Name)
assert.Equal(t, "", noses1[1].Face.Name)
var noses2 []Nose4
err = testEngine.Cascade().Find(&noses2)
assert.NoError(t, err)
assert.Equal(t, 2, len(noses2))
assert.NotNil(t, noses2[0].Face)
assert.NotNil(t, noses2[1].Face)
assert.Equal(t, face1.Id, noses2[0].Face.Id)
assert.Equal(t, face2.Id, noses2[1].Face.Id)
assert.Equal(t, "face1", noses2[0].Face.Name)
assert.Equal(t, "face2", noses2[1].Face.Name)
err = testEngine.Load(noses2, "face")
assert.NoError(t, err)
}

View File

@ -0,0 +1,66 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package integrations
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestExtendsTag(t *testing.T) {
assert.NoError(t, PrepareEngine())
table, err := testEngine.TableInfo(new(Userdetail))
assert.NoError(t, err)
assert.NotNil(t, table)
assert.EqualValues(t, 3, len(table.ColumnsSeq()))
assert.EqualValues(t, "id", table.ColumnsSeq()[0])
assert.EqualValues(t, "intro", table.ColumnsSeq()[1])
assert.EqualValues(t, "profile", table.ColumnsSeq()[2])
table, err = testEngine.TableInfo(new(Userinfo))
assert.NoError(t, err)
assert.NotNil(t, table)
assert.EqualValues(t, 8, len(table.ColumnsSeq()))
assert.EqualValues(t, "id", table.ColumnsSeq()[0])
assert.EqualValues(t, "username", table.ColumnsSeq()[1])
assert.EqualValues(t, "departname", table.ColumnsSeq()[2])
assert.EqualValues(t, "created", table.ColumnsSeq()[3])
assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4])
assert.EqualValues(t, "height", table.ColumnsSeq()[5])
assert.EqualValues(t, "avatar", table.ColumnsSeq()[6])
assert.EqualValues(t, "is_man", table.ColumnsSeq()[7])
table, err = testEngine.TableInfo(new(UserAndDetail))
assert.NoError(t, err)
assert.NotNil(t, table)
assert.EqualValues(t, 11, len(table.ColumnsSeq()))
assert.EqualValues(t, "id", table.ColumnsSeq()[0])
assert.EqualValues(t, "username", table.ColumnsSeq()[1])
assert.EqualValues(t, "departname", table.ColumnsSeq()[2])
assert.EqualValues(t, "created", table.ColumnsSeq()[3])
assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4])
assert.EqualValues(t, "height", table.ColumnsSeq()[5])
assert.EqualValues(t, "avatar", table.ColumnsSeq()[6])
assert.EqualValues(t, "is_man", table.ColumnsSeq()[7])
assert.EqualValues(t, "id", table.ColumnsSeq()[8])
assert.EqualValues(t, "intro", table.ColumnsSeq()[9])
assert.EqualValues(t, "profile", table.ColumnsSeq()[10])
cols := table.Columns()
assert.EqualValues(t, 11, len(cols))
assert.EqualValues(t, "Userinfo.Uid", cols[0].FieldName)
assert.EqualValues(t, "Userinfo.Username", cols[1].FieldName)
assert.EqualValues(t, "Userinfo.Departname", cols[2].FieldName)
assert.EqualValues(t, "Userinfo.Created", cols[3].FieldName)
assert.EqualValues(t, "Userinfo.Detail", cols[4].FieldName)
assert.EqualValues(t, "Userinfo.Height", cols[5].FieldName)
assert.EqualValues(t, "Userinfo.Avatar", cols[6].FieldName)
assert.EqualValues(t, "Userinfo.IsMan", cols[7].FieldName)
assert.EqualValues(t, "Userdetail.Id", cols[8].FieldName)
assert.EqualValues(t, "Userdetail.Intro", cols[9].FieldName)
assert.EqualValues(t, "Userdetail.Profile", cols[10].FieldName)
}

View File

@ -24,6 +24,7 @@ type Interface interface {
Alias(alias string) *Session
Asc(colNames ...string) *Session
BufferSize(size int) *Session
Cascade(trueOrFalse ...bool) *Session
Cols(columns ...string) *Session
Count(...interface{}) (int64, error)
CreateIndexes(bean interface{}) error
@ -48,6 +49,7 @@ type Interface interface {
IsTableExist(beanOrTableName interface{}) (bool, error)
Iterate(interface{}, IterFunc) error
Limit(int, ...int) *Session
Load(beanOrSlices interface{}, cols ...string) error
MustCols(columns ...string) *Session
NoAutoCondition(...bool) *Session
NotIn(string, ...interface{}) *Session

12
schemas/associate.go Normal file
View File

@ -0,0 +1,12 @@
// Copyright 2021 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
type AssociateType int
const (
AssociateNone AssociateType = iota
AssociateBelongsTo
)

View File

@ -22,8 +22,9 @@ const (
type Column struct {
Name string
TableName string
FieldName string // Available only when parsed from a struct
FieldIndex []int // Available only when parsed from a struct
FieldName string // Available only when parsed from a struct
FieldIndex []int // Available only when parsed from a struct
FieldType reflect.Type // Available only when parsed from a struct
SQLType SQLType
IsJSON bool
Length int
@ -45,6 +46,8 @@ type Column struct {
DisableTimeZone bool
TimeZone *time.Location // column specified time zone
Comment string
AssociateType
AssociateTable *Table
}
// NewColumn creates a new column
@ -71,6 +74,8 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable
DefaultIsEmpty: true, // default should be no default
EnumOptions: make(map[string]int),
Comment: "",
AssociateType: AssociateNone,
AssociateTable: nil,
}
}

View File

@ -10,7 +10,6 @@ import (
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"hash/crc32"
"io"
@ -54,6 +53,22 @@ const (
groupSession sessionType = true
)
type cascadeMode int
const (
cascadeCompitable cascadeMode = iota // load field beans with another SQL with no
cascadeEager // load field beans with another SQL
cascadeJoin // load field beans with join
cascadeLazy // don't load anything
)
type loadClosure struct {
Func func(schemas.PK, *reflect.Value) error
pk schemas.PK
fieldValue *reflect.Value
loaded bool
}
// Session keep a pointer to sql.DB and provides all execution of all
// kind of database operations.
type Session struct {
@ -86,6 +101,9 @@ type Session struct {
ctx context.Context
sessionType sessionType
cascadeMode cascadeMode
cascadeLevel int // load level
}
func newSessionID() string {
@ -134,7 +152,9 @@ func newSession(engine *Engine) *Session {
lastSQL: "",
lastSQLArgs: make([]interface{}, 0),
sessionType: engineSession,
sessionType: engineSession,
cascadeMode: cascadeCompitable,
cascadeLevel: 2,
}
if engine.logSessionID {
session.ctx = context.WithValue(session.ctx, log.SessionKey, session)
@ -241,7 +261,7 @@ func (session *Session) Alias(alias string) *Session {
// NoCascade indicate that no cascade load child object
func (session *Session) NoCascade() *Session {
session.statement.UseCascade = false
session.cascadeMode = cascadeLazy
return session
}
@ -296,9 +316,15 @@ func (session *Session) Charset(charset string) *Session {
// Cascade indicates if loading sub Struct
func (session *Session) Cascade(trueOrFalse ...bool) *Session {
var mode = cascadeEager
if len(trueOrFalse) >= 1 {
session.statement.UseCascade = trueOrFalse[0]
if trueOrFalse[0] {
mode = cascadeEager
} else {
mode = cascadeLazy
}
}
session.cascadeMode = mode
return session
}
@ -604,9 +630,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
}
fieldValue.Set(reflect.ValueOf(v).Elem().Convert(fieldType))
return nil
}
if fieldType.ConvertibleTo(schemas.TimeType) {
} else if fieldType.ConvertibleTo(schemas.TimeType) {
dbTZ := session.engine.DatabaseTZ
if col.TimeZone != nil {
dbTZ = col.TimeZone
@ -620,42 +644,50 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType))
return nil
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
err := nulVal.Scan(scanResult)
if err == nil {
return nil
}
session.engine.logger.Errorf("sql.Sanner error: %v", err)
} else if session.statement.UseCascade {
table, err := session.engine.tagParser.ParseWithCache(*fieldValue)
if err != nil {
return err
return nulVal.Scan(scanResult)
} else if session.cascadeLevel > 0 && ((col.AssociateType == schemas.AssociateNone &&
session.cascadeMode == cascadeCompitable) ||
(col.AssociateType == schemas.AssociateBelongsTo &&
session.cascadeMode == cascadeEager)) {
var associateTable *schemas.Table
if col.AssociateType == schemas.AssociateNone && col.AssociateTable == nil {
var err error
associateTable, err = session.engine.tagParser.ParseWithCache(*fieldValue)
if err != nil {
return err
}
} else {
associateTable = col.AssociateTable
}
if len(table.PrimaryKeys) != 1 {
return errors.New("unsupported non or composited primary key cascade")
}
var pk = make(schemas.PK, len(table.PrimaryKeys))
fmt.Println("=====", associateTable)
var pk = make(schemas.PK, len(associateTable.PrimaryKeys))
var err error
pk[0], err = asKind(vv, reflect.TypeOf(scanResult))
if err != nil {
return err
}
if !pk.IsZero() {
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily
structInter := reflect.New(fieldValue.Type())
has, err := session.ID(pk).NoCascade().get(structInter.Interface())
if err != nil {
return err
}
if has {
fieldValue.Set(structInter.Elem())
} else {
return errors.New("cascade obj is not exist")
}
}
session.afterProcessors = append(session.afterProcessors, executedProcessor{
fun: func(session *Session, bean interface{}) error {
fieldValue := bean.(*reflect.Value)
fmt.Println("3333333")
return session.getStructByPK(pk, fieldValue)
},
session: session,
bean: fieldValue,
})
session.cascadeLevel--
fmt.Println("222222")
return nil
} else if col.AssociateType == schemas.AssociateBelongsTo && session.cascadeMode == cascadeLazy {
pkCols := col.AssociateTable.PKColumns()
colV, err := pkCols[0].ValueOfV(fieldValue)
if err != nil {
return err
}
return convert.AssignValue(*colV, scanResult)
}
} // switch fieldType.Kind()

243
session_associate.go Normal file
View File

@ -0,0 +1,243 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"errors"
"fmt"
"reflect"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
// Load loads associated fields from database
func (session *Session) Load(beanOrSlices interface{}, cols ...string) error {
v := reflect.ValueOf(beanOrSlices)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() == reflect.Slice {
return session.loadFindSlice(v, cols...)
} else if v.Kind() == reflect.Map {
return session.loadFindMap(v, cols...)
} else if v.Kind() == reflect.Struct {
return session.loadGet(v, cols...)
}
return errors.New("unsupported load type, must struct, slice or map")
}
func isStringInSlice(s string, slice []string) bool {
for _, e := range slice {
if s == e {
return true
}
}
return false
}
// loadFind load 's belongs to tag field immedicatlly
func (session *Session) loadFindSlice(v reflect.Value, cols ...string) error {
if v.Kind() != reflect.Slice {
return errors.New("only slice is supported")
}
if v.Len() <= 0 {
return nil
}
tableValue := v.Index(0)
if tableValue.Kind() == reflect.Ptr {
tableValue = tableValue.Elem()
}
tb, err := session.engine.tagParser.ParseWithCache(tableValue)
if err != nil {
return err
}
type Va struct {
v []reflect.Value
pk []interface{}
col *schemas.Column
}
var pks = make(map[*schemas.Column]*Va)
for _, col := range tb.Columns() {
if col.AssociateTable == nil || col.AssociateType != schemas.AssociateBelongsTo {
continue
}
if len(cols) > 0 && !isStringInSlice(col.Name, cols) {
continue
}
pkCols := col.AssociateTable.PKColumns()
if len(pkCols) != 1 {
return fmt.Errorf("unsupported primary key number")
}
pks[col] = &Va{
col: pkCols[0],
}
}
for i := 0; i < v.Len(); i++ {
value := v.Index(i)
for col, va := range pks {
colV, err := col.ValueOfV(&value)
if err != nil {
return err
}
pkCols := col.AssociateTable.PKColumns()
pkV, err := pkCols[0].ValueOfV(colV)
if err != nil {
return err
}
vv := pkV.Interface()
if !utils.IsZero(vv) { // TODO: duplicate primary key
va.v = append(va.v, *colV)
va.pk = append(va.pk, vv)
}
}
}
for col, va := range pks {
pkCols := col.AssociateTable.PKColumns()
mp := reflect.MakeMap(reflect.MapOf(pkCols[0].FieldType, col.FieldType))
x := reflect.New(mp.Type())
x.Elem().Set(mp)
err = session.In(va.col.Name, va.pk...).find(x.Interface())
if err != nil {
return err
}
for _, v := range va.v {
pkCols := col.AssociateTable.PKColumns()
pkV, err := pkCols[0].ValueOfV(&v)
if err != nil {
return err
}
v.Set(mp.MapIndex(*pkV))
}
}
return nil
}
// loadFindMap load 's belongs to tag field immedicatlly
func (session *Session) loadFindMap(v reflect.Value, cols ...string) error {
if v.Kind() != reflect.Map {
return errors.New("only map is supported")
}
if v.Len() <= 0 {
return nil
}
vv := v.Index(0)
if vv.Kind() == reflect.Ptr {
vv = vv.Elem()
}
tb, err := session.engine.tagParser.ParseWithCache(vv)
if err != nil {
return err
}
var pks = make(map[*schemas.Column][]interface{})
for i := 0; i < v.Len(); i++ {
ev := v.Index(i)
for _, col := range tb.Columns() {
if len(cols) > 0 && !isStringInSlice(col.Name, cols) {
continue
}
if col.AssociateTable != nil {
if col.AssociateType == schemas.AssociateBelongsTo {
colV, err := col.ValueOfV(&ev)
if err != nil {
return err
}
vv := colV.Interface()
/*var colPtr reflect.Value
if colV.Kind() == reflect.Ptr {
colPtr = *colV
} else {
colPtr = colV.Addr()
}*/
if !utils.IsZero(vv) {
pks[col] = append(pks[col], vv)
}
}
}
}
}
for col, pk := range pks {
slice := reflect.MakeSlice(col.FieldType, 0, len(pk))
err = session.In(col.Name, pk...).find(slice.Addr().Interface())
if err != nil {
return err
}
}
return nil
}
// loadGet load bean's belongs to tag field immedicatlly
func (session *Session) loadGet(v reflect.Value, cols ...string) error {
if session.isAutoClose {
defer session.Close()
}
tb, err := session.engine.tagParser.ParseWithCache(v)
if err != nil {
return err
}
for _, col := range tb.Columns() {
if len(cols) > 0 && !isStringInSlice(col.Name, cols) {
continue
}
if col.AssociateTable == nil || col.AssociateType != schemas.AssociateBelongsTo {
continue
}
colV, err := col.ValueOfV(&v)
if err != nil {
return err
}
var colPtr reflect.Value
if colV.Kind() == reflect.Ptr {
colPtr = *colV
} else {
colPtr = colV.Addr()
}
pks := col.AssociateTable.PKColumns()
pkV, err := pks[0].ValueOfV(colV)
if err != nil {
return err
}
vv := pkV.Interface()
if !utils.IsZero(vv) && session.cascadeLevel > 0 {
has, err := session.ID(vv).NoAutoCondition().get(colPtr.Interface())
if err != nil {
return err
}
if !has {
return errors.New("load bean does not exist")
}
session.cascadeLevel--
}
}
return nil
}

View File

@ -280,6 +280,42 @@ func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields
}
}
func (session *Session) getStructByPK(pk schemas.PK, fieldValue *reflect.Value) error {
if pk.IsZero() {
return errors.New("getStructByPK: primary key is zero")
}
var structInter reflect.Value
if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
structInter = reflect.New(fieldValue.Type().Elem())
} else {
structInter = *fieldValue
}
} else {
structInter = fieldValue.Addr()
}
has, err := session.ID(pk).NoAutoCondition().get(structInter.Interface())
if err != nil {
return err
}
if !has {
return errors.New("cascade obj is not exist")
}
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
fieldValue.Set(structInter)
fmt.Println("getByPK value ptr:", fieldValue.Interface())
return nil
} else if fieldValue.Kind() == reflect.Struct {
fieldValue.Set(structInter.Elem())
fmt.Println("getByPK value:", fieldValue.Interface())
return nil
}
return errors.New("set value failed")
}
func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) {
// if has no reftable, then don't use cache currently
if !session.canCache() {

View File

@ -167,6 +167,7 @@ func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructFi
field.Name, sqlType, sqlType.DefaultLength,
sqlType.DefaultLength2, true)
col.FieldIndex = []int{fieldIndex}
col.FieldType = fieldValue.Type()
if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
col.IsAutoIncrement = true
@ -186,6 +187,7 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, f
MapType: schemas.TWOSIDES,
Indexes: make(map[string]int),
DefaultIsEmpty: true,
FieldType: fieldValue.Type(),
}
var ctx = Context{

View File

@ -5,6 +5,7 @@
package tags
import (
"errors"
"fmt"
"reflect"
"strconv"
@ -102,28 +103,29 @@ type Handler func(ctx *Context) error
var (
// defaultTagHandlers enumerates all the default tag handler
defaultTagHandlers = map[string]Handler{
"-": IgnoreHandler,
"<-": OnlyFromDBTagHandler,
"->": OnlyToDBTagHandler,
"PK": PKTagHandler,
"NULL": NULLTagHandler,
"NOT": NotTagHandler,
"AUTOINCR": AutoIncrTagHandler,
"DEFAULT": DefaultTagHandler,
"CREATED": CreatedTagHandler,
"UPDATED": UpdatedTagHandler,
"DELETED": DeletedTagHandler,
"VERSION": VersionTagHandler,
"UTC": UTCTagHandler,
"LOCAL": LocalTagHandler,
"NOTNULL": NotNullTagHandler,
"INDEX": IndexTagHandler,
"UNIQUE": UniqueTagHandler,
"CACHE": CacheTagHandler,
"NOCACHE": NoCacheTagHandler,
"COMMENT": CommentTagHandler,
"EXTENDS": ExtendsTagHandler,
"UNSIGNED": UnsignedTagHandler,
"-": IgnoreHandler,
"<-": OnlyFromDBTagHandler,
"->": OnlyToDBTagHandler,
"PK": PKTagHandler,
"NULL": NULLTagHandler,
"NOT": NotTagHandler,
"AUTOINCR": AutoIncrTagHandler,
"DEFAULT": DefaultTagHandler,
"CREATED": CreatedTagHandler,
"UPDATED": UpdatedTagHandler,
"DELETED": DeletedTagHandler,
"VERSION": VersionTagHandler,
"UTC": UTCTagHandler,
"LOCAL": LocalTagHandler,
"NOTNULL": NotNullTagHandler,
"INDEX": IndexTagHandler,
"UNIQUE": UniqueTagHandler,
"CACHE": CacheTagHandler,
"NOCACHE": NoCacheTagHandler,
"COMMENT": CommentTagHandler,
"EXTENDS": ExtendsTagHandler,
"UNSIGNED": UnsignedTagHandler,
"BELONGS_TO": BelongsToTagHandler,
}
)
@ -398,3 +400,51 @@ func NoCacheTagHandler(ctx *Context) error {
}
return nil
}
func isStruct(t reflect.Type) bool {
return t.Kind() == reflect.Struct || isPtrStruct(t)
}
func isPtrStruct(t reflect.Type) bool {
return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
}
// BelongsToTagHandler describes belongs_to tag handler
func BelongsToTagHandler(ctx *Context) error {
if !isStruct(ctx.fieldValue.Type()) {
return errors.New("tag belongs_to cannot be applied on non-struct field")
}
ctx.col.AssociateType = schemas.AssociateBelongsTo
var t reflect.Value
if ctx.fieldValue.Kind() == reflect.Struct {
t = ctx.fieldValue
} else {
if ctx.fieldValue.Type().Kind() == reflect.Ptr && ctx.fieldValue.Type().Elem().Kind() == reflect.Struct {
if ctx.fieldValue.IsNil() {
t = reflect.New(ctx.fieldValue.Type().Elem()).Elem()
} else {
t = ctx.fieldValue
}
} else {
return errors.New("only struct or ptr to struct field could add belongs_to flag")
}
}
belongsT, err := ctx.parser.ParseWithCache(t)
if err != nil {
return err
}
pks := belongsT.PKColumns()
if len(pks) != 1 {
return errors.New("blongs_to only should be as a tag of table has one primary key")
}
ctx.col.AssociateTable = belongsT
ctx.col.SQLType = pks[0].SQLType
if len(ctx.col.Name) == 0 {
ctx.col.Name = ctx.parser.columnMapper.Obj2Table(ctx.col.FieldName) + "_id"
}
return nil
}