refactor get #1967

Merged
lunny merged 7 commits from lunny/refactor_get into master 7 months ago
  1. 416
      convert.go
  2. 11
      dialects/driver.go
  3. 1
      dialects/mysql.go
  4. 6
      dialects/postgres.go
  5. 6
      dialects/sqlite3.go
  6. 197
      scan.go
  7. 2
      session_find.go
  8. 263
      session_get.go
  9. 6
      session_insert.go
  10. 2
      session_query.go

416
convert.go

@ -5,12 +5,15 @@ @@ -5,12 +5,15 @@
package xorm
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strconv"
"time"
"xorm.io/xorm/convert"
)
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
@ -37,6 +40,12 @@ func asString(src interface{}) string { @@ -37,6 +40,12 @@ func asString(src interface{}) string {
return v
case []byte:
return string(v)
case *sql.NullString:
return v.String
case *sql.NullInt32:
return fmt.Sprintf("%d", v.Int32)
case *sql.NullInt64:
return fmt.Sprintf("%d", v.Int64)
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
@ -54,6 +63,156 @@ func asString(src interface{}) string { @@ -54,6 +63,156 @@ func asString(src interface{}) string {
return fmt.Sprintf("%v", src)
}
func asInt64(src interface{}) (int64, error) {
switch v := src.(type) {
case int:
return int64(v), nil
case int16:
return int64(v), nil
case int32:
return int64(v), nil
case int8:
return int64(v), nil
case int64:
return v, nil
case uint:
return int64(v), nil
case uint8:
return int64(v), nil
case uint16:
return int64(v), nil
case uint32:
return int64(v), nil
case uint64:
return int64(v), nil
case []byte:
return strconv.ParseInt(string(v), 10, 64)
case string:
return strconv.ParseInt(v, 10, 64)
case *sql.NullString:
return strconv.ParseInt(v.String, 10, 64)
case *sql.NullInt32:
return int64(v.Int32), nil
case *sql.NullInt64:
return int64(v.Int64), nil
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return int64(rv.Uint()), nil
case reflect.Float64:
return int64(rv.Float()), nil
case reflect.Float32:
return int64(rv.Float()), nil
case reflect.String:
return strconv.ParseInt(rv.String(), 10, 64)
}
return 0, fmt.Errorf("unsupported value %T as int64", src)
}
func asUint64(src interface{}) (uint64, error) {
switch v := src.(type) {
case int:
return uint64(v), nil
case int16:
return uint64(v), nil
case int32:
return uint64(v), nil
case int8:
return uint64(v), nil
case int64:
return uint64(v), nil
case uint:
return uint64(v), nil
case uint8:
return uint64(v), nil
case uint16:
return uint64(v), nil
case uint32:
return uint64(v), nil
case uint64:
return v, nil
case []byte:
return strconv.ParseUint(string(v), 10, 64)
case string:
return strconv.ParseUint(v, 10, 64)
case *sql.NullString:
return strconv.ParseUint(v.String, 10, 64)
case *sql.NullInt32:
return uint64(v.Int32), nil
case *sql.NullInt64:
return uint64(v.Int64), nil
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return uint64(rv.Int()), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return uint64(rv.Uint()), nil
case reflect.Float64:
return uint64(rv.Float()), nil
case reflect.Float32:
return uint64(rv.Float()), nil
case reflect.String:
return strconv.ParseUint(rv.String(), 10, 64)
}
return 0, fmt.Errorf("unsupported value %T as uint64", src)
}
func asFloat64(src interface{}) (float64, error) {
switch v := src.(type) {
case int:
return float64(v), nil
case int16:
return float64(v), nil
case int32:
return float64(v), nil
case int8:
return float64(v), nil
case int64:
return float64(v), nil
case uint:
return float64(v), nil
case uint8:
return float64(v), nil
case uint16:
return float64(v), nil
case uint32:
return float64(v), nil
case uint64:
return float64(v), nil
case []byte:
return strconv.ParseFloat(string(v), 64)
case string:
return strconv.ParseFloat(v, 64)
case *sql.NullString:
return strconv.ParseFloat(v.String, 64)
case *sql.NullInt32:
return float64(v.Int32), nil
case *sql.NullInt64:
return float64(v.Int64), nil
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return float64(rv.Int()), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return float64(rv.Uint()), nil
case reflect.Float64:
return float64(rv.Float()), nil
case reflect.Float32:
return float64(rv.Float()), nil
case reflect.String:
return strconv.ParseFloat(rv.String(), 64)
}
return 0, fmt.Errorf("unsupported value %T as int64", src)
}
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@ -76,7 +235,7 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { @@ -76,7 +235,7 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
// convertAssign copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information.
// dest should be a pointer type.
func convertAssign(dest, src interface{}) error {
func convertAssign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error {
// Common cases, without reflect.
switch s := src.(type) {
case string:
@ -143,6 +302,163 @@ func convertAssign(dest, src interface{}) error { @@ -143,6 +302,163 @@ func convertAssign(dest, src interface{}) error {
*d = nil
return nil
}
case *sql.NullString:
switch d := dest.(type) {
case *int:
if s.Valid {
*d, _ = strconv.Atoi(s.String)
}
case *int64:
if s.Valid {
*d, _ = strconv.ParseInt(s.String, 10, 64)
}
case *string:
if s.Valid {
*d = s.String
}
return nil
case *time.Time:
if s.Valid {
var err error
dt, err := convert.String2Time(s.String, originalLocation, convertedLocation)
if err != nil {
return err
}
*d = *dt
}
return nil
case *sql.NullTime:
if s.Valid {
var err error
dt, err := convert.String2Time(s.String, originalLocation, convertedLocation)
if err != nil {
return err
}
d.Valid = true
d.Time = *dt
}
}
case *sql.NullInt32:
switch d := dest.(type) {
case *int:
if s.Valid {
*d = int(s.Int32)
}
return nil
case *int8:
if s.Valid {
*d = int8(s.Int32)
}
return nil
case *int16:
if s.Valid {
*d = int16(s.Int32)
}
return nil
case *int32:
if s.Valid {
*d = s.Int32
}
return nil
case *int64:
if s.Valid {
*d = int64(s.Int32)
}
return nil
}
case *sql.NullInt64:
switch d := dest.(type) {
case *int:
if s.Valid {
*d = int(s.Int64)
}
return nil
case *int8:
if s.Valid {
*d = int8(s.Int64)
}
return nil
case *int16:
if s.Valid {
*d = int16(s.Int64)
}
return nil
case *int32:
if s.Valid {
*d = int32(s.Int64)
}
return nil
case *int64:
if s.Valid {
*d = s.Int64
}
return nil
}
case *sql.NullFloat64:
switch d := dest.(type) {
case *int:
if s.Valid {
*d = int(s.Float64)
}
return nil
case *float64:
if s.Valid {
*d = s.Float64
}
return nil
}
case *sql.NullBool:
switch d := dest.(type) {
case *bool:
if s.Valid {
*d = s.Bool
}
return nil
}
case *sql.NullTime:
switch d := dest.(type) {
case *time.Time:
if s.Valid {
*d = s.Time
}
return nil
case *string:
if s.Valid {
*d = s.Time.In(convertedLocation).Format("2006-01-02 15:04:05")
}
return nil
}
case *NullUint32:
switch d := dest.(type) {
case *uint8:
if s.Valid {
*d = uint8(s.Uint32)
}
return nil
case *uint16:
if s.Valid {
*d = uint16(s.Uint32)
}
return nil
case *uint:
if s.Valid {
*d = uint(s.Uint32)
}
return nil
}
case *NullUint64:
switch d := dest.(type) {
case *uint64:
if s.Valid {
*d = s.Uint64
}
return nil
}
case *sql.RawBytes:
switch d := dest.(type) {
case convert.Conversion:
return d.FromDB(*s)
}
}
var sv reflect.Value
@ -175,10 +491,10 @@ func convertAssign(dest, src interface{}) error { @@ -175,10 +491,10 @@ func convertAssign(dest, src interface{}) error {
return nil
}
return convertAssignV(reflect.ValueOf(dest), src)
return convertAssignV(reflect.ValueOf(dest), src, originalLocation, convertedLocation)
}
func convertAssignV(dpv reflect.Value, src interface{}) error {
func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, convertedLocation *time.Location) error {
if dpv.Kind() != reflect.Ptr {
return errors.New("destination not a pointer")
}
@ -212,31 +528,28 @@ func convertAssignV(dpv reflect.Value, src interface{}) error { @@ -212,31 +528,28 @@ func convertAssignV(dpv reflect.Value, src interface{}) error {
}
dv.Set(reflect.New(dv.Type().Elem()))
return convertAssign(dv.Interface(), src)
return convertAssign(dv.Interface(), src, originalLocation, convertedLocation)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
s := asString(src)
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
i64, err := asInt64(src)
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err)
}
dv.SetInt(i64)
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
s := asString(src)
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
u64, err := asUint64(src)
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err)
}
dv.SetUint(u64)
return nil
case reflect.Float32, reflect.Float64:
s := asString(src)
f64, err := strconv.ParseFloat(s, dv.Type().Bits())
f64, err := asFloat64(src)
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err)
}
dv.SetFloat(f64)
return nil
@ -376,3 +689,80 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) { @@ -376,3 +689,80 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) {
}
return v.Interface(), nil
}
var (
_ sql.Scanner = &NullUint64{}
)
// NullUint64 represents an uint64 that may be null.
// NullUint64 implements the Scanner interface so
// it can be used as a scan destination, similar to NullString.
type NullUint64 struct {
Uint64 uint64
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullUint64) Scan(value interface{}) error {
if value == nil {
n.Uint64, n.Valid = 0, false
return nil
}
n.Valid = true
var err error
n.Uint64, err = asUint64(value)
return err
}
// Value implements the driver Valuer interface.
func (n NullUint64) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Uint64, nil
}
var (
_ sql.Scanner = &NullUint32{}
)
// NullUint32 represents an uint32 that may be null.
// NullUint32 implements the Scanner interface so
// it can be used as a scan destination, similar to NullString.
type NullUint32 struct {
Uint32 uint32
Valid bool // Valid is true if Uint32 is not NULL
}
// Scan implements the Scanner interface.
func (n *NullUint32) Scan(value interface{}) error {
if value == nil {
n.Uint32, n.Valid = 0, false
return nil
}
n.Valid = true
i64, err := asUint64(value)
if err != nil {
return err
}
n.Uint32 = uint32(i64)
return nil
}
// Value implements the driver Valuer interface.
func (n NullUint32) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return int64(n.Uint32), nil
}
var (
_ sql.Scanner = &EmptyScanner{}
)
type EmptyScanner struct{}
func (EmptyScanner) Scan(value interface{}) error {
return nil
}

11
dialects/driver.go

@ -18,9 +18,14 @@ type ScanContext struct { @@ -18,9 +18,14 @@ type ScanContext struct {
UserLocation *time.Location
}
type DriverFeatures struct {
SupportNullable bool
}
// Driver represents a database driver
type Driver interface {
Parse(string, string) (*URI, error)
Features() DriverFeatures
GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface
Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error
}
@ -77,3 +82,9 @@ type baseDriver struct{} @@ -77,3 +82,9 @@ type baseDriver struct{}
func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error {
return rows.Scan(v...)
}
func (b *baseDriver) Features() DriverFeatures {
return DriverFeatures{
SupportNullable: true,
}
}

1
dialects/mysql.go

@ -633,6 +633,7 @@ func (db *mysql) Filters() []Filter { @@ -633,6 +633,7 @@ func (db *mysql) Filters() []Filter {
}
type mysqlDriver struct {
baseDriver
}
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {

6
dialects/postgres.go

@ -1302,6 +1302,12 @@ type pqDriver struct { @@ -1302,6 +1302,12 @@ type pqDriver struct {
baseDriver
}
func (b *pqDriver) Features() DriverFeatures {
return DriverFeatures{
SupportNullable: false,
}
}
type values map[string]string
func (vs values) Set(k, v string) {

6
dialects/sqlite3.go

@ -576,3 +576,9 @@ func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { @@ -576,3 +576,9 @@ func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) {
return &r, nil
}
}
func (b *sqlite3Driver) Features() DriverFeatures {
return DriverFeatures{
SupportNullable: false,
}
}

197
scan.go

@ -6,12 +6,120 @@ package xorm @@ -6,12 +6,120 @@ package xorm
import (
"database/sql"
"fmt"
"reflect"
"time"
"xorm.io/xorm/convert"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
)
// genScanResultsByBeanNullabale generates scan result
func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) {
switch t := bean.(type) {
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes:
return t, false, nil
case *time.Time:
return &sql.NullTime{}, true, nil
case *string:
return &sql.NullString{}, true, nil
case *int, *int8, *int16, *int32:
return &sql.NullInt32{}, true, nil
case *int64:
return &sql.NullInt64{}, true, nil
case *uint, *uint8, *uint16, *uint32:
return &NullUint32{}, true, nil
case *uint64:
return &NullUint64{}, true, nil
case *float32, *float64:
return &sql.NullFloat64{}, true, nil
case *bool:
return &sql.NullBool{}, true, nil
case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString,
time.Time,
string,
int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64,
float32, float64,
bool:
return nil, false, fmt.Errorf("unsupported scan type: %t", t)
case convert.Conversion:
return &sql.RawBytes{}, true, nil
}
tp := reflect.TypeOf(bean).Elem()
switch tp.Kind() {
case reflect.String:
return &sql.NullString{}, true, nil
case reflect.Int64:
return &sql.NullInt64{}, true, nil
case reflect.Int32, reflect.Int, reflect.Int16, reflect.Int8:
return &sql.NullInt32{}, true, nil
case reflect.Uint64:
return &NullUint64{}, true, nil
case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8:
return &NullUint32{}, true, nil
default:
return nil, false, fmt.Errorf("unsupported type: %#v", bean)
}
}
func genScanResultsByBean(bean interface{}) (interface{}, bool, error) {
switch t := bean.(type) {
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString,
*string,
*int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64,
*float32, *float64,
*bool:
return t, false, nil
case *time.Time:
return &sql.NullTime{}, true, nil
case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString,
time.Time,
string,
int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64,
bool:
return nil, false, fmt.Errorf("unsupported scan type: %t", t)
case convert.Conversion:
return &sql.RawBytes{}, true, nil
}
tp := reflect.TypeOf(bean).Elem()
switch tp.Kind() {
case reflect.String:
return new(string), true, nil
case reflect.Int64:
return new(int64), true, nil
case reflect.Int32:
return new(int32), true, nil
case reflect.Int:
return new(int32), true, nil
case reflect.Int16:
return new(int32), true, nil
case reflect.Int8:
return new(int32), true, nil
case reflect.Uint64:
return new(uint64), true, nil
case reflect.Uint32:
return new(uint32), true, nil
case reflect.Uint:
return new(uint), true, nil
case reflect.Uint16:
return new(uint16), true, nil
case reflect.Uint8:
return new(uint8), true, nil
case reflect.Float32:
return new(float32), true, nil
case reflect.Float64:
return new(float64), true, nil
default:
return nil, false, fmt.Errorf("unsupported type: %#v", bean)
}
}
func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) {
var scanResults = make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
@ -50,18 +158,97 @@ func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (ma @@ -50,18 +158,97 @@ func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (ma
return result, nil
}
func row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) {
results := make([]string, 0, len(fields))
var scanResults = make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) {
var scanResults = make([]interface{}, len(types))
for i := 0; i < len(types); i++ {
var s sql.NullString
scanResults[i] = &s
}
if err := rows.Scan(scanResults...); err != nil {
if err := engine.driver.Scan(&dialects.ScanContext{
DBLocation: engine.DatabaseTZ,
UserLocation: engine.TZLocation,
}, rows, types, scanResults...); err != nil {
return nil, err
}
return scanResults, nil
}
// scan is a wrap of driver.Scan but will automatically change the input values according requirements
func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error {
var scanResults = make([]interface{}, 0, len(types))
var replaces = make([]bool, 0, len(types))
var err error
for _, v := range vv {
var replaced bool
var scanResult interface{}
if _, ok := v.(sql.Scanner); !ok {
var useNullable = true
if engine.driver.Features().SupportNullable {
nullable, ok := types[0].Nullable()
useNullable = ok && nullable
}
if useNullable {
scanResult, replaced, err = genScanResultsByBeanNullable(v)
} else {
scanResult, replaced, err = genScanResultsByBean(v)
}
if err != nil {
return err
}
} else {
scanResult = v
}
scanResults = append(scanResults, scanResult)
replaces = append(replaces, replaced)
}
var scanCtx = dialects.ScanContext{
DBLocation: engine.DatabaseTZ,
UserLocation: engine.TZLocation,
}
if err = engine.driver.Scan(&scanCtx, rows, types, scanResults...); err != nil {
return err
}
for i, replaced := range replaces {
if replaced {
if err = convertAssign(vv[i], scanResults[i], scanCtx.DBLocation, engine.TZLocation); err != nil {
return err
}
}
}
return nil
}
func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) {
var scanResultContainers = make([]interface{}, len(types))
for i := 0; i < len(types); i++ {
scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName())
if err != nil {
return nil, err
}
scanResultContainers[i] = scanResult
}
if err := engine.driver.Scan(&dialects.ScanContext{
DBLocation: engine.DatabaseTZ,
UserLocation: engine.TZLocation,
}, rows, types, scanResultContainers...); err != nil {
return nil, err
}
return scanResultContainers, nil
}
func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) {
scanResults, err := engine.scanStringInterface(rows, types)
if err != nil {
return nil, err
}
var results = make([]string, 0, len(fields))
for i := 0; i < len(fields); i++ {
results = append(results, scanResults[i].(*sql.NullString).String)
}

2
session_find.go

@ -276,7 +276,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect @@ -276,7 +276,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error {
cols := table.PKColumns()
if len(cols) == 1 {
return convertAssign(dst, pk[0])
return convertAssign(dst, pk[0], nil, nil)
}
dst = pk

263
session_get.go

@ -6,12 +6,16 @@ package xorm @@ -6,12 +6,16 @@ package xorm
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strconv"
"time"
"xorm.io/xorm/caches"
"xorm.io/xorm/convert"
"xorm.io/xorm/core"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
@ -108,6 +112,17 @@ func (session *Session) get(bean interface{}) (bool, error) { @@ -108,6 +112,17 @@ func (session *Session) get(bean interface{}) (bool, error) {
return true, nil
}
var (
valuerTypePlaceHolder driver.Valuer
valuerType = reflect.TypeOf(&valuerTypePlaceHolder).Elem()
scannerTypePlaceHolder sql.Scanner
scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem()
conversionTypePlaceHolder convert.Conversion
conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem()
)
func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
@ -122,155 +137,161 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, @@ -122,155 +137,161 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table,
return false, nil
}
switch bean.(type) {
case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString:
return true, rows.Scan(&bean)
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString:
return true, rows.Scan(bean)
case *string:
var res sql.NullString
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*string)) = res.String
}
return true, nil
case *int:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int)) = int(res.Int64)
}
return true, nil
case *int8:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int8)) = int8(res.Int64)
}
return true, nil
case *int16:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int16)) = int16(res.Int64)
}
return true, nil
case *int32:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int32)) = int32(res.Int64)
}
return true, nil
case *int64:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*int64)) = int64(res.Int64)
}
return true, nil
case *uint:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint)) = uint(res.Int64)
// WARN: Alougth rows return true, but we may also return error.
types, err := rows.ColumnTypes()
if err != nil {
return true, err
}
fields, err := rows.Columns()
if err != nil {
return true, err
}
switch beanKind {
case reflect.Struct:
if _, ok := bean.(*time.Time); ok {
break
}
return true, nil
case *uint8:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return true, err
if _, ok := bean.(sql.Scanner); ok {
break
}
if res.Valid {
*(bean.(*uint8)) = uint8(res.Int64)
if _, ok := bean.(convert.Conversion); len(types) == 1 && ok {
break
}
return true, nil
case *uint16:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
return session.getStruct(rows, types, fields, table, bean)
case reflect.Slice:
return session.getSlice(rows, types, fields, bean)
case reflect.Map:
return session.getMap(rows, types, fields, bean)
}
return session.getVars(rows, types, fields, bean)
}
func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) {
switch t := bean.(type) {
case *[]string:
res, err := session.engine.scanStringInterface(rows, types)
if err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint16)) = uint16(res.Int64)
var needAppend = len(*t) == 0 // both support slice is empty or has been initlized
for i, r := range res {
if needAppend {
*t = append(*t, r.(*sql.NullString).String)
} else {
(*t)[i] = r.(*sql.NullString).String
}
}
return true, nil
case *uint32:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
case *[]interface{}:
scanResults, err := session.engine.scanInterfaces(rows, types)
if err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint32)) = uint32(res.Int64)
var needAppend = len(*t) == 0
for ii := range fields {
s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
if err != nil {
return true, err
}
if needAppend {
*t = append(*t, s)
} else {
(*t)[ii] = s
}
}
return true, nil
case *uint64:
var res sql.NullInt64
if err := rows.Scan(&res); err != nil {
default:
return true, fmt.Errorf("unspoorted slice type: %t", t)
}
}
func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) {
switch t := bean.(type) {
case *map[string]string:
scanResults, err := session.engine.scanStringInterface(rows, types)
if err != nil {
return true, err
}
if res.Valid {
*(bean.(*uint64)) = uint64(res.Int64)
for ii, key := range fields {
(*t)[key] = scanResults[ii].(*sql.NullString).String
}
return true, nil
case *bool:
var res sql.NullBool
if err := rows.Scan(&res); err != nil {
case *map[string]interface{}:
scanResults, err := session.engine.scanInterfaces(rows, types)
if err != nil {
return true, err
}
if res.Valid {
*(bean.(*bool)) = res.Bool
for ii, key := range fields {
s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii])
if err != nil {
return true, err
}
(*t)[key] = s
}
return true, nil
default:
return true, fmt.Errorf("unspoorted map type: %t", t)
}
}
switch beanKind {
case reflect.Struct:
fields, err := rows.Columns()
if err != nil {
// WARN: Alougth rows return true, but get fields failed
return true, err
func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields []string, beans ...interface{}) (bool, error) {
if len(beans) != len(types) {
return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans))
}
var scanResults = make([]interface{}, 0, len(types))
var replaceds = make([]bool, 0, len(types))
for _, bean := range beans {
switch t := bean.(type) {
case sql.Scanner:
scanResults = append(scanResults, t)
replaceds = append(replaceds, false)
case convert.Conversion:
scanResults = append(scanResults, &sql.RawBytes{})
replaceds = append(replaceds, true)
default:
scanResults = append(scanResults, bean)
replaceds = append(replaceds, false)
}
}
scanResults, err := session.row2Slice(rows, fields, bean)
if err != nil {
return false, err
err := session.engine.scan(rows, fields, types, scanResults...)
if err != nil {
return true, err
}
for i, replaced := range replaceds {
if replaced {
err = convertAssign(beans[i], scanResults[i], session.engine.DatabaseTZ, session.engine.TZLocation)
if err != nil {
return true, err
}
}
// close it before convert data
rows.Close()
}
return true, nil
}
dataStruct := utils.ReflectValue(bean)
_, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table)
if err != nil {
return true, err
}
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {
fields, err := rows.Columns()
if err != nil {
// WARN: Alougth rows return true, but get fields failed
return true, err
}
return true, session.executeProcessors()
case reflect.Slice:
err = rows.ScanSlice(bean)
case reflect.Map:
err = rows.ScanMap(bean)
case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
err = rows.Scan(bean)
default:
err = rows.Scan(bean)
scanResults, err := session.row2Slice(rows, fields, bean)
if err != nil {
return false, err
}
// close it before convert data
rows.Close()
dataStruct := utils.ReflectValue(bean)
_, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table)
if err != nil {
return true, err
}
return true, err
return true, session.executeProcessors()
}
func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) {

6
session_insert.go

@ -375,7 +375,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { @@ -375,7 +375,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 1, nil
}
return 1, convertAssignV(aiValue.Addr(), id)
return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation)
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
session.engine.dialect.URI().DBType == schemas.MSSQL) {
res, err := session.queryBytes(sqlStr, args...)
@ -415,7 +415,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { @@ -415,7 +415,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 1, nil
}
return 1, convertAssignV(aiValue.Addr(), id)
return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation)
}
res, err := session.exec(sqlStr, args...)
@ -455,7 +455,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { @@ -455,7 +455,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return res.RowsAffected()
}
if err := convertAssignV(aiValue.Addr(), id); err != nil {
if err := convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation); err != nil {
return 0, err
}

2
session_query.go

@ -54,7 +54,7 @@ func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]stri @@ -54,7 +54,7 @@ func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]stri
}
for rows.Next() {
record, err := row2sliceStr(rows, types, fields)
record, err := session.engine.row2sliceStr(rows, types, fields)
if err != nil {
return nil, err
}

Loading…
Cancel
Save