Support big.Float #1973

Merged
lunny merged 2 commits from lunny/support_bigfloat into master 2021-07-07 06:00:17 +00:00
6 changed files with 102 additions and 45 deletions

View File

@ -9,6 +9,7 @@ import (
"database/sql/driver"
"errors"
"fmt"
"math/big"
"reflect"
"strconv"
"time"
@ -308,10 +309,12 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve
if s.Valid {
*d, _ = strconv.Atoi(s.String)
}
return nil
case *int64:
if s.Valid {
*d, _ = strconv.ParseInt(s.String, 10, 64)
}
return nil
case *string:
if s.Valid {
*d = s.String
@ -337,6 +340,15 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve
d.Valid = true
d.Time = *dt
}
return nil
case *big.Float:
if s.Valid {
if d == nil {
d = big.NewFloat(0)
}
d.SetString(s.String)
}
return nil
}
case *sql.NullInt32:
switch d := dest.(type) {

View File

@ -565,7 +565,7 @@ func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) {
case "REAL":
var s sql.NullFloat64
return &s, nil
case "NUMERIC":
case "NUMERIC", "DECIMAL":
var s sql.NullString
return &s, nil
case "BLOB":

View File

@ -8,6 +8,7 @@ import (
"database/sql"
"errors"
"fmt"
"math/big"
"strconv"
"testing"
"time"
@ -766,3 +767,53 @@ func TestGetNil(t *testing.T) {
assert.True(t, errors.Is(err, xorm.ErrObjectIsNil))
assert.False(t, has)
}
func TestGetBigFloat(t *testing.T) {
type GetBigFloat struct {
Id int64
Money *big.Float `xorm:"numeric(22,2)"`
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(GetBigFloat))
{
var gf = GetBigFloat{
Money: big.NewFloat(999999.99),
}
_, err := testEngine.Insert(&gf)
assert.NoError(t, err)
var m big.Float
has, err := testEngine.Table("get_big_float").Cols("money").Where("id=?", gf.Id).Get(&m)
assert.NoError(t, err)
assert.True(t, has)
assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String())
//fmt.Println(m.Cmp(gf.Money))
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
}
type GetBigFloat2 struct {
Id int64
Money *big.Float `xorm:"decimal(22,2)"`
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(GetBigFloat2))
{
var gf2 = GetBigFloat2{
Money: big.NewFloat(9999999.99),
}
_, err := testEngine.Insert(&gf2)
assert.NoError(t, err)
var m2 big.Float
has, err := testEngine.Table("get_big_float2").Cols("money").Where("id=?", gf2.Id).Get(&m2)
assert.NoError(t, err)
assert.True(t, has)
assert.True(t, m2.String() == gf2.Money.String(), "%v != %v", m2.String(), gf2.Money.String())
//fmt.Println(m.Cmp(gf.Money))
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
}
}

View File

@ -8,6 +8,7 @@ import (
"database/sql"
"database/sql/driver"
"fmt"
"math/big"
"reflect"
"time"
@ -19,6 +20,7 @@ import (
var (
nullFloatType = reflect.TypeOf(sql.NullFloat64{})
bigFloatType = reflect.TypeOf(big.Float{})
)
// Value2Interface convert a field value of a struct to interface for puting into database
@ -84,6 +86,9 @@ func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue refl
return nil, nil
}
return t.Float64, nil
} else if fieldType.ConvertibleTo(bigFloatType) {
t := fieldValue.Convert(bigFloatType).Interface().(big.Float)
return t.String(), nil
}
if !col.IsJSON {

24
scan.go
View File

@ -7,6 +7,7 @@ package xorm
import (
"database/sql"
"fmt"
"math/big"
"reflect"
"time"
@ -182,13 +183,21 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
for _, v := range vv {
var replaced bool
var scanResult interface{}
if _, ok := v.(sql.Scanner); !ok {
switch t := v.(type) {
case sql.Scanner:
scanResult = t
case convert.Conversion:
scanResult = &sql.RawBytes{}
replaced = true
case *big.Float:
scanResult = &sql.NullString{}
replaced = true
default:
var useNullable = true
if engine.driver.Features().SupportNullable {
nullable, ok := types[0].Nullable()
useNullable = ok && nullable
}
if useNullable {
scanResult, replaced, err = genScanResultsByBeanNullable(v)
} else {
@ -197,25 +206,22 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
if err != nil {
return err
}
} else {
scanResult = v
}
scanResults = append(scanResults, scanResult)
replaces = append(replaces, replaced)
}
var scanCtx = dialects.ScanContext{
if err = engine.driver.Scan(&dialects.ScanContext{
DBLocation: engine.DatabaseTZ,
UserLocation: engine.TZLocation,
}
if err = engine.driver.Scan(&scanCtx, rows, types, scanResults...); err != nil {
}, rows, types, scanResults...); err != nil {
return err
}
for i, replaced := range replaces {
if replaced {
if err = convertAssign(vv[i], scanResults[i], scanCtx.DBLocation, engine.TZLocation); err != nil {
if err = convertAssign(vv[i], scanResults[i], engine.DatabaseTZ, engine.TZLocation); err != nil {
return err
}
}

View File

@ -9,6 +9,7 @@ import (
"database/sql/driver"
"errors"
"fmt"
"math/big"
"reflect"
"strconv"
"time"
@ -123,6 +124,20 @@ var (
conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem()
)
func isScannableStruct(bean interface{}, typeLen int) bool {
switch bean.(type) {
case *time.Time:
return false
case sql.Scanner:
return false
case convert.Conversion:
return typeLen > 1
case *big.Float:
return false
}
return true
}
func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
@ -148,13 +163,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table,
}
switch beanKind {
case reflect.Struct:
if _, ok := bean.(*time.Time); ok {
break
}
if _, ok := bean.(sql.Scanner); ok {
break
}
if _, ok := bean.(convert.Conversion); len(types) == 1 && ok {
if !isScannableStruct(bean, len(types)) {
break
}
return session.getStruct(rows, types, fields, table, bean)
@ -240,35 +249,9 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields
if len(beans) != len(types) {
return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans))
}
var scanResults = make([]interface{}, 0, len(types))
var replaceds = make([]bool, 0, len(types))
for _, bean := range beans {
switch t := bean.(type) {
case sql.Scanner:
scanResults = append(scanResults, t)
replaceds = append(replaceds, false)
case convert.Conversion:
scanResults = append(scanResults, &sql.RawBytes{})
replaceds = append(replaceds, true)
default:
scanResults = append(scanResults, bean)
replaceds = append(replaceds, false)
}
}
err := session.engine.scan(rows, fields, types, scanResults...)
if err != nil {
return true, err
}
for i, replaced := range replaceds {
if replaced {
err = convertAssign(beans[i], scanResults[i], session.engine.DatabaseTZ, session.engine.TZLocation)
if err != nil {
return true, err
}
}
}
return true, nil
err := session.engine.scan(rows, fields, types, beans...)
return true, err
}
func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) {