From a5143f230b96daed8dcdc37a7a33c189669e5eb9 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 28 Jul 2021 13:24:44 +0800 Subject: [PATCH] Move assign functions to convert package --- convert.go | 414 --------------------------------- internal/convert/conversion.go | 376 ++++++++++++++++++++++++++++++ scan.go | 2 +- session.go | 41 +++- session_find.go | 5 +- session_insert.go | 5 +- 6 files changed, 423 insertions(+), 420 deletions(-) delete mode 100644 convert.go diff --git a/convert.go b/convert.go deleted file mode 100644 index c4fc7867..00000000 --- a/convert.go +++ /dev/null @@ -1,414 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "database/sql" - "database/sql/driver" - "encoding/json" - "errors" - "fmt" - "math/big" - "reflect" - "strconv" - "time" - - "xorm.io/xorm/internal/convert" -) - -var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error - -func strconvErr(err error) error { - if ne, ok := err.(*strconv.NumError); ok { - return ne.Err - } - return err -} - -func cloneBytes(b []byte) []byte { - if b == nil { - return nil - } - c := make([]byte, len(b)) - copy(c, b) - return c -} - -// convertAssign copies to dest the value in src, converting it if possible. -// An error is returned if the copy would result in loss of information. -// dest should be a pointer type. -func convertAssign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error { - // Common cases, without reflect. - switch s := src.(type) { - case *interface{}: - return convertAssign(dest, *s, originalLocation, convertedLocation) - case string: - switch d := dest.(type) { - case *string: - if d == nil { - return errNilPtr - } - *d = s - return nil - case *[]byte: - if d == nil { - return errNilPtr - } - *d = []byte(s) - return nil - } - case []byte: - switch d := dest.(type) { - case *string: - if d == nil { - return errNilPtr - } - *d = string(s) - return nil - case *interface{}: - if d == nil { - return errNilPtr - } - *d = cloneBytes(s) - return nil - case *[]byte: - if d == nil { - return errNilPtr - } - *d = cloneBytes(s) - return nil - } - case time.Time: - switch d := dest.(type) { - case *string: - *d = s.Format(time.RFC3339Nano) - return nil - case *[]byte: - if d == nil { - return errNilPtr - } - *d = []byte(s.Format(time.RFC3339Nano)) - return nil - } - case nil: - switch d := dest.(type) { - case *interface{}: - if d == nil { - return errNilPtr - } - *d = nil - return nil - case *[]byte: - if d == nil { - return errNilPtr - } - *d = nil - return nil - } - case *sql.NullString: - switch d := dest.(type) { - case *int: - if s.Valid { - *d, _ = strconv.Atoi(s.String) - } - return nil - case *int64: - if s.Valid { - *d, _ = strconv.ParseInt(s.String, 10, 64) - } - return nil - case *string: - if s.Valid { - *d = s.String - } - return nil - case *time.Time: - if s.Valid { - var err error - dt, err := convert.String2Time(s.String, originalLocation, convertedLocation) - if err != nil { - return err - } - *d = *dt - } - return nil - case *sql.NullTime: - if s.Valid { - var err error - dt, err := convert.String2Time(s.String, originalLocation, convertedLocation) - if err != nil { - return err - } - d.Valid = true - d.Time = *dt - } - return nil - case *big.Float: - if s.Valid { - if d == nil { - d = big.NewFloat(0) - } - d.SetString(s.String) - } - return nil - } - case *sql.NullInt32: - switch d := dest.(type) { - case *int: - if s.Valid { - *d = int(s.Int32) - } - return nil - case *int8: - if s.Valid { - *d = int8(s.Int32) - } - return nil - case *int16: - if s.Valid { - *d = int16(s.Int32) - } - return nil - case *int32: - if s.Valid { - *d = s.Int32 - } - return nil - case *int64: - if s.Valid { - *d = int64(s.Int32) - } - return nil - } - case *sql.NullInt64: - switch d := dest.(type) { - case *int: - if s.Valid { - *d = int(s.Int64) - } - return nil - case *int8: - if s.Valid { - *d = int8(s.Int64) - } - return nil - case *int16: - if s.Valid { - *d = int16(s.Int64) - } - return nil - case *int32: - if s.Valid { - *d = int32(s.Int64) - } - return nil - case *int64: - if s.Valid { - *d = s.Int64 - } - return nil - } - case *sql.NullFloat64: - switch d := dest.(type) { - case *int: - if s.Valid { - *d = int(s.Float64) - } - return nil - case *float64: - if s.Valid { - *d = s.Float64 - } - return nil - } - case *sql.NullBool: - switch d := dest.(type) { - case *bool: - if s.Valid { - *d = s.Bool - } - return nil - } - case *sql.NullTime: - switch d := dest.(type) { - case *time.Time: - if s.Valid { - *d = s.Time - } - return nil - case *string: - if s.Valid { - *d = s.Time.In(convertedLocation).Format("2006-01-02 15:04:05") - } - return nil - } - case *convert.NullUint32: - switch d := dest.(type) { - case *uint8: - if s.Valid { - *d = uint8(s.Uint32) - } - return nil - case *uint16: - if s.Valid { - *d = uint16(s.Uint32) - } - return nil - case *uint: - if s.Valid { - *d = uint(s.Uint32) - } - return nil - } - case *convert.NullUint64: - switch d := dest.(type) { - case *uint64: - if s.Valid { - *d = s.Uint64 - } - return nil - } - case *sql.RawBytes: - switch d := dest.(type) { - case convert.Conversion: - return d.FromDB(*s) - } - } - - var sv reflect.Value - - switch d := dest.(type) { - case *string: - sv = reflect.ValueOf(src) - switch sv.Kind() { - case reflect.Bool, - reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, - reflect.Float32, reflect.Float64: - *d = convert.AsString(src) - return nil - } - case *[]byte: - if b, ok := convert.AsBytes(src); ok { - *d = b - return nil - } - case *bool: - bv, err := driver.Bool.ConvertValue(src) - if err == nil { - *d = bv.(bool) - } - return err - case *interface{}: - *d = src - return nil - } - - return convertAssignV(reflect.ValueOf(dest), src) -} - -func convertAssignV(dv reflect.Value, src interface{}) error { - if src == nil { - return nil - } - - if dv.Type().Implements(scannerType) { - return dv.Interface().(sql.Scanner).Scan(src) - } - - switch dv.Kind() { - case reflect.Ptr: - if dv.IsNil() { - dv.Set(reflect.New(dv.Type().Elem())) - } - return convertAssignV(dv.Elem(), src) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - i64, err := convert.AsInt64(src) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) - } - dv.SetInt(i64) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - u64, err := convert.AsUint64(src) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) - } - dv.SetUint(u64) - return nil - case reflect.Float32, reflect.Float64: - f64, err := convert.AsFloat64(src) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) - } - dv.SetFloat(f64) - return nil - case reflect.String: - dv.SetString(convert.AsString(src)) - return nil - case reflect.Bool: - b, err := convert.AsBool(src) - if err != nil { - return err - } - dv.SetBool(b) - return nil - case reflect.Slice, reflect.Map, reflect.Struct, reflect.Array: - data, ok := convert.AsBytes(src) - if !ok { - return fmt.Errorf("onvertAssignV: src cannot be as bytes %#v", src) - } - if data == nil { - return nil - } - if dv.Kind() != reflect.Ptr { - dv = dv.Addr() - } - return json.Unmarshal(data, dv.Interface()) - default: - return fmt.Errorf("convertAssignV: unsupported Scan, storing driver.Value type %T into type %T", src, dv.Interface()) - } -} - -func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { - switch tp.Kind() { - case reflect.Ptr: - return asKind(vv.Elem(), tp.Elem()) - case reflect.Int64: - return vv.Int(), nil - case reflect.Int: - return int(vv.Int()), nil - case reflect.Int32: - return int32(vv.Int()), nil - case reflect.Int16: - return int16(vv.Int()), nil - case reflect.Int8: - return int8(vv.Int()), nil - case reflect.Uint64: - return vv.Uint(), nil - case reflect.Uint: - return uint(vv.Uint()), nil - case reflect.Uint32: - return uint32(vv.Uint()), nil - case reflect.Uint16: - return uint16(vv.Uint()), nil - case reflect.Uint8: - return uint8(vv.Uint()), nil - case reflect.String: - return vv.String(), nil - case reflect.Slice: - if tp.Elem().Kind() == reflect.Uint8 { - v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64) - if err != nil { - return nil, err - } - return v, nil - } - } - return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) -} diff --git a/internal/convert/conversion.go b/internal/convert/conversion.go index 16f1a92a..096fcfaf 100644 --- a/internal/convert/conversion.go +++ b/internal/convert/conversion.go @@ -4,9 +4,385 @@ package convert +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "math/big" + "reflect" + "strconv" + "time" +) + // Conversion is an interface. A type implements Conversion will according // the custom method to fill into database and retrieve from database. type Conversion interface { FromDB([]byte) error ToDB() ([]byte, error) } + +// ErrNilPtr represents an error +var ErrNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error + +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } + c := make([]byte, len(b)) + copy(c, b) + return c +} + +// Assign copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. +func Assign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error { + // Common cases, without reflect. + switch s := src.(type) { + case *interface{}: + return Assign(dest, *s, originalLocation, convertedLocation) + case string: + switch d := dest.(type) { + case *string: + if d == nil { + return ErrNilPtr + } + *d = s + return nil + case *[]byte: + if d == nil { + return ErrNilPtr + } + *d = []byte(s) + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + if d == nil { + return ErrNilPtr + } + *d = string(s) + return nil + case *interface{}: + if d == nil { + return ErrNilPtr + } + *d = cloneBytes(s) + return nil + case *[]byte: + if d == nil { + return ErrNilPtr + } + *d = cloneBytes(s) + return nil + } + case time.Time: + switch d := dest.(type) { + case *string: + *d = s.Format(time.RFC3339Nano) + return nil + case *[]byte: + if d == nil { + return ErrNilPtr + } + *d = []byte(s.Format(time.RFC3339Nano)) + return nil + } + case nil: + switch d := dest.(type) { + case *interface{}: + if d == nil { + return ErrNilPtr + } + *d = nil + return nil + case *[]byte: + if d == nil { + return ErrNilPtr + } + *d = nil + return nil + } + case *sql.NullString: + switch d := dest.(type) { + case *int: + if s.Valid { + *d, _ = strconv.Atoi(s.String) + } + return nil + case *int64: + if s.Valid { + *d, _ = strconv.ParseInt(s.String, 10, 64) + } + return nil + case *string: + if s.Valid { + *d = s.String + } + return nil + case *time.Time: + if s.Valid { + var err error + dt, err := String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + *d = *dt + } + return nil + case *sql.NullTime: + if s.Valid { + var err error + dt, err := String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + d.Valid = true + d.Time = *dt + } + return nil + case *big.Float: + if s.Valid { + if d == nil { + d = big.NewFloat(0) + } + d.SetString(s.String) + } + return nil + } + case *sql.NullInt32: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int32) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int32) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int32) + } + return nil + case *int32: + if s.Valid { + *d = s.Int32 + } + return nil + case *int64: + if s.Valid { + *d = int64(s.Int32) + } + return nil + } + case *sql.NullInt64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int64) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int64) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int64) + } + return nil + case *int32: + if s.Valid { + *d = int32(s.Int64) + } + return nil + case *int64: + if s.Valid { + *d = s.Int64 + } + return nil + } + case *sql.NullFloat64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Float64) + } + return nil + case *float64: + if s.Valid { + *d = s.Float64 + } + return nil + } + case *sql.NullBool: + switch d := dest.(type) { + case *bool: + if s.Valid { + *d = s.Bool + } + return nil + } + case *sql.NullTime: + switch d := dest.(type) { + case *time.Time: + if s.Valid { + *d = s.Time + } + return nil + case *string: + if s.Valid { + *d = s.Time.In(convertedLocation).Format("2006-01-02 15:04:05") + } + return nil + } + case *NullUint32: + switch d := dest.(type) { + case *uint8: + if s.Valid { + *d = uint8(s.Uint32) + } + return nil + case *uint16: + if s.Valid { + *d = uint16(s.Uint32) + } + return nil + case *uint: + if s.Valid { + *d = uint(s.Uint32) + } + return nil + } + case *NullUint64: + switch d := dest.(type) { + case *uint64: + if s.Valid { + *d = s.Uint64 + } + return nil + } + case *sql.RawBytes: + switch d := dest.(type) { + case Conversion: + return d.FromDB(*s) + } + } + + var sv reflect.Value + + switch d := dest.(type) { + case *string: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = AsString(src) + return nil + } + case *[]byte: + if b, ok := AsBytes(src); ok { + *d = b + return nil + } + case *bool: + bv, err := driver.Bool.ConvertValue(src) + if err == nil { + *d = bv.(bool) + } + return err + case *interface{}: + *d = src + return nil + } + + return AssignValue(reflect.ValueOf(dest), src) +} + +var ( + scannerTypePlaceHolder sql.Scanner + scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem() +) + +// AssignValue assign src as dv +func AssignValue(dv reflect.Value, src interface{}) error { + if src == nil { + return nil + } + + if dv.Type().Implements(scannerType) { + return dv.Interface().(sql.Scanner).Scan(src) + } + + switch dv.Kind() { + case reflect.Ptr: + if dv.IsNil() { + dv.Set(reflect.New(dv.Type().Elem())) + } + return AssignValue(dv.Elem(), src) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i64, err := AsInt64(src) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + u64, err := AsUint64(src) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + f64, err := AsFloat64(src) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + dv.SetString(AsString(src)) + return nil + case reflect.Bool: + b, err := AsBool(src) + if err != nil { + return err + } + dv.SetBool(b) + return nil + case reflect.Slice, reflect.Map, reflect.Struct, reflect.Array: + data, ok := AsBytes(src) + if !ok { + return fmt.Errorf("convert.AssignValue: src cannot be as bytes %#v", src) + } + if data == nil { + return nil + } + if dv.Kind() != reflect.Ptr { + dv = dv.Addr() + } + return json.Unmarshal(data, dv.Interface()) + default: + return fmt.Errorf("convert.AssignValue: unsupported Scan, storing driver.Value type %T into type %T", src, dv.Interface()) + } +} diff --git a/scan.go b/scan.go index 83ad0b02..b712f18a 100644 --- a/scan.go +++ b/scan.go @@ -235,7 +235,7 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column for i, replaced := range replaces { if replaced { - if err = convertAssign(vv[i], scanResults[i], engine.DatabaseTZ, engine.TZLocation); err != nil { + if err = convert.Assign(vv[i], scanResults[i], engine.DatabaseTZ, engine.TZLocation); err != nil { return err } } diff --git a/session.go b/session.go index a15f5c3c..304d1079 100644 --- a/session.go +++ b/session.go @@ -15,6 +15,7 @@ import ( "hash/crc32" "io" "reflect" + "strconv" "strings" "xorm.io/xorm/contexts" @@ -464,6 +465,44 @@ func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Typ return nil } +func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { + switch tp.Kind() { + case reflect.Ptr: + return asKind(vv.Elem(), tp.Elem()) + case reflect.Int64: + return vv.Int(), nil + case reflect.Int: + return int(vv.Int()), nil + case reflect.Int32: + return int32(vv.Int()), nil + case reflect.Int16: + return int16(vv.Int()), nil + case reflect.Int8: + return int8(vv.Int()), nil + case reflect.Uint64: + return vv.Uint(), nil + case reflect.Uint: + return uint(vv.Uint()), nil + case reflect.Uint32: + return uint32(vv.Uint()), nil + case reflect.Uint16: + return uint16(vv.Uint()), nil + case reflect.Uint8: + return uint8(vv.Uint()), nil + case reflect.String: + return vv.String(), nil + case reflect.Slice: + if tp.Elem().Kind() == reflect.Uint8 { + v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64) + if err != nil { + return nil, err + } + return v, nil + } + } + return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) +} + func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, scanResult interface{}, table *schemas.Table) error { v, ok := scanResult.(*interface{}) @@ -612,7 +651,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } } // switch fieldType.Kind() - return convertAssignV(fieldValue.Addr(), scanResult) + return convert.AssignValue(fieldValue.Addr(), scanResult) } func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { diff --git a/session_find.go b/session_find.go index 2f5438fb..82a302b7 100644 --- a/session_find.go +++ b/session_find.go @@ -10,6 +10,7 @@ import ( "xorm.io/builder" "xorm.io/xorm/caches" + "xorm.io/xorm/internal/convert" "xorm.io/xorm/internal/statements" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" @@ -280,7 +281,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { cols := table.PKColumns() if len(cols) == 1 { - return convertAssign(dst, pk[0], nil, nil) + return convert.Assign(dst, pk[0], nil, nil) } dst = pk @@ -478,7 +479,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in keyValue := reflect.New(keyType) var ikey interface{} if len(key) == 1 { - if err := convertAssignV(keyValue, key[0]); err != nil { + if err := convert.AssignValue(keyValue, key[0]); err != nil { return err } ikey = keyValue.Elem().Interface() diff --git a/session_insert.go b/session_insert.go index f35cca53..1583858e 100644 --- a/session_insert.go +++ b/session_insert.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "xorm.io/xorm/internal/convert" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -378,7 +379,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - return 1, convertAssignV(*aiValue, id) + return 1, convert.AssignValue(*aiValue, id) } res, err := session.exec(sqlStr, args...) @@ -418,7 +419,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - if err := convertAssignV(*aiValue, id); err != nil { + if err := convert.AssignValue(*aiValue, id); err != nil { return 0, err } -- 2.40.1