From 64af3fef001b9f9c0d88b8097533440d8e66dc5a Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 29 Jun 2021 11:04:46 +0800 Subject: [PATCH 1/4] improve get field value of bean --- convert.go | 52 ++++---------------------------- engine.go | 6 +--- internal/statements/statement.go | 2 ++ schemas/column.go | 43 ++++++-------------------- schemas/table.go | 19 +----------- session.go | 3 ++ session_convert.go | 8 ----- session_insert.go | 12 +++----- session_update.go | 13 +++----- tags/parser.go | 22 ++++++-------- tags/tag.go | 1 + 11 files changed, 43 insertions(+), 138 deletions(-) diff --git a/convert.go b/convert.go index ee5b6029..b7f30cad 100644 --- a/convert.go +++ b/convert.go @@ -175,7 +175,10 @@ func convertAssign(dest, src interface{}) error { return nil } - dpv := reflect.ValueOf(dest) + return convertAssignV(reflect.ValueOf(dest), src) +} + +func convertAssignV(dpv reflect.Value, src interface{}) error { if dpv.Kind() != reflect.Ptr { return errors.New("destination not a pointer") } @@ -183,9 +186,7 @@ func convertAssign(dest, src interface{}) error { return errNilPtr } - if !sv.IsValid() { - sv = reflect.ValueOf(src) - } + var sv = reflect.ValueOf(src) dv := reflect.Indirect(dpv) if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { @@ -244,7 +245,7 @@ func convertAssign(dest, src interface{}) error { return nil } - return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dpv.Interface()) } func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { @@ -375,44 +376,3 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) { } return v.Interface(), nil } - -func int64ToIntValue(id int64, tp reflect.Type) reflect.Value { - var v interface{} - kind := tp.Kind() - - if kind == reflect.Ptr { - kind = tp.Elem().Kind() - } - - switch kind { - case reflect.Int16: - temp := int16(id) - v = &temp - case reflect.Int32: - temp := int32(id) - v = &temp - case reflect.Int: - temp := int(id) - v = &temp - case reflect.Int64: - temp := id - v = &temp - case reflect.Uint16: - temp := uint16(id) - v = &temp - case reflect.Uint32: - temp := uint32(id) - v = &temp - case reflect.Uint64: - temp := uint64(id) - v = &temp - case reflect.Uint: - temp := uint(id) - v = &temp - } - - if tp.Kind() == reflect.Ptr { - return reflect.ValueOf(v).Convert(tp) - } - return reflect.ValueOf(v).Elem().Convert(tp) -} diff --git a/engine.go b/engine.go index 649ec1a2..76ce8f1a 100644 --- a/engine.go +++ b/engine.go @@ -652,11 +652,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return errors.New("unknown column error") } - fields := strings.Split(col.FieldName, ".") - field := dataStruct - for _, fieldName := range fields { - field = field.FieldByName(fieldName) - } + field := dataStruct.FieldByIndex(col.FieldIndex) temp += "," + formatColumnValue(dstDialect, field.Interface(), col) } _, err = io.WriteString(w, temp[1:]+");\n") diff --git a/internal/statements/statement.go b/internal/statements/statement.go index ca59817b..b1a5ed3c 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -734,6 +734,8 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, //engine.logger.Warn(err) } continue + } else if fieldValuePtr == nil { + continue } if col.IsDeleted && !unscoped { // tag "deleted" is enabled diff --git a/schemas/column.go b/schemas/column.go index 24b53802..6dd3c1d7 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -6,10 +6,8 @@ package schemas import ( "errors" - "fmt" "reflect" "strconv" - "strings" "time" ) @@ -25,6 +23,7 @@ type Column struct { Name string TableName string FieldName string // Available only when parsed from a struct + FieldIndex []int // Available only when parsed from a struct SQLType SQLType IsJSON bool Length int @@ -83,41 +82,17 @@ func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) { // ValueOfV returns column's filed of struct's value accept reflevt value func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { - var fieldValue reflect.Value - fieldPath := strings.Split(col.FieldName, ".") - - if dataStruct.Type().Kind() == reflect.Map { - keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1]) - fieldValue = dataStruct.MapIndex(keyValue) - return &fieldValue, nil - } else if dataStruct.Type().Kind() == reflect.Interface { - structValue := reflect.ValueOf(dataStruct.Interface()) - dataStruct = &structValue - } - - level := len(fieldPath) - fieldValue = dataStruct.FieldByName(fieldPath[0]) - for i := 0; i < level-1; i++ { - if !fieldValue.IsValid() { - break - } - if fieldValue.Kind() == reflect.Struct { - fieldValue = fieldValue.FieldByName(fieldPath[i+1]) - } else if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + var v = *dataStruct + for _, i := range col.FieldIndex { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return nil, nil } - fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) - } else { - return nil, fmt.Errorf("field %v is not valid", col.FieldName) + v = v.Elem() } + v = v.FieldByIndex([]int{i}) } - - if !fieldValue.IsValid() { - return nil, fmt.Errorf("field %v is not valid", col.FieldName) - } - - return &fieldValue, nil + return &v, nil } // ConvertID converts id content to suitable type according column type diff --git a/schemas/table.go b/schemas/table.go index bfa517aa..91b33e06 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -5,7 +5,6 @@ package schemas import ( - "fmt" "reflect" "strconv" "strings" @@ -159,24 +158,8 @@ func (table *Table) IDOfV(rv reflect.Value) (PK, error) { for i, col := range table.PKColumns() { var err error - fieldName := col.FieldName - for { - parts := strings.SplitN(fieldName, ".", 2) - if len(parts) == 1 { - break - } + pkField := v.FieldByIndex(col.FieldIndex) - v = v.FieldByName(parts[0]) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - if v.Kind() != reflect.Struct { - return nil, fmt.Errorf("Unsupported read value of column %s from field %s", col.Name, col.FieldName) - } - fieldName = parts[1] - } - - pkField := v.FieldByName(fieldName) switch pkField.Kind() { case reflect.String: pk[i], err = col.ConvertID(pkField.String()) diff --git a/session.go b/session.go index d5ccb6dc..6df9e20d 100644 --- a/session.go +++ b/session.go @@ -375,6 +375,9 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s if err != nil { return nil, err } + if fieldValue == nil { + return nil, ErrFieldIsNotValid{key, table.Name} + } if !fieldValue.IsValid() || !fieldValue.CanSet() { return nil, ErrFieldIsNotValid{key, table.Name} diff --git a/session_convert.go b/session_convert.go index a6839947..9951b300 100644 --- a/session_convert.go +++ b/session_convert.go @@ -35,27 +35,20 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time sd, err := strconv.ParseInt(sdata, 10, 64) if err == nil { x = time.Unix(sd, 0) - //session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - } else { - //session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } } else if len(sdata) > 19 && strings.Contains(sdata, "-") { x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc) session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) if err != nil { x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc) - //session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } if err != nil { x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc) - //session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } } else if len(sdata) == 19 && strings.Contains(sdata, "-") { x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc) - //session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) - //session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else if col.SQLType.Name == schemas.Time { if strings.Contains(sdata, " ") { ssd := strings.Split(sdata, " ") @@ -69,7 +62,6 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time st := fmt.Sprintf("2006-01-02 %v", sdata) x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc) - //session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else { outErr = fmt.Errorf("unsupported time format %v", sdata) return diff --git a/session_insert.go b/session_insert.go index 5f968151..82d91969 100644 --- a/session_insert.go +++ b/session_insert.go @@ -374,9 +374,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) - - return 1, nil + return 1, convertAssignV(aiValue.Addr(), id) } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) @@ -416,9 +414,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) - - return 1, nil + return 1, convertAssignV(aiValue.Addr(), id) } res, err := session.exec(sqlStr, args...) @@ -458,7 +454,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) + if err := convertAssignV(aiValue.Addr(), id); err != nil { + return 0, err + } return res.RowsAffected() } diff --git a/session_update.go b/session_update.go index d96226da..78907e43 100644 --- a/session_update.go +++ b/session_update.go @@ -280,15 +280,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 k = ct.Elem().Kind() } if k == reflect.Struct { - var refTable = session.statement.RefTable - if refTable == nil { - refTable, err = session.engine.TableInfo(condiBean[0]) - if err != nil { - return 0, err - } + condTable, err := session.engine.TableInfo(condiBean[0]) + if err != nil { + return 0, err } - var err error - autoCond, err = session.statement.BuildConds(refTable, condiBean[0], true, true, false, true, false) + + autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false) if err != nil { return 0, err } diff --git a/tags/parser.go b/tags/parser.go index 599e9e0e..d701e316 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -126,7 +126,7 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index var ErrIgnoreField = errors.New("field will be ignored") -func (parser *Parser) parseFieldWithNoTag(field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { +func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { var sqlType schemas.SQLType if fieldValue.CanAddr() { if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { @@ -141,6 +141,7 @@ func (parser *Parser) parseFieldWithNoTag(field reflect.StructField, fieldValue col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name), field.Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true) + col.FieldIndex = []int{fieldIndex} if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { col.IsAutoIncrement = true @@ -150,9 +151,10 @@ func (parser *Parser) parseFieldWithNoTag(field reflect.StructField, fieldValue return col, nil } -func (parser *Parser) parseFieldWithTags(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) { +func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) { var col = &schemas.Column{ FieldName: field.Name, + FieldIndex: []int{fieldIndex}, Nullable: true, IsPrimaryKey: false, IsAutoIncrement: false, @@ -238,7 +240,7 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, field reflect.Str return col, nil } -func (parser *Parser) parseField(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { +func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { var ( tag = field.Tag ormTagStr = strings.TrimSpace(tag.Get(parser.identifier)) @@ -247,13 +249,13 @@ func (parser *Parser) parseField(table *schemas.Table, field reflect.StructField return nil, ErrIgnoreField } if ormTagStr == "" { - return parser.parseFieldWithNoTag(field, fieldValue) + return parser.parseFieldWithNoTag(fieldIndex, field, fieldValue) } tags, err := splitTag(ormTagStr) if err != nil { return nil, err } - return parser.parseFieldWithTags(table, field, fieldValue, tags) + return parser.parseFieldWithTags(table, fieldIndex, field, fieldValue, tags) } func isNotTitle(n string) bool { @@ -279,16 +281,12 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { table.Name = names.GetTableName(parser.tableMapper, v) for i := 0; i < t.NumField(); i++ { - if isNotTitle(t.Field(i).Name) { + var field = t.Field(i) + if isNotTitle(field.Name) { continue } - var ( - field = t.Field(i) - fieldValue = v.Field(i) - ) - - col, err := parser.parseField(table, field, fieldValue) + col, err := parser.parseField(table, i, field, v.Field(i)) if err == ErrIgnoreField { continue } else if err != nil { diff --git a/tags/tag.go b/tags/tag.go index d8d9bb46..4a39ba54 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -338,6 +338,7 @@ func ExtendsTagHandler(ctx *Context) error { } for _, col := range parentTable.Columns() { col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName) + col.FieldIndex = append(ctx.col.FieldIndex, col.FieldIndex...) var tagPrefix = ctx.col.FieldName if len(ctx.params) > 0 { -- 2.40.1 From 1c5b27a85b4d54fe658daa34e365c43118066619 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 29 Jun 2021 11:17:03 +0800 Subject: [PATCH 2/4] fix nil struct --- schemas/column.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schemas/column.go b/schemas/column.go index 6dd3c1d7..4bbb6c2d 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -86,7 +86,7 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { for _, i := range col.FieldIndex { if v.Kind() == reflect.Ptr { if v.IsNil() { - return nil, nil + v.Set(reflect.New(v.Type().Elem())) } v = v.Elem() } -- 2.40.1 From c9d3db98be6dc0f6f9869b2db052a7014124c1f3 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 29 Jun 2021 11:24:50 +0800 Subject: [PATCH 3/4] Fix nil struct --- internal/statements/update.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/statements/update.go b/internal/statements/update.go index 251880b2..06cf0689 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -88,6 +88,9 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, if err != nil { return nil, nil, err } + if fieldValuePtr == nil { + continue + } fieldValue := *fieldValuePtr fieldType := reflect.TypeOf(fieldValue.Interface()) -- 2.40.1 From d8be6d76cac8efa86ccf4d908f4824665061de87 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 29 Jun 2021 13:18:44 +0800 Subject: [PATCH 4/4] upgrade mssql image --- .drone.yml | 4 ++-- session_convert.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.drone.yml b/.drone.yml index 9b4ffe9a..4f84d7fa 100644 --- a/.drone.yml +++ b/.drone.yml @@ -249,11 +249,11 @@ volumes: services: - name: mssql pull: always - image: microsoft/mssql-server-linux:latest + image: mcr.microsoft.com/mssql/server:latest environment: ACCEPT_EULA: Y SA_PASSWORD: yourStrong(!)Password - MSSQL_PID: Developer + MSSQL_PID: Standard --- kind: pipeline diff --git a/session_convert.go b/session_convert.go index 9951b300..b8218a77 100644 --- a/session_convert.go +++ b/session_convert.go @@ -38,7 +38,7 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time } } else if len(sdata) > 19 && strings.Contains(sdata, "-") { x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc) - session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) + session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.Name, x, sdata) if err != nil { x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc) } -- 2.40.1