Support big.Float #1973
12
convert.go
12
convert.go
|
@ -9,6 +9,7 @@ import (
|
|||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
@ -308,10 +309,12 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve
|
|||
if s.Valid {
|
||||
*d, _ = strconv.Atoi(s.String)
|
||||
}
|
||||
return nil
|
||||
case *int64:
|
||||
if s.Valid {
|
||||
*d, _ = strconv.ParseInt(s.String, 10, 64)
|
||||
}
|
||||
return nil
|
||||
case *string:
|
||||
if s.Valid {
|
||||
*d = s.String
|
||||
|
@ -337,6 +340,15 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve
|
|||
d.Valid = true
|
||||
d.Time = *dt
|
||||
}
|
||||
return nil
|
||||
case *big.Float:
|
||||
if s.Valid {
|
||||
if d == nil {
|
||||
d = big.NewFloat(0)
|
||||
}
|
||||
d.SetString(s.String)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case *sql.NullInt32:
|
||||
switch d := dest.(type) {
|
||||
|
|
|
@ -565,7 +565,7 @@ func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) {
|
|||
case "REAL":
|
||||
var s sql.NullFloat64
|
||||
return &s, nil
|
||||
case "NUMERIC":
|
||||
case "NUMERIC", "DECIMAL":
|
||||
var s sql.NullString
|
||||
return &s, nil
|
||||
case "BLOB":
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -766,3 +767,53 @@ func TestGetNil(t *testing.T) {
|
|||
assert.True(t, errors.Is(err, xorm.ErrObjectIsNil))
|
||||
assert.False(t, has)
|
||||
}
|
||||
|
||||
func TestGetBigFloat(t *testing.T) {
|
||||
type GetBigFloat struct {
|
||||
Id int64
|
||||
Money *big.Float `xorm:"numeric(22,2)"`
|
||||
}
|
||||
|
||||
assert.NoError(t, PrepareEngine())
|
||||
assertSync(t, new(GetBigFloat))
|
||||
|
||||
{
|
||||
var gf = GetBigFloat{
|
||||
Money: big.NewFloat(999999.99),
|
||||
}
|
||||
_, err := testEngine.Insert(&gf)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var m big.Float
|
||||
has, err := testEngine.Table("get_big_float").Cols("money").Where("id=?", gf.Id).Get(&m)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, has)
|
||||
assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String())
|
||||
//fmt.Println(m.Cmp(gf.Money))
|
||||
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
|
||||
}
|
||||
|
||||
type GetBigFloat2 struct {
|
||||
Id int64
|
||||
Money *big.Float `xorm:"decimal(22,2)"`
|
||||
}
|
||||
|
||||
assert.NoError(t, PrepareEngine())
|
||||
assertSync(t, new(GetBigFloat2))
|
||||
|
||||
{
|
||||
var gf2 = GetBigFloat2{
|
||||
Money: big.NewFloat(9999999.99),
|
||||
}
|
||||
_, err := testEngine.Insert(&gf2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var m2 big.Float
|
||||
has, err := testEngine.Table("get_big_float2").Cols("money").Where("id=?", gf2.Id).Get(&m2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, has)
|
||||
assert.True(t, m2.String() == gf2.Money.String(), "%v != %v", m2.String(), gf2.Money.String())
|
||||
//fmt.Println(m.Cmp(gf.Money))
|
||||
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
|
@ -19,6 +20,7 @@ import (
|
|||
|
||||
var (
|
||||
nullFloatType = reflect.TypeOf(sql.NullFloat64{})
|
||||
bigFloatType = reflect.TypeOf(big.Float{})
|
||||
)
|
||||
|
||||
// Value2Interface convert a field value of a struct to interface for puting into database
|
||||
|
@ -84,6 +86,9 @@ func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue refl
|
|||
return nil, nil
|
||||
}
|
||||
return t.Float64, nil
|
||||
} else if fieldType.ConvertibleTo(bigFloatType) {
|
||||
t := fieldValue.Convert(bigFloatType).Interface().(big.Float)
|
||||
return t.String(), nil
|
||||
}
|
||||
|
||||
if !col.IsJSON {
|
||||
|
|
24
scan.go
24
scan.go
|
@ -7,6 +7,7 @@ package xorm
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
|
@ -182,13 +183,21 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
|
|||
for _, v := range vv {
|
||||
var replaced bool
|
||||
var scanResult interface{}
|
||||
if _, ok := v.(sql.Scanner); !ok {
|
||||
switch t := v.(type) {
|
||||
case sql.Scanner:
|
||||
scanResult = t
|
||||
case convert.Conversion:
|
||||
scanResult = &sql.RawBytes{}
|
||||
replaced = true
|
||||
case *big.Float:
|
||||
scanResult = &sql.NullString{}
|
||||
replaced = true
|
||||
default:
|
||||
var useNullable = true
|
||||
if engine.driver.Features().SupportNullable {
|
||||
nullable, ok := types[0].Nullable()
|
||||
useNullable = ok && nullable
|
||||
}
|
||||
|
||||
if useNullable {
|
||||
scanResult, replaced, err = genScanResultsByBeanNullable(v)
|
||||
} else {
|
||||
|
@ -197,25 +206,22 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
scanResult = v
|
||||
}
|
||||
|
||||
scanResults = append(scanResults, scanResult)
|
||||
replaces = append(replaces, replaced)
|
||||
}
|
||||
|
||||
var scanCtx = dialects.ScanContext{
|
||||
if err = engine.driver.Scan(&dialects.ScanContext{
|
||||
DBLocation: engine.DatabaseTZ,
|
||||
UserLocation: engine.TZLocation,
|
||||
}
|
||||
|
||||
if err = engine.driver.Scan(&scanCtx, rows, types, scanResults...); err != nil {
|
||||
}, rows, types, scanResults...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, replaced := range replaces {
|
||||
if replaced {
|
||||
if err = convertAssign(vv[i], scanResults[i], scanCtx.DBLocation, engine.TZLocation); err != nil {
|
||||
if err = convertAssign(vv[i], scanResults[i], engine.DatabaseTZ, engine.TZLocation); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
@ -123,6 +124,20 @@ var (
|
|||
conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem()
|
||||
)
|
||||
|
||||
func isScannableStruct(bean interface{}, typeLen int) bool {
|
||||
switch bean.(type) {
|
||||
case *time.Time:
|
||||
return false
|
||||
case sql.Scanner:
|
||||
return false
|
||||
case convert.Conversion:
|
||||
return typeLen > 1
|
||||
case *big.Float:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
|
||||
rows, err := session.queryRows(sqlStr, args...)
|
||||
if err != nil {
|
||||
|
@ -148,13 +163,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table,
|
|||
}
|
||||
switch beanKind {
|
||||
case reflect.Struct:
|
||||
if _, ok := bean.(*time.Time); ok {
|
||||
break
|
||||
}
|
||||
if _, ok := bean.(sql.Scanner); ok {
|
||||
break
|
||||
}
|
||||
if _, ok := bean.(convert.Conversion); len(types) == 1 && ok {
|
||||
if !isScannableStruct(bean, len(types)) {
|
||||
break
|
||||
}
|
||||
return session.getStruct(rows, types, fields, table, bean)
|
||||
|
@ -240,35 +249,9 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields
|
|||
if len(beans) != len(types) {
|
||||
return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans))
|
||||
}
|
||||
var scanResults = make([]interface{}, 0, len(types))
|
||||
var replaceds = make([]bool, 0, len(types))
|
||||
for _, bean := range beans {
|
||||
switch t := bean.(type) {
|
||||
case sql.Scanner:
|
||||
scanResults = append(scanResults, t)
|
||||
replaceds = append(replaceds, false)
|
||||
case convert.Conversion:
|
||||
scanResults = append(scanResults, &sql.RawBytes{})
|
||||
replaceds = append(replaceds, true)
|
||||
default:
|
||||
scanResults = append(scanResults, bean)
|
||||
replaceds = append(replaceds, false)
|
||||
}
|
||||
}
|
||||
|
||||
err := session.engine.scan(rows, fields, types, scanResults...)
|
||||
if err != nil {
|
||||
err := session.engine.scan(rows, fields, types, beans...)
|
||||
return true, err
|
||||
}
|
||||
for i, replaced := range replaceds {
|
||||
if replaced {
|
||||
err = convertAssign(beans[i], scanResults[i], session.engine.DatabaseTZ, session.engine.TZLocation)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user