diff --git a/internal/statements/column_map.go b/internal/statements/column_map.go index 8440f821..bb764b4e 100644 --- a/internal/statements/column_map.go +++ b/internal/statements/column_map.go @@ -30,7 +30,15 @@ func (m columnMap) Contain(colName string) bool { return false } -func (m *columnMap) add(colName string) bool { +func (m columnMap) Len() int { + return len(m) +} + +func (m columnMap) IsEmpty() bool { + return len(m) == 0 +} + +func (m *columnMap) Add(colName string) bool { if m.Contain(colName) { return false } diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 90771d4b..a2a356ff 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -389,7 +389,7 @@ func col2NewCols(columns ...string) []string { func (statement *Statement) Cols(columns ...string) *Statement { cols := col2NewCols(columns...) for _, nc := range cols { - statement.ColumnMap.add(nc) + statement.ColumnMap.Add(nc) } return statement } diff --git a/session_find.go b/session_find.go index 9551b767..72882a28 100644 --- a/session_find.go +++ b/session_find.go @@ -65,13 +65,14 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { defer session.resetStatement() - if session.statement.LastError != nil { return session.statement.LastError } sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) - if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { + var isSlice = sliceValue.Kind() == reflect.Slice + var isMap = sliceValue.Kind() == reflect.Map + if !isSlice && !isMap { return errors.New("needs a pointer to a slice or a map") } @@ -127,12 +128,18 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } } + if isMap && !session.statement.ColumnMap.IsEmpty() { + for _, k := range session.statement.RefTable.PrimaryKeys { + session.statement.ColumnMap.Add(k) + } + } + sqlStr, args, err := session.statement.GenFindSQL(autoCond) if err != nil { return err } - if session.canCache() { + if session.statement.ColumnMap.IsEmpty() && session.canCache() { if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && !session.statement.IsDistinct && !session.statement.GetUnscoped() { diff --git a/session_find_test.go b/session_find_test.go index 2d15eed8..28ba1d60 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -5,13 +5,13 @@ package xorm import ( - "fmt" "testing" "time" - "github.com/stretchr/testify/assert" "xorm.io/xorm/internal/utils" "xorm.io/xorm/names" + + "github.com/stretchr/testify/assert" ) func TestJoinLimit(t *testing.T) { @@ -79,11 +79,9 @@ func TestWhere(t *testing.T) { users := make([]Userinfo, 0) err := testEngine.Where("id > ?", 2).Find(&users) assert.NoError(t, err) - fmt.Println(users) err = testEngine.Where("id > ?", 2).And("id < ?", 10).Find(&users) assert.NoError(t, err) - fmt.Println(users) } func TestFind(t *testing.T) { @@ -94,9 +92,6 @@ func TestFind(t *testing.T) { err := testEngine.Find(&users) assert.NoError(t, err) - for _, user := range users { - fmt.Println(user) - } users2 := make([]Userinfo, 0) var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true)) @@ -112,10 +107,6 @@ func TestFind2(t *testing.T) { err := testEngine.Find(&users) assert.NoError(t, err) - - for _, user := range users { - fmt.Println(user) - } } type Team struct { @@ -191,9 +182,29 @@ func TestFindMap(t *testing.T) { assert.NoError(t, prepareEngine()) assertSync(t, new(Userinfo)) - users := make(map[int64]Userinfo) - err := testEngine.Find(&users) + cnt, err := testEngine.Insert(&Userinfo{ + Username: "lunny", + Departname: "depart1", + IsMan: true, + }) assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + users := make(map[int64]Userinfo) + err = testEngine.Find(&users) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) + assert.EqualValues(t, "lunny", users[1].Username) + assert.EqualValues(t, "depart1", users[1].Departname) + assert.True(t, users[1].IsMan) + + users = make(map[int64]Userinfo) + err = testEngine.Cols("username, departname").Find(&users) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) + assert.EqualValues(t, "lunny", users[1].Username) + assert.EqualValues(t, "depart1", users[1].Departname) + assert.False(t, users[1].IsMan) } func TestFindMap2(t *testing.T) { diff --git a/session_get.go b/session_get.go index c468b440..76918194 100644 --- a/session_get.go +++ b/session_get.go @@ -65,7 +65,7 @@ func (session *Session) get(bean interface{}) (bool, error) { table := session.statement.RefTable - if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { + if session.statement.ColumnMap.IsEmpty() && session.canCache() && beanValue.Elem().Kind() == reflect.Struct { if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && !session.statement.GetUnscoped() { has, err := session.cacheGet(bean, sqlStr, args...)