diff --git a/engine.go b/engine.go index 7a57b08a..9325ffeb 100644 --- a/engine.go +++ b/engine.go @@ -703,6 +703,13 @@ func (engine *Engine) AllCols() *Session { return session.AllCols() } +// Load loads associated fields from database +func (engine *Engine) Load(beanOrSlices interface{}, cols ...string) error { + session := engine.NewSession() + session.isAutoClose = true + return session.Load(beanOrSlices, cols...) +} + // MustCols specify some columns must use even if they are empty func (engine *Engine) MustCols(columns ...string) *Session { session := engine.NewSession() diff --git a/integrations/session_associate_test.go b/integrations/session_associate_test.go new file mode 100644 index 00000000..9406d996 --- /dev/null +++ b/integrations/session_associate_test.go @@ -0,0 +1,232 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBelongsTo_Get(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type Face1 struct { + Id int64 + Name string + } + + type Nose1 struct { + Id int64 + Face Face1 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose1), new(Face1)) + assert.NoError(t, err) + + var face = Face1{ + Name: "face1", + } + _, err = testEngine.Insert(&face) + assert.NoError(t, err) + + var cfgFace Face1 + has, err := testEngine.Get(&cfgFace) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, face, cfgFace) + + var nose = Nose1{Face: face} + _, err = testEngine.Insert(&nose) + assert.NoError(t, err) + + var cfgNose Nose1 + has, err = testEngine.Get(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "", cfgNose.Face.Name) + + err = testEngine.Load(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "face1", cfgNose.Face.Name) + + var cfgNose2 Nose1 + has, err = testEngine.Cascade().Get(&cfgNose2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose2.Id) + assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id) + assert.Equal(t, "face1", cfgNose2.Face.Name) +} + +func TestBelongsTo_GetPtr(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type Face2 struct { + Id int64 + Name string + } + + type Nose2 struct { + Id int64 + Face *Face2 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose2), new(Face2)) + assert.NoError(t, err) + + var face = Face2{ + Name: "face1", + } + _, err = testEngine.Insert(&face) + assert.NoError(t, err) + + var cfgFace Face2 + has, err := testEngine.Get(&cfgFace) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, face, cfgFace) + + var nose = Nose2{Face: &face} + _, err = testEngine.Insert(&nose) + assert.NoError(t, err) + + var cfgNose Nose2 + has, err = testEngine.Get(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + + err = testEngine.Load(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "face1", cfgNose.Face.Name) + + var cfgNose2 Nose2 + has, err = testEngine.Cascade().Get(&cfgNose2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose2.Id) + assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id) + assert.Equal(t, "face1", cfgNose2.Face.Name) +} + +func TestBelongsTo_Find(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type Face3 struct { + Id int64 + Name string + } + + type Nose3 struct { + Id int64 + Face Face3 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose3), new(Face3)) + assert.NoError(t, err) + + var face1 = Face3{ + Name: "face1", + } + var face2 = Face3{ + Name: "face2", + } + _, err = testEngine.Insert(&face1, &face2) + assert.NoError(t, err) + + var noses = []Nose3{ + {Face: face1}, + {Face: face2}, + } + _, err = testEngine.Insert(&noses) + assert.NoError(t, err) + + var noses1 []Nose3 + err = testEngine.Find(&noses1) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses1)) + assert.Equal(t, face1.Id, noses1[0].Face.Id) + assert.Equal(t, face2.Id, noses1[1].Face.Id) + assert.Equal(t, "", noses1[0].Face.Name) + assert.Equal(t, "", noses1[1].Face.Name) + + var noses2 []Nose3 + err = testEngine.Cascade().Find(&noses2) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses2)) + assert.Equal(t, face1.Id, noses2[0].Face.Id) + assert.Equal(t, face2.Id, noses2[1].Face.Id) + assert.Equal(t, "face1", noses2[0].Face.Name) + assert.Equal(t, "face2", noses2[1].Face.Name) + + err = testEngine.Load(noses1, "face_id") + assert.NoError(t, err) + assert.Equal(t, "face1", noses1[0].Face.Name) + assert.Equal(t, "face2", noses1[1].Face.Name) +} + +func TestBelongsTo_FindPtr(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type Face4 struct { + Id int64 + Name string + } + + type Nose4 struct { + Id int64 + Face *Face4 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose4), new(Face4)) + assert.NoError(t, err) + + var face1 = Face4{ + Name: "face1", + } + var face2 = Face4{ + Name: "face2", + } + _, err = testEngine.Insert(&face1, &face2) + assert.NoError(t, err) + + var noses = []Nose4{ + {Face: &face1}, + {Face: &face2}, + } + _, err = testEngine.Insert(&noses) + assert.NoError(t, err) + + var noses1 []Nose4 + err = testEngine.Find(&noses1) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses1)) + assert.Equal(t, face1.Id, noses1[0].Face.Id) + assert.Equal(t, face2.Id, noses1[1].Face.Id) + assert.Equal(t, "", noses1[0].Face.Name) + assert.Equal(t, "", noses1[1].Face.Name) + + var noses2 []Nose4 + err = testEngine.Cascade().Find(&noses2) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses2)) + assert.NotNil(t, noses2[0].Face) + assert.NotNil(t, noses2[1].Face) + assert.Equal(t, face1.Id, noses2[0].Face.Id) + assert.Equal(t, face2.Id, noses2[1].Face.Id) + assert.Equal(t, "face1", noses2[0].Face.Name) + assert.Equal(t, "face2", noses2[1].Face.Name) + + err = testEngine.Load(noses2, "face") + assert.NoError(t, err) +} diff --git a/integrations/session_tag_test.go b/integrations/session_tag_test.go new file mode 100644 index 00000000..170e0374 --- /dev/null +++ b/integrations/session_tag_test.go @@ -0,0 +1,66 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtendsTag(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + table, err := testEngine.TableInfo(new(Userdetail)) + assert.NoError(t, err) + assert.NotNil(t, table) + assert.EqualValues(t, 3, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "intro", table.ColumnsSeq()[1]) + assert.EqualValues(t, "profile", table.ColumnsSeq()[2]) + + table, err = testEngine.TableInfo(new(Userinfo)) + assert.NoError(t, err) + assert.NotNil(t, table) + assert.EqualValues(t, 8, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "username", table.ColumnsSeq()[1]) + assert.EqualValues(t, "departname", table.ColumnsSeq()[2]) + assert.EqualValues(t, "created", table.ColumnsSeq()[3]) + assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4]) + assert.EqualValues(t, "height", table.ColumnsSeq()[5]) + assert.EqualValues(t, "avatar", table.ColumnsSeq()[6]) + assert.EqualValues(t, "is_man", table.ColumnsSeq()[7]) + + table, err = testEngine.TableInfo(new(UserAndDetail)) + assert.NoError(t, err) + assert.NotNil(t, table) + assert.EqualValues(t, 11, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "username", table.ColumnsSeq()[1]) + assert.EqualValues(t, "departname", table.ColumnsSeq()[2]) + assert.EqualValues(t, "created", table.ColumnsSeq()[3]) + assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4]) + assert.EqualValues(t, "height", table.ColumnsSeq()[5]) + assert.EqualValues(t, "avatar", table.ColumnsSeq()[6]) + assert.EqualValues(t, "is_man", table.ColumnsSeq()[7]) + assert.EqualValues(t, "id", table.ColumnsSeq()[8]) + assert.EqualValues(t, "intro", table.ColumnsSeq()[9]) + assert.EqualValues(t, "profile", table.ColumnsSeq()[10]) + + cols := table.Columns() + assert.EqualValues(t, 11, len(cols)) + assert.EqualValues(t, "Userinfo.Uid", cols[0].FieldName) + assert.EqualValues(t, "Userinfo.Username", cols[1].FieldName) + assert.EqualValues(t, "Userinfo.Departname", cols[2].FieldName) + assert.EqualValues(t, "Userinfo.Created", cols[3].FieldName) + assert.EqualValues(t, "Userinfo.Detail", cols[4].FieldName) + assert.EqualValues(t, "Userinfo.Height", cols[5].FieldName) + assert.EqualValues(t, "Userinfo.Avatar", cols[6].FieldName) + assert.EqualValues(t, "Userinfo.IsMan", cols[7].FieldName) + assert.EqualValues(t, "Userdetail.Id", cols[8].FieldName) + assert.EqualValues(t, "Userdetail.Intro", cols[9].FieldName) + assert.EqualValues(t, "Userdetail.Profile", cols[10].FieldName) +} diff --git a/interface.go b/interface.go index 42dc9a0a..bbb8062f 100644 --- a/interface.go +++ b/interface.go @@ -24,6 +24,7 @@ type Interface interface { Alias(alias string) *Session Asc(colNames ...string) *Session BufferSize(size int) *Session + Cascade(trueOrFalse ...bool) *Session Cols(columns ...string) *Session Count(...interface{}) (int64, error) CreateIndexes(bean interface{}) error @@ -48,6 +49,7 @@ type Interface interface { IsTableExist(beanOrTableName interface{}) (bool, error) Iterate(interface{}, IterFunc) error Limit(int, ...int) *Session + Load(beanOrSlices interface{}, cols ...string) error MustCols(columns ...string) *Session NoAutoCondition(...bool) *Session NotIn(string, ...interface{}) *Session diff --git a/schemas/associate.go b/schemas/associate.go new file mode 100644 index 00000000..3dfd85f8 --- /dev/null +++ b/schemas/associate.go @@ -0,0 +1,12 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +type AssociateType int + +const ( + AssociateNone AssociateType = iota + AssociateBelongsTo +) diff --git a/schemas/column.go b/schemas/column.go index 4bbb6c2d..7ded9677 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -22,8 +22,9 @@ const ( type Column struct { Name string TableName string - FieldName string // Available only when parsed from a struct - FieldIndex []int // Available only when parsed from a struct + FieldName string // Available only when parsed from a struct + FieldIndex []int // Available only when parsed from a struct + FieldType reflect.Type // Available only when parsed from a struct SQLType SQLType IsJSON bool Length int @@ -45,6 +46,8 @@ type Column struct { DisableTimeZone bool TimeZone *time.Location // column specified time zone Comment string + AssociateType + AssociateTable *Table } // NewColumn creates a new column @@ -71,6 +74,8 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable DefaultIsEmpty: true, // default should be no default EnumOptions: make(map[string]int), Comment: "", + AssociateType: AssociateNone, + AssociateTable: nil, } } diff --git a/session.go b/session.go index f5b45a73..14989781 100644 --- a/session.go +++ b/session.go @@ -10,7 +10,6 @@ import ( "crypto/sha256" "database/sql" "encoding/hex" - "errors" "fmt" "hash/crc32" "io" @@ -54,6 +53,22 @@ const ( groupSession sessionType = true ) +type cascadeMode int + +const ( + cascadeCompitable cascadeMode = iota // load field beans with another SQL with no + cascadeEager // load field beans with another SQL + cascadeJoin // load field beans with join + cascadeLazy // don't load anything +) + +type loadClosure struct { + Func func(schemas.PK, *reflect.Value) error + pk schemas.PK + fieldValue *reflect.Value + loaded bool +} + // Session keep a pointer to sql.DB and provides all execution of all // kind of database operations. type Session struct { @@ -86,6 +101,9 @@ type Session struct { ctx context.Context sessionType sessionType + + cascadeMode cascadeMode + cascadeLevel int // load level } func newSessionID() string { @@ -134,7 +152,9 @@ func newSession(engine *Engine) *Session { lastSQL: "", lastSQLArgs: make([]interface{}, 0), - sessionType: engineSession, + sessionType: engineSession, + cascadeMode: cascadeCompitable, + cascadeLevel: 2, } if engine.logSessionID { session.ctx = context.WithValue(session.ctx, log.SessionKey, session) @@ -241,7 +261,7 @@ func (session *Session) Alias(alias string) *Session { // NoCascade indicate that no cascade load child object func (session *Session) NoCascade() *Session { - session.statement.UseCascade = false + session.cascadeMode = cascadeLazy return session } @@ -296,9 +316,15 @@ func (session *Session) Charset(charset string) *Session { // Cascade indicates if loading sub Struct func (session *Session) Cascade(trueOrFalse ...bool) *Session { + var mode = cascadeEager if len(trueOrFalse) >= 1 { - session.statement.UseCascade = trueOrFalse[0] + if trueOrFalse[0] { + mode = cascadeEager + } else { + mode = cascadeLazy + } } + session.cascadeMode = mode return session } @@ -604,9 +630,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } fieldValue.Set(reflect.ValueOf(v).Elem().Convert(fieldType)) return nil - } - - if fieldType.ConvertibleTo(schemas.TimeType) { + } else if fieldType.ConvertibleTo(schemas.TimeType) { dbTZ := session.engine.DatabaseTZ if col.TimeZone != nil { dbTZ = col.TimeZone @@ -620,42 +644,50 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType)) return nil } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err := nulVal.Scan(scanResult) - if err == nil { - return nil - } - session.engine.logger.Errorf("sql.Sanner error: %v", err) - } else if session.statement.UseCascade { - table, err := session.engine.tagParser.ParseWithCache(*fieldValue) - if err != nil { - return err + return nulVal.Scan(scanResult) + } else if session.cascadeLevel > 0 && ((col.AssociateType == schemas.AssociateNone && + session.cascadeMode == cascadeCompitable) || + (col.AssociateType == schemas.AssociateBelongsTo && + session.cascadeMode == cascadeEager)) { + var associateTable *schemas.Table + if col.AssociateType == schemas.AssociateNone && col.AssociateTable == nil { + var err error + associateTable, err = session.engine.tagParser.ParseWithCache(*fieldValue) + if err != nil { + return err + } + } else { + associateTable = col.AssociateTable } - if len(table.PrimaryKeys) != 1 { - return errors.New("unsupported non or composited primary key cascade") - } - var pk = make(schemas.PK, len(table.PrimaryKeys)) + fmt.Println("=====", associateTable) + + var pk = make(schemas.PK, len(associateTable.PrimaryKeys)) + var err error pk[0], err = asKind(vv, reflect.TypeOf(scanResult)) if err != nil { return err } - if !pk.IsZero() { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return err - } - if has { - fieldValue.Set(structInter.Elem()) - } else { - return errors.New("cascade obj is not exist") - } - } + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(session *Session, bean interface{}) error { + fieldValue := bean.(*reflect.Value) + fmt.Println("3333333") + return session.getStructByPK(pk, fieldValue) + }, + session: session, + bean: fieldValue, + }) + session.cascadeLevel-- + fmt.Println("222222") return nil + } else if col.AssociateType == schemas.AssociateBelongsTo && session.cascadeMode == cascadeLazy { + pkCols := col.AssociateTable.PKColumns() + colV, err := pkCols[0].ValueOfV(fieldValue) + if err != nil { + return err + } + return convert.AssignValue(*colV, scanResult) } } // switch fieldType.Kind() diff --git a/session_associate.go b/session_associate.go new file mode 100644 index 00000000..9279053b --- /dev/null +++ b/session_associate.go @@ -0,0 +1,243 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "errors" + "fmt" + "reflect" + + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" +) + +// Load loads associated fields from database +func (session *Session) Load(beanOrSlices interface{}, cols ...string) error { + v := reflect.ValueOf(beanOrSlices) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() == reflect.Slice { + return session.loadFindSlice(v, cols...) + } else if v.Kind() == reflect.Map { + return session.loadFindMap(v, cols...) + } else if v.Kind() == reflect.Struct { + return session.loadGet(v, cols...) + } + return errors.New("unsupported load type, must struct, slice or map") +} + +func isStringInSlice(s string, slice []string) bool { + for _, e := range slice { + if s == e { + return true + } + } + return false +} + +// loadFind load 's belongs to tag field immedicatlly +func (session *Session) loadFindSlice(v reflect.Value, cols ...string) error { + if v.Kind() != reflect.Slice { + return errors.New("only slice is supported") + } + + if v.Len() <= 0 { + return nil + } + + tableValue := v.Index(0) + if tableValue.Kind() == reflect.Ptr { + tableValue = tableValue.Elem() + } + tb, err := session.engine.tagParser.ParseWithCache(tableValue) + if err != nil { + return err + } + + type Va struct { + v []reflect.Value + pk []interface{} + col *schemas.Column + } + + var pks = make(map[*schemas.Column]*Va) + for _, col := range tb.Columns() { + if col.AssociateTable == nil || col.AssociateType != schemas.AssociateBelongsTo { + continue + } + + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + pkCols := col.AssociateTable.PKColumns() + if len(pkCols) != 1 { + return fmt.Errorf("unsupported primary key number") + } + + pks[col] = &Va{ + col: pkCols[0], + } + } + + for i := 0; i < v.Len(); i++ { + value := v.Index(i) + for col, va := range pks { + colV, err := col.ValueOfV(&value) + if err != nil { + return err + } + + pkCols := col.AssociateTable.PKColumns() + pkV, err := pkCols[0].ValueOfV(colV) + if err != nil { + return err + } + vv := pkV.Interface() + if !utils.IsZero(vv) { // TODO: duplicate primary key + va.v = append(va.v, *colV) + va.pk = append(va.pk, vv) + } + } + } + + for col, va := range pks { + pkCols := col.AssociateTable.PKColumns() + mp := reflect.MakeMap(reflect.MapOf(pkCols[0].FieldType, col.FieldType)) + x := reflect.New(mp.Type()) + x.Elem().Set(mp) + + err = session.In(va.col.Name, va.pk...).find(x.Interface()) + if err != nil { + return err + } + + for _, v := range va.v { + pkCols := col.AssociateTable.PKColumns() + pkV, err := pkCols[0].ValueOfV(&v) + if err != nil { + return err + } + + v.Set(mp.MapIndex(*pkV)) + } + } + return nil +} + +// loadFindMap load 's belongs to tag field immedicatlly +func (session *Session) loadFindMap(v reflect.Value, cols ...string) error { + if v.Kind() != reflect.Map { + return errors.New("only map is supported") + } + + if v.Len() <= 0 { + return nil + } + + vv := v.Index(0) + if vv.Kind() == reflect.Ptr { + vv = vv.Elem() + } + tb, err := session.engine.tagParser.ParseWithCache(vv) + if err != nil { + return err + } + + var pks = make(map[*schemas.Column][]interface{}) + for i := 0; i < v.Len(); i++ { + ev := v.Index(i) + + for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + if col.AssociateTable != nil { + if col.AssociateType == schemas.AssociateBelongsTo { + colV, err := col.ValueOfV(&ev) + if err != nil { + return err + } + + vv := colV.Interface() + /*var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + }*/ + + if !utils.IsZero(vv) { + pks[col] = append(pks[col], vv) + } + } + } + } + } + + for col, pk := range pks { + slice := reflect.MakeSlice(col.FieldType, 0, len(pk)) + err = session.In(col.Name, pk...).find(slice.Addr().Interface()) + if err != nil { + return err + } + } + return nil +} + +// loadGet load bean's belongs to tag field immedicatlly +func (session *Session) loadGet(v reflect.Value, cols ...string) error { + if session.isAutoClose { + defer session.Close() + } + + tb, err := session.engine.tagParser.ParseWithCache(v) + if err != nil { + return err + } + + for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + if col.AssociateTable == nil || col.AssociateType != schemas.AssociateBelongsTo { + continue + } + + colV, err := col.ValueOfV(&v) + if err != nil { + return err + } + + var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + } + + pks := col.AssociateTable.PKColumns() + pkV, err := pks[0].ValueOfV(colV) + if err != nil { + return err + } + vv := pkV.Interface() + + if !utils.IsZero(vv) && session.cascadeLevel > 0 { + has, err := session.ID(vv).NoAutoCondition().get(colPtr.Interface()) + if err != nil { + return err + } + if !has { + return errors.New("load bean does not exist") + } + session.cascadeLevel-- + } + } + return nil +} diff --git a/session_get.go b/session_get.go index 48616a6b..9bf916ce 100644 --- a/session_get.go +++ b/session_get.go @@ -280,6 +280,42 @@ func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields } } +func (session *Session) getStructByPK(pk schemas.PK, fieldValue *reflect.Value) error { + if pk.IsZero() { + return errors.New("getStructByPK: primary key is zero") + } + + var structInter reflect.Value + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + structInter = reflect.New(fieldValue.Type().Elem()) + } else { + structInter = *fieldValue + } + } else { + structInter = fieldValue.Addr() + } + + has, err := session.ID(pk).NoAutoCondition().get(structInter.Interface()) + if err != nil { + return err + } + if !has { + return errors.New("cascade obj is not exist") + } + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(structInter) + fmt.Println("getByPK value ptr:", fieldValue.Interface()) + return nil + } else if fieldValue.Kind() == reflect.Struct { + fieldValue.Set(structInter.Elem()) + fmt.Println("getByPK value:", fieldValue.Interface()) + return nil + } + return errors.New("set value failed") + +} + func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { // if has no reftable, then don't use cache currently if !session.canCache() { diff --git a/tags/parser.go b/tags/parser.go index efee11e7..03aa5f63 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -167,6 +167,7 @@ func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructFi field.Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true) col.FieldIndex = []int{fieldIndex} + col.FieldType = fieldValue.Type() if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { col.IsAutoIncrement = true @@ -186,6 +187,7 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, f MapType: schemas.TWOSIDES, Indexes: make(map[string]int), DefaultIsEmpty: true, + FieldType: fieldValue.Type(), } var ctx = Context{ diff --git a/tags/tag.go b/tags/tag.go index 4e1f1ce7..c354617b 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -5,6 +5,7 @@ package tags import ( + "errors" "fmt" "reflect" "strconv" @@ -102,28 +103,29 @@ type Handler func(ctx *Context) error var ( // defaultTagHandlers enumerates all the default tag handler defaultTagHandlers = map[string]Handler{ - "-": IgnoreHandler, - "<-": OnlyFromDBTagHandler, - "->": OnlyToDBTagHandler, - "PK": PKTagHandler, - "NULL": NULLTagHandler, - "NOT": NotTagHandler, - "AUTOINCR": AutoIncrTagHandler, - "DEFAULT": DefaultTagHandler, - "CREATED": CreatedTagHandler, - "UPDATED": UpdatedTagHandler, - "DELETED": DeletedTagHandler, - "VERSION": VersionTagHandler, - "UTC": UTCTagHandler, - "LOCAL": LocalTagHandler, - "NOTNULL": NotNullTagHandler, - "INDEX": IndexTagHandler, - "UNIQUE": UniqueTagHandler, - "CACHE": CacheTagHandler, - "NOCACHE": NoCacheTagHandler, - "COMMENT": CommentTagHandler, - "EXTENDS": ExtendsTagHandler, - "UNSIGNED": UnsignedTagHandler, + "-": IgnoreHandler, + "<-": OnlyFromDBTagHandler, + "->": OnlyToDBTagHandler, + "PK": PKTagHandler, + "NULL": NULLTagHandler, + "NOT": NotTagHandler, + "AUTOINCR": AutoIncrTagHandler, + "DEFAULT": DefaultTagHandler, + "CREATED": CreatedTagHandler, + "UPDATED": UpdatedTagHandler, + "DELETED": DeletedTagHandler, + "VERSION": VersionTagHandler, + "UTC": UTCTagHandler, + "LOCAL": LocalTagHandler, + "NOTNULL": NotNullTagHandler, + "INDEX": IndexTagHandler, + "UNIQUE": UniqueTagHandler, + "CACHE": CacheTagHandler, + "NOCACHE": NoCacheTagHandler, + "COMMENT": CommentTagHandler, + "EXTENDS": ExtendsTagHandler, + "UNSIGNED": UnsignedTagHandler, + "BELONGS_TO": BelongsToTagHandler, } ) @@ -398,3 +400,51 @@ func NoCacheTagHandler(ctx *Context) error { } return nil } + +func isStruct(t reflect.Type) bool { + return t.Kind() == reflect.Struct || isPtrStruct(t) +} + +func isPtrStruct(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + +// BelongsToTagHandler describes belongs_to tag handler +func BelongsToTagHandler(ctx *Context) error { + if !isStruct(ctx.fieldValue.Type()) { + return errors.New("tag belongs_to cannot be applied on non-struct field") + } + + ctx.col.AssociateType = schemas.AssociateBelongsTo + var t reflect.Value + if ctx.fieldValue.Kind() == reflect.Struct { + t = ctx.fieldValue + } else { + if ctx.fieldValue.Type().Kind() == reflect.Ptr && ctx.fieldValue.Type().Elem().Kind() == reflect.Struct { + if ctx.fieldValue.IsNil() { + t = reflect.New(ctx.fieldValue.Type().Elem()).Elem() + } else { + t = ctx.fieldValue + } + } else { + return errors.New("only struct or ptr to struct field could add belongs_to flag") + } + } + + belongsT, err := ctx.parser.ParseWithCache(t) + if err != nil { + return err + } + pks := belongsT.PKColumns() + if len(pks) != 1 { + return errors.New("blongs_to only should be as a tag of table has one primary key") + } + + ctx.col.AssociateTable = belongsT + ctx.col.SQLType = pks[0].SQLType + + if len(ctx.col.Name) == 0 { + ctx.col.Name = ctx.parser.columnMapper.Obj2Table(ctx.col.FieldName) + "_id" + } + return nil +}