diff --git a/Makefile b/Makefile index 1bdd44c9..5675589d 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ GOFMT ?= gofmt -s TAGS ?= SED_INPLACE := sed -i -GO_DIRS := caches contexts integrations convert core dialects internal log migrate names schemas tags +GO_DIRS := caches contexts integrations core dialects internal log migrate names schemas tags GOFILES := $(wildcard *.go) GOFILES += $(shell find $(GO_DIRS) -name "*.go" -type f) INTEGRATION_PACKAGES := xorm.io/xorm/integrations diff --git a/convert.go b/convert.go index c3eb4de9..c4fc7867 100644 --- a/convert.go +++ b/convert.go @@ -15,7 +15,7 @@ import ( "strconv" "time" - "xorm.io/xorm/convert" + "xorm.io/xorm/internal/convert" ) var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error @@ -36,347 +36,6 @@ func cloneBytes(b []byte) []byte { return c } -func asString(src interface{}) string { - switch v := src.(type) { - case string: - return v - case []byte: - return string(v) - case *sql.NullString: - return v.String - case *sql.NullInt32: - return fmt.Sprintf("%d", v.Int32) - case *sql.NullInt64: - return fmt.Sprintf("%d", v.Int64) - } - rv := reflect.ValueOf(src) - switch rv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return strconv.FormatInt(rv.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return strconv.FormatUint(rv.Uint(), 10) - case reflect.Float64: - return strconv.FormatFloat(rv.Float(), 'g', -1, 64) - case reflect.Float32: - return strconv.FormatFloat(rv.Float(), 'g', -1, 32) - case reflect.Bool: - return strconv.FormatBool(rv.Bool()) - } - return fmt.Sprintf("%v", src) -} - -func asInt64(src interface{}) (int64, error) { - switch v := src.(type) { - case int: - return int64(v), nil - case int16: - return int64(v), nil - case int32: - return int64(v), nil - case int8: - return int64(v), nil - case int64: - return v, nil - case uint: - return int64(v), nil - case uint8: - return int64(v), nil - case uint16: - return int64(v), nil - case uint32: - return int64(v), nil - case uint64: - return int64(v), nil - case []byte: - return strconv.ParseInt(string(v), 10, 64) - case string: - return strconv.ParseInt(v, 10, 64) - case *sql.NullString: - return strconv.ParseInt(v.String, 10, 64) - case *sql.NullInt32: - return int64(v.Int32), nil - case *sql.NullInt64: - return int64(v.Int64), nil - } - - rv := reflect.ValueOf(src) - switch rv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return rv.Int(), nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return int64(rv.Uint()), nil - case reflect.Float64, reflect.Float32: - return int64(rv.Float()), nil - case reflect.String: - return strconv.ParseInt(rv.String(), 10, 64) - } - return 0, fmt.Errorf("unsupported value %T as int64", src) -} - -func asUint64(src interface{}) (uint64, error) { - switch v := src.(type) { - case int: - return uint64(v), nil - case int16: - return uint64(v), nil - case int32: - return uint64(v), nil - case int8: - return uint64(v), nil - case int64: - return uint64(v), nil - case uint: - return uint64(v), nil - case uint8: - return uint64(v), nil - case uint16: - return uint64(v), nil - case uint32: - return uint64(v), nil - case uint64: - return v, nil - case []byte: - return strconv.ParseUint(string(v), 10, 64) - case string: - return strconv.ParseUint(v, 10, 64) - case *sql.NullString: - return strconv.ParseUint(v.String, 10, 64) - case *sql.NullInt32: - return uint64(v.Int32), nil - case *sql.NullInt64: - return uint64(v.Int64), nil - } - - rv := reflect.ValueOf(src) - switch rv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return uint64(rv.Int()), nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return uint64(rv.Uint()), nil - case reflect.Float64, reflect.Float32: - return uint64(rv.Float()), nil - case reflect.String: - return strconv.ParseUint(rv.String(), 10, 64) - } - return 0, fmt.Errorf("unsupported value %T as uint64", src) -} - -func asFloat64(src interface{}) (float64, error) { - switch v := src.(type) { - case int: - return float64(v), nil - case int16: - return float64(v), nil - case int32: - return float64(v), nil - case int8: - return float64(v), nil - case int64: - return float64(v), nil - case uint: - return float64(v), nil - case uint8: - return float64(v), nil - case uint16: - return float64(v), nil - case uint32: - return float64(v), nil - case uint64: - return float64(v), nil - case []byte: - return strconv.ParseFloat(string(v), 64) - case string: - return strconv.ParseFloat(v, 64) - case *sql.NullString: - return strconv.ParseFloat(v.String, 64) - case *sql.NullInt32: - return float64(v.Int32), nil - case *sql.NullInt64: - return float64(v.Int64), nil - case *sql.NullFloat64: - return v.Float64, nil - } - - rv := reflect.ValueOf(src) - switch rv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return float64(rv.Int()), nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return float64(rv.Uint()), nil - case reflect.Float64, reflect.Float32: - return float64(rv.Float()), nil - case reflect.String: - return strconv.ParseFloat(rv.String(), 64) - } - return 0, fmt.Errorf("unsupported value %T as int64", src) -} - -func asBigFloat(src interface{}) (*big.Float, error) { - res := big.NewFloat(0) - switch v := src.(type) { - case int: - res.SetInt64(int64(v)) - return res, nil - case int16: - res.SetInt64(int64(v)) - return res, nil - case int32: - res.SetInt64(int64(v)) - return res, nil - case int8: - res.SetInt64(int64(v)) - return res, nil - case int64: - res.SetInt64(int64(v)) - return res, nil - case uint: - res.SetUint64(uint64(v)) - return res, nil - case uint8: - res.SetUint64(uint64(v)) - return res, nil - case uint16: - res.SetUint64(uint64(v)) - return res, nil - case uint32: - res.SetUint64(uint64(v)) - return res, nil - case uint64: - res.SetUint64(uint64(v)) - return res, nil - case []byte: - res.SetString(string(v)) - return res, nil - case string: - res.SetString(v) - return res, nil - case *sql.NullString: - if v.Valid { - res.SetString(v.String) - return res, nil - } - return nil, nil - case *sql.NullInt32: - if v.Valid { - res.SetInt64(int64(v.Int32)) - return res, nil - } - return nil, nil - case *sql.NullInt64: - if v.Valid { - res.SetInt64(int64(v.Int64)) - return res, nil - } - return nil, nil - } - - rv := reflect.ValueOf(src) - switch rv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - res.SetInt64(rv.Int()) - return res, nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - res.SetUint64(rv.Uint()) - return res, nil - case reflect.Float64, reflect.Float32: - res.SetFloat64(rv.Float()) - return res, nil - case reflect.String: - res.SetString(rv.String()) - return res, nil - } - return nil, fmt.Errorf("unsupported value %T as big.Float", src) -} - -func asBytes(src interface{}) ([]byte, bool) { - switch t := src.(type) { - case []byte: - return t, true - case *sql.NullString: - if !t.Valid { - return nil, true - } - return []byte(t.String), true - case *sql.RawBytes: - return *t, true - } - - rv := reflect.ValueOf(src) - - switch rv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return strconv.AppendInt(nil, rv.Int(), 10), true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return strconv.AppendUint(nil, rv.Uint(), 10), true - case reflect.Float32: - return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 32), true - case reflect.Float64: - return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 64), true - case reflect.Bool: - return strconv.AppendBool(nil, rv.Bool()), true - case reflect.String: - return []byte(rv.String()), true - } - return nil, false -} - -func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time.Time, error) { - switch t := src.(type) { - case string: - return convert.String2Time(t, dbLoc, uiLoc) - case *sql.NullString: - if !t.Valid { - return nil, nil - } - return convert.String2Time(t.String, dbLoc, uiLoc) - case []uint8: - if t == nil { - return nil, nil - } - return convert.String2Time(string(t), dbLoc, uiLoc) - case *sql.NullTime: - if !t.Valid { - return nil, nil - } - z, _ := t.Time.Zone() - if len(z) == 0 || t.Time.Year() == 0 || t.Time.Location().String() != dbLoc.String() { - tm := time.Date(t.Time.Year(), t.Time.Month(), t.Time.Day(), t.Time.Hour(), - t.Time.Minute(), t.Time.Second(), t.Time.Nanosecond(), dbLoc).In(uiLoc) - return &tm, nil - } - tm := t.Time.In(uiLoc) - return &tm, nil - case *time.Time: - z, _ := t.Zone() - if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { - tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), - t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc) - return &tm, nil - } - tm := t.In(uiLoc) - return &tm, nil - case time.Time: - z, _ := t.Zone() - if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { - tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), - t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc) - return &tm, nil - } - tm := t.In(uiLoc) - return &tm, nil - case int: - tm := time.Unix(int64(t), 0).In(uiLoc) - return &tm, nil - case int64: - tm := time.Unix(t, 0).In(uiLoc) - return &tm, nil - case *sql.NullInt64: - tm := time.Unix(t.Int64, 0).In(uiLoc) - return &tm, nil - } - return nil, fmt.Errorf("unsupported value %#v as time", src) -} - // convertAssign copies to dest the value in src, converting it if possible. // An error is returned if the copy would result in loss of information. // dest should be a pointer type. @@ -585,7 +244,7 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve } return nil } - case *NullUint32: + case *convert.NullUint32: switch d := dest.(type) { case *uint8: if s.Valid { @@ -603,7 +262,7 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve } return nil } - case *NullUint64: + case *convert.NullUint64: switch d := dest.(type) { case *uint64: if s.Valid { @@ -628,11 +287,11 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: - *d = asString(src) + *d = convert.AsString(src) return nil } case *[]byte: - if b, ok := asBytes(src); ok { + if b, ok := convert.AsBytes(src); ok { *d = b return nil } @@ -666,7 +325,7 @@ func convertAssignV(dv reflect.Value, src interface{}) error { } return convertAssignV(dv.Elem(), src) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - i64, err := asInt64(src) + i64, err := convert.AsInt64(src) if err != nil { err = strconvErr(err) return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) @@ -674,7 +333,7 @@ func convertAssignV(dv reflect.Value, src interface{}) error { dv.SetInt(i64) return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - u64, err := asUint64(src) + u64, err := convert.AsUint64(src) if err != nil { err = strconvErr(err) return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) @@ -682,7 +341,7 @@ func convertAssignV(dv reflect.Value, src interface{}) error { dv.SetUint(u64) return nil case reflect.Float32, reflect.Float64: - f64, err := asFloat64(src) + f64, err := convert.AsFloat64(src) if err != nil { err = strconvErr(err) return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) @@ -690,17 +349,17 @@ func convertAssignV(dv reflect.Value, src interface{}) error { dv.SetFloat(f64) return nil case reflect.String: - dv.SetString(asString(src)) + dv.SetString(convert.AsString(src)) return nil case reflect.Bool: - b, err := asBool(src) + b, err := convert.AsBool(src) if err != nil { return err } dv.SetBool(b) return nil case reflect.Slice, reflect.Map, reflect.Struct, reflect.Array: - data, ok := asBytes(src) + data, ok := convert.AsBytes(src) if !ok { return fmt.Errorf("onvertAssignV: src cannot be as bytes %#v", src) } @@ -753,201 +412,3 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { } return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) } - -func asBool(src interface{}) (bool, error) { - switch v := src.(type) { - case bool: - return v, nil - case *bool: - return *v, nil - case *sql.NullBool: - return v.Bool, nil - case int64: - return v > 0, nil - case int: - return v > 0, nil - case int8: - return v > 0, nil - case int16: - return v > 0, nil - case int32: - return v > 0, nil - case []byte: - if len(v) == 0 { - return false, nil - } - if v[0] == 0x00 { - return false, nil - } else if v[0] == 0x01 { - return true, nil - } - return strconv.ParseBool(string(v)) - case string: - return strconv.ParseBool(v) - case *sql.NullInt64: - return v.Int64 > 0, nil - case *sql.NullInt32: - return v.Int32 > 0, nil - default: - return false, fmt.Errorf("unknow type %T as bool", src) - } -} - -// str2PK convert string value to primary key value according to tp -func str2PKValue(s string, tp reflect.Type) (reflect.Value, error) { - var err error - var result interface{} - var defReturn = reflect.Zero(tp) - - switch tp.Kind() { - case reflect.Int: - result, err = strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int: %s", s, err.Error()) - } - case reflect.Int8: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int8: %s", s, err.Error()) - } - result = int8(x) - case reflect.Int16: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int16: %s", s, err.Error()) - } - result = int16(x) - case reflect.Int32: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int32: %s", s, err.Error()) - } - result = int32(x) - case reflect.Int64: - result, err = strconv.ParseInt(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int64: %s", s, err.Error()) - } - case reflect.Uint: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint: %s", s, err.Error()) - } - result = uint(x) - case reflect.Uint8: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint8: %s", s, err.Error()) - } - result = uint8(x) - case reflect.Uint16: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint16: %s", s, err.Error()) - } - result = uint16(x) - case reflect.Uint32: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint32: %s", s, err.Error()) - } - result = uint32(x) - case reflect.Uint64: - result, err = strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint64: %s", s, err.Error()) - } - case reflect.String: - result = s - default: - return defReturn, errors.New("unsupported convert type") - } - return reflect.ValueOf(result).Convert(tp), nil -} - -func str2PK(s string, tp reflect.Type) (interface{}, error) { - v, err := str2PKValue(s, tp) - if err != nil { - return nil, err - } - return v.Interface(), nil -} - -var ( - _ sql.Scanner = &NullUint64{} -) - -// NullUint64 represents an uint64 that may be null. -// NullUint64 implements the Scanner interface so -// it can be used as a scan destination, similar to NullString. -type NullUint64 struct { - Uint64 uint64 - Valid bool -} - -// Scan implements the Scanner interface. -func (n *NullUint64) Scan(value interface{}) error { - if value == nil { - n.Uint64, n.Valid = 0, false - return nil - } - n.Valid = true - var err error - n.Uint64, err = asUint64(value) - return err -} - -// Value implements the driver Valuer interface. -func (n NullUint64) Value() (driver.Value, error) { - if !n.Valid { - return nil, nil - } - return n.Uint64, nil -} - -var ( - _ sql.Scanner = &NullUint32{} -) - -// NullUint32 represents an uint32 that may be null. -// NullUint32 implements the Scanner interface so -// it can be used as a scan destination, similar to NullString. -type NullUint32 struct { - Uint32 uint32 - Valid bool // Valid is true if Uint32 is not NULL -} - -// Scan implements the Scanner interface. -func (n *NullUint32) Scan(value interface{}) error { - if value == nil { - n.Uint32, n.Valid = 0, false - return nil - } - n.Valid = true - i64, err := asUint64(value) - if err != nil { - return err - } - n.Uint32 = uint32(i64) - return nil -} - -// Value implements the driver Valuer interface. -func (n NullUint32) Value() (driver.Value, error) { - if !n.Valid { - return nil, nil - } - return int64(n.Uint32), nil -} - -var ( - _ sql.Scanner = &EmptyScanner{} -) - -// EmptyScanner represents an empty scanner which will ignore the scan -type EmptyScanner struct{} - -// Scan implements sql.Scanner -func (EmptyScanner) Scan(value interface{}) error { - return nil -} diff --git a/convert/time.go b/convert/time.go deleted file mode 100644 index 6a53171b..00000000 --- a/convert/time.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2021 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package convert - -import ( - "fmt" - "strconv" - "time" - - "xorm.io/xorm/internal/utils" -) - -// String2Time converts a string to time with original location -func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) { - if len(s) == 19 { - if s == utils.ZeroTime0 || s == utils.ZeroTime1 { - return &time.Time{}, nil - } - dt, err := time.ParseInLocation("2006-01-02 15:04:05", s, originalLocation) - if err != nil { - return nil, err - } - dt = dt.In(convertedLocation) - return &dt, nil - } else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' { - dt, err := time.ParseInLocation("2006-01-02T15:04:05", s[:19], originalLocation) - if err != nil { - return nil, err - } - dt = dt.In(convertedLocation) - return &dt, nil - } else if len(s) == 25 && s[10] == 'T' && s[19] == '+' && s[22] == ':' { - dt, err := time.Parse(time.RFC3339, s) - if err != nil { - return nil, err - } - dt = dt.In(convertedLocation) - return &dt, nil - } else { - i, err := strconv.ParseInt(s, 10, 64) - if err == nil { - tm := time.Unix(i, 0).In(convertedLocation) - return &tm, nil - } - } - return nil, fmt.Errorf("unsupported conversion from %s to time", s) -} diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index d3ce2a11..c0376c70 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -9,68 +9,18 @@ import ( "errors" "fmt" "math/big" - "strconv" "testing" "time" "xorm.io/xorm" "xorm.io/xorm/contexts" + "xorm.io/xorm/internal/convert" "xorm.io/xorm/schemas" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" ) -func convertInt(v interface{}) (int64, error) { - switch v.(type) { - case int: - return int64(v.(int)), nil - case int8: - return int64(v.(int8)), nil - case int16: - return int64(v.(int16)), nil - case int32: - return int64(v.(int32)), nil - case int64: - return v.(int64), nil - case []byte: - i, err := strconv.ParseInt(string(v.([]byte)), 10, 64) - if err != nil { - return 0, err - } - return i, nil - case string: - i, err := strconv.ParseInt(v.(string), 10, 64) - if err != nil { - return 0, err - } - return i, nil - } - return 0, fmt.Errorf("unsupported type: %v", v) -} - -func convertFloat(v interface{}) (float64, error) { - switch v.(type) { - case float32: - return float64(v.(float32)), nil - case float64: - return v.(float64), nil - case string: - i, err := strconv.ParseFloat(v.(string), 64) - if err != nil { - return 0, err - } - return i, nil - case []byte: - i, err := strconv.ParseFloat(string(v.([]byte)), 64) - if err != nil { - return 0, err - } - return i, nil - } - return 0, fmt.Errorf("unsupported type: %v", v) -} - func TestGetVar(t *testing.T) { assert.NoError(t, PrepareEngine()) @@ -261,17 +211,17 @@ func TestGetVar(t *testing.T) { assert.NoError(t, err) assert.Equal(t, true, has) - v1, err := convertInt(valuesSliceInter[0]) + v1, err := convert.AsInt64(valuesSliceInter[0]) assert.NoError(t, err) assert.EqualValues(t, 1, v1) assert.Equal(t, "hi", fmt.Sprintf("%s", valuesSliceInter[1])) - v3, err := convertInt(valuesSliceInter[2]) + v3, err := convert.AsInt64(valuesSliceInter[2]) assert.NoError(t, err) assert.EqualValues(t, 28, v3) - v4, err := convertFloat(valuesSliceInter[3]) + v4, err := convert.AsFloat64(valuesSliceInter[3]) assert.NoError(t, err) assert.Equal(t, "1.5", fmt.Sprintf("%v", v4)) } diff --git a/integrations/types_test.go b/integrations/types_test.go index f192c1ff..9d4e46a0 100644 --- a/integrations/types_test.go +++ b/integrations/types_test.go @@ -13,7 +13,7 @@ import ( "testing" "xorm.io/xorm" - "xorm.io/xorm/convert" + "xorm.io/xorm/internal/convert" "xorm.io/xorm/internal/json" "xorm.io/xorm/schemas" diff --git a/internal/convert/bool.go b/internal/convert/bool.go new file mode 100644 index 00000000..58b23f4b --- /dev/null +++ b/internal/convert/bool.go @@ -0,0 +1,51 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "fmt" + "strconv" +) + +// AsBool convert interface as bool +func AsBool(src interface{}) (bool, error) { + switch v := src.(type) { + case bool: + return v, nil + case *bool: + return *v, nil + case *sql.NullBool: + return v.Bool, nil + case int64: + return v > 0, nil + case int: + return v > 0, nil + case int8: + return v > 0, nil + case int16: + return v > 0, nil + case int32: + return v > 0, nil + case []byte: + if len(v) == 0 { + return false, nil + } + if v[0] == 0x00 { + return false, nil + } else if v[0] == 0x01 { + return true, nil + } + return strconv.ParseBool(string(v)) + case string: + return strconv.ParseBool(v) + case *sql.NullInt64: + return v.Int64 > 0, nil + case *sql.NullInt32: + return v.Int32 > 0, nil + default: + return false, fmt.Errorf("unknow type %T as bool", src) + } +} diff --git a/convert/conversion.go b/internal/convert/conversion.go similarity index 100% rename from convert/conversion.go rename to internal/convert/conversion.go diff --git a/internal/convert/float.go b/internal/convert/float.go new file mode 100644 index 00000000..51b441ce --- /dev/null +++ b/internal/convert/float.go @@ -0,0 +1,142 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "fmt" + "math/big" + "reflect" + "strconv" +) + +// AsFloat64 convets interface as float64 +func AsFloat64(src interface{}) (float64, error) { + switch v := src.(type) { + case int: + return float64(v), nil + case int16: + return float64(v), nil + case int32: + return float64(v), nil + case int8: + return float64(v), nil + case int64: + return float64(v), nil + case uint: + return float64(v), nil + case uint8: + return float64(v), nil + case uint16: + return float64(v), nil + case uint32: + return float64(v), nil + case uint64: + return float64(v), nil + case []byte: + return strconv.ParseFloat(string(v), 64) + case string: + return strconv.ParseFloat(v, 64) + case *sql.NullString: + return strconv.ParseFloat(v.String, 64) + case *sql.NullInt32: + return float64(v.Int32), nil + case *sql.NullInt64: + return float64(v.Int64), nil + case *sql.NullFloat64: + return v.Float64, nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return float64(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return float64(rv.Uint()), nil + case reflect.Float64, reflect.Float32: + return float64(rv.Float()), nil + case reflect.String: + return strconv.ParseFloat(rv.String(), 64) + } + return 0, fmt.Errorf("unsupported value %T as int64", src) +} + +// AsBigFloat converts interface as big.Float +func AsBigFloat(src interface{}) (*big.Float, error) { + res := big.NewFloat(0) + switch v := src.(type) { + case int: + res.SetInt64(int64(v)) + return res, nil + case int16: + res.SetInt64(int64(v)) + return res, nil + case int32: + res.SetInt64(int64(v)) + return res, nil + case int8: + res.SetInt64(int64(v)) + return res, nil + case int64: + res.SetInt64(int64(v)) + return res, nil + case uint: + res.SetUint64(uint64(v)) + return res, nil + case uint8: + res.SetUint64(uint64(v)) + return res, nil + case uint16: + res.SetUint64(uint64(v)) + return res, nil + case uint32: + res.SetUint64(uint64(v)) + return res, nil + case uint64: + res.SetUint64(uint64(v)) + return res, nil + case []byte: + res.SetString(string(v)) + return res, nil + case string: + res.SetString(v) + return res, nil + case *sql.NullString: + if v.Valid { + res.SetString(v.String) + return res, nil + } + return nil, nil + case *sql.NullInt32: + if v.Valid { + res.SetInt64(int64(v.Int32)) + return res, nil + } + return nil, nil + case *sql.NullInt64: + if v.Valid { + res.SetInt64(int64(v.Int64)) + return res, nil + } + return nil, nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + res.SetInt64(rv.Int()) + return res, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + res.SetUint64(rv.Uint()) + return res, nil + case reflect.Float64, reflect.Float32: + res.SetFloat64(rv.Float()) + return res, nil + case reflect.String: + res.SetString(rv.String()) + return res, nil + } + return nil, fmt.Errorf("unsupported value %T as big.Float", src) +} diff --git a/internal/convert/int.go b/internal/convert/int.go new file mode 100644 index 00000000..af8d4f75 --- /dev/null +++ b/internal/convert/int.go @@ -0,0 +1,178 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "strconv" +) + +// AsInt64 converts interface as int64 +func AsInt64(src interface{}) (int64, error) { + switch v := src.(type) { + case int: + return int64(v), nil + case int16: + return int64(v), nil + case int32: + return int64(v), nil + case int8: + return int64(v), nil + case int64: + return v, nil + case uint: + return int64(v), nil + case uint8: + return int64(v), nil + case uint16: + return int64(v), nil + case uint32: + return int64(v), nil + case uint64: + return int64(v), nil + case []byte: + return strconv.ParseInt(string(v), 10, 64) + case string: + return strconv.ParseInt(v, 10, 64) + case *sql.NullString: + return strconv.ParseInt(v.String, 10, 64) + case *sql.NullInt32: + return int64(v.Int32), nil + case *sql.NullInt64: + return int64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return int64(rv.Uint()), nil + case reflect.Float64, reflect.Float32: + return int64(rv.Float()), nil + case reflect.String: + return strconv.ParseInt(rv.String(), 10, 64) + } + return 0, fmt.Errorf("unsupported value %T as int64", src) +} + +// AsUint64 converts interface as uint64 +func AsUint64(src interface{}) (uint64, error) { + switch v := src.(type) { + case int: + return uint64(v), nil + case int16: + return uint64(v), nil + case int32: + return uint64(v), nil + case int8: + return uint64(v), nil + case int64: + return uint64(v), nil + case uint: + return uint64(v), nil + case uint8: + return uint64(v), nil + case uint16: + return uint64(v), nil + case uint32: + return uint64(v), nil + case uint64: + return v, nil + case []byte: + return strconv.ParseUint(string(v), 10, 64) + case string: + return strconv.ParseUint(v, 10, 64) + case *sql.NullString: + return strconv.ParseUint(v.String, 10, 64) + case *sql.NullInt32: + return uint64(v.Int32), nil + case *sql.NullInt64: + return uint64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return uint64(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uint64(rv.Uint()), nil + case reflect.Float64, reflect.Float32: + return uint64(rv.Float()), nil + case reflect.String: + return strconv.ParseUint(rv.String(), 10, 64) + } + return 0, fmt.Errorf("unsupported value %T as uint64", src) +} + +var ( + _ sql.Scanner = &NullUint64{} +) + +// NullUint64 represents an uint64 that may be null. +// NullUint64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint64 struct { + Uint64 uint64 + Valid bool +} + +// Scan implements the Scanner interface. +func (n *NullUint64) Scan(value interface{}) error { + if value == nil { + n.Uint64, n.Valid = 0, false + return nil + } + n.Valid = true + var err error + n.Uint64, err = AsUint64(value) + return err +} + +// Value implements the driver Valuer interface. +func (n NullUint64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Uint64, nil +} + +var ( + _ sql.Scanner = &NullUint32{} +) + +// NullUint32 represents an uint32 that may be null. +// NullUint32 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint32 struct { + Uint32 uint32 + Valid bool // Valid is true if Uint32 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullUint32) Scan(value interface{}) error { + if value == nil { + n.Uint32, n.Valid = 0, false + return nil + } + n.Valid = true + i64, err := AsUint64(value) + if err != nil { + return err + } + n.Uint32 = uint32(i64) + return nil +} + +// Value implements the driver Valuer interface. +func (n NullUint32) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return int64(n.Uint32), nil +} diff --git a/convert/interface.go b/internal/convert/interface.go similarity index 100% rename from convert/interface.go rename to internal/convert/interface.go diff --git a/internal/convert/scanner.go b/internal/convert/scanner.go new file mode 100644 index 00000000..505d3be0 --- /dev/null +++ b/internal/convert/scanner.go @@ -0,0 +1,19 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import "database/sql" + +var ( + _ sql.Scanner = &EmptyScanner{} +) + +// EmptyScanner represents an empty scanner which will ignore the scan +type EmptyScanner struct{} + +// Scan implements sql.Scanner +func (EmptyScanner) Scan(value interface{}) error { + return nil +} diff --git a/internal/convert/string.go b/internal/convert/string.go new file mode 100644 index 00000000..de11fa01 --- /dev/null +++ b/internal/convert/string.go @@ -0,0 +1,75 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "fmt" + "reflect" + "strconv" +) + +// AsString converts interface as string +func AsString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + case *sql.NullString: + return v.String + case *sql.NullInt32: + return fmt.Sprintf("%d", v.Int32) + case *sql.NullInt64: + return fmt.Sprintf("%d", v.Int64) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +// AsBytes converts interface as bytes +func AsBytes(src interface{}) ([]byte, bool) { + switch t := src.(type) { + case []byte: + return t, true + case *sql.NullString: + if !t.Valid { + return nil, true + } + return []byte(t.String), true + case *sql.RawBytes: + return *t, true + } + + rv := reflect.ValueOf(src) + + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.AppendInt(nil, rv.Int(), 10), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.AppendUint(nil, rv.Uint(), 10), true + case reflect.Float32: + return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 32), true + case reflect.Float64: + return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 64), true + case reflect.Bool: + return strconv.AppendBool(nil, rv.Bool()), true + case reflect.String: + return []byte(rv.String()), true + } + return nil, false +} diff --git a/internal/convert/time.go b/internal/convert/time.go new file mode 100644 index 00000000..ecb30a3f --- /dev/null +++ b/internal/convert/time.go @@ -0,0 +1,108 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "fmt" + "strconv" + "time" + + "xorm.io/xorm/internal/utils" +) + +// String2Time converts a string to time with original location +func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) { + if len(s) == 19 { + if s == utils.ZeroTime0 || s == utils.ZeroTime1 { + return &time.Time{}, nil + } + dt, err := time.ParseInLocation("2006-01-02 15:04:05", s, originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' { + dt, err := time.ParseInLocation("2006-01-02T15:04:05", s[:19], originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } else if len(s) == 25 && s[10] == 'T' && s[19] == '+' && s[22] == ':' { + dt, err := time.Parse(time.RFC3339, s) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } else { + i, err := strconv.ParseInt(s, 10, 64) + if err == nil { + tm := time.Unix(i, 0).In(convertedLocation) + return &tm, nil + } + } + return nil, fmt.Errorf("unsupported conversion from %s to time", s) +} + +// AsTime converts interface as time +func AsTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time.Time, error) { + switch t := src.(type) { + case string: + return String2Time(t, dbLoc, uiLoc) + case *sql.NullString: + if !t.Valid { + return nil, nil + } + return String2Time(t.String, dbLoc, uiLoc) + case []uint8: + if t == nil { + return nil, nil + } + return String2Time(string(t), dbLoc, uiLoc) + case *sql.NullTime: + if !t.Valid { + return nil, nil + } + z, _ := t.Time.Zone() + if len(z) == 0 || t.Time.Year() == 0 || t.Time.Location().String() != dbLoc.String() { + tm := time.Date(t.Time.Year(), t.Time.Month(), t.Time.Day(), t.Time.Hour(), + t.Time.Minute(), t.Time.Second(), t.Time.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.Time.In(uiLoc) + return &tm, nil + case *time.Time: + z, _ := t.Zone() + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { + tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.In(uiLoc) + return &tm, nil + case time.Time: + z, _ := t.Zone() + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { + tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.In(uiLoc) + return &tm, nil + case int: + tm := time.Unix(int64(t), 0).In(uiLoc) + return &tm, nil + case int64: + tm := time.Unix(t, 0).In(uiLoc) + return &tm, nil + case *sql.NullInt64: + tm := time.Unix(t.Int64, 0).In(uiLoc) + return &tm, nil + } + return nil, fmt.Errorf("unsupported value %#v as time", src) +} diff --git a/convert/time_test.go b/internal/convert/time_test.go similarity index 100% rename from convert/time_test.go rename to internal/convert/time_test.go diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 0e245a96..8e3c083c 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -15,8 +15,8 @@ import ( "xorm.io/builder" "xorm.io/xorm/contexts" - "xorm.io/xorm/convert" "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/convert" "xorm.io/xorm/internal/json" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" diff --git a/internal/statements/update.go b/internal/statements/update.go index 3020595b..be6ed885 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -11,8 +11,8 @@ import ( "reflect" "time" - "xorm.io/xorm/convert" "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/convert" "xorm.io/xorm/internal/json" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" diff --git a/internal/statements/values.go b/internal/statements/values.go index c572ead5..ada01755 100644 --- a/internal/statements/values.go +++ b/internal/statements/values.go @@ -12,8 +12,8 @@ import ( "reflect" "time" - "xorm.io/xorm/convert" "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/convert" "xorm.io/xorm/internal/json" "xorm.io/xorm/schemas" ) diff --git a/scan.go b/scan.go index ccd6938d..83ad0b02 100644 --- a/scan.go +++ b/scan.go @@ -11,9 +11,9 @@ import ( "reflect" "time" - "xorm.io/xorm/convert" "xorm.io/xorm/core" "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/convert" "xorm.io/xorm/schemas" ) @@ -35,9 +35,9 @@ func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { case *int64: return &sql.NullInt64{}, true, nil case *uint, *uint8, *uint16, *uint32: - return &NullUint32{}, true, nil + return &convert.NullUint32{}, true, nil case *uint64: - return &NullUint64{}, true, nil + return &convert.NullUint64{}, true, nil case *float32, *float64: return &sql.NullFloat64{}, true, nil case *bool: @@ -63,9 +63,9 @@ func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { case reflect.Int32, reflect.Int, reflect.Int16, reflect.Int8: return &sql.NullInt32{}, true, nil case reflect.Uint64: - return &NullUint64{}, true, nil + return &convert.NullUint64{}, true, nil case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8: - return &NullUint32{}, true, nil + return &convert.NullUint32{}, true, nil default: return nil, false, fmt.Errorf("unsupported type: %#v", bean) } diff --git a/session.go b/session.go index 62d6a770..a15f5c3c 100644 --- a/session.go +++ b/session.go @@ -18,8 +18,8 @@ import ( "strings" "xorm.io/xorm/contexts" - "xorm.io/xorm/convert" "xorm.io/xorm/core" + "xorm.io/xorm/internal/convert" "xorm.io/xorm/internal/json" "xorm.io/xorm/internal/statements" "xorm.io/xorm/log" @@ -435,7 +435,7 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql } func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Type, scanResult interface{}) error { - bs, ok := asBytes(scanResult) + bs, ok := convert.AsBytes(scanResult) if !ok { return fmt.Errorf("unsupported database data type: %#v", scanResult) } @@ -476,7 +476,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec if fieldValue.CanAddr() { if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - data, ok := asBytes(scanResult) + data, ok := convert.AsBytes(scanResult) if !ok { return fmt.Errorf("cannot convert %#v as bytes", scanResult) } @@ -485,7 +485,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { - data, ok := asBytes(scanResult) + data, ok := convert.AsBytes(scanResult) if !ok { return fmt.Errorf("cannot convert %#v as bytes", scanResult) } @@ -525,7 +525,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec case reflect.Complex64, reflect.Complex128: return session.setJSON(fieldValue, fieldType, scanResult) case reflect.Slice, reflect.Array: - bs, ok := asBytes(scanResult) + bs, ok := convert.AsBytes(scanResult) if ok && fieldType.Elem().Kind() == reflect.Uint8 { if col.SQLType.IsText() { x := reflect.New(fieldType) @@ -551,7 +551,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } case reflect.Struct: if fieldType.ConvertibleTo(schemas.BigFloatType) { - v, err := asBigFloat(scanResult) + v, err := convert.AsBigFloat(scanResult) if err != nil { return err } @@ -565,7 +565,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec dbTZ = col.TimeZone } - t, err := asTime(scanResult, dbTZ, session.engine.TZLocation) + t, err := convert.AsTime(scanResult, dbTZ, session.engine.TZLocation) if err != nil { return err } diff --git a/session_find.go b/session_find.go index 010ecd6c..2f5438fb 100644 --- a/session_find.go +++ b/session_find.go @@ -6,7 +6,6 @@ package xorm import ( "errors" - "fmt" "reflect" "xorm.io/builder" @@ -476,12 +475,13 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in } else if sliceValue.Kind() == reflect.Map { var key = ids[j] keyType := sliceValue.Type().Key() + keyValue := reflect.New(keyType) var ikey interface{} if len(key) == 1 { - ikey, err = str2PK(fmt.Sprintf("%v", key[0]), keyType) - if err != nil { + if err := convertAssignV(keyValue, key[0]); err != nil { return err } + ikey = keyValue.Elem().Interface() } else { if keyType.Kind() != reflect.Slice { return errors.New("table have multiple primary keys, key is not schemas.PK or slice") diff --git a/session_get.go b/session_get.go index 08172524..617ca169 100644 --- a/session_get.go +++ b/session_get.go @@ -15,8 +15,8 @@ import ( "time" "xorm.io/xorm/caches" - "xorm.io/xorm/convert" "xorm.io/xorm/core" + "xorm.io/xorm/internal/convert" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) diff --git a/tags/parser.go b/tags/parser.go index 9f9a8f62..7d8c3bd6 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -14,8 +14,8 @@ import ( "unicode" "xorm.io/xorm/caches" - "xorm.io/xorm/convert" "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/convert" "xorm.io/xorm/names" "xorm.io/xorm/schemas" )