From 9e396cdd9cac47f6686ff4652973684f79bac974 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 16 Sep 2021 22:07:53 +0800 Subject: [PATCH 1/2] Fix bug of Rows --- integrations/session_iterate_test.go | 17 ++++++++++++++--- session.go | 10 +++++----- session_get.go | 2 -- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/integrations/session_iterate_test.go b/integrations/session_iterate_test.go index fa394d17..c5ecc593 100644 --- a/integrations/session_iterate_test.go +++ b/integrations/session_iterate_test.go @@ -26,16 +26,27 @@ func TestIterate(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + cnt, err = testEngine.Insert(&UserIterate{ + IsMan: false, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + cnt = 0 err = testEngine.Iterate(new(UserIterate), func(i int, bean interface{}) error { user := bean.(*UserIterate) - assert.EqualValues(t, 1, user.Id) - assert.EqualValues(t, true, user.IsMan) + if cnt == 0 { + assert.EqualValues(t, 1, user.Id) + assert.EqualValues(t, true, user.IsMan) + } else { + assert.EqualValues(t, 2, user.Id) + assert.EqualValues(t, false, user.IsMan) + } cnt++ return nil }) assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) + assert.EqualValues(t, 2, cnt) } func TestBufferIterate(t *testing.T) { diff --git a/session.go b/session.go index 0e2c48ca..f5b45a73 100644 --- a/session.go +++ b/session.go @@ -370,7 +370,7 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, return } -func (session *Session) getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) { +func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) { var col = table.GetColumnIdx(colName, idx) if col == nil { return nil, nil, ErrFieldIsNotExist{colName, table.Name} @@ -440,7 +440,7 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql return scanResults, nil } -func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Type, scanResult interface{}) error { +func setJSON(fieldValue *reflect.Value, fieldType reflect.Type, scanResult interface{}) error { bs, ok := convert.AsBytes(scanResult) if !ok { return fmt.Errorf("unsupported database data type: %#v", scanResult) @@ -551,7 +551,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec fieldType := fieldValue.Type() if col.IsJSON { - return session.setJSON(fieldValue, fieldType, scanResult) + return setJSON(fieldValue, fieldType, scanResult) } switch fieldType.Kind() { @@ -570,7 +570,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } return nil case reflect.Complex64, reflect.Complex128: - return session.setJSON(fieldValue, fieldType, scanResult) + return setJSON(fieldValue, fieldType, scanResult) case reflect.Slice, reflect.Array: bs, ok := convert.AsBytes(scanResult) if ok && fieldType.Elem().Kind() == reflect.Uint8 { @@ -683,7 +683,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } tempMap[lKey] = idx - col, fieldValue, err := session.getField(dataStruct, table, colName, idx) + col, fieldValue, err := getField(dataStruct, table, colName, idx) if _, ok := err.(ErrFieldIsNotExist); ok { continue } else if err != nil { diff --git a/session_get.go b/session_get.go index a82cae92..3307d359 100644 --- a/session_get.go +++ b/session_get.go @@ -273,8 +273,6 @@ func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fiel if err != nil { return err } - // close it before convert data - rows.Close() dataStruct := utils.ReflectValue(bean) _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) -- 2.40.1 From 4c1dd2884ecf5c155834b0fa5a5321a66c5fc234 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 16 Sep 2021 23:25:41 +0800 Subject: [PATCH 2/2] Fix get --- session_get.go | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/session_get.go b/session_get.go index 3307d359..48616a6b 100644 --- a/session_get.go +++ b/session_get.go @@ -173,7 +173,12 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, return true, err } - return true, session.scan(rows, table, beanKind, beans, types, fields) + if err := session.scan(rows, table, beanKind, beans, types, fields); err != nil { + return true, err + } + rows.Close() + + return true, session.executeProcessors() } func (session *Session) scan(rows *core.Rows, table *schemas.Table, firstBeanKind reflect.Kind, beans []interface{}, types []*sql.ColumnType, fields []string) error { @@ -184,7 +189,14 @@ func (session *Session) scan(rows *core.Rows, table *schemas.Table, firstBeanKin if !isScannableStruct(bean, len(types)) { break } - return session.getStruct(rows, types, fields, table, bean) + scanResults, err := session.row2Slice(rows, fields, types, bean) + if err != nil { + return err + } + + dataStruct := utils.ReflectValue(bean) + _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) + return err case reflect.Slice: return session.getSlice(rows, types, fields, bean) case reflect.Map: @@ -268,21 +280,6 @@ func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields } } -func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) error { - scanResults, err := session.row2Slice(rows, fields, types, bean) - if err != nil { - return err - } - - dataStruct := utils.ReflectValue(bean) - _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) - if err != nil { - return err - } - - return session.executeProcessors() -} - func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { // if has no reftable, then don't use cache currently if !session.canCache() { -- 2.40.1