Improve get field value of bean #1961

Merged
lunny merged 4 commits from lunny/refactor_get_fields into master 2021-06-29 06:32:29 +00:00
13 changed files with 49 additions and 141 deletions

View File

@ -249,11 +249,11 @@ volumes:
services:
- name: mssql
pull: always
image: microsoft/mssql-server-linux:latest
image: mcr.microsoft.com/mssql/server:latest
environment:
ACCEPT_EULA: Y
SA_PASSWORD: yourStrong(!)Password
MSSQL_PID: Developer
MSSQL_PID: Standard
---
kind: pipeline

View File

@ -175,7 +175,10 @@ func convertAssign(dest, src interface{}) error {
return nil
}
dpv := reflect.ValueOf(dest)
return convertAssignV(reflect.ValueOf(dest), src)
}
func convertAssignV(dpv reflect.Value, src interface{}) error {
if dpv.Kind() != reflect.Ptr {
return errors.New("destination not a pointer")
}
@ -183,9 +186,7 @@ func convertAssign(dest, src interface{}) error {
return errNilPtr
}
if !sv.IsValid() {
sv = reflect.ValueOf(src)
}
var sv = reflect.ValueOf(src)
dv := reflect.Indirect(dpv)
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
@ -244,7 +245,7 @@ func convertAssign(dest, src interface{}) error {
return nil
}
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dpv.Interface())
}
func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
@ -375,44 +376,3 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) {
}
return v.Interface(), nil
}
func int64ToIntValue(id int64, tp reflect.Type) reflect.Value {
var v interface{}
kind := tp.Kind()
if kind == reflect.Ptr {
kind = tp.Elem().Kind()
}
switch kind {
case reflect.Int16:
temp := int16(id)
v = &temp
case reflect.Int32:
temp := int32(id)
v = &temp
case reflect.Int:
temp := int(id)
v = &temp
case reflect.Int64:
temp := id
v = &temp
case reflect.Uint16:
temp := uint16(id)
v = &temp
case reflect.Uint32:
temp := uint32(id)
v = &temp
case reflect.Uint64:
temp := uint64(id)
v = &temp
case reflect.Uint:
temp := uint(id)
v = &temp
}
if tp.Kind() == reflect.Ptr {
return reflect.ValueOf(v).Convert(tp)
}
return reflect.ValueOf(v).Elem().Convert(tp)
}

View File

@ -652,11 +652,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return errors.New("unknown column error")
}
fields := strings.Split(col.FieldName, ".")
field := dataStruct
for _, fieldName := range fields {
field = field.FieldByName(fieldName)
}
field := dataStruct.FieldByIndex(col.FieldIndex)
temp += "," + formatColumnValue(dstDialect, field.Interface(), col)
}
_, err = io.WriteString(w, temp[1:]+");\n")

View File

@ -734,6 +734,8 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
//engine.logger.Warn(err)
}
continue
} else if fieldValuePtr == nil {
continue
}
if col.IsDeleted && !unscoped { // tag "deleted" is enabled

View File

@ -88,6 +88,9 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value,
if err != nil {
return nil, nil, err
}
if fieldValuePtr == nil {
continue
}
fieldValue := *fieldValuePtr
fieldType := reflect.TypeOf(fieldValue.Interface())

View File

@ -6,10 +6,8 @@ package schemas
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
@ -25,6 +23,7 @@ type Column struct {
Name string
TableName string
FieldName string // Available only when parsed from a struct
FieldIndex []int // Available only when parsed from a struct
SQLType SQLType
IsJSON bool
Length int
@ -83,41 +82,17 @@ func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) {
// ValueOfV returns column's filed of struct's value accept reflevt value
func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
var fieldValue reflect.Value
fieldPath := strings.Split(col.FieldName, ".")
if dataStruct.Type().Kind() == reflect.Map {
keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1])
fieldValue = dataStruct.MapIndex(keyValue)
return &fieldValue, nil
} else if dataStruct.Type().Kind() == reflect.Interface {
structValue := reflect.ValueOf(dataStruct.Interface())
dataStruct = &structValue
}
level := len(fieldPath)
fieldValue = dataStruct.FieldByName(fieldPath[0])
for i := 0; i < level-1; i++ {
if !fieldValue.IsValid() {
break
}
if fieldValue.Kind() == reflect.Struct {
fieldValue = fieldValue.FieldByName(fieldPath[i+1])
} else if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
var v = *dataStruct
for _, i := range col.FieldIndex {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1])
} else {
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
v = v.Elem()
}
v = v.FieldByIndex([]int{i})
}
if !fieldValue.IsValid() {
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
}
return &fieldValue, nil
return &v, nil
}
// ConvertID converts id content to suitable type according column type

View File

@ -5,7 +5,6 @@
package schemas
import (
"fmt"
"reflect"
"strconv"
"strings"
@ -159,24 +158,8 @@ func (table *Table) IDOfV(rv reflect.Value) (PK, error) {
for i, col := range table.PKColumns() {
var err error
fieldName := col.FieldName
for {
parts := strings.SplitN(fieldName, ".", 2)
if len(parts) == 1 {
break
}
pkField := v.FieldByIndex(col.FieldIndex)
v = v.FieldByName(parts[0])
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return nil, fmt.Errorf("Unsupported read value of column %s from field %s", col.Name, col.FieldName)
}
fieldName = parts[1]
}
pkField := v.FieldByName(fieldName)
switch pkField.Kind() {
case reflect.String:
pk[i], err = col.ConvertID(pkField.String())

View File

@ -375,6 +375,9 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s
if err != nil {
return nil, err
}
if fieldValue == nil {
return nil, ErrFieldIsNotValid{key, table.Name}
}
if !fieldValue.IsValid() || !fieldValue.CanSet() {
return nil, ErrFieldIsNotValid{key, table.Name}

View File

@ -35,27 +35,20 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time
sd, err := strconv.ParseInt(sdata, 10, 64)
if err == nil {
x = time.Unix(sd, 0)
//session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else {
//session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
}
} else if len(sdata) > 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc)
session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.Name, x, sdata)
if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc)
//session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
}
if err != nil {
x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc)
//session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
}
} else if len(sdata) == 19 && strings.Contains(sdata, "-") {
x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc)
//session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc)
//session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if col.SQLType.Name == schemas.Time {
if strings.Contains(sdata, " ") {
ssd := strings.Split(sdata, " ")
@ -69,7 +62,6 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time
st := fmt.Sprintf("2006-01-02 %v", sdata)
x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc)
//session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else {
outErr = fmt.Errorf("unsupported time format %v", sdata)
return

View File

@ -374,9 +374,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 1, nil
}
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
return 1, nil
return 1, convertAssignV(aiValue.Addr(), id)
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
session.engine.dialect.URI().DBType == schemas.MSSQL) {
res, err := session.queryBytes(sqlStr, args...)
@ -416,9 +414,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 1, nil
}
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
return 1, nil
return 1, convertAssignV(aiValue.Addr(), id)
}
res, err := session.exec(sqlStr, args...)
@ -458,7 +454,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return res.RowsAffected()
}
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
if err := convertAssignV(aiValue.Addr(), id); err != nil {
return 0, err
}
return res.RowsAffected()
}

View File

@ -280,15 +280,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
k = ct.Elem().Kind()
}
if k == reflect.Struct {
var refTable = session.statement.RefTable
if refTable == nil {
refTable, err = session.engine.TableInfo(condiBean[0])
if err != nil {
return 0, err
}
condTable, err := session.engine.TableInfo(condiBean[0])
if err != nil {
return 0, err
}
var err error
autoCond, err = session.statement.BuildConds(refTable, condiBean[0], true, true, false, true, false)
autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false)
if err != nil {
return 0, err
}

View File

@ -126,7 +126,7 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index
var ErrIgnoreField = errors.New("field will be ignored")
func (parser *Parser) parseFieldWithNoTag(field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
var sqlType schemas.SQLType
if fieldValue.CanAddr() {
if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
@ -141,6 +141,7 @@ func (parser *Parser) parseFieldWithNoTag(field reflect.StructField, fieldValue
col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name),
field.Name, sqlType, sqlType.DefaultLength,
sqlType.DefaultLength2, true)
col.FieldIndex = []int{fieldIndex}
if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
col.IsAutoIncrement = true
@ -150,9 +151,10 @@ func (parser *Parser) parseFieldWithNoTag(field reflect.StructField, fieldValue
return col, nil
}
func (parser *Parser) parseFieldWithTags(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) {
func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) {
var col = &schemas.Column{
FieldName: field.Name,
FieldIndex: []int{fieldIndex},
Nullable: true,
IsPrimaryKey: false,
IsAutoIncrement: false,
@ -238,7 +240,7 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, field reflect.Str
return col, nil
}
func (parser *Parser) parseField(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
var (
tag = field.Tag
ormTagStr = strings.TrimSpace(tag.Get(parser.identifier))
@ -247,13 +249,13 @@ func (parser *Parser) parseField(table *schemas.Table, field reflect.StructField
return nil, ErrIgnoreField
}
if ormTagStr == "" {
return parser.parseFieldWithNoTag(field, fieldValue)
return parser.parseFieldWithNoTag(fieldIndex, field, fieldValue)
}
tags, err := splitTag(ormTagStr)
if err != nil {
return nil, err
}
return parser.parseFieldWithTags(table, field, fieldValue, tags)
return parser.parseFieldWithTags(table, fieldIndex, field, fieldValue, tags)
}
func isNotTitle(n string) bool {
@ -279,16 +281,12 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
table.Name = names.GetTableName(parser.tableMapper, v)
for i := 0; i < t.NumField(); i++ {
if isNotTitle(t.Field(i).Name) {
var field = t.Field(i)
if isNotTitle(field.Name) {
continue
}
var (
field = t.Field(i)
fieldValue = v.Field(i)
)
col, err := parser.parseField(table, field, fieldValue)
col, err := parser.parseField(table, i, field, v.Field(i))
if err == ErrIgnoreField {
continue
} else if err != nil {

View File

@ -338,6 +338,7 @@ func ExtendsTagHandler(ctx *Context) error {
}
for _, col := range parentTable.Columns() {
col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName)
col.FieldIndex = append(ctx.col.FieldIndex, col.FieldIndex...)
var tagPrefix = ctx.col.FieldName
if len(ctx.params) > 0 {