Fix find and count bug #1651

Merged
lunny merged 3 commits from lunny/fix_bug into master 2020-06-13 05:31:36 +00:00
5 changed files with 114 additions and 76 deletions

View File

@ -816,81 +816,11 @@ func (engine *Engine) IsTableExist(beanOrTableName interface{}) (bool, error) {
return session.IsTableExist(beanOrTableName)
}
// IDOf get id from one struct
func (engine *Engine) IDOf(bean interface{}) (schemas.PK, error) {
return engine.IDOfV(reflect.ValueOf(bean))
}
// TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
return dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, includeSchema...)
}
// IDOfV get id from one value of struct
func (engine *Engine) IDOfV(rv reflect.Value) (schemas.PK, error) {
return engine.idOfV(rv)
}
func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) {
v := reflect.Indirect(rv)
table, err := engine.tagParser.ParseWithCache(v)
if err != nil {
return nil, err
}
pk := make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
var err error
fieldName := col.FieldName
for {
parts := strings.SplitN(fieldName, ".", 2)
if len(parts) == 1 {
break
}
v = v.FieldByName(parts[0])
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return nil, ErrUnSupportedType
}
fieldName = parts[1]
}
pkField := v.FieldByName(fieldName)
switch pkField.Kind() {
case reflect.String:
pk[i], err = engine.idTypeAssertion(col, pkField.String())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
pk[i], err = engine.idTypeAssertion(col, strconv.FormatInt(pkField.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
// id of uint will be converted to int64
pk[i], err = engine.idTypeAssertion(col, strconv.FormatUint(pkField.Uint(), 10))
}
if err != nil {
return nil, err
}
}
return schemas.PK(pk), nil
}
func (engine *Engine) idTypeAssertion(col *schemas.Column, sid string) (interface{}, error) {
if col.SQLType.IsNumeric() {
n, err := strconv.ParseInt(sid, 10, 64)
if err != nil {
return nil, err
}
return n, nil
} else if col.SQLType.IsText() {
return sid, nil
} else {
return nil, errors.New("not supported")
}
}
// CreateIndexes create indexes
func (engine *Engine) CreateIndexes(bean interface{}) error {
session := engine.NewSession()

View File

@ -502,10 +502,48 @@ func TestFindAndCountOneFunc(t *testing.T) {
assert.EqualValues(t, 1, cnt)
results = make([]FindAndCountStruct, 0, 1)
cnt, err = testEngine.Where("msg = ?", true).Limit(1).FindAndCount(&results)
cnt, err = testEngine.Where("1=1").Limit(1).FindAndCount(&results)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, 1, cnt)
assert.EqualValues(t, 2, cnt)
assert.EqualValues(t, FindAndCountStruct{
Id: 1,
Content: "111",
Msg: false,
}, results[0])
results = make([]FindAndCountStruct, 0, 1)
cnt, err = testEngine.Where("1=1").Limit(1).FindAndCount(&results)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, 2, cnt)
assert.EqualValues(t, FindAndCountStruct{
Id: 1,
Content: "111",
Msg: false,
}, results[0])
results = make([]FindAndCountStruct, 0, 1)
cnt, err = testEngine.Where("1=1").Limit(1, 1).FindAndCount(&results)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, 2, cnt)
assert.EqualValues(t, FindAndCountStruct{
Id: 2,
Content: "222",
Msg: true,
}, results[0])
results = make([]FindAndCountStruct, 0, 1)
cnt, err = testEngine.Where("1=1").Limit(1, 1).FindAndCount(&results)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, 2, cnt)
assert.EqualValues(t, FindAndCountStruct{
Id: 2,
Content: "222",
Msg: true,
}, results[0])
results = make([]FindAndCountStruct, 0, 1)
cnt, err = testEngine.Where("msg = ?", true).Select("id, content, msg").

View File

@ -5,8 +5,10 @@
package schemas
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
@ -115,3 +117,17 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
return &fieldValue, nil
}
// ConvertID converts id content to suitable type according column type
func (col *Column) ConvertID(sid string) (interface{}, error) {
if col.SQLType.IsNumeric() {
n, err := strconv.ParseInt(sid, 10, 64)
if err != nil {
return nil, err
}
return n, nil
} else if col.SQLType.IsText() {
return sid, nil
}
return nil, errors.New("not supported")
}

View File

@ -5,7 +5,9 @@
package schemas
import (
"fmt"
"reflect"
"strconv"
"strings"
)
@ -28,6 +30,7 @@ type Table struct {
Comment string
}
// NewEmptyTable creates an empty table
func NewEmptyTable() *Table {
return NewTable("", nil)
}
@ -44,10 +47,12 @@ func NewTable(name string, t reflect.Type) *Table {
}
}
// Columns returns table's columns
func (table *Table) Columns() []*Column {
return table.columns
}
// ColumnsSeq returns table's column names according sequence
func (table *Table) ColumnsSeq() []string {
return table.columnsSeq
}
@ -61,6 +66,7 @@ func (table *Table) columnsByName(name string) []*Column {
return nil
}
// GetColumn returns column according column name, if column not found, return nil
func (table *Table) GetColumn(name string) *Column {
cols := table.columnsByName(name)
if cols != nil {
@ -70,6 +76,7 @@ func (table *Table) GetColumn(name string) *Column {
return nil
}
// GetColumnIdx returns column according name and idx
func (table *Table) GetColumnIdx(name string, idx int) *Column {
cols := table.columnsByName(name)
if cols != nil && idx < len(cols) {
@ -144,3 +151,45 @@ func (table *Table) AddColumn(col *Column) {
func (table *Table) AddIndex(index *Index) {
table.Indexes[index.Name] = index
}
// IDOfV get id from one value of struct
func (table *Table) IDOfV(rv reflect.Value) (PK, error) {
v := reflect.Indirect(rv)
pk := make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
var err error
fieldName := col.FieldName
for {
parts := strings.SplitN(fieldName, ".", 2)
if len(parts) == 1 {
break
}
v = v.FieldByName(parts[0])
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return nil, fmt.Errorf("Unsupported read value of column %s from field %s", col.Name, col.FieldName)
}
fieldName = parts[1]
}
pkField := v.FieldByName(fieldName)
switch pkField.Kind() {
case reflect.String:
pk[i], err = col.ConvertID(pkField.String())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
pk[i], err = col.ConvertID(strconv.FormatInt(pkField.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
// id of uint will be converted to int64
pk[i], err = col.ConvertID(strconv.FormatUint(pkField.Uint(), 10))
}
if err != nil {
return nil, err
}
}
return PK(pk), nil
}

View File

@ -60,6 +60,12 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte
if session.statement.OrderStr != "" {
session.statement.OrderStr = ""
}
if session.statement.LimitN != nil {
session.statement.LimitN = nil
}
if session.statement.Start > 0 {
session.statement.Start = 0
}
// session has stored the conditions so we use `unscoped` to avoid duplicated condition.
return session.Unscoped().Count(reflect.New(sliceElementType).Interface())
@ -320,7 +326,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
}
var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
pk[i], err = session.engine.idTypeAssertion(col, res[i])
pk[i], err = col.ConvertID(res[i])
if err != nil {
return err
}
@ -370,7 +376,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
} else {
session.engine.logger.Debugf("[cache] cache hit bean: %v, %v, %v", tableName, id, bean)
pk, err := session.engine.IDOf(bean)
pk, err := table.IDOfV(reflect.ValueOf(bean))
if err != nil {
return err
}
@ -419,7 +425,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
if err != nil {
return err
}
session.statement = statement
vs := reflect.Indirect(reflect.ValueOf(beans))
@ -428,7 +433,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
id, err := session.engine.idOfV(rv)
id, err := table.IDOfV(rv)
if err != nil {
return err
}