From 06d4d50e8274267bd1c5e4c8ea20e5869d40186e Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 7 Jun 2021 09:13:45 +0800 Subject: [PATCH 01/11] Add belongs to --- integrations/session_associate_test.go | 232 +++++++++++++++++++++++++ internal/statements/associate.go | 14 ++ session_associate.go | 145 ++++++++++++++++ tags/tag.go | 87 +++++++--- 4 files changed, 456 insertions(+), 22 deletions(-) create mode 100644 integrations/session_associate_test.go create mode 100644 internal/statements/associate.go create mode 100644 session_associate.go diff --git a/integrations/session_associate_test.go b/integrations/session_associate_test.go new file mode 100644 index 00000000..d34549d9 --- /dev/null +++ b/integrations/session_associate_test.go @@ -0,0 +1,232 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBelongsTo_Get(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type Face1 struct { + Id int64 + Name string + } + + type Nose1 struct { + Id int64 + Face Face1 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose1), new(Face1)) + assert.NoError(t, err) + + var face = Face1{ + Name: "face1", + } + _, err = testEngine.Insert(&face) + assert.NoError(t, err) + + var cfgFace Face1 + has, err := testEngine.Get(&cfgFace) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, face, cfgFace) + + var nose = Nose1{Face: face} + _, err = testEngine.Insert(&nose) + assert.NoError(t, err) + + var cfgNose Nose1 + has, err = testEngine.Get(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "", cfgNose.Face.Name) + + err = testEngine.Load(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "face1", cfgNose.Face.Name) + + var cfgNose2 Nose1 + has, err = testEngine.Cascade().Get(&cfgNose2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose2.Id) + assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id) + assert.Equal(t, "face1", cfgNose2.Face.Name) +} + +func TestBelongsTo_GetPtr(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type Face2 struct { + Id int64 + Name string + } + + type Nose2 struct { + Id int64 + Face *Face2 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose2), new(Face2)) + assert.NoError(t, err) + + var face = Face2{ + Name: "face1", + } + _, err = testEngine.Insert(&face) + assert.NoError(t, err) + + var cfgFace Face2 + has, err := testEngine.Get(&cfgFace) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, face, cfgFace) + + var nose = Nose2{Face: &face} + _, err = testEngine.Insert(&nose) + assert.NoError(t, err) + + var cfgNose Nose2 + has, err = testEngine.Get(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + + err = testEngine.Load(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "face1", cfgNose.Face.Name) + + var cfgNose2 Nose2 + has, err = testEngine.Cascade().Get(&cfgNose2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose2.Id) + assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id) + assert.Equal(t, "face1", cfgNose2.Face.Name) +} + +func TestBelongsTo_Find(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type Face3 struct { + Id int64 + Name string + } + + type Nose3 struct { + Id int64 + Face Face3 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose3), new(Face3)) + assert.NoError(t, err) + + var face1 = Face3{ + Name: "face1", + } + var face2 = Face3{ + Name: "face2", + } + _, err = testEngine.Insert(&face1, &face2) + assert.NoError(t, err) + + var noses = []Nose3{ + {Face: face1}, + {Face: face2}, + } + _, err = testEngine.Insert(&noses) + assert.NoError(t, err) + + var noses1 []Nose3 + err = testEngine.Find(&noses1) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses1)) + assert.Equal(t, face1.Id, noses1[0].Face.Id) + assert.Equal(t, face2.Id, noses1[1].Face.Id) + assert.Equal(t, "", noses1[0].Face.Name) + assert.Equal(t, "", noses1[1].Face.Name) + + var noses2 []Nose3 + err = testEngine.Cascade().Find(&noses2) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses2)) + assert.Equal(t, face1.Id, noses2[0].Face.Id) + assert.Equal(t, face2.Id, noses2[1].Face.Id) + assert.Equal(t, "face1", noses2[0].Face.Name) + assert.Equal(t, "face2", noses2[1].Face.Name) + + err = testEngine.Load(noses1, "face") + assert.NoError(t, err) + assert.Equal(t, "face1", noses1[0].Face.Name) + assert.Equal(t, "face2", noses1[1].Face.Name) +} + +func TestBelongsTo_FindPtr(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type Face4 struct { + Id int64 + Name string + } + + type Nose4 struct { + Id int64 + Face *Face4 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose4), new(Face4)) + assert.NoError(t, err) + + var face1 = Face4{ + Name: "face1", + } + var face2 = Face4{ + Name: "face2", + } + _, err = testEngine.Insert(&face1, &face2) + assert.NoError(t, err) + + var noses = []Nose4{ + {Face: &face1}, + {Face: &face2}, + } + _, err = testEngine.Insert(&noses) + assert.NoError(t, err) + + var noses1 []Nose4 + err = testEngine.Find(&noses1) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses1)) + assert.Equal(t, face1.Id, noses1[0].Face.Id) + assert.Equal(t, face2.Id, noses1[1].Face.Id) + assert.Equal(t, "", noses1[0].Face.Name) + assert.Equal(t, "", noses1[1].Face.Name) + + var noses2 []Nose4 + err = testEngine.Cascade().Find(&noses2) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses2)) + assert.NotNil(t, noses2[0].Face) + assert.NotNil(t, noses2[1].Face) + assert.Equal(t, face1.Id, noses2[0].Face.Id) + assert.Equal(t, face2.Id, noses2[1].Face.Id) + assert.Equal(t, "face1", noses2[0].Face.Name) + assert.Equal(t, "face2", noses2[1].Face.Name) + + err = testEngine.Load(noses2, "face") + assert.NoError(t, err) +} diff --git a/internal/statements/associate.go b/internal/statements/associate.go new file mode 100644 index 00000000..5659ddc9 --- /dev/null +++ b/internal/statements/associate.go @@ -0,0 +1,14 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +type cascadeMode int + +const ( + cascadeCompitable cascadeMode = iota // load field beans with another SQL with no + cascadeEager // load field beans with another SQL + cascadeJoin // load field beans with join + cascadeLazy // don't load anything +) diff --git a/session_associate.go b/session_associate.go new file mode 100644 index 00000000..62e924a3 --- /dev/null +++ b/session_associate.go @@ -0,0 +1,145 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "errors" + "reflect" + + "github.com/go-xorm/core" +) + +// Load loads associated fields from database +func (session *Session) Load(beanOrSlices interface{}, cols ...string) error { + v := reflect.ValueOf(beanOrSlices) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() == reflect.Slice { + return session.loadFind(beanOrSlices, cols...) + } else if v.Kind() == reflect.Struct { + return session.loadGet(beanOrSlices, cols...) + } + return errors.New("unsupported load type, must struct or slice") +} + +// loadFind load 's belongs to tag field immedicatlly +func (session *Session) loadFind(slices interface{}, cols ...string) error { + v := reflect.ValueOf(slices) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Slice { + return errors.New("only slice is supported") + } + + if v.Len() <= 0 { + return nil + } + + vv := v.Index(0) + if vv.Kind() == reflect.Ptr { + vv = vv.Elem() + } + tb, err := session.engine.autoMapType(vv) + if err != nil { + return err + } + + var pks = make(map[*core.Column][]interface{}) + for i := 0; i < v.Len(); i++ { + ev := v.Index(i) + + for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + if col.AssociateTable != nil { + if col.AssociateType == core.AssociateBelongsTo { + colV, err := col.ValueOfV(&ev) + if err != nil { + return err + } + + pk, err := session.engine.idOfV(*colV) + if err != nil { + return err + } + /*var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + }*/ + + if !isZero(pk[0]) { + pks[col] = append(pks[col], pk[0]) + } + } + } + } + } + + for col, pk := range pks { + slice := reflect.MakeSlice(col.FieldType, 0, len(pk)) + err = session.In(col.Name, pk...).find(slice.Addr().Interface()) + if err != nil { + return err + } + } + return nil +} + +// loadGet load bean's belongs to tag field immedicatlly +func (session *Session) loadGet(bean interface{}, cols ...string) error { + if session.isAutoClose { + defer session.Close() + } + + v := rValue(bean) + tb, err := session.engine.autoMapType(v) + if err != nil { + return err + } + + for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + if col.AssociateTable != nil { + if col.AssociateType == core.AssociateBelongsTo { + colV, err := col.ValueOfV(&v) + if err != nil { + return err + } + + pk, err := session.engine.idOfV(*colV) + if err != nil { + return err + } + var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + } + + if !isZero(pk[0]) && session.cascadeLevel > 0 { + has, err := session.ID(pk).NoAutoCondition().get(colPtr.Interface()) + if err != nil { + return err + } + if !has { + return errors.New("load bean does not exist") + } + session.cascadeLevel-- + } + } + } + } + return nil +} diff --git a/tags/tag.go b/tags/tag.go index 4e1f1ce7..cb5dde79 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -5,6 +5,7 @@ package tags import ( + "errors" "fmt" "reflect" "strconv" @@ -102,28 +103,29 @@ type Handler func(ctx *Context) error var ( // defaultTagHandlers enumerates all the default tag handler defaultTagHandlers = map[string]Handler{ - "-": IgnoreHandler, - "<-": OnlyFromDBTagHandler, - "->": OnlyToDBTagHandler, - "PK": PKTagHandler, - "NULL": NULLTagHandler, - "NOT": NotTagHandler, - "AUTOINCR": AutoIncrTagHandler, - "DEFAULT": DefaultTagHandler, - "CREATED": CreatedTagHandler, - "UPDATED": UpdatedTagHandler, - "DELETED": DeletedTagHandler, - "VERSION": VersionTagHandler, - "UTC": UTCTagHandler, - "LOCAL": LocalTagHandler, - "NOTNULL": NotNullTagHandler, - "INDEX": IndexTagHandler, - "UNIQUE": UniqueTagHandler, - "CACHE": CacheTagHandler, - "NOCACHE": NoCacheTagHandler, - "COMMENT": CommentTagHandler, - "EXTENDS": ExtendsTagHandler, - "UNSIGNED": UnsignedTagHandler, + "-": IgnoreHandler, + "<-": OnlyFromDBTagHandler, + "->": OnlyToDBTagHandler, + "PK": PKTagHandler, + "NULL": NULLTagHandler, + "NOT": NotTagHandler, + "AUTOINCR": AutoIncrTagHandler, + "DEFAULT": DefaultTagHandler, + "CREATED": CreatedTagHandler, + "UPDATED": UpdatedTagHandler, + "DELETED": DeletedTagHandler, + "VERSION": VersionTagHandler, + "UTC": UTCTagHandler, + "LOCAL": LocalTagHandler, + "NOTNULL": NotNullTagHandler, + "INDEX": IndexTagHandler, + "UNIQUE": UniqueTagHandler, + "CACHE": CacheTagHandler, + "NOCACHE": NoCacheTagHandler, + "COMMENT": CommentTagHandler, + "EXTENDS": ExtendsTagHandler, + "UNSIGNED": UnsignedTagHandler, + "BELONGS_TO": BelongsToTagHandler, } ) @@ -398,3 +400,44 @@ func NoCacheTagHandler(ctx *Context) error { } return nil } + +// BelongsToTagHandler describes belongs_to tag handler +func BelongsToTagHandler(ctx *Context) error { + if !isStruct(ctx.fieldValue.Type()) { + return errors.New("Tag belongs_to cannot be applied on non-struct field") + } + + ctx.col.AssociateType = core.AssociateBelongsTo + var t reflect.Value + if ctx.fieldValue.Kind() == reflect.Struct { + t = ctx.fieldValue + } else { + if ctx.fieldValue.Type().Kind() == reflect.Ptr && ctx.fieldValue.Type().Elem().Kind() == reflect.Struct { + if ctx.fieldValue.IsNil() { + t = reflect.New(ctx.fieldValue.Type().Elem()).Elem() + } else { + t = ctx.fieldValue + } + } else { + return errors.New("Only struct or ptr to struct field could add belongs_to flag") + } + } + + belongsT, err := ctx.engine.mapType(ctx.parsingTables, t) + if err != nil { + return err + } + pks := belongsT.PKColumns() + if len(pks) != 1 { + panic("unsupported non or composited primary key cascade") + return errors.New("blongs_to only should be as a tag of table has one primary key") + } + + ctx.col.AssociateTable = belongsT + ctx.col.SQLType = pks[0].SQLType + + if len(ctx.col.Name) == 0 { + ctx.col.Name = ctx.engine.ColumnMapper.Obj2Table(ctx.col.FieldName) + "_id" + } + return nil +} -- 2.40.1 From 68f18c80e2f031a99744275088b25339b7410cc6 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 22 Jul 2021 11:25:37 +0800 Subject: [PATCH 02/11] Compile pass --- integrations/session_tag_test.go | 63 ++++++++++++++++++++++++++++++++ schemas/associate.go | 12 ++++++ schemas/column.go | 9 ++++- session.go | 33 +++++++++++++++-- session_associate.go | 42 +++++++++++---------- tags/tag.go | 19 +++++++--- 6 files changed, 148 insertions(+), 30 deletions(-) create mode 100644 integrations/session_tag_test.go create mode 100644 schemas/associate.go diff --git a/integrations/session_tag_test.go b/integrations/session_tag_test.go new file mode 100644 index 00000000..0e273b3e --- /dev/null +++ b/integrations/session_tag_test.go @@ -0,0 +1,63 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtendsTag(t *testing.T) { + assert.NoError(t, prepareEngine()) + + table := testEngine.TableInfo(new(Userdetail)) + assert.NotNil(t, table) + assert.EqualValues(t, 3, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "intro", table.ColumnsSeq()[1]) + assert.EqualValues(t, "profile", table.ColumnsSeq()[2]) + + table = testEngine.TableInfo(new(Userinfo)) + assert.NotNil(t, table) + assert.EqualValues(t, 8, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "username", table.ColumnsSeq()[1]) + assert.EqualValues(t, "departname", table.ColumnsSeq()[2]) + assert.EqualValues(t, "created", table.ColumnsSeq()[3]) + assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4]) + assert.EqualValues(t, "height", table.ColumnsSeq()[5]) + assert.EqualValues(t, "avatar", table.ColumnsSeq()[6]) + assert.EqualValues(t, "is_man", table.ColumnsSeq()[7]) + + table = testEngine.TableInfo(new(UserAndDetail)) + assert.NotNil(t, table) + assert.EqualValues(t, 11, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "username", table.ColumnsSeq()[1]) + assert.EqualValues(t, "departname", table.ColumnsSeq()[2]) + assert.EqualValues(t, "created", table.ColumnsSeq()[3]) + assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4]) + assert.EqualValues(t, "height", table.ColumnsSeq()[5]) + assert.EqualValues(t, "avatar", table.ColumnsSeq()[6]) + assert.EqualValues(t, "is_man", table.ColumnsSeq()[7]) + assert.EqualValues(t, "id", table.ColumnsSeq()[8]) + assert.EqualValues(t, "intro", table.ColumnsSeq()[9]) + assert.EqualValues(t, "profile", table.ColumnsSeq()[10]) + + cols := table.Columns() + assert.EqualValues(t, 11, len(cols)) + assert.EqualValues(t, "Userinfo.Uid", cols[0].FieldName) + assert.EqualValues(t, "Userinfo.Username", cols[1].FieldName) + assert.EqualValues(t, "Userinfo.Departname", cols[2].FieldName) + assert.EqualValues(t, "Userinfo.Created", cols[3].FieldName) + assert.EqualValues(t, "Userinfo.Detail", cols[4].FieldName) + assert.EqualValues(t, "Userinfo.Height", cols[5].FieldName) + assert.EqualValues(t, "Userinfo.Avatar", cols[6].FieldName) + assert.EqualValues(t, "Userinfo.IsMan", cols[7].FieldName) + assert.EqualValues(t, "Userdetail.Id", cols[8].FieldName) + assert.EqualValues(t, "Userdetail.Intro", cols[9].FieldName) + assert.EqualValues(t, "Userdetail.Profile", cols[10].FieldName) +} diff --git a/schemas/associate.go b/schemas/associate.go new file mode 100644 index 00000000..3dfd85f8 --- /dev/null +++ b/schemas/associate.go @@ -0,0 +1,12 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +type AssociateType int + +const ( + AssociateNone AssociateType = iota + AssociateBelongsTo +) diff --git a/schemas/column.go b/schemas/column.go index 4bbb6c2d..7ded9677 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -22,8 +22,9 @@ const ( type Column struct { Name string TableName string - FieldName string // Available only when parsed from a struct - FieldIndex []int // Available only when parsed from a struct + FieldName string // Available only when parsed from a struct + FieldIndex []int // Available only when parsed from a struct + FieldType reflect.Type // Available only when parsed from a struct SQLType SQLType IsJSON bool Length int @@ -45,6 +46,8 @@ type Column struct { DisableTimeZone bool TimeZone *time.Location // column specified time zone Comment string + AssociateType + AssociateTable *Table } // NewColumn creates a new column @@ -71,6 +74,8 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable DefaultIsEmpty: true, // default should be no default EnumOptions: make(map[string]int), Comment: "", + AssociateType: AssociateNone, + AssociateTable: nil, } } diff --git a/session.go b/session.go index f5b45a73..6891864f 100644 --- a/session.go +++ b/session.go @@ -54,6 +54,22 @@ const ( groupSession sessionType = true ) +type cascadeMode int + +const ( + cascadeCompitable cascadeMode = iota // load field beans with another SQL with no + cascadeEager // load field beans with another SQL + cascadeJoin // load field beans with join + cascadeLazy // don't load anything +) + +type loadClosure struct { + Func func(schemas.PK, *reflect.Value) error + pk schemas.PK + fieldValue *reflect.Value + loaded bool +} + // Session keep a pointer to sql.DB and provides all execution of all // kind of database operations. type Session struct { @@ -86,6 +102,9 @@ type Session struct { ctx context.Context sessionType sessionType + + cascadeMode cascadeMode + cascadeLevel int // load level } func newSessionID() string { @@ -134,7 +153,9 @@ func newSession(engine *Engine) *Session { lastSQL: "", lastSQLArgs: make([]interface{}, 0), - sessionType: engineSession, + sessionType: engineSession, + cascadeMode: cascadeCompitable, + cascadeLevel: 2, } if engine.logSessionID { session.ctx = context.WithValue(session.ctx, log.SessionKey, session) @@ -241,7 +262,7 @@ func (session *Session) Alias(alias string) *Session { // NoCascade indicate that no cascade load child object func (session *Session) NoCascade() *Session { - session.statement.UseCascade = false + session.cascadeMode = cascadeLazy return session } @@ -296,9 +317,15 @@ func (session *Session) Charset(charset string) *Session { // Cascade indicates if loading sub Struct func (session *Session) Cascade(trueOrFalse ...bool) *Session { + var mode = cascadeEager if len(trueOrFalse) >= 1 { - session.statement.UseCascade = trueOrFalse[0] + if trueOrFalse[0] { + mode = cascadeEager + } else { + mode = cascadeLazy + } } + session.cascadeMode = mode return session } diff --git a/session_associate.go b/session_associate.go index 62e924a3..7e6041d2 100644 --- a/session_associate.go +++ b/session_associate.go @@ -8,7 +8,8 @@ import ( "errors" "reflect" - "github.com/go-xorm/core" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" ) // Load loads associated fields from database @@ -25,6 +26,15 @@ func (session *Session) Load(beanOrSlices interface{}, cols ...string) error { return errors.New("unsupported load type, must struct or slice") } +func isStringInSlice(s string, slice []string) bool { + for _, e := range slice { + if s == e { + return true + } + } + return false +} + // loadFind load 's belongs to tag field immedicatlly func (session *Session) loadFind(slices interface{}, cols ...string) error { v := reflect.ValueOf(slices) @@ -43,12 +53,12 @@ func (session *Session) loadFind(slices interface{}, cols ...string) error { if vv.Kind() == reflect.Ptr { vv = vv.Elem() } - tb, err := session.engine.autoMapType(vv) + tb, err := session.engine.tagParser.ParseWithCache(vv) if err != nil { return err } - var pks = make(map[*core.Column][]interface{}) + var pks = make(map[*schemas.Column][]interface{}) for i := 0; i < v.Len(); i++ { ev := v.Index(i) @@ -58,16 +68,13 @@ func (session *Session) loadFind(slices interface{}, cols ...string) error { } if col.AssociateTable != nil { - if col.AssociateType == core.AssociateBelongsTo { + if col.AssociateType == schemas.AssociateBelongsTo { colV, err := col.ValueOfV(&ev) if err != nil { return err } - pk, err := session.engine.idOfV(*colV) - if err != nil { - return err - } + vv := colV.Interface() /*var colPtr reflect.Value if colV.Kind() == reflect.Ptr { colPtr = *colV @@ -75,8 +82,8 @@ func (session *Session) loadFind(slices interface{}, cols ...string) error { colPtr = colV.Addr() }*/ - if !isZero(pk[0]) { - pks[col] = append(pks[col], pk[0]) + if !utils.IsZero(vv) { + pks[col] = append(pks[col], vv) } } } @@ -99,8 +106,8 @@ func (session *Session) loadGet(bean interface{}, cols ...string) error { defer session.Close() } - v := rValue(bean) - tb, err := session.engine.autoMapType(v) + v := reflect.Indirect(reflect.ValueOf(bean)) + tb, err := session.engine.tagParser.ParseWithCache(v) if err != nil { return err } @@ -111,16 +118,13 @@ func (session *Session) loadGet(bean interface{}, cols ...string) error { } if col.AssociateTable != nil { - if col.AssociateType == core.AssociateBelongsTo { + if col.AssociateType == schemas.AssociateBelongsTo { colV, err := col.ValueOfV(&v) if err != nil { return err } - pk, err := session.engine.idOfV(*colV) - if err != nil { - return err - } + vv := colV.Interface() var colPtr reflect.Value if colV.Kind() == reflect.Ptr { colPtr = *colV @@ -128,8 +132,8 @@ func (session *Session) loadGet(bean interface{}, cols ...string) error { colPtr = colV.Addr() } - if !isZero(pk[0]) && session.cascadeLevel > 0 { - has, err := session.ID(pk).NoAutoCondition().get(colPtr.Interface()) + if !utils.IsZero(vv) && session.cascadeLevel > 0 { + has, err := session.ID(vv).NoAutoCondition().get(colPtr.Interface()) if err != nil { return err } diff --git a/tags/tag.go b/tags/tag.go index cb5dde79..c354617b 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -401,13 +401,21 @@ func NoCacheTagHandler(ctx *Context) error { return nil } +func isStruct(t reflect.Type) bool { + return t.Kind() == reflect.Struct || isPtrStruct(t) +} + +func isPtrStruct(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + // BelongsToTagHandler describes belongs_to tag handler func BelongsToTagHandler(ctx *Context) error { if !isStruct(ctx.fieldValue.Type()) { - return errors.New("Tag belongs_to cannot be applied on non-struct field") + return errors.New("tag belongs_to cannot be applied on non-struct field") } - ctx.col.AssociateType = core.AssociateBelongsTo + ctx.col.AssociateType = schemas.AssociateBelongsTo var t reflect.Value if ctx.fieldValue.Kind() == reflect.Struct { t = ctx.fieldValue @@ -419,17 +427,16 @@ func BelongsToTagHandler(ctx *Context) error { t = ctx.fieldValue } } else { - return errors.New("Only struct or ptr to struct field could add belongs_to flag") + return errors.New("only struct or ptr to struct field could add belongs_to flag") } } - belongsT, err := ctx.engine.mapType(ctx.parsingTables, t) + belongsT, err := ctx.parser.ParseWithCache(t) if err != nil { return err } pks := belongsT.PKColumns() if len(pks) != 1 { - panic("unsupported non or composited primary key cascade") return errors.New("blongs_to only should be as a tag of table has one primary key") } @@ -437,7 +444,7 @@ func BelongsToTagHandler(ctx *Context) error { ctx.col.SQLType = pks[0].SQLType if len(ctx.col.Name) == 0 { - ctx.col.Name = ctx.engine.ColumnMapper.Obj2Table(ctx.col.FieldName) + "_id" + ctx.col.Name = ctx.parser.columnMapper.Obj2Table(ctx.col.FieldName) + "_id" } return nil } -- 2.40.1 From b5c26997dfe9e4e68485de506080e2444af0fc43 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 22 Jul 2021 11:29:16 +0800 Subject: [PATCH 03/11] Add fieldtype for column --- tags/parser.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tags/parser.go b/tags/parser.go index efee11e7..03aa5f63 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -167,6 +167,7 @@ func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructFi field.Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true) col.FieldIndex = []int{fieldIndex} + col.FieldType = fieldValue.Type() if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { col.IsAutoIncrement = true @@ -186,6 +187,7 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, f MapType: schemas.TWOSIDES, Indexes: make(map[string]int), DefaultIsEmpty: true, + FieldType: fieldValue.Type(), } var ctx = Context{ -- 2.40.1 From dca8a72b643f7ec7cab3dbb2ac8f392311c7a57a Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 22 Jul 2021 11:41:09 +0800 Subject: [PATCH 04/11] Fix test --- engine.go | 7 +++++++ integrations/session_associate_test.go | 2 +- integrations/session_tag_test.go | 11 +++++++---- interface.go | 2 ++ 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/engine.go b/engine.go index 7a57b08a..9325ffeb 100644 --- a/engine.go +++ b/engine.go @@ -703,6 +703,13 @@ func (engine *Engine) AllCols() *Session { return session.AllCols() } +// Load loads associated fields from database +func (engine *Engine) Load(beanOrSlices interface{}, cols ...string) error { + session := engine.NewSession() + session.isAutoClose = true + return session.Load(beanOrSlices, cols...) +} + // MustCols specify some columns must use even if they are empty func (engine *Engine) MustCols(columns ...string) *Session { session := engine.NewSession() diff --git a/integrations/session_associate_test.go b/integrations/session_associate_test.go index d34549d9..3386bef1 100644 --- a/integrations/session_associate_test.go +++ b/integrations/session_associate_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "testing" diff --git a/integrations/session_tag_test.go b/integrations/session_tag_test.go index 0e273b3e..170e0374 100644 --- a/integrations/session_tag_test.go +++ b/integrations/session_tag_test.go @@ -11,16 +11,18 @@ import ( ) func TestExtendsTag(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) - table := testEngine.TableInfo(new(Userdetail)) + table, err := testEngine.TableInfo(new(Userdetail)) + assert.NoError(t, err) assert.NotNil(t, table) assert.EqualValues(t, 3, len(table.ColumnsSeq())) assert.EqualValues(t, "id", table.ColumnsSeq()[0]) assert.EqualValues(t, "intro", table.ColumnsSeq()[1]) assert.EqualValues(t, "profile", table.ColumnsSeq()[2]) - table = testEngine.TableInfo(new(Userinfo)) + table, err = testEngine.TableInfo(new(Userinfo)) + assert.NoError(t, err) assert.NotNil(t, table) assert.EqualValues(t, 8, len(table.ColumnsSeq())) assert.EqualValues(t, "id", table.ColumnsSeq()[0]) @@ -32,7 +34,8 @@ func TestExtendsTag(t *testing.T) { assert.EqualValues(t, "avatar", table.ColumnsSeq()[6]) assert.EqualValues(t, "is_man", table.ColumnsSeq()[7]) - table = testEngine.TableInfo(new(UserAndDetail)) + table, err = testEngine.TableInfo(new(UserAndDetail)) + assert.NoError(t, err) assert.NotNil(t, table) assert.EqualValues(t, 11, len(table.ColumnsSeq())) assert.EqualValues(t, "id", table.ColumnsSeq()[0]) diff --git a/interface.go b/interface.go index 42dc9a0a..bbb8062f 100644 --- a/interface.go +++ b/interface.go @@ -24,6 +24,7 @@ type Interface interface { Alias(alias string) *Session Asc(colNames ...string) *Session BufferSize(size int) *Session + Cascade(trueOrFalse ...bool) *Session Cols(columns ...string) *Session Count(...interface{}) (int64, error) CreateIndexes(bean interface{}) error @@ -48,6 +49,7 @@ type Interface interface { IsTableExist(beanOrTableName interface{}) (bool, error) Iterate(interface{}, IterFunc) error Limit(int, ...int) *Session + Load(beanOrSlices interface{}, cols ...string) error MustCols(columns ...string) *Session NoAutoCondition(...bool) *Session NotIn(string, ...interface{}) *Session -- 2.40.1 From b88257902642a881683bf2ebc555d45fe5b71fd5 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 22 Jul 2021 13:22:35 +0800 Subject: [PATCH 05/11] Fix test --- session.go | 58 ++++++++----------- session_associate.go | 130 ++++++++++++++++++++++++++++++++----------- session_get.go | 36 ++++++++++++ 3 files changed, 157 insertions(+), 67 deletions(-) diff --git a/session.go b/session.go index 6891864f..a2781d75 100644 --- a/session.go +++ b/session.go @@ -10,7 +10,6 @@ import ( "crypto/sha256" "database/sql" "encoding/hex" - "errors" "fmt" "hash/crc32" "io" @@ -631,9 +630,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } fieldValue.Set(reflect.ValueOf(v).Elem().Convert(fieldType)) return nil - } - - if fieldType.ConvertibleTo(schemas.TimeType) { + } else if fieldType.ConvertibleTo(schemas.TimeType) { dbTZ := session.engine.DatabaseTZ if col.TimeZone != nil { dbTZ = col.TimeZone @@ -647,42 +644,35 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType)) return nil } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err := nulVal.Scan(scanResult) - if err == nil { - return nil - } - session.engine.logger.Errorf("sql.Sanner error: %v", err) - } else if session.statement.UseCascade { - table, err := session.engine.tagParser.ParseWithCache(*fieldValue) - if err != nil { - return err - } - - if len(table.PrimaryKeys) != 1 { - return errors.New("unsupported non or composited primary key cascade") - } - var pk = make(schemas.PK, len(table.PrimaryKeys)) + return nulVal.Scan(scanResult) + } else if session.cascadeLevel > 0 && ((col.AssociateType == schemas.AssociateNone && + session.cascadeMode == cascadeCompitable) || + (col.AssociateType == schemas.AssociateBelongsTo && + session.cascadeMode == cascadeEager)) { + var pk = make(schemas.PK, len(col.AssociateTable.PrimaryKeys)) + var err error pk[0], err = asKind(vv, reflect.TypeOf(scanResult)) if err != nil { return err } - if !pk.IsZero() { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return err - } - if has { - fieldValue.Set(structInter.Elem()) - } else { - return errors.New("cascade obj is not exist") - } - } + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(session *Session, bean interface{}) error { + fieldValue := bean.(*reflect.Value) + return session.getStructByPK(pk, fieldValue) + }, + session: session, + bean: fieldValue, + }) + session.cascadeLevel-- return nil + } else if col.AssociateType == schemas.AssociateBelongsTo { + pkCols := col.AssociateTable.PKColumns() + colV, err := pkCols[0].ValueOfV(fieldValue) + if err != nil { + return err + } + return convertAssignV(*colV, scanResult) } } // switch fieldType.Kind() diff --git a/session_associate.go b/session_associate.go index 7e6041d2..2951b6fc 100644 --- a/session_associate.go +++ b/session_associate.go @@ -19,11 +19,13 @@ func (session *Session) Load(beanOrSlices interface{}, cols ...string) error { v = v.Elem() } if v.Kind() == reflect.Slice { - return session.loadFind(beanOrSlices, cols...) + return session.loadFindSlice(v, cols...) + } else if v.Kind() == reflect.Map { + return session.loadFindMap(v, cols...) } else if v.Kind() == reflect.Struct { - return session.loadGet(beanOrSlices, cols...) + return session.loadGet(v, cols...) } - return errors.New("unsupported load type, must struct or slice") + return errors.New("unsupported load type, must struct, slice or map") } func isStringInSlice(s string, slice []string) bool { @@ -36,11 +38,7 @@ func isStringInSlice(s string, slice []string) bool { } // loadFind load 's belongs to tag field immedicatlly -func (session *Session) loadFind(slices interface{}, cols ...string) error { - v := reflect.ValueOf(slices) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } +func (session *Session) loadFindSlice(v reflect.Value, cols ...string) error { if v.Kind() != reflect.Slice { return errors.New("only slice is supported") } @@ -100,13 +98,73 @@ func (session *Session) loadFind(slices interface{}, cols ...string) error { return nil } +// loadFindMap load 's belongs to tag field immedicatlly +func (session *Session) loadFindMap(v reflect.Value, cols ...string) error { + if v.Kind() != reflect.Map { + return errors.New("only map is supported") + } + + if v.Len() <= 0 { + return nil + } + + vv := v.Index(0) + if vv.Kind() == reflect.Ptr { + vv = vv.Elem() + } + tb, err := session.engine.tagParser.ParseWithCache(vv) + if err != nil { + return err + } + + var pks = make(map[*schemas.Column][]interface{}) + for i := 0; i < v.Len(); i++ { + ev := v.Index(i) + + for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + if col.AssociateTable != nil { + if col.AssociateType == schemas.AssociateBelongsTo { + colV, err := col.ValueOfV(&ev) + if err != nil { + return err + } + + vv := colV.Interface() + /*var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + }*/ + + if !utils.IsZero(vv) { + pks[col] = append(pks[col], vv) + } + } + } + } + } + + for col, pk := range pks { + slice := reflect.MakeSlice(col.FieldType, 0, len(pk)) + err = session.In(col.Name, pk...).find(slice.Addr().Interface()) + if err != nil { + return err + } + } + return nil +} + // loadGet load bean's belongs to tag field immedicatlly -func (session *Session) loadGet(bean interface{}, cols ...string) error { +func (session *Session) loadGet(v reflect.Value, cols ...string) error { if session.isAutoClose { defer session.Close() } - v := reflect.Indirect(reflect.ValueOf(bean)) tb, err := session.engine.tagParser.ParseWithCache(v) if err != nil { return err @@ -117,32 +175,38 @@ func (session *Session) loadGet(bean interface{}, cols ...string) error { continue } - if col.AssociateTable != nil { - if col.AssociateType == schemas.AssociateBelongsTo { - colV, err := col.ValueOfV(&v) - if err != nil { - return err - } + if col.AssociateTable == nil || col.AssociateType != schemas.AssociateBelongsTo { + continue + } - vv := colV.Interface() - var colPtr reflect.Value - if colV.Kind() == reflect.Ptr { - colPtr = *colV - } else { - colPtr = colV.Addr() - } + colV, err := col.ValueOfV(&v) + if err != nil { + return err + } - if !utils.IsZero(vv) && session.cascadeLevel > 0 { - has, err := session.ID(vv).NoAutoCondition().get(colPtr.Interface()) - if err != nil { - return err - } - if !has { - return errors.New("load bean does not exist") - } - session.cascadeLevel-- - } + var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + } + + pks := col.AssociateTable.PKColumns() + pkV, err := pks[0].ValueOfV(colV) + if err != nil { + return err + } + vv := pkV.Interface() + + if !utils.IsZero(vv) && session.cascadeLevel > 0 { + has, err := session.ID(vv).NoAutoCondition().get(colPtr.Interface()) + if err != nil { + return err } + if !has { + return errors.New("load bean does not exist") + } + session.cascadeLevel-- } } return nil diff --git a/session_get.go b/session_get.go index 48616a6b..9bf916ce 100644 --- a/session_get.go +++ b/session_get.go @@ -280,6 +280,42 @@ func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields } } +func (session *Session) getStructByPK(pk schemas.PK, fieldValue *reflect.Value) error { + if pk.IsZero() { + return errors.New("getStructByPK: primary key is zero") + } + + var structInter reflect.Value + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + structInter = reflect.New(fieldValue.Type().Elem()) + } else { + structInter = *fieldValue + } + } else { + structInter = fieldValue.Addr() + } + + has, err := session.ID(pk).NoAutoCondition().get(structInter.Interface()) + if err != nil { + return err + } + if !has { + return errors.New("cascade obj is not exist") + } + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(structInter) + fmt.Println("getByPK value ptr:", fieldValue.Interface()) + return nil + } else if fieldValue.Kind() == reflect.Struct { + fieldValue.Set(structInter.Elem()) + fmt.Println("getByPK value:", fieldValue.Interface()) + return nil + } + return errors.New("set value failed") + +} + func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { // if has no reftable, then don't use cache currently if !session.canCache() { -- 2.40.1 From f2a1e6ea2ffafe217fd6aa5d9924be907458f550 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 22 Jul 2021 15:17:38 +0800 Subject: [PATCH 06/11] Fix test --- integrations/session_associate_test.go | 2 +- session_associate.go | 71 +++++++++++++++++++------- 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/integrations/session_associate_test.go b/integrations/session_associate_test.go index 3386bef1..9406d996 100644 --- a/integrations/session_associate_test.go +++ b/integrations/session_associate_test.go @@ -169,7 +169,7 @@ func TestBelongsTo_Find(t *testing.T) { assert.Equal(t, "face1", noses2[0].Face.Name) assert.Equal(t, "face2", noses2[1].Face.Name) - err = testEngine.Load(noses1, "face") + err = testEngine.Load(noses1, "face_id") assert.NoError(t, err) assert.Equal(t, "face1", noses1[0].Face.Name) assert.Equal(t, "face2", noses1[1].Face.Name) diff --git a/session_associate.go b/session_associate.go index 2951b6fc..3877d89a 100644 --- a/session_associate.go +++ b/session_associate.go @@ -6,6 +6,7 @@ package xorm import ( "errors" + "fmt" "reflect" "xorm.io/xorm/internal/utils" @@ -56,44 +57,76 @@ func (session *Session) loadFindSlice(v reflect.Value, cols ...string) error { return err } - var pks = make(map[*schemas.Column][]interface{}) + type Va struct { + v reflect.Value + pk []interface{} + col *schemas.Column + } + + var pks = make(map[*schemas.Column]*Va) for i := 0; i < v.Len(); i++ { ev := v.Index(i) + fmt.Println("1====", ev.Interface(), tb.Name, len(tb.Columns())) + for _, col := range tb.Columns() { + fmt.Println("====", cols, col.Name) if len(cols) > 0 && !isStringInSlice(col.Name, cols) { continue } - if col.AssociateTable != nil { - if col.AssociateType == schemas.AssociateBelongsTo { - colV, err := col.ValueOfV(&ev) - if err != nil { - return err - } + fmt.Println("3------", col.Name, col.AssociateTable) - vv := colV.Interface() - /*var colPtr reflect.Value - if colV.Kind() == reflect.Ptr { - colPtr = *colV - } else { - colPtr = colV.Addr() - }*/ + if col.AssociateTable == nil || col.AssociateType != schemas.AssociateBelongsTo { + continue + } - if !utils.IsZero(vv) { - pks[col] = append(pks[col], vv) + colV, err := col.ValueOfV(&ev) + if err != nil { + return err + } + + pkCols := col.AssociateTable.PKColumns() + pkV, err := pkCols[0].ValueOfV(colV) + if err != nil { + return err + } + vv := pkV.Interface() + + fmt.Println("2====", vv) + + if !utils.IsZero(vv) { + va, ok := pks[col] + if !ok { + va = &Va{ + v: ev, + col: pkCols[0], } + pks[col] = va } + va.pk = append(va.pk, vv) } } } - for col, pk := range pks { - slice := reflect.MakeSlice(col.FieldType, 0, len(pk)) - err = session.In(col.Name, pk...).find(slice.Addr().Interface()) + for col, va := range pks { + slice := reflect.MakeSlice(reflect.SliceOf(col.FieldType), 0, len(va.pk)) + err = session.In(va.col.Name, va.pk...).find(slice.Interface()) if err != nil { return err } + + /*vv, err := col.ValueOfV(&va.v) + if err != nil { + return err + } + vv.Set() + + for i := 0; i < slice.Len(); i++ { + + + va.col.ValueOfV(slice.Index(i)) + }*/ } return nil } -- 2.40.1 From 934df04ca440bdf8e4e89c637b1a1e06e9a5f710 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 29 Jul 2021 11:15:55 +0800 Subject: [PATCH 07/11] Rebase --- session.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/session.go b/session.go index a2781d75..9df2764b 100644 --- a/session.go +++ b/session.go @@ -672,7 +672,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec if err != nil { return err } - return convertAssignV(*colV, scanResult) + return convert.AssignValue(*colV, scanResult) } } // switch fieldType.Kind() -- 2.40.1 From 9fe6702252227edfbc75954e64da96cd08a9df46 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 29 Jul 2021 12:47:42 +0800 Subject: [PATCH 08/11] Remove duplicated code --- internal/statements/associate.go | 14 -------------- 1 file changed, 14 deletions(-) delete mode 100644 internal/statements/associate.go diff --git a/internal/statements/associate.go b/internal/statements/associate.go deleted file mode 100644 index 5659ddc9..00000000 --- a/internal/statements/associate.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package statements - -type cascadeMode int - -const ( - cascadeCompitable cascadeMode = iota // load field beans with another SQL with no - cascadeEager // load field beans with another SQL - cascadeJoin // load field beans with join - cascadeLazy // don't load anything -) -- 2.40.1 From 6dc2f9eec79786ec144fe156965c5e6eab35892e Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 10 Aug 2021 20:31:37 +0800 Subject: [PATCH 09/11] Fix bug --- session_associate.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/session_associate.go b/session_associate.go index 3877d89a..f699a7b1 100644 --- a/session_associate.go +++ b/session_associate.go @@ -110,8 +110,14 @@ func (session *Session) loadFindSlice(v reflect.Value, cols ...string) error { } for col, va := range pks { - slice := reflect.MakeSlice(reflect.SliceOf(col.FieldType), 0, len(va.pk)) - err = session.In(va.col.Name, va.pk...).find(slice.Interface()) + //slice := reflect.New(reflect.SliceOf(col.FieldType)) + pkCols := col.AssociateTable.PKColumns() + if len(pkCols) != 1 { + return fmt.Errorf("unsupported primary key number") + } + mp := reflect.MakeMap(reflect.MapOf(pkCols[0].FieldType, col.FieldType)) + //slice := reflect.MakeSlice(, 0, len(va.pk)) + err = session.In(va.col.Name, va.pk...).find(mp.Addr().Interface()) if err != nil { return err } -- 2.40.1 From 59521246d3db71af4c57200ff88a21be7bea6bc9 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 11 Aug 2021 14:09:26 +0800 Subject: [PATCH 10/11] Fix load slice --- session_associate.go | 85 ++++++++++++++++++++------------------------ 1 file changed, 38 insertions(+), 47 deletions(-) diff --git a/session_associate.go b/session_associate.go index f699a7b1..9279053b 100644 --- a/session_associate.go +++ b/session_associate.go @@ -48,40 +48,45 @@ func (session *Session) loadFindSlice(v reflect.Value, cols ...string) error { return nil } - vv := v.Index(0) - if vv.Kind() == reflect.Ptr { - vv = vv.Elem() + tableValue := v.Index(0) + if tableValue.Kind() == reflect.Ptr { + tableValue = tableValue.Elem() } - tb, err := session.engine.tagParser.ParseWithCache(vv) + tb, err := session.engine.tagParser.ParseWithCache(tableValue) if err != nil { return err } type Va struct { - v reflect.Value + v []reflect.Value pk []interface{} col *schemas.Column } var pks = make(map[*schemas.Column]*Va) + for _, col := range tb.Columns() { + if col.AssociateTable == nil || col.AssociateType != schemas.AssociateBelongsTo { + continue + } + + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + pkCols := col.AssociateTable.PKColumns() + if len(pkCols) != 1 { + return fmt.Errorf("unsupported primary key number") + } + + pks[col] = &Va{ + col: pkCols[0], + } + } + for i := 0; i < v.Len(); i++ { - ev := v.Index(i) - - fmt.Println("1====", ev.Interface(), tb.Name, len(tb.Columns())) - - for _, col := range tb.Columns() { - fmt.Println("====", cols, col.Name) - if len(cols) > 0 && !isStringInSlice(col.Name, cols) { - continue - } - - fmt.Println("3------", col.Name, col.AssociateTable) - - if col.AssociateTable == nil || col.AssociateType != schemas.AssociateBelongsTo { - continue - } - - colV, err := col.ValueOfV(&ev) + value := v.Index(i) + for col, va := range pks { + colV, err := col.ValueOfV(&value) if err != nil { return err } @@ -92,47 +97,33 @@ func (session *Session) loadFindSlice(v reflect.Value, cols ...string) error { return err } vv := pkV.Interface() - - fmt.Println("2====", vv) - - if !utils.IsZero(vv) { - va, ok := pks[col] - if !ok { - va = &Va{ - v: ev, - col: pkCols[0], - } - pks[col] = va - } + if !utils.IsZero(vv) { // TODO: duplicate primary key + va.v = append(va.v, *colV) va.pk = append(va.pk, vv) } } } for col, va := range pks { - //slice := reflect.New(reflect.SliceOf(col.FieldType)) pkCols := col.AssociateTable.PKColumns() - if len(pkCols) != 1 { - return fmt.Errorf("unsupported primary key number") - } mp := reflect.MakeMap(reflect.MapOf(pkCols[0].FieldType, col.FieldType)) - //slice := reflect.MakeSlice(, 0, len(va.pk)) - err = session.In(va.col.Name, va.pk...).find(mp.Addr().Interface()) + x := reflect.New(mp.Type()) + x.Elem().Set(mp) + + err = session.In(va.col.Name, va.pk...).find(x.Interface()) if err != nil { return err } - /*vv, err := col.ValueOfV(&va.v) + for _, v := range va.v { + pkCols := col.AssociateTable.PKColumns() + pkV, err := pkCols[0].ValueOfV(&v) if err != nil { return err } - vv.Set() - for i := 0; i < slice.Len(); i++ { - - - va.col.ValueOfV(slice.Index(i)) - }*/ + v.Set(mp.MapIndex(*pkV)) + } } return nil } -- 2.40.1 From 3b77cdd8083a9d832349f9dd8394f59e0a7c58ae Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 11 Aug 2021 16:20:19 +0800 Subject: [PATCH 11/11] Fix --- session.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/session.go b/session.go index 9df2764b..14989781 100644 --- a/session.go +++ b/session.go @@ -649,7 +649,20 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec session.cascadeMode == cascadeCompitable) || (col.AssociateType == schemas.AssociateBelongsTo && session.cascadeMode == cascadeEager)) { - var pk = make(schemas.PK, len(col.AssociateTable.PrimaryKeys)) + var associateTable *schemas.Table + if col.AssociateType == schemas.AssociateNone && col.AssociateTable == nil { + var err error + associateTable, err = session.engine.tagParser.ParseWithCache(*fieldValue) + if err != nil { + return err + } + } else { + associateTable = col.AssociateTable + } + + fmt.Println("=====", associateTable) + + var pk = make(schemas.PK, len(associateTable.PrimaryKeys)) var err error pk[0], err = asKind(vv, reflect.TypeOf(scanResult)) if err != nil { @@ -659,14 +672,16 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec session.afterProcessors = append(session.afterProcessors, executedProcessor{ fun: func(session *Session, bean interface{}) error { fieldValue := bean.(*reflect.Value) + fmt.Println("3333333") return session.getStructByPK(pk, fieldValue) }, session: session, bean: fieldValue, }) session.cascadeLevel-- + fmt.Println("222222") return nil - } else if col.AssociateType == schemas.AssociateBelongsTo { + } else if col.AssociateType == schemas.AssociateBelongsTo && session.cascadeMode == cascadeLazy { pkCols := col.AssociateTable.PKColumns() colV, err := pkCols[0].ValueOfV(fieldValue) if err != nil { -- 2.40.1