Fix map with cols #1575

Merged
lunny merged 2 commits from lunny/fix_map_cols into master 2020-03-06 08:55:19 +00:00
5 changed files with 45 additions and 19 deletions

View File

@ -30,7 +30,15 @@ func (m columnMap) Contain(colName string) bool {
return false return false
} }
func (m *columnMap) add(colName string) bool { func (m columnMap) Len() int {
return len(m)
}
func (m columnMap) IsEmpty() bool {
return len(m) == 0
}
func (m *columnMap) Add(colName string) bool {
if m.Contain(colName) { if m.Contain(colName) {
return false return false
} }

View File

@ -389,7 +389,7 @@ func col2NewCols(columns ...string) []string {
func (statement *Statement) Cols(columns ...string) *Statement { func (statement *Statement) Cols(columns ...string) *Statement {
cols := col2NewCols(columns...) cols := col2NewCols(columns...)
for _, nc := range cols { for _, nc := range cols {
statement.ColumnMap.add(nc) statement.ColumnMap.Add(nc)
} }
return statement return statement
} }

View File

@ -65,13 +65,14 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte
func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
defer session.resetStatement() defer session.resetStatement()
if session.statement.LastError != nil { if session.statement.LastError != nil {
return session.statement.LastError return session.statement.LastError
} }
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { var isSlice = sliceValue.Kind() == reflect.Slice
var isMap = sliceValue.Kind() == reflect.Map
if !isSlice && !isMap {
return errors.New("needs a pointer to a slice or a map") return errors.New("needs a pointer to a slice or a map")
} }
@ -127,12 +128,18 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
} }
if isMap && !session.statement.ColumnMap.IsEmpty() {
for _, k := range session.statement.RefTable.PrimaryKeys {
session.statement.ColumnMap.Add(k)
}
}
sqlStr, args, err := session.statement.GenFindSQL(autoCond) sqlStr, args, err := session.statement.GenFindSQL(autoCond)
if err != nil { if err != nil {
return err return err
} }
if session.canCache() { if session.statement.ColumnMap.IsEmpty() && session.canCache() {
if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil &&
!session.statement.IsDistinct && !session.statement.IsDistinct &&
!session.statement.GetUnscoped() { !session.statement.GetUnscoped() {

View File

@ -5,13 +5,13 @@
package xorm package xorm
import ( import (
"fmt"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/names" "xorm.io/xorm/names"
"github.com/stretchr/testify/assert"
) )
func TestJoinLimit(t *testing.T) { func TestJoinLimit(t *testing.T) {
@ -79,11 +79,9 @@ func TestWhere(t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
err := testEngine.Where("id > ?", 2).Find(&users) err := testEngine.Where("id > ?", 2).Find(&users)
assert.NoError(t, err) assert.NoError(t, err)
fmt.Println(users)
err = testEngine.Where("id > ?", 2).And("id < ?", 10).Find(&users) err = testEngine.Where("id > ?", 2).And("id < ?", 10).Find(&users)
assert.NoError(t, err) assert.NoError(t, err)
fmt.Println(users)
} }
func TestFind(t *testing.T) { func TestFind(t *testing.T) {
@ -94,9 +92,6 @@ func TestFind(t *testing.T) {
err := testEngine.Find(&users) err := testEngine.Find(&users)
assert.NoError(t, err) assert.NoError(t, err)
for _, user := range users {
fmt.Println(user)
}
users2 := make([]Userinfo, 0) users2 := make([]Userinfo, 0)
var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true)) var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true))
@ -112,10 +107,6 @@ func TestFind2(t *testing.T) {
err := testEngine.Find(&users) err := testEngine.Find(&users)
assert.NoError(t, err) assert.NoError(t, err)
for _, user := range users {
fmt.Println(user)
}
} }
type Team struct { type Team struct {
@ -191,9 +182,29 @@ func TestFindMap(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
assertSync(t, new(Userinfo)) assertSync(t, new(Userinfo))
users := make(map[int64]Userinfo) cnt, err := testEngine.Insert(&Userinfo{
err := testEngine.Find(&users) Username: "lunny",
Departname: "depart1",
IsMan: true,
})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
users := make(map[int64]Userinfo)
err = testEngine.Find(&users)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(users))
assert.EqualValues(t, "lunny", users[1].Username)
assert.EqualValues(t, "depart1", users[1].Departname)
assert.True(t, users[1].IsMan)
users = make(map[int64]Userinfo)
err = testEngine.Cols("username, departname").Find(&users)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(users))
assert.EqualValues(t, "lunny", users[1].Username)
assert.EqualValues(t, "depart1", users[1].Departname)
assert.False(t, users[1].IsMan)
} }
func TestFindMap2(t *testing.T) { func TestFindMap2(t *testing.T) {

View File

@ -65,7 +65,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
table := session.statement.RefTable table := session.statement.RefTable
if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { if session.statement.ColumnMap.IsEmpty() && session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil &&
!session.statement.GetUnscoped() { !session.statement.GetUnscoped() {
has, err := session.cacheGet(bean, sqlStr, args...) has, err := session.cacheGet(bean, sqlStr, args...)