Make Get and Rows.Scan accept multiple parameters #2029

Merged
lunny merged 2 commits from lunny/get_vars into master 2021-08-24 07:42:35 +00:00
9 changed files with 209 additions and 72 deletions

View File

@ -161,6 +161,11 @@ has, err := engine.Table(&user).Where("name = ?", name).Cols("id").Get(&id)
has, err := engine.SQL("select id from user").Get(&id)
// SELECT id FROM user WHERE name = ?
var id int64
var name string
has, err := engine.Table(&user).Cols("id", "name").Get(&id, &name)
// SELECT id, name FROM user LIMIT 1
var valuesMap = make(map[string]string)
has, err := engine.Table(&user).Where("id = ?", id).Get(&valuesMap)
// SELECT * FROM user WHERE id = ?
@ -234,7 +239,11 @@ err := engine.BufferSize(100).Iterate(&User{Name:name}, func(idx int, bean inter
})
// SELECT * FROM user Limit 0, 100
// SELECT * FROM user Limit 101, 100
```
You can use rows which is similiar with `sql.Rows`
```Go
rows, err := engine.Rows(&User{Name:name})
// SELECT * FROM user
defer rows.Close()
@ -244,6 +253,19 @@ for rows.Next() {
}
```
or
```Go
rows, err := engine.Cols("name", "age").Rows(&User{Name:name})
// SELECT * FROM user
defer rows.Close()
for rows.Next() {
var name string
var age int
err = rows.Scan(&name, &age)
}
```
* `Update` update one or more records, default will update non-empty and non-zero fields except when you use Cols, AllCols and so on.
```Go

View File

@ -158,6 +158,11 @@ has, err := engine.Table(&user).Where("name = ?", name).Cols("id").Get(&id)
has, err := engine.SQL("select id from user").Get(&id)
// SELECT id FROM user WHERE name = ?
var id int64
var name string
has, err := engine.Table(&user).Cols("id", "name").Get(&id, &name)
// SELECT id, name FROM user LIMIT 1
var valuesMap = make(map[string]string)
has, err := engine.Table(&user).Where("id = ?", id).Get(&valuesMap)
// SELECT * FROM user WHERE id = ?
@ -231,7 +236,11 @@ err := engine.BufferSize(100).Iterate(&User{Name:name}, func(idx int, bean inter
})
// SELECT * FROM user Limit 0, 100
// SELECT * FROM user Limit 101, 100
```
Rows 的用法类似 `sql.Rows`
```Go
rows, err := engine.Rows(&User{Name:name})
// SELECT * FROM user
defer rows.Close()
@ -241,6 +250,19 @@ for rows.Next() {
}
```
或者
```Go
rows, err := engine.Cols("name", "age").Rows(&User{Name:name})
// SELECT * FROM user
defer rows.Close()
for rows.Next() {
var name string
var age int
err = rows.Scan(&name, &age)
}
```
* `Update` 更新数据除非使用Cols,AllCols函数指明默认只更新非空和非0的字段
```Go

16
doc.go
View File

@ -67,6 +67,11 @@ There are 8 major ORM methods and many helpful methods to use to operate databas
has, err := engine.Table("user").Where("name = ?", name).Get(&id)
// SELECT id FROM user WHERE name = ? LIMIT 1
var id int64
var name string
has, err := engine.Table(&user).Cols("id", "name").Get(&id, &name)
// SELECT id, name FROM user LIMIT 1
3. Query multiple records from database
var sliceOfStructs []Struct
@ -97,6 +102,17 @@ another is Rows
err = rows.Scan(bean)
}
or
rows, err := engine.Cols("name", "age").Rows(...)
// SELECT * FROM user
defer rows.Close()
for rows.Next() {
var name string
var age int
err = rows.Scan(&name, &age)
}
5. Update one or more records
affected, err := engine.ID(...).Update(&user)

View File

@ -1135,10 +1135,10 @@ func (engine *Engine) Delete(beans ...interface{}) (int64, error) {
// Get retrieve one record from table, bean's non-empty fields
// are conditions
func (engine *Engine) Get(bean interface{}) (bool, error) {
func (engine *Engine) Get(beans ...interface{}) (bool, error) {
session := engine.NewSession()
defer session.Close()
return session.Get(bean)
return session.Get(beans...)
}
// Exist returns true if the record exist otherwise return false

View File

@ -160,5 +160,49 @@ func TestRowsSpecTableName(t *testing.T) {
assert.NoError(t, err)
cnt++
}
assert.NoError(t, rows.Err())
assert.EqualValues(t, 1, cnt)
}
func TestRowsScanVars(t *testing.T) {
type RowsScanVars struct {
Id int64
Name string
Age int
}
assert.NoError(t, PrepareEngine())
assert.NoError(t, testEngine.Sync2(new(RowsScanVars)))
cnt, err := testEngine.Insert(&RowsScanVars{
Name: "xlw",
Age: 42,
}, &RowsScanVars{
Name: "xlw2",
Age: 24,
})
assert.NoError(t, err)
assert.EqualValues(t, 2, cnt)
rows, err := testEngine.Cols("name", "age").Rows(new(RowsScanVars))
assert.NoError(t, err)
defer rows.Close()
cnt = 0
for rows.Next() {
var name string
var age int
err = rows.Scan(&name, &age)
assert.NoError(t, err)
if cnt == 0 {
assert.EqualValues(t, "xlw", name)
assert.EqualValues(t, 42, age)
} else if cnt == 1 {
assert.EqualValues(t, "xlw2", name)
assert.EqualValues(t, 24, age)
}
cnt++
}
assert.NoError(t, rows.Err())
assert.EqualValues(t, 2, cnt)
}

View File

@ -890,3 +890,28 @@ func TestGetTime(t *testing.T) {
assert.True(t, has)
assert.EqualValues(t, gts.CreateTime.Format(time.RFC3339), gn.Format(time.RFC3339))
}
func TestGetVars(t *testing.T) {
type GetVars struct {
Id int64
Name string
Age int
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(GetVars))
_, err := testEngine.Insert(&GetVars{
Name: "xlw",
Age: 42,
})
assert.NoError(t, err)
var name string
var age int
has, err := testEngine.Table(new(GetVars)).Cols("name", "age").Get(&name, &age)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "xlw", name)
assert.EqualValues(t, 42, age)
}

View File

@ -37,7 +37,7 @@ type Interface interface {
Exist(bean ...interface{}) (bool, error)
Find(interface{}, ...interface{}) error
FindAndCount(interface{}, ...interface{}) (int64, error)
Get(interface{}) (bool, error)
Get(...interface{}) (bool, error)
GroupBy(keys string) *Session
ID(interface{}) *Session
In(string, ...interface{}) *Session

29
rows.go
View File

@ -11,7 +11,6 @@ import (
"xorm.io/builder"
"xorm.io/xorm/core"
"xorm.io/xorm/internal/utils"
)
// Rows rows wrapper a rows to
@ -93,17 +92,26 @@ func (rows *Rows) Err() error {
}
// Scan row record to bean properties
func (rows *Rows) Scan(bean interface{}) error {
func (rows *Rows) Scan(beans ...interface{}) error {
if rows.Err() != nil {
return rows.Err()
}
if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType {
return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType)
var bean = beans[0]
var tp = reflect.TypeOf(bean)
if tp.Kind() == reflect.Ptr {
tp = tp.Elem()
}
var beanKind = tp.Kind()
if err := rows.session.statement.SetRefBean(bean); err != nil {
return err
if len(beans) == 1 {
if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType {
return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType)
}
if err := rows.session.statement.SetRefBean(bean); err != nil {
return err
}
}
fields, err := rows.rows.Columns()
@ -115,14 +123,7 @@ func (rows *Rows) Scan(bean interface{}) error {
return err
}
scanResults, err := rows.session.row2Slice(rows.rows, fields, types, bean)
if err != nil {
return err
}
dataStruct := utils.ReflectValue(bean)
_, err = rows.session.slice2Bean(scanResults, fields, bean, &dataStruct, rows.session.statement.RefTable)
if err != nil {
if err := rows.session.scan(rows.rows, rows.session.statement.RefTable, beanKind, beans, types, fields); err != nil {
return err
}

View File

@ -28,11 +28,11 @@ var (
// Get retrieve one record from database, bean's non-empty fields
// will be as conditions
func (session *Session) Get(bean interface{}) (bool, error) {
func (session *Session) Get(beans ...interface{}) (bool, error) {
if session.isAutoClose {
defer session.Close()
}
return session.get(bean)
return session.get(beans...)
}
func isPtrOfTime(v interface{}) bool {
@ -48,14 +48,17 @@ func isPtrOfTime(v interface{}) bool {
return el.Type().ConvertibleTo(schemas.TimeType)
}
func (session *Session) get(bean interface{}) (bool, error) {
func (session *Session) get(beans ...interface{}) (bool, error) {
defer session.resetStatement()
if session.statement.LastError != nil {
return false, session.statement.LastError
}
if len(beans) == 0 {
return false, errors.New("needs at least one parameter for get")
}
beanValue := reflect.ValueOf(bean)
beanValue := reflect.ValueOf(beans[0])
if beanValue.Kind() != reflect.Ptr {
return false, errors.New("needs a pointer to a value")
} else if beanValue.Elem().Kind() == reflect.Ptr {
@ -64,8 +67,9 @@ func (session *Session) get(bean interface{}) (bool, error) {
return false, ErrObjectIsNil
}
if beanValue.Elem().Kind() == reflect.Struct && !isPtrOfTime(bean) {
if err := session.statement.SetRefBean(bean); err != nil {
var isStruct = beanValue.Elem().Kind() == reflect.Struct && !isPtrOfTime(beans[0])
if isStruct {
if err := session.statement.SetRefBean(beans[0]); err != nil {
return false, err
}
}
@ -79,7 +83,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
return false, ErrTableNotFound
}
session.statement.Limit(1)
sqlStr, args, err = session.statement.GenGetSQL(bean)
sqlStr, args, err = session.statement.GenGetSQL(beans[0])
if err != nil {
return false, err
}
@ -90,10 +94,10 @@ func (session *Session) get(bean interface{}) (bool, error) {
table := session.statement.RefTable
if session.statement.ColumnMap.IsEmpty() && session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
if session.statement.ColumnMap.IsEmpty() && session.canCache() && isStruct {
if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil &&
!session.statement.GetUnscoped() {
has, err := session.cacheGet(bean, sqlStr, args...)
has, err := session.cacheGet(beans[0], sqlStr, args...)
if err != ErrCacheFailed {
return has, err
}
@ -101,12 +105,12 @@ func (session *Session) get(bean interface{}) (bool, error) {
}
context := session.statement.Context
if context != nil {
if context != nil && isStruct {
res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args))
if res != nil {
session.engine.logger.Debugf("hit context cache: %s", sqlStr)
structValue := reflect.Indirect(reflect.ValueOf(bean))
structValue := reflect.Indirect(reflect.ValueOf(beans[0]))
structValue.Set(reflect.Indirect(reflect.ValueOf(res)))
session.lastSQL = ""
session.lastSQLArgs = nil
@ -114,13 +118,13 @@ func (session *Session) get(bean interface{}) (bool, error) {
}
}
has, err := session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...)
has, err := session.nocacheGet(beanValue.Elem().Kind(), table, beans, sqlStr, args...)
if err != nil || !has {
return has, err
}
if context != nil {
context.Put(fmt.Sprintf("%v-%v", sqlStr, args), bean)
if context != nil && isStruct {
context.Put(fmt.Sprintf("%v-%v", sqlStr, args), beans[0])
}
return true, nil
@ -148,7 +152,7 @@ func isScannableStruct(bean interface{}, typeLen int) bool {
return true
}
func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, beans []interface{}, sqlStr string, args ...interface{}) (bool, error) {
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return false, err
@ -168,27 +172,39 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table,
if err != nil {
return true, err
}
switch beanKind {
case reflect.Struct:
if !isScannableStruct(bean, len(types)) {
break
}
return session.getStruct(rows, types, fields, table, bean)
case reflect.Slice:
return session.getSlice(rows, types, fields, bean)
case reflect.Map:
return session.getMap(rows, types, fields, bean)
}
return session.getVars(rows, types, fields, bean)
return true, session.scan(rows, table, beanKind, beans, types, fields)
}
func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) {
func (session *Session) scan(rows *core.Rows, table *schemas.Table, firstBeanKind reflect.Kind, beans []interface{}, types []*sql.ColumnType, fields []string) error {
if len(beans) == 1 {
bean := beans[0]
switch firstBeanKind {
case reflect.Struct:
if !isScannableStruct(bean, len(types)) {
break
}
return session.getStruct(rows, types, fields, table, bean)
case reflect.Slice:
return session.getSlice(rows, types, fields, bean)
case reflect.Map:
return session.getMap(rows, types, fields, bean)
}
}
if len(beans) != len(types) {
return fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans))
}
return session.engine.scan(rows, fields, types, beans...)
}
func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) error {
switch t := bean.(type) {
case *[]string:
res, err := session.engine.scanStringInterface(rows, fields, types)
if err != nil {
return true, err
return err
}
var needAppend = len(*t) == 0 // both support slice is empty or has been initlized
@ -199,17 +215,17 @@ func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, field
(*t)[i] = r.(*sql.NullString).String
}
}
return true, nil
return nil
case *[]interface{}:
scanResults, err := session.engine.scanInterfaces(rows, fields, types)
if err != nil {
return true, err
return err
}
var needAppend = len(*t) == 0
for ii := range fields {
s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
if err != nil {
return true, err
return err
}
if needAppend {
*t = append(*t, s)
@ -217,54 +233,45 @@ func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, field
(*t)[ii] = s
}
}
return true, nil
return nil
default:
return true, fmt.Errorf("unspoorted slice type: %t", t)
return fmt.Errorf("unspoorted slice type: %t", t)
}
}
func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) {
func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) error {
switch t := bean.(type) {
case *map[string]string:
scanResults, err := session.engine.scanStringInterface(rows, fields, types)
if err != nil {
return true, err
return err
}
for ii, key := range fields {
(*t)[key] = scanResults[ii].(*sql.NullString).String
}
return true, nil
return nil
case *map[string]interface{}:
scanResults, err := session.engine.scanInterfaces(rows, fields, types)
if err != nil {
return true, err
return err
}
for ii, key := range fields {
s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
if err != nil {
return true, err
return err
}
(*t)[key] = s
}
return true, nil
return nil
default:
return true, fmt.Errorf("unspoorted map type: %t", t)
return fmt.Errorf("unspoorted map type: %t", t)
}
}
func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields []string, beans ...interface{}) (bool, error) {
if len(beans) != len(types) {
return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans))
}
err := session.engine.scan(rows, fields, types, beans...)
return true, err
}
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) error {
scanResults, err := session.row2Slice(rows, fields, types, bean)
if err != nil {
return false, err
return err
}
// close it before convert data
rows.Close()
@ -272,10 +279,10 @@ func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fiel
dataStruct := utils.ReflectValue(bean)
_, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table)
if err != nil {
return true, err
return err
}
return true, session.executeProcessors()
return session.executeProcessors()
}
func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) {
@ -354,7 +361,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
cacheBean := cacher.GetBean(tableName, sid)
if cacheBean == nil {
cacheBean = bean
has, err = session.nocacheGet(reflect.Struct, table, cacheBean, sqlStr, args...)
has, err = session.nocacheGet(reflect.Struct, table, []interface{}{cacheBean}, sqlStr, args...)
if err != nil || !has {
return has, err
}