From 4b3ca2dd42659803cfe0c26addaf7952a961c41e Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 7 Jul 2021 11:19:50 +0800 Subject: [PATCH 1/2] Support big.Float --- convert.go | 12 ++++++++ integrations/session_get_test.go | 25 +++++++++++++++ internal/statements/values.go | 5 +++ scan.go | 24 +++++++++------ session_get.go | 53 +++++++++++--------------------- 5 files changed, 75 insertions(+), 44 deletions(-) diff --git a/convert.go b/convert.go index f7d733ad..c61180d3 100644 --- a/convert.go +++ b/convert.go @@ -9,6 +9,7 @@ import ( "database/sql/driver" "errors" "fmt" + "math/big" "reflect" "strconv" "time" @@ -308,10 +309,12 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve if s.Valid { *d, _ = strconv.Atoi(s.String) } + return nil case *int64: if s.Valid { *d, _ = strconv.ParseInt(s.String, 10, 64) } + return nil case *string: if s.Valid { *d = s.String @@ -337,6 +340,15 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve d.Valid = true d.Time = *dt } + return nil + case *big.Float: + if s.Valid { + if d == nil { + d = big.NewFloat(0) + } + d.SetString(s.String) + } + return nil } case *sql.NullInt32: switch d := dest.(type) { diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 99db98fc..47b38dfa 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -8,6 +8,7 @@ import ( "database/sql" "errors" "fmt" + "math/big" "strconv" "testing" "time" @@ -766,3 +767,27 @@ func TestGetNil(t *testing.T) { assert.True(t, errors.Is(err, xorm.ErrObjectIsNil)) assert.False(t, has) } + +func TestGetBigFloat(t *testing.T) { + type GetBigFloat struct { + Id int64 + Money *big.Float `xorm:"numeric"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetBigFloat)) + + var gf = GetBigFloat{ + Money: big.NewFloat(999999.99), + } + _, err := testEngine.Insert(&gf) + assert.NoError(t, err) + + var m big.Float + has, err := testEngine.Table("get_big_float").Cols("money").Where("id=?", gf.Id).Get(&m) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, m.String() == gf.Money.String()) + //fmt.Println(m.Cmp(gf.Money)) + //assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) +} diff --git a/internal/statements/values.go b/internal/statements/values.go index 71327c55..994070ac 100644 --- a/internal/statements/values.go +++ b/internal/statements/values.go @@ -8,6 +8,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "math/big" "reflect" "time" @@ -19,6 +20,7 @@ import ( var ( nullFloatType = reflect.TypeOf(sql.NullFloat64{}) + bigFloatType = reflect.TypeOf(big.Float{}) ) // Value2Interface convert a field value of a struct to interface for puting into database @@ -84,6 +86,9 @@ func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue refl return nil, nil } return t.Float64, nil + } else if fieldType.ConvertibleTo(bigFloatType) { + t := fieldValue.Convert(bigFloatType).Interface().(big.Float) + return t.String(), nil } if !col.IsJSON { diff --git a/scan.go b/scan.go index c5cb77ff..6396b097 100644 --- a/scan.go +++ b/scan.go @@ -7,6 +7,7 @@ package xorm import ( "database/sql" "fmt" + "math/big" "reflect" "time" @@ -182,13 +183,21 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column for _, v := range vv { var replaced bool var scanResult interface{} - if _, ok := v.(sql.Scanner); !ok { + switch t := v.(type) { + case sql.Scanner: + scanResult = t + case convert.Conversion: + scanResult = &sql.RawBytes{} + replaced = true + case *big.Float: + scanResult = &sql.NullString{} + replaced = true + default: var useNullable = true if engine.driver.Features().SupportNullable { nullable, ok := types[0].Nullable() useNullable = ok && nullable } - if useNullable { scanResult, replaced, err = genScanResultsByBeanNullable(v) } else { @@ -197,25 +206,22 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column if err != nil { return err } - } else { - scanResult = v } + scanResults = append(scanResults, scanResult) replaces = append(replaces, replaced) } - var scanCtx = dialects.ScanContext{ + if err = engine.driver.Scan(&dialects.ScanContext{ DBLocation: engine.DatabaseTZ, UserLocation: engine.TZLocation, - } - - if err = engine.driver.Scan(&scanCtx, rows, types, scanResults...); err != nil { + }, rows, types, scanResults...); err != nil { return err } for i, replaced := range replaces { if replaced { - if err = convertAssign(vv[i], scanResults[i], scanCtx.DBLocation, engine.TZLocation); err != nil { + if err = convertAssign(vv[i], scanResults[i], engine.DatabaseTZ, engine.TZLocation); err != nil { return err } } diff --git a/session_get.go b/session_get.go index cb2bda75..58255033 100644 --- a/session_get.go +++ b/session_get.go @@ -9,6 +9,7 @@ import ( "database/sql/driver" "errors" "fmt" + "math/big" "reflect" "strconv" "time" @@ -123,6 +124,20 @@ var ( conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem() ) +func isScannableStruct(bean interface{}, typeLen int) bool { + switch bean.(type) { + case *time.Time: + return false + case sql.Scanner: + return false + case convert.Conversion: + return typeLen > 1 + case *big.Float: + return false + } + return true +} + func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { @@ -148,13 +163,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, } switch beanKind { case reflect.Struct: - if _, ok := bean.(*time.Time); ok { - break - } - if _, ok := bean.(sql.Scanner); ok { - break - } - if _, ok := bean.(convert.Conversion); len(types) == 1 && ok { + if !isScannableStruct(bean, len(types)) { break } return session.getStruct(rows, types, fields, table, bean) @@ -240,35 +249,9 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields if len(beans) != len(types) { return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans)) } - var scanResults = make([]interface{}, 0, len(types)) - var replaceds = make([]bool, 0, len(types)) - for _, bean := range beans { - switch t := bean.(type) { - case sql.Scanner: - scanResults = append(scanResults, t) - replaceds = append(replaceds, false) - case convert.Conversion: - scanResults = append(scanResults, &sql.RawBytes{}) - replaceds = append(replaceds, true) - default: - scanResults = append(scanResults, bean) - replaceds = append(replaceds, false) - } - } - err := session.engine.scan(rows, fields, types, scanResults...) - if err != nil { - return true, err - } - for i, replaced := range replaceds { - if replaced { - err = convertAssign(beans[i], scanResults[i], session.engine.DatabaseTZ, session.engine.TZLocation) - if err != nil { - return true, err - } - } - } - return true, nil + err := session.engine.scan(rows, fields, types, beans...) + return true, err } func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { -- 2.40.1 From 6119c63a66a5fef41d41e0d489d6eeafd2062d4e Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 7 Jul 2021 13:01:59 +0800 Subject: [PATCH 2/2] improve tests --- dialects/sqlite3.go | 2 +- integrations/session_get_test.go | 52 ++++++++++++++++++++++++-------- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 1bc0b218..04e5b457 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -565,7 +565,7 @@ func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { case "REAL": var s sql.NullFloat64 return &s, nil - case "NUMERIC": + case "NUMERIC", "DECIMAL": var s sql.NullString return &s, nil case "BLOB": diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 47b38dfa..6fc202bc 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -771,23 +771,49 @@ func TestGetNil(t *testing.T) { func TestGetBigFloat(t *testing.T) { type GetBigFloat struct { Id int64 - Money *big.Float `xorm:"numeric"` + Money *big.Float `xorm:"numeric(22,2)"` } assert.NoError(t, PrepareEngine()) assertSync(t, new(GetBigFloat)) - var gf = GetBigFloat{ - Money: big.NewFloat(999999.99), - } - _, err := testEngine.Insert(&gf) - assert.NoError(t, err) + { + var gf = GetBigFloat{ + Money: big.NewFloat(999999.99), + } + _, err := testEngine.Insert(&gf) + assert.NoError(t, err) - var m big.Float - has, err := testEngine.Table("get_big_float").Cols("money").Where("id=?", gf.Id).Get(&m) - assert.NoError(t, err) - assert.True(t, has) - assert.True(t, m.String() == gf.Money.String()) - //fmt.Println(m.Cmp(gf.Money)) - //assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) + var m big.Float + has, err := testEngine.Table("get_big_float").Cols("money").Where("id=?", gf.Id).Get(&m) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) + //fmt.Println(m.Cmp(gf.Money)) + //assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) + } + + type GetBigFloat2 struct { + Id int64 + Money *big.Float `xorm:"decimal(22,2)"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetBigFloat2)) + + { + var gf2 = GetBigFloat2{ + Money: big.NewFloat(9999999.99), + } + _, err := testEngine.Insert(&gf2) + assert.NoError(t, err) + + var m2 big.Float + has, err := testEngine.Table("get_big_float2").Cols("money").Where("id=?", gf2.Id).Get(&m2) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, m2.String() == gf2.Money.String(), "%v != %v", m2.String(), gf2.Money.String()) + //fmt.Println(m.Cmp(gf.Money)) + //assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) + } } -- 2.40.1