diff --git a/engine.go b/engine.go index 444611af..52ec1e3f 100644 --- a/engine.go +++ b/engine.go @@ -1453,6 +1453,13 @@ func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error { return session.Find(beans, condiBeans...) } +// FindAndCount find the results and also return the counts +func (engine *Engine) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) { + session := engine.NewSession() + defer session.Close() + return session.FindAndCount(rowsSlicePtr, condiBean...) +} + // Iterate record by record handle records from table, bean's non-empty fields // are conditions. func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error { diff --git a/interface.go b/interface.go index 9a3b6da0..85a46a27 100644 --- a/interface.go +++ b/interface.go @@ -30,6 +30,7 @@ type Interface interface { Exec(string, ...interface{}) (sql.Result, error) Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error + FindAndCount(interface{}, ...interface{}) (int64, error) Get(interface{}) (bool, error) GroupBy(keys string) *Session ID(interface{}) *Session diff --git a/session_find.go b/session_find.go index f95dcfef..68880d97 100644 --- a/session_find.go +++ b/session_find.go @@ -29,6 +29,29 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) return session.find(rowsSlicePtr, condiBean...) } +// FindAndCount find the results and also return the counts +func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) { + if session.isAutoClose { + defer session.Close() + } + + session.autoResetStatement = false + err := session.find(rowsSlicePtr, condiBean...) + if err != nil { + return 0, err + } + + sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) + if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { + return 0, errors.New("needs a pointer to a slice or a map") + } + + sliceElementType := sliceValue.Type().Elem() + session.autoResetStatement = true + + return session.Count(reflect.New(sliceElementType).Interface()) +} + func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { diff --git a/session_find_test.go b/session_find_test.go index 46816acc..b2e8f5a3 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -520,3 +520,39 @@ func TestFindMark(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, len(results)) } + +func TestFindAndCountOneFunc(t *testing.T) { + type FindAndCountStruct struct { + Id int64 + Content string + Msg bool `xorm:"bit"` + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(FindAndCountStruct)) + + cnt, err := testEngine.Insert([]FindAndCountStruct{ + { + Content: "111", + Msg: false, + }, + { + Content: "222", + Msg: true, + }, + }) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + var results = make([]FindAndCountStruct, 0, 2) + cnt, err = testEngine.FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(results)) + assert.EqualValues(t, 2, cnt) + + results = make([]FindAndCountStruct, 0, 1) + cnt, err = testEngine.Where("msg = ?", true).FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, 1, cnt) +} \ No newline at end of file