From d90767bcb797f9c12a0d99b7d287fa9ec387ea93 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 4 Jul 2021 23:49:59 +0800 Subject: [PATCH 1/7] refactor get --- convert.go | 244 +++++++++++++++++++++++++++++++++++- dialects/driver.go | 11 ++ dialects/mysql.go | 1 + dialects/sqlite3.go | 6 + scan.go | 198 ++++++++++++++++++++++++++++- session_find.go | 2 +- session_get.go | 299 ++++++++++++++++++++++++-------------------- session_insert.go | 6 +- session_query.go | 2 +- 9 files changed, 616 insertions(+), 153 deletions(-) diff --git a/convert.go b/convert.go index b7f30cad..c4774d97 100644 --- a/convert.go +++ b/convert.go @@ -5,12 +5,15 @@ package xorm import ( + "database/sql" "database/sql/driver" "errors" "fmt" "reflect" "strconv" "time" + + "xorm.io/xorm/convert" ) var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error @@ -76,7 +79,7 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { // convertAssign copies to dest the value in src, converting it if possible. // An error is returned if the copy would result in loss of information. // dest should be a pointer type. -func convertAssign(dest, src interface{}) error { +func convertAssign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error { // Common cases, without reflect. switch s := src.(type) { case string: @@ -143,6 +146,163 @@ func convertAssign(dest, src interface{}) error { *d = nil return nil } + case *sql.NullString: + switch d := dest.(type) { + case *int: + if s.Valid { + *d, _ = strconv.Atoi(s.String) + } + case *int64: + if s.Valid { + *d, _ = strconv.ParseInt(s.String, 10, 64) + } + case *string: + if s.Valid { + *d = s.String + } + return nil + case *time.Time: + if s.Valid { + var err error + dt, err := convert.String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + *d = *dt + } + return nil + case *sql.NullTime: + if s.Valid { + var err error + dt, err := convert.String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + d.Valid = true + d.Time = *dt + } + } + case *sql.NullInt32: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int32) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int32) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int32) + } + return nil + case *int32: + if s.Valid { + *d = s.Int32 + } + return nil + case *int64: + if s.Valid { + *d = int64(s.Int32) + } + return nil + } + case *sql.NullInt64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int64) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int64) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int64) + } + return nil + case *int32: + if s.Valid { + *d = int32(s.Int64) + } + return nil + case *int64: + if s.Valid { + *d = s.Int64 + } + return nil + } + case *sql.NullFloat64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Float64) + } + return nil + case *float64: + if s.Valid { + *d = s.Float64 + } + return nil + } + case *sql.NullBool: + switch d := dest.(type) { + case *bool: + if s.Valid { + *d = s.Bool + } + return nil + } + case *sql.NullTime: + switch d := dest.(type) { + case *time.Time: + if s.Valid { + *d = s.Time + } + return nil + case *string: + if s.Valid { + *d = s.Time.In(convertedLocation).Format("2006-01-02 15:04:05") + } + return nil + } + case *NullUint32: + switch d := dest.(type) { + case *uint8: + if s.Valid { + *d = uint8(s.Uint32) + } + return nil + case *uint16: + if s.Valid { + *d = uint16(s.Uint32) + } + return nil + case *uint: + if s.Valid { + *d = uint(s.Uint32) + } + return nil + } + case *NullUint64: + switch d := dest.(type) { + case *uint64: + if s.Valid { + *d = s.Uint64 + } + return nil + } + case *sql.RawBytes: + switch d := dest.(type) { + case convert.Conversion: + return d.FromDB(*s) + } } var sv reflect.Value @@ -175,10 +335,10 @@ func convertAssign(dest, src interface{}) error { return nil } - return convertAssignV(reflect.ValueOf(dest), src) + return convertAssignV(reflect.ValueOf(dest), src, originalLocation, convertedLocation) } -func convertAssignV(dpv reflect.Value, src interface{}) error { +func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, convertedLocation *time.Location) error { if dpv.Kind() != reflect.Ptr { return errors.New("destination not a pointer") } @@ -212,7 +372,7 @@ func convertAssignV(dpv reflect.Value, src interface{}) error { } dv.Set(reflect.New(dv.Type().Elem())) - return convertAssign(dv.Interface(), src) + return convertAssign(dv.Interface(), src, originalLocation, convertedLocation) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: s := asString(src) i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) @@ -376,3 +536,79 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) { } return v.Interface(), nil } + +var ( + _ sql.Scanner = &NullUint64{} +) + +// NullUint64 represents an uint64 that may be null. +// NullUint64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint64 struct { + Uint64 uint64 + Valid bool // Valid is true if Uint64 is not NULL + OriginalLocation *time.Location + ConvertedLocation *time.Location +} + +// Scan implements the Scanner interface. +func (n *NullUint64) Scan(value interface{}) error { + if value == nil { + n.Uint64, n.Valid = 0, false + return nil + } + n.Valid = true + fmt.Println("======44444") + return convertAssign(&n.Uint64, value, n.OriginalLocation, n.ConvertedLocation) +} + +// Value implements the driver Valuer interface. +func (n NullUint64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Uint64, nil +} + +var ( + _ sql.Scanner = &NullUint32{} +) + +// NullUint32 represents an uint32 that may be null. +// NullUint32 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint32 struct { + Uint32 uint32 + Valid bool // Valid is true if Uint32 is not NULL + OriginalLocation *time.Location + ConvertedLocation *time.Location +} + +// Scan implements the Scanner interface. +func (n *NullUint32) Scan(value interface{}) error { + if value == nil { + n.Uint32, n.Valid = 0, false + return nil + } + n.Valid = true + fmt.Println("555555") + return convertAssign(&n.Uint32, value, n.OriginalLocation, n.ConvertedLocation) +} + +// Value implements the driver Valuer interface. +func (n NullUint32) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return int64(n.Uint32), nil +} + +var ( + _ sql.Scanner = &EmptyScanner{} +) + +type EmptyScanner struct{} + +func (EmptyScanner) Scan(value interface{}) error { + return nil +} diff --git a/dialects/driver.go b/dialects/driver.go index c511b665..0b6187d3 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -18,9 +18,14 @@ type ScanContext struct { UserLocation *time.Location } +type DriverFeatures struct { + SupportNullable bool +} + // Driver represents a database driver type Driver interface { Parse(string, string) (*URI, error) + Features() DriverFeatures GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error } @@ -77,3 +82,9 @@ type baseDriver struct{} func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error { return rows.Scan(v...) } + +func (b *baseDriver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: true, + } +} diff --git a/dialects/mysql.go b/dialects/mysql.go index 03bc9a4b..a341ce05 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -633,6 +633,7 @@ func (db *mysql) Filters() []Filter { } type mysqlDriver struct { + baseDriver } func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 306f377c..1bc0b218 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -576,3 +576,9 @@ func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { return &r, nil } } + +func (b *sqlite3Driver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: false, + } +} diff --git a/scan.go b/scan.go index e11d6e8d..b23785d8 100644 --- a/scan.go +++ b/scan.go @@ -6,12 +6,121 @@ package xorm import ( "database/sql" + "fmt" + "reflect" + "time" "xorm.io/xorm/convert" "xorm.io/xorm/core" "xorm.io/xorm/dialects" ) +// genScanResultsByBeanNullabale generates scan result +func genScanResultsByBeanNullable(bean interface{}, originalLocation, convertedLocation *time.Location) (interface{}, bool, error) { + switch t := bean.(type) { + case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes: + return t, false, nil + case *time.Time: + return &sql.NullTime{}, true, nil + case *string: + return &sql.NullString{}, true, nil + case *int, *int8, *int16, *int32: + return &sql.NullInt32{}, true, nil + case *int64: + return &sql.NullInt64{}, true, nil + case *uint, *uint8, *uint16, *uint32: + return &NullUint32{ + OriginalLocation: originalLocation, + ConvertedLocation: convertedLocation, + }, true, nil + case *uint64: + return &NullUint64{ + OriginalLocation: originalLocation, + ConvertedLocation: convertedLocation, + }, true, nil + case *float32, *float64: + return &sql.NullFloat64{}, true, nil + case *bool: + return &sql.NullBool{}, true, nil + case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString, + time.Time, + string, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + float32, float64, + bool: + return nil, false, fmt.Errorf("unsupported scan type: %t", t) + case convert.Conversion: + return &sql.RawBytes{}, true, nil + } + + tp := reflect.TypeOf(bean).Elem() + switch tp.Kind() { + case reflect.String: + return &sql.NullString{}, true, nil + case reflect.Int64: + return &sql.NullInt64{}, true, nil + case reflect.Int32, reflect.Int, reflect.Int16, reflect.Int8: + return &sql.NullInt32{}, true, nil + case reflect.Uint64: + return &NullUint64{}, true, nil + case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8: + return &NullUint32{}, true, nil + default: + return nil, false, fmt.Errorf("unsupported type: %#v", bean) + } +} + +func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { + switch t := bean.(type) { + case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, + *string, + *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, + *bool: + return t, false, nil + case *time.Time: + return &sql.NullTime{}, true, nil + case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString, + time.Time, + string, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + bool: + return nil, false, fmt.Errorf("unsupported scan type: %t", t) + case convert.Conversion: + return &sql.RawBytes{}, true, nil + } + + tp := reflect.TypeOf(bean).Elem() + switch tp.Kind() { + case reflect.String: + return new(string), true, nil + case reflect.Int64: + return new(int64), true, nil + case reflect.Int32: + return new(int32), true, nil + case reflect.Int: + return new(int32), true, nil + case reflect.Int16: + return new(int32), true, nil + case reflect.Int8: + return new(int32), true, nil + case reflect.Uint64: + return new(uint64), true, nil + case reflect.Uint32: + return new(uint32), true, nil + case reflect.Uint: + return new(uint), true, nil + case reflect.Uint16: + return new(uint16), true, nil + case reflect.Uint8: + return new(uint8), true, nil + default: + return nil, false, fmt.Errorf("unsupported type: %#v", bean) + } +} + func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { var scanResults = make([]interface{}, len(fields)) for i := 0; i < len(fields); i++ { @@ -50,18 +159,97 @@ func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (ma return result, nil } -func row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { - results := make([]string, 0, len(fields)) - var scanResults = make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { +func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { + var scanResults = make([]interface{}, len(types)) + for i := 0; i < len(types); i++ { var s sql.NullString scanResults[i] = &s } - if err := rows.Scan(scanResults...); err != nil { + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResults...); err != nil { + return nil, err + } + return scanResults, nil +} + +// scan is a wrap of driver.Scan but will automatically change the input values according requirements +func (engine *Engine) scan(rows *core.Rows, types []*sql.ColumnType, vv ...interface{}) error { + var scanResults = make([]interface{}, 0, len(types)) + var replaces = make([]bool, 0, len(types)) + var err error + for _, v := range vv { + var replaced bool + var scanResult interface{} + if _, ok := v.(sql.Scanner); !ok { + var useNullable = true + if engine.driver.Features().SupportNullable { + nullable, ok := types[0].Nullable() + useNullable = ok && !nullable + } + + if useNullable { + scanResult, replaced, err = genScanResultsByBeanNullable(v, engine.DatabaseTZ, engine.TZLocation) + } else { + scanResult, replaced, err = genScanResultsByBean(v) + } + if err != nil { + return err + } + } else { + scanResult = v + } + scanResults = append(scanResults, scanResult) + replaces = append(replaces, replaced) + } + + var scanCtx = dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + } + + if err = engine.driver.Scan(&scanCtx, rows, types, scanResults...); err != nil { + return err + } + + for i, replaced := range replaces { + if replaced { + if err = convertAssign(vv[i], scanResults[i], scanCtx.DBLocation, engine.TZLocation); err != nil { + return err + } + } + } + + return nil +} + +func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { + var scanResultContainers = make([]interface{}, len(types)) + for i := 0; i < len(types); i++ { + scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) + if err != nil { + return nil, err + } + scanResultContainers[i] = scanResult + } + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResultContainers...); err != nil { + return nil, err + } + return scanResultContainers, nil +} + +func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { + scanResults, err := engine.scanStringInterface(rows, types) + if err != nil { return nil, err } + var results = make([]string, 0, len(fields)) for i := 0; i < len(fields); i++ { results = append(results, scanResults[i].(*sql.NullString).String) } diff --git a/session_find.go b/session_find.go index 0daea005..261e6b7f 100644 --- a/session_find.go +++ b/session_find.go @@ -276,7 +276,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { cols := table.PKColumns() if len(cols) == 1 { - return convertAssign(dst, pk[0]) + return convertAssign(dst, pk[0], nil, nil) } dst = pk diff --git a/session_get.go b/session_get.go index e303176d..a84d3745 100644 --- a/session_get.go +++ b/session_get.go @@ -6,12 +6,16 @@ package xorm import ( "database/sql" + "database/sql/driver" "errors" "fmt" "reflect" "strconv" + "time" "xorm.io/xorm/caches" + "xorm.io/xorm/convert" + "xorm.io/xorm/core" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -108,6 +112,17 @@ func (session *Session) get(bean interface{}) (bool, error) { return true, nil } +var ( + valuerTypePlaceHolder driver.Valuer + valuerType = reflect.TypeOf(&valuerTypePlaceHolder).Elem() + + scannerTypePlaceHolder sql.Scanner + scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem() + + conversionTypePlaceHolder convert.Conversion + conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem() +) + func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { @@ -122,155 +137,161 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, return false, nil } - switch bean.(type) { - case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString: - return true, rows.Scan(&bean) - case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString: - return true, rows.Scan(bean) - case *string: - var res sql.NullString - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*string)) = res.String - } - return true, nil - case *int: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int)) = int(res.Int64) - } - return true, nil - case *int8: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int8)) = int8(res.Int64) - } - return true, nil - case *int16: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int16)) = int16(res.Int64) - } - return true, nil - case *int32: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int32)) = int32(res.Int64) - } - return true, nil - case *int64: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int64)) = int64(res.Int64) - } - return true, nil - case *uint: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint)) = uint(res.Int64) - } - return true, nil - case *uint8: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint8)) = uint8(res.Int64) - } - return true, nil - case *uint16: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint16)) = uint16(res.Int64) - } - return true, nil - case *uint32: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint32)) = uint32(res.Int64) - } - return true, nil - case *uint64: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint64)) = uint64(res.Int64) - } - return true, nil - case *bool: - var res sql.NullBool - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*bool)) = res.Bool - } - return true, nil + // WARN: Alougth rows return true, but we may also return error. + types, err := rows.ColumnTypes() + if err != nil { + return true, err + } + fields, err := rows.Columns() + if err != nil { + return true, err } - switch beanKind { case reflect.Struct: - fields, err := rows.Columns() - if err != nil { - // WARN: Alougth rows return true, but get fields failed - return true, err + if _, ok := bean.(*time.Time); ok { + break } - - scanResults, err := session.row2Slice(rows, fields, bean) - if err != nil { - return false, err + if _, ok := bean.(sql.Scanner); ok { + break } - // close it before convert data - rows.Close() - - dataStruct := utils.ReflectValue(bean) - _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) - if err != nil { - return true, err + if _, ok := bean.(convert.Conversion); len(types) == 1 && ok { + break } - - return true, session.executeProcessors() + return session.getStruct(rows, types, fields, table, bean) case reflect.Slice: - err = rows.ScanSlice(bean) + return session.getSlice(rows, types, fields, bean) case reflect.Map: - err = rows.ScanMap(bean) - case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - err = rows.Scan(bean) - default: - err = rows.Scan(bean) + return session.getMap(rows, types, fields, bean) } - return true, err + return session.getVars(rows, types, fields, bean) +} + +func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { + switch t := bean.(type) { + case *[]string: + res, err := session.engine.scanStringInterface(rows, types) + if err != nil { + return true, err + } + + var needAppend = len(*t) == 0 // both support slice is empty or has been initlized + for i, r := range res { + if needAppend { + *t = append(*t, r.(*sql.NullString).String) + } else { + (*t)[i] = r.(*sql.NullString).String + } + } + return true, nil + case *[]interface{}: + scanResults, err := session.engine.scanInterfaces(rows, types) + if err != nil { + return true, err + } + var needAppend = len(*t) == 0 + for ii := range fields { + s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) + if err != nil { + return true, err + } + if needAppend { + *t = append(*t, s) + } else { + (*t)[ii] = s + } + } + return true, nil + default: + return true, fmt.Errorf("unspoorted slice type: %t", t) + } +} + +func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { + switch t := bean.(type) { + case *map[string]string: + scanResults, err := session.engine.scanStringInterface(rows, types) + if err != nil { + return true, err + } + for ii, key := range fields { + (*t)[key] = scanResults[ii].(*sql.NullString).String + } + return true, nil + case *map[string]interface{}: + scanResults, err := session.engine.scanInterfaces(rows, types) + if err != nil { + return true, err + } + for ii, key := range fields { + s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) + if err != nil { + return true, err + } + (*t)[key] = s + } + return true, nil + default: + return true, fmt.Errorf("unspoorted map type: %t", t) + } +} + +func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields []string, beans ...interface{}) (bool, error) { + if len(beans) != len(types) { + return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans)) + } + var scanResults = make([]interface{}, 0, len(types)) + var replaceds = make([]bool, 0, len(types)) + for _, bean := range beans { + switch t := bean.(type) { + case sql.Scanner: + scanResults = append(scanResults, t) + replaceds = append(replaceds, false) + case convert.Conversion: + scanResults = append(scanResults, &sql.RawBytes{}) + replaceds = append(replaceds, true) + default: + scanResults = append(scanResults, bean) + replaceds = append(replaceds, false) + } + } + + err := session.engine.scan(rows, types, scanResults...) + if err != nil { + return true, err + } + for i, replaced := range replaceds { + if replaced { + err = convertAssign(beans[i], scanResults[i], session.engine.DatabaseTZ, session.engine.TZLocation) + if err != nil { + return true, err + } + } + } + return true, nil +} + +func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { + fields, err := rows.Columns() + if err != nil { + // WARN: Alougth rows return true, but get fields failed + return true, err + } + + scanResults, err := session.row2Slice(rows, fields, bean) + if err != nil { + return false, err + } + // close it before convert data + rows.Close() + + dataStruct := utils.ReflectValue(bean) + _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) + if err != nil { + return true, err + } + + return true, session.executeProcessors() } func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { diff --git a/session_insert.go b/session_insert.go index e733e06e..7f8f3008 100644 --- a/session_insert.go +++ b/session_insert.go @@ -375,7 +375,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - return 1, convertAssignV(aiValue.Addr(), id) + return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) @@ -415,7 +415,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - return 1, convertAssignV(aiValue.Addr(), id) + return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) } res, err := session.exec(sqlStr, args...) @@ -455,7 +455,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - if err := convertAssignV(aiValue.Addr(), id); err != nil { + if err := convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation); err != nil { return 0, err } diff --git a/session_query.go b/session_query.go index 01cd6f44..fa33496d 100644 --- a/session_query.go +++ b/session_query.go @@ -54,7 +54,7 @@ func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]stri } for rows.Next() { - record, err := row2sliceStr(rows, types, fields) + record, err := session.engine.row2sliceStr(rows, types, fields) if err != nil { return nil, err } -- 2.40.1 From dfdef4b68f934dcbbcdc0ce9d28ff6d872d4b9a8 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 5 Jul 2021 14:16:02 +0800 Subject: [PATCH 2/7] Fix test --- convert.go | 7 +++++++ integrations/session_get_test.go | 2 +- scan.go | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/convert.go b/convert.go index c4774d97..1be0852d 100644 --- a/convert.go +++ b/convert.go @@ -40,6 +40,12 @@ func asString(src interface{}) string { return v case []byte: return string(v) + case *sql.NullString: + return v.String + case *sql.NullInt32: + return fmt.Sprintf("%d", v.Int32) + case *sql.NullInt64: + return fmt.Sprintf("%d", v.Int64) } rv := reflect.ValueOf(src) switch rv.Kind() { @@ -401,6 +407,7 @@ func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, conver dv.SetFloat(f64) return nil case reflect.String: + fmt.Println("=====", src) dv.SetString(asString(src)) return nil } diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 99db98fc..f60a7f7b 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -719,7 +719,7 @@ func TestCustomTypes(t *testing.T) { has, err := testEngine.Table(new(TestCustomizeStruct)).ID(s.Id).Cols("name").Get(&name) assert.NoError(t, err) assert.True(t, has) - assert.EqualValues(t, "test", name) + assert.EqualValues(t, "test", string(name)) var age MyInt has, err = testEngine.Table(new(TestCustomizeStruct)).ID(s.Id).Select("age").Get(&age) diff --git a/scan.go b/scan.go index b23785d8..ce1fcf98 100644 --- a/scan.go +++ b/scan.go @@ -57,6 +57,7 @@ func genScanResultsByBeanNullable(bean interface{}, originalLocation, convertedL tp := reflect.TypeOf(bean).Elem() switch tp.Kind() { case reflect.String: + fmt.Println("=====", tp) return &sql.NullString{}, true, nil case reflect.Int64: return &sql.NullInt64{}, true, nil -- 2.40.1 From 9e5bf5e20463c1913ff45f3fd0e2d599381c4fa7 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 5 Jul 2021 14:43:11 +0800 Subject: [PATCH 3/7] performance optimization --- convert.go | 166 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 156 insertions(+), 10 deletions(-) diff --git a/convert.go b/convert.go index 1be0852d..ba834a8d 100644 --- a/convert.go +++ b/convert.go @@ -63,6 +63,156 @@ func asString(src interface{}) string { return fmt.Sprintf("%v", src) } +func asInt64(src interface{}) (int64, error) { + switch v := src.(type) { + case int: + return int64(v), nil + case int16: + return int64(v), nil + case int32: + return int64(v), nil + case int8: + return int64(v), nil + case int64: + return v, nil + case uint: + return int64(v), nil + case uint8: + return int64(v), nil + case uint16: + return int64(v), nil + case uint32: + return int64(v), nil + case uint64: + return int64(v), nil + case []byte: + return strconv.ParseInt(string(v), 10, 64) + case string: + return strconv.ParseInt(v, 10, 64) + case *sql.NullString: + return strconv.ParseInt(v.String, 10, 64) + case *sql.NullInt32: + return int64(v.Int32), nil + case *sql.NullInt64: + return int64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return int64(rv.Uint()), nil + case reflect.Float64: + return int64(rv.Float()), nil + case reflect.Float32: + return int64(rv.Float()), nil + case reflect.String: + return strconv.ParseInt(rv.String(), 10, 64) + } + return 0, fmt.Errorf("unsupported value %T as int64", src) +} + +func asUint64(src interface{}) (uint64, error) { + switch v := src.(type) { + case int: + return uint64(v), nil + case int16: + return uint64(v), nil + case int32: + return uint64(v), nil + case int8: + return uint64(v), nil + case int64: + return uint64(v), nil + case uint: + return uint64(v), nil + case uint8: + return uint64(v), nil + case uint16: + return uint64(v), nil + case uint32: + return uint64(v), nil + case uint64: + return v, nil + case []byte: + return strconv.ParseUint(string(v), 10, 64) + case string: + return strconv.ParseUint(v, 10, 64) + case *sql.NullString: + return strconv.ParseUint(v.String, 10, 64) + case *sql.NullInt32: + return uint64(v.Int32), nil + case *sql.NullInt64: + return uint64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return uint64(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uint64(rv.Uint()), nil + case reflect.Float64: + return uint64(rv.Float()), nil + case reflect.Float32: + return uint64(rv.Float()), nil + case reflect.String: + return strconv.ParseUint(rv.String(), 10, 64) + } + return 0, fmt.Errorf("unsupported value %T as uint64", src) +} + +func asFloat64(src interface{}) (float64, error) { + switch v := src.(type) { + case int: + return float64(v), nil + case int16: + return float64(v), nil + case int32: + return float64(v), nil + case int8: + return float64(v), nil + case int64: + return float64(v), nil + case uint: + return float64(v), nil + case uint8: + return float64(v), nil + case uint16: + return float64(v), nil + case uint32: + return float64(v), nil + case uint64: + return float64(v), nil + case []byte: + return strconv.ParseFloat(string(v), 64) + case string: + return strconv.ParseFloat(v, 64) + case *sql.NullString: + return strconv.ParseFloat(v.String, 64) + case *sql.NullInt32: + return float64(v.Int32), nil + case *sql.NullInt64: + return float64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return float64(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return float64(rv.Uint()), nil + case reflect.Float64: + return float64(rv.Float()), nil + case reflect.Float32: + return float64(rv.Float()), nil + case reflect.String: + return strconv.ParseFloat(rv.String(), 64) + } + return 0, fmt.Errorf("unsupported value %T as int64", src) +} + func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -380,34 +530,30 @@ func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, conver dv.Set(reflect.New(dv.Type().Elem())) return convertAssign(dv.Interface(), src, originalLocation, convertedLocation) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - s := asString(src) - i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + i64, err := asInt64(src) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) } dv.SetInt(i64) return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - s := asString(src) - u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + u64, err := asUint64(src) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) } dv.SetUint(u64) return nil case reflect.Float32, reflect.Float64: - s := asString(src) - f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + f64, err := asFloat64(src) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) } dv.SetFloat(f64) return nil case reflect.String: - fmt.Println("=====", src) dv.SetString(asString(src)) return nil } -- 2.40.1 From 8c862e11d7da0f7ff6fd4875ba98d7ec8068e42c Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 5 Jul 2021 14:45:57 +0800 Subject: [PATCH 4/7] revert unnecessary change --- integrations/session_get_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index f60a7f7b..99db98fc 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -719,7 +719,7 @@ func TestCustomTypes(t *testing.T) { has, err := testEngine.Table(new(TestCustomizeStruct)).ID(s.Id).Cols("name").Get(&name) assert.NoError(t, err) assert.True(t, has) - assert.EqualValues(t, "test", string(name)) + assert.EqualValues(t, "test", name) var age MyInt has, err = testEngine.Table(new(TestCustomizeStruct)).ID(s.Id).Select("age").Get(&age) -- 2.40.1 From 44f6fc09abedca5adb4f74848331dca9ee13e116 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 6 Jul 2021 10:57:06 +0800 Subject: [PATCH 5/7] Fix bug --- scan.go | 9 +++++++-- session_get.go | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/scan.go b/scan.go index ce1fcf98..7b242ec9 100644 --- a/scan.go +++ b/scan.go @@ -78,6 +78,7 @@ func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { *string, *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, + *float32, *float64, *bool: return t, false, nil case *time.Time: @@ -117,6 +118,10 @@ func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { return new(uint16), true, nil case reflect.Uint8: return new(uint8), true, nil + case reflect.Float32: + return new(float32), true, nil + case reflect.Float64: + return new(float64), true, nil default: return nil, false, fmt.Errorf("unsupported type: %#v", bean) } @@ -177,7 +182,7 @@ func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnTy } // scan is a wrap of driver.Scan but will automatically change the input values according requirements -func (engine *Engine) scan(rows *core.Rows, types []*sql.ColumnType, vv ...interface{}) error { +func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error { var scanResults = make([]interface{}, 0, len(types)) var replaces = make([]bool, 0, len(types)) var err error @@ -188,7 +193,7 @@ func (engine *Engine) scan(rows *core.Rows, types []*sql.ColumnType, vv ...inter var useNullable = true if engine.driver.Features().SupportNullable { nullable, ok := types[0].Nullable() - useNullable = ok && !nullable + useNullable = ok && nullable } if useNullable { diff --git a/session_get.go b/session_get.go index a84d3745..cb2bda75 100644 --- a/session_get.go +++ b/session_get.go @@ -256,7 +256,7 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields } } - err := session.engine.scan(rows, types, scanResults...) + err := session.engine.scan(rows, fields, types, scanResults...) if err != nil { return true, err } -- 2.40.1 From a56f7cf5eb99e72ccdb570f4da949a6410f1ed97 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 6 Jul 2021 13:15:56 +0800 Subject: [PATCH 6/7] Fix postgres --- dialects/postgres.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dialects/postgres.go b/dialects/postgres.go index e4641509..a2611c60 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1302,6 +1302,12 @@ type pqDriver struct { baseDriver } +func (b *pqDriver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: false, + } +} + type values map[string]string func (vs values) Set(k, v string) { -- 2.40.1 From faa9ae8be8c7bd5b146ffe9e0a73cae729fed5f3 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 6 Jul 2021 14:11:52 +0800 Subject: [PATCH 7/7] Improve code --- convert.go | 25 +++++++++++++------------ scan.go | 15 ++++----------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/convert.go b/convert.go index ba834a8d..f7d733ad 100644 --- a/convert.go +++ b/convert.go @@ -698,10 +698,8 @@ var ( // NullUint64 implements the Scanner interface so // it can be used as a scan destination, similar to NullString. type NullUint64 struct { - Uint64 uint64 - Valid bool // Valid is true if Uint64 is not NULL - OriginalLocation *time.Location - ConvertedLocation *time.Location + Uint64 uint64 + Valid bool } // Scan implements the Scanner interface. @@ -711,8 +709,9 @@ func (n *NullUint64) Scan(value interface{}) error { return nil } n.Valid = true - fmt.Println("======44444") - return convertAssign(&n.Uint64, value, n.OriginalLocation, n.ConvertedLocation) + var err error + n.Uint64, err = asUint64(value) + return err } // Value implements the driver Valuer interface. @@ -731,10 +730,8 @@ var ( // NullUint32 implements the Scanner interface so // it can be used as a scan destination, similar to NullString. type NullUint32 struct { - Uint32 uint32 - Valid bool // Valid is true if Uint32 is not NULL - OriginalLocation *time.Location - ConvertedLocation *time.Location + Uint32 uint32 + Valid bool // Valid is true if Uint32 is not NULL } // Scan implements the Scanner interface. @@ -744,8 +741,12 @@ func (n *NullUint32) Scan(value interface{}) error { return nil } n.Valid = true - fmt.Println("555555") - return convertAssign(&n.Uint32, value, n.OriginalLocation, n.ConvertedLocation) + i64, err := asUint64(value) + if err != nil { + return err + } + n.Uint32 = uint32(i64) + return nil } // Value implements the driver Valuer interface. diff --git a/scan.go b/scan.go index 7b242ec9..c5cb77ff 100644 --- a/scan.go +++ b/scan.go @@ -16,7 +16,7 @@ import ( ) // genScanResultsByBeanNullabale generates scan result -func genScanResultsByBeanNullable(bean interface{}, originalLocation, convertedLocation *time.Location) (interface{}, bool, error) { +func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { switch t := bean.(type) { case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes: return t, false, nil @@ -29,15 +29,9 @@ func genScanResultsByBeanNullable(bean interface{}, originalLocation, convertedL case *int64: return &sql.NullInt64{}, true, nil case *uint, *uint8, *uint16, *uint32: - return &NullUint32{ - OriginalLocation: originalLocation, - ConvertedLocation: convertedLocation, - }, true, nil + return &NullUint32{}, true, nil case *uint64: - return &NullUint64{ - OriginalLocation: originalLocation, - ConvertedLocation: convertedLocation, - }, true, nil + return &NullUint64{}, true, nil case *float32, *float64: return &sql.NullFloat64{}, true, nil case *bool: @@ -57,7 +51,6 @@ func genScanResultsByBeanNullable(bean interface{}, originalLocation, convertedL tp := reflect.TypeOf(bean).Elem() switch tp.Kind() { case reflect.String: - fmt.Println("=====", tp) return &sql.NullString{}, true, nil case reflect.Int64: return &sql.NullInt64{}, true, nil @@ -197,7 +190,7 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column } if useNullable { - scanResult, replaced, err = genScanResultsByBeanNullable(v, engine.DatabaseTZ, engine.TZLocation) + scanResult, replaced, err = genScanResultsByBeanNullable(v) } else { scanResult, replaced, err = genScanResultsByBean(v) } -- 2.40.1