Browse Source

Move value2interface from session to statement package (#1587)

Fix zero

Fix tests

Move value2interface from session to statement package

Reviewed-on: #1587
pull/1588/head
Lunny Xiao 4 months ago
parent
commit
188da20272
6 changed files with 275 additions and 185 deletions
  1. +151
    -0
      internal/statements/values.go
  2. +43
    -7
      internal/utils/zero.go
  3. +73
    -0
      internal/utils/zero_test.go
  4. +0
    -136
      session_convert.go
  5. +5
    -22
      session_insert.go
  6. +3
    -20
      session_update.go

+ 151
- 0
internal/statements/values.go View File

@ -0,0 +1,151 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"time"
"xorm.io/xorm/convert"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/json"
"xorm.io/xorm/schemas"
)
var (
nullFloatType = reflect.TypeOf(sql.NullFloat64{})
)
// Value2Interface convert a field value of a struct to interface for puting into database
func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue reflect.Value) (interface{}, error) {
if fieldValue.CanAddr() {
if fieldConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
data, err := fieldConvert.ToDB()
if err != nil {
return nil, err
}
if col.SQLType.IsBlob() {
return data, nil
}
return string(data), nil
}
}
if fieldConvert, ok := fieldValue.Interface().(convert.Conversion); ok {
data, err := fieldConvert.ToDB()
if err != nil {
return nil, err
}
if col.SQLType.IsBlob() {
return data, nil
}
return string(data), nil
}
fieldType := fieldValue.Type()
k := fieldType.Kind()
if k == reflect.Ptr {
if fieldValue.IsNil() {
return nil, nil
} else if !fieldValue.IsValid() {
return nil, nil
} else {
// !nashtsai! deference pointer type to instance type
fieldValue = fieldValue.Elem()
fieldType = fieldValue.Type()
k = fieldType.Kind()
}
}
switch k {
case reflect.Bool:
return fieldValue.Bool(), nil
case reflect.String:
return fieldValue.String(), nil
case reflect.Struct:
if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
tf := dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t)
return tf, nil
} else if fieldType.ConvertibleTo(nullFloatType) {
t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64)
if !t.Valid {
return nil, nil
}
return t.Float64, nil
}
if !col.SQLType.IsJson() {
// !<winxxp>! 增加支持driver.Valuer接口的结构如sql.NullString
if v, ok := fieldValue.Interface().(driver.Valuer); ok {
return v.Value()
}
fieldTable, err := statement.tagParser.ParseWithCache(fieldValue)
if err != nil {
return nil, err
}
if len(fieldTable.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName)
return pkField.Interface(), nil
}
return nil, fmt.Errorf("no primary key for col %v", col.Name)
}
if col.SQLType.IsText() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, err
}
return string(bytes), nil
} else if col.SQLType.IsBlob() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, err
}
return bytes, nil
}
return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type())
case reflect.Complex64, reflect.Complex128:
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, err
}
return string(bytes), nil
case reflect.Array, reflect.Slice, reflect.Map:
if !fieldValue.IsValid() {
return fieldValue.Interface(), nil
}
if col.SQLType.IsText() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, err
}
return string(bytes), nil
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if (k == reflect.Slice) &&
(fieldValue.Type().Elem().Kind() == reflect.Uint8) {
bytes = fieldValue.Bytes()
} else {
bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, err
}
}
return bytes, nil
}
return nil, ErrUnSupportedType
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return int64(fieldValue.Uint()), nil
default:
return fieldValue.Interface(), nil
}
}

+ 43
- 7
internal/utils/zero.go View File

@ -13,7 +13,14 @@ type Zeroable interface {
IsZero() bool
}
var nilTime *time.Time
// IsZero returns false if k is nil or has a zero value
func IsZero(k interface{}) bool {
if k == nil {
return true
}
switch k.(type) {
case int:
return k.(int) == 0
@ -43,28 +50,57 @@ func IsZero(k interface{}) bool {
return k.(bool) == false
case string:
return k.(string) == ""
case *time.Time:
return k.(*time.Time) == nilTime || IsTimeZero(*k.(*time.Time))
case time.Time:
return IsTimeZero(k.(time.Time))
case Zeroable:
return k.(Zeroable).IsZero()
return k.(Zeroable) == nil || k.(Zeroable).IsZero()
case reflect.Value: // for go version less than 1.13 because reflect.Value has no method IsZero
return IsValueZero(k.(reflect.Value))
}
return false
return IsValueZero(reflect.ValueOf(k))
}
var zeroType = reflect.TypeOf((*Zeroable)(nil)).Elem()
func IsValueZero(v reflect.Value) bool {
if IsZero(v.Interface()) {
return true
}
switch v.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Slice:
return v.IsNil()
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
return v.Int() == 0
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
return v.Uint() == 0
case reflect.String:
return v.Len() == 0
case reflect.Ptr:
if v.IsNil() {
return true
}
return IsValueZero(v.Elem())
case reflect.Struct:
return IsStructZero(v)
case reflect.Array:
return IsArrayZero(v)
}
return false
}
func IsStructZero(v reflect.Value) bool {
if !v.IsValid() {
if !v.IsValid() || v.NumField() == 0 {
return true
}
if v.Type().Implements(zeroType) {
f := v.MethodByName("IsZero")
if f.IsValid() {
res := f.Call(nil)
return len(res) == 1 && res[0].Bool()
}
}
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
switch field.Kind() {

+ 73
- 0
internal/utils/zero_test.go View File

@ -0,0 +1,73 @@
// Copyright 2020 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package utils
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type MyInt int
type ZeroStruct struct{}
func TestZero(t *testing.T) {
var zeroValues = []interface{}{
int8(0),
int16(0),
int(0),
int32(0),
int64(0),
uint8(0),
uint16(0),
uint(0),
uint32(0),
uint64(0),
MyInt(0),
reflect.ValueOf(0),
nil,
time.Time{},
&time.Time{},
nilTime,
ZeroStruct{},
&ZeroStruct{},
}
for _, v := range zeroValues {
t.Run(fmt.Sprintf("%#v", v), func(t *testing.T) {
assert.True(t, IsZero(v))
})
}
}
func TestIsValueZero(t *testing.T) {
var zeroReflectValues = []reflect.Value{
reflect.ValueOf(int8(0)),
reflect.ValueOf(int16(0)),
reflect.ValueOf(int(0)),
reflect.ValueOf(int32(0)),
reflect.ValueOf(int64(0)),
reflect.ValueOf(uint8(0)),
reflect.ValueOf(uint16(0)),
reflect.ValueOf(uint(0)),
reflect.ValueOf(uint32(0)),
reflect.ValueOf(uint64(0)),
reflect.ValueOf(MyInt(0)),
reflect.ValueOf(time.Time{}),
reflect.ValueOf(&time.Time{}),
reflect.ValueOf(nilTime),
reflect.ValueOf(ZeroStruct{}),
reflect.ValueOf(&ZeroStruct{}),
}
for _, v := range zeroReflectValues {
t.Run(fmt.Sprintf("%#v", v), func(t *testing.T) {
assert.True(t, IsValueZero(v))
})
}
}

+ 0
- 136
session_convert.go View File

@ -6,7 +6,6 @@ package xorm
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
@ -15,7 +14,6 @@ import (
"time"
"xorm.io/xorm/convert"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/json"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
@ -88,10 +86,6 @@ func (session *Session) byte2Time(col *schemas.Column, data []byte) (outTime tim
return session.str2Time(col, string(data))
}
var (
nullFloatType = reflect.TypeOf(sql.NullFloat64{})
)
// convert a db data([]byte) to a field value
func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Value, data []byte) error {
if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
@ -533,133 +527,3 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
return nil
}
// convert a field value of a struct to interface for put into db
func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.Value) (interface{}, error) {
if fieldValue.CanAddr() {
if fieldConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
data, err := fieldConvert.ToDB()
if err != nil {
return 0, err
}
if col.SQLType.IsBlob() {
return data, nil
}
return string(data), nil
}
}
if fieldConvert, ok := fieldValue.Interface().(convert.Conversion); ok {
data, err := fieldConvert.ToDB()
if err != nil {
return 0, err
}
if col.SQLType.IsBlob() {
return data, nil
}
return string(data), nil
}
fieldType := fieldValue.Type()
k := fieldType.Kind()
if k == reflect.Ptr {
if fieldValue.IsNil() {
return nil, nil
} else if !fieldValue.IsValid() {
session.engine.logger.Warnf("the field [%s] is invalid", col.FieldName)
return nil, nil
} else {
// !nashtsai! deference pointer type to instance type
fieldValue = fieldValue.Elem()
fieldType = fieldValue.Type()
k = fieldType.Kind()
}
}
switch k {
case reflect.Bool:
return fieldValue.Bool(), nil
case reflect.String:
return fieldValue.String(), nil
case reflect.Struct:
if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
tf := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, t)
return tf, nil
} else if fieldType.ConvertibleTo(nullFloatType) {
t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64)
if !t.Valid {
return nil, nil
}
return t.Float64, nil
}
if !col.SQLType.IsJson() {
// !<winxxp>! 增加支持driver.Valuer接口的结构如sql.NullString
if v, ok := fieldValue.Interface().(driver.Valuer); ok {
return v.Value()
}
fieldTable, err := session.engine.tagParser.ParseWithCache(fieldValue)
if err != nil {
return nil, err
}
if len(fieldTable.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName)
return pkField.Interface(), nil
}
return 0, fmt.Errorf("no primary key for col %v", col.Name)
}
if col.SQLType.IsText() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return 0, err
}
return string(bytes), nil
} else if col.SQLType.IsBlob() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return 0, err
}
return bytes, nil
}
return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type())
case reflect.Complex64, reflect.Complex128:
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return 0, err
}
return string(bytes), nil
case reflect.Array, reflect.Slice, reflect.Map:
if !fieldValue.IsValid() {
return fieldValue.Interface(), nil
}
if col.SQLType.IsText() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return 0, err
}
return string(bytes), nil
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if (k == reflect.Slice) &&
(fieldValue.Type().Elem().Kind() == reflect.Uint8) {
bytes = fieldValue.Bytes()
} else {
bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return 0, err
}
}
return bytes, nil
}
return nil, ErrUnSupportedType
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return int64(fieldValue.Uint()), nil
default:
return fieldValue.Interface(), nil
}
}

+ 5
- 22
session_insert.go View File

@ -176,7 +176,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
setColumnInt(bean, col, 1)
})
} else {
arg, err := session.value2Interface(col, fieldValue)
arg, err := session.statement.Value2Interface(col, fieldValue)
if err != nil {
return 0, err
}
@ -227,7 +227,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
setColumnInt(bean, col, 1)
})
} else {
arg, err := session.value2Interface(col, fieldValue)
arg, err := session.statement.Value2Interface(col, fieldValue)
if err != nil {
return 0, err
}
@ -567,25 +567,8 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
if fieldValue.Uint() == 0 {
continue
}
case reflect.String:
if len(fieldValue.String()) == 0 {
continue
}
case reflect.Ptr:
if fieldValue.Pointer() == 0 {
continue
}
}
if col.IsAutoIncrement && utils.IsValueZero(fieldValue) {
continue
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
@ -609,7 +592,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
} else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
arg, err := session.statement.Value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}

+ 3
- 20
session_update.go View File

@ -473,25 +473,8 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
if fieldValue.Uint() == 0 {
continue
}
case reflect.String:
if len(fieldValue.String()) == 0 {
continue
}
case reflect.Ptr:
if fieldValue.Pointer() == 0 {
continue
}
}
if col.IsAutoIncrement && utils.IsValueZero(fieldValue) {
continue
}
if (col.IsDeleted && !session.statement.GetUnscoped()) || col.IsCreated {
@ -532,7 +515,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
} else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1)
} else {
arg, err := session.value2Interface(col, fieldValue)
arg, err := session.statement.Value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}

Loading…
Cancel
Save