From 407d542e3e7ed08c8c2eb7daaa318fe5fb921ec5 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 23 Mar 2020 19:58:25 +0800 Subject: [PATCH 1/3] Fix find and count bug --- engine.go | 70 ------------------------------- integrations/session_find_test.go | 42 ++++++++++++++++++- schemas/column.go | 16 +++++++ schemas/table.go | 49 ++++++++++++++++++++++ session_find.go | 7 ++-- 5 files changed, 108 insertions(+), 76 deletions(-) diff --git a/engine.go b/engine.go index 7399f41a..b5cb6558 100644 --- a/engine.go +++ b/engine.go @@ -816,81 +816,11 @@ func (engine *Engine) IsTableExist(beanOrTableName interface{}) (bool, error) { return session.IsTableExist(beanOrTableName) } -// IDOf get id from one struct -func (engine *Engine) IDOf(bean interface{}) (schemas.PK, error) { - return engine.IDOfV(reflect.ValueOf(bean)) -} - // TableName returns table name with schema prefix if has func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { return dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, includeSchema...) } -// IDOfV get id from one value of struct -func (engine *Engine) IDOfV(rv reflect.Value) (schemas.PK, error) { - return engine.idOfV(rv) -} - -func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) { - v := reflect.Indirect(rv) - table, err := engine.tagParser.ParseWithCache(v) - if err != nil { - return nil, err - } - - pk := make([]interface{}, len(table.PrimaryKeys)) - for i, col := range table.PKColumns() { - var err error - - fieldName := col.FieldName - for { - parts := strings.SplitN(fieldName, ".", 2) - if len(parts) == 1 { - break - } - - v = v.FieldByName(parts[0]) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - if v.Kind() != reflect.Struct { - return nil, ErrUnSupportedType - } - fieldName = parts[1] - } - - pkField := v.FieldByName(fieldName) - switch pkField.Kind() { - case reflect.String: - pk[i], err = engine.idTypeAssertion(col, pkField.String()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - pk[i], err = engine.idTypeAssertion(col, strconv.FormatInt(pkField.Int(), 10)) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - // id of uint will be converted to int64 - pk[i], err = engine.idTypeAssertion(col, strconv.FormatUint(pkField.Uint(), 10)) - } - - if err != nil { - return nil, err - } - } - return schemas.PK(pk), nil -} - -func (engine *Engine) idTypeAssertion(col *schemas.Column, sid string) (interface{}, error) { - if col.SQLType.IsNumeric() { - n, err := strconv.ParseInt(sid, 10, 64) - if err != nil { - return nil, err - } - return n, nil - } else if col.SQLType.IsText() { - return sid, nil - } else { - return nil, errors.New("not supported") - } -} - // CreateIndexes create indexes func (engine *Engine) CreateIndexes(bean interface{}) error { session := engine.NewSession() diff --git a/integrations/session_find_test.go b/integrations/session_find_test.go index b9d722ba..95cf9384 100644 --- a/integrations/session_find_test.go +++ b/integrations/session_find_test.go @@ -502,10 +502,48 @@ func TestFindAndCountOneFunc(t *testing.T) { assert.EqualValues(t, 1, cnt) results = make([]FindAndCountStruct, 0, 1) - cnt, err = testEngine.Where("msg = ?", true).Limit(1).FindAndCount(&results) + cnt, err = testEngine.Where("1=1").Limit(1).FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) - assert.EqualValues(t, 1, cnt) + assert.EqualValues(t, 2, cnt) + assert.EqualValues(t, FindAndCountStruct{ + Id: 1, + Content: "111", + Msg: false, + }, results[0]) + + results = make([]FindAndCountStruct, 0, 1) + cnt, err = testEngine.Where("1=1").Limit(1).FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, 2, cnt) + assert.EqualValues(t, FindAndCountStruct{ + Id: 1, + Content: "111", + Msg: false, + }, results[0]) + + results = make([]FindAndCountStruct, 0, 1) + cnt, err = testEngine.Where("1=1").Limit(1, 1).FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, 2, cnt) + assert.EqualValues(t, FindAndCountStruct{ + Id: 2, + Content: "222", + Msg: true, + }, results[0]) + + results = make([]FindAndCountStruct, 0, 1) + cnt, err = testEngine.Where("1=1").Limit(1, 1).FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, 2, cnt) + assert.EqualValues(t, FindAndCountStruct{ + Id: 2, + Content: "222", + Msg: true, + }, results[0]) results = make([]FindAndCountStruct, 0, 1) cnt, err = testEngine.Where("msg = ?", true).Select("id, content, msg"). diff --git a/schemas/column.go b/schemas/column.go index 418629ac..db66a3a6 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -5,8 +5,10 @@ package schemas import ( + "errors" "fmt" "reflect" + "strconv" "strings" "time" ) @@ -115,3 +117,17 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { return &fieldValue, nil } + +// ConvertID converts id content to suitable type according column type +func (col *Column) ConvertID(sid string) (interface{}, error) { + if col.SQLType.IsNumeric() { + n, err := strconv.ParseInt(sid, 10, 64) + if err != nil { + return nil, err + } + return n, nil + } else if col.SQLType.IsText() { + return sid, nil + } + return nil, errors.New("not supported") +} diff --git a/schemas/table.go b/schemas/table.go index 38596991..6c57a7e3 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -5,7 +5,9 @@ package schemas import ( + "fmt" "reflect" + "strconv" "strings" ) @@ -28,6 +30,7 @@ type Table struct { Comment string } +// NewEmptyTable creates an empty table func NewEmptyTable() *Table { return NewTable("", nil) } @@ -44,10 +47,12 @@ func NewTable(name string, t reflect.Type) *Table { } } +// Columns returns table's columns func (table *Table) Columns() []*Column { return table.columns } +// ColumnsSeq returns table's column names according sequence func (table *Table) ColumnsSeq() []string { return table.columnsSeq } @@ -61,6 +66,7 @@ func (table *Table) columnsByName(name string) []*Column { return nil } +// GetColumn returns column according column name, if column not found, return nil func (table *Table) GetColumn(name string) *Column { cols := table.columnsByName(name) if cols != nil { @@ -70,6 +76,7 @@ func (table *Table) GetColumn(name string) *Column { return nil } +// GetColumnIdx returns column according name and idx func (table *Table) GetColumnIdx(name string, idx int) *Column { cols := table.columnsByName(name) if cols != nil && idx < len(cols) { @@ -144,3 +151,45 @@ func (table *Table) AddColumn(col *Column) { func (table *Table) AddIndex(index *Index) { table.Indexes[index.Name] = index } + +// IDOfV get id from one value of struct +func (table *Table) IDOfV(rv reflect.Value) (PK, error) { + v := reflect.Indirect(rv) + pk := make([]interface{}, len(table.PrimaryKeys)) + for i, col := range table.PKColumns() { + var err error + + fieldName := col.FieldName + for { + parts := strings.SplitN(fieldName, ".", 2) + if len(parts) == 1 { + break + } + + v = v.FieldByName(parts[0]) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Struct { + return nil, fmt.Errorf("Unsupported read value of column %s from field %s", col.Name, col.FieldName) + } + fieldName = parts[1] + } + + pkField := v.FieldByName(fieldName) + switch pkField.Kind() { + case reflect.String: + pk[i], err = col.ConvertID(pkField.String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + pk[i], err = col.ConvertID(strconv.FormatInt(pkField.Int(), 10)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // id of uint will be converted to int64 + pk[i], err = col.ConvertID(strconv.FormatUint(pkField.Uint(), 10)) + } + + if err != nil { + return nil, err + } + } + return PK(pk), nil +} diff --git a/session_find.go b/session_find.go index 3bc6a642..c1dbe4ca 100644 --- a/session_find.go +++ b/session_find.go @@ -320,7 +320,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in } var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys)) for i, col := range table.PKColumns() { - pk[i], err = session.engine.idTypeAssertion(col, res[i]) + pk[i], err = col.ConvertID(res[i]) if err != nil { return err } @@ -370,7 +370,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in } else { session.engine.logger.Debugf("[cache] cache hit bean: %v, %v, %v", tableName, id, bean) - pk, err := session.engine.IDOf(bean) + pk, err := table.IDOfV(reflect.ValueOf(bean)) if err != nil { return err } @@ -419,7 +419,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in if err != nil { return err } - session.statement = statement vs := reflect.Indirect(reflect.ValueOf(beans)) @@ -428,7 +427,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in if rv.Kind() != reflect.Ptr { rv = rv.Addr() } - id, err := session.engine.idOfV(rv) + id, err := table.IDOfV(rv) if err != nil { return err } -- 2.40.1 From 378fb1e642c8f725e4a5445efb82072cc3283172 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 13 Jun 2020 11:11:38 +0800 Subject: [PATCH 2/3] fix mssql findandcount --- session_find.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/session_find.go b/session_find.go index c1dbe4ca..90c2b6dc 100644 --- a/session_find.go +++ b/session_find.go @@ -60,6 +60,9 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte if session.statement.OrderStr != "" { session.statement.OrderStr = "" } + if session.statement.LimitN != nil { + session.statement.LimitN = nil + } // session has stored the conditions so we use `unscoped` to avoid duplicated condition. return session.Unscoped().Count(reflect.New(sliceElementType).Interface()) -- 2.40.1 From e39cc55b03be86a3de33e04cd7bc76d7e05e9ba5 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 13 Jun 2020 12:57:46 +0800 Subject: [PATCH 3/3] Fix bug --- session_find.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/session_find.go b/session_find.go index 90c2b6dc..642093f2 100644 --- a/session_find.go +++ b/session_find.go @@ -63,6 +63,9 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte if session.statement.LimitN != nil { session.statement.LimitN = nil } + if session.statement.Start > 0 { + session.statement.Start = 0 + } // session has stored the conditions so we use `unscoped` to avoid duplicated condition. return session.Unscoped().Count(reflect.New(sliceElementType).Interface()) -- 2.40.1