From 940afb2af3cd9c37e0897c0b4b23a1c14709f00d Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 28 Jun 2021 00:22:03 +0800 Subject: [PATCH 1/4] refactor splitTag function --- tags/parser.go | 317 ++++++++++++++++++++------------------------ tags/parser_test.go | 131 +++++++++++++++++- tags/tag.go | 144 +++++++++++++------- tags/tag_test.go | 79 +++++++++-- 4 files changed, 438 insertions(+), 233 deletions(-) diff --git a/tags/parser.go b/tags/parser.go index ff329daa..23d74aa9 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -7,7 +7,6 @@ package tags import ( "encoding/gob" "errors" - "fmt" "reflect" "strings" "sync" @@ -23,7 +22,7 @@ import ( var ( // ErrUnsupportedType represents an unsupported type error - ErrUnsupportedType = errors.New("Unsupported type") + ErrUnsupportedType = errors.New("unsupported type") ) // Parser represents a parser for xorm tag @@ -125,6 +124,141 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index } } +var ErrIgnoreField = errors.New("field will be ignored") + +func (parser *Parser) parseFieldWithNoTag(field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { + var sqlType schemas.SQLType + if fieldValue.CanAddr() { + if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } + } + if _, ok := fieldValue.Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } else { + sqlType = schemas.Type2SQLType(field.Type) + } + col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name), + field.Name, sqlType, sqlType.DefaultLength, + sqlType.DefaultLength2, true) + + if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { + col.IsAutoIncrement = true + col.IsPrimaryKey = true + col.Nullable = false + } + return col, nil +} + +func (parser *Parser) parseFieldWithTags(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) { + var col = &schemas.Column{ + FieldName: field.Name, + Nullable: true, + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: schemas.TWOSIDES, + Indexes: make(map[string]int), + DefaultIsEmpty: true, + } + + var ctx = Context{ + table: table, + col: col, + fieldValue: fieldValue, + indexNames: make(map[string]int), + parser: parser, + } + + for j, tag := range tags { + if ctx.ignoreNext { + ctx.ignoreNext = false + continue + } + + ctx.tag = tag + ctx.tagUname = strings.ToUpper(tag.name) + + if j > 0 { + ctx.preTag = strings.ToUpper(tags[j-1].name) + } + if j < len(tags)-1 { + ctx.nextTag = tags[j+1].name + } else { + ctx.nextTag = "" + } + + if h, ok := parser.handlers[ctx.tagUname]; ok { + if err := h(&ctx); err != nil { + if err == ErrIgnoreField { + continue + } + return nil, err + } + } else { + if strings.HasPrefix(ctx.tag.name, "'") && strings.HasSuffix(ctx.tag.name, "'") { + col.Name = ctx.tag.name[1 : len(ctx.tag.name)-1] + } else { + col.Name = ctx.tag.name + } + } + + if ctx.hasCacheTag { + if parser.cacherMgr.GetDefaultCacher() != nil { + parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher()) + } else { + parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)) + } + } + if ctx.hasNoCacheTag { + parser.cacherMgr.SetCacher(table.Name, nil) + } + } + + if col.SQLType.Name == "" { + col.SQLType = schemas.Type2SQLType(field.Type) + } + parser.dialect.SQLType(col) + if col.Length == 0 { + col.Length = col.SQLType.DefaultLength + } + if col.Length2 == 0 { + col.Length2 = col.SQLType.DefaultLength2 + } + if col.Name == "" { + col.Name = parser.columnMapper.Obj2Table(field.Name) + } + + if ctx.isUnique { + ctx.indexNames[col.Name] = schemas.UniqueType + } else if ctx.isIndex { + ctx.indexNames[col.Name] = schemas.IndexType + } + + for indexName, indexType := range ctx.indexNames { + addIndex(indexName, table, col, indexType) + } + + return col, nil +} + +func (parser *Parser) parseField(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { + var ( + tag = field.Tag + ormTagStr = strings.TrimSpace(tag.Get(parser.identifier)) + ) + if ormTagStr == "" { + return parser.parseFieldWithNoTag(field, fieldValue) + } + if ormTagStr == "-" { + return nil, ErrIgnoreField + } + tags, err := splitTag(ormTagStr) + if err != nil { + return nil, err + } + return parser.parseFieldWithTags(table, field, fieldValue, tags) +} + // Parse parses a struct as a table information func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { t := v.Type() @@ -140,9 +274,6 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { table.Type = t table.Name = names.GetTableName(parser.tableMapper, v) - var idFieldColName string - var hasCacheTag, hasNoCacheTag bool - for i := 0; i < t.NumField(); i++ { var isUnexportField bool for _, c := range t.Field(i).Name { @@ -155,178 +286,20 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { continue } - tag := t.Field(i).Tag - ormTagStr := tag.Get(parser.identifier) - var col *schemas.Column - fieldValue := v.Field(i) - fieldType := fieldValue.Type() + var ( + field = t.Field(i) + fieldValue = v.Field(i) + ) - if ormTagStr != "" { - col = &schemas.Column{ - FieldName: t.Field(i).Name, - Nullable: true, - IsPrimaryKey: false, - IsAutoIncrement: false, - MapType: schemas.TWOSIDES, - Indexes: make(map[string]int), - DefaultIsEmpty: true, - } - tags := splitTag(ormTagStr) - - if len(tags) > 0 { - if tags[0] == "-" { - continue - } - - var ctx = Context{ - table: table, - col: col, - fieldValue: fieldValue, - indexNames: make(map[string]int), - parser: parser, - } - - if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") { - pStart := strings.Index(tags[0], "(") - if pStart > -1 && strings.HasSuffix(tags[0], ")") { - var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool { - return r == '\'' || r == '"' - }) - - ctx.params = []string{tagPrefix} - } - - if err := ExtendsTagHandler(&ctx); err != nil { - return nil, err - } - continue - } - - for j, key := range tags { - if ctx.ignoreNext { - ctx.ignoreNext = false - continue - } - - k := strings.ToUpper(key) - ctx.tagName = k - ctx.params = []string{} - - pStart := strings.Index(k, "(") - if pStart == 0 { - return nil, errors.New("( could not be the first character") - } - if pStart > -1 { - if !strings.HasSuffix(k, ")") { - return nil, fmt.Errorf("field %s tag %s cannot match ) character", col.FieldName, key) - } - - ctx.tagName = k[:pStart] - ctx.params = strings.Split(key[pStart+1:len(k)-1], ",") - } - - if j > 0 { - ctx.preTag = strings.ToUpper(tags[j-1]) - } - if j < len(tags)-1 { - ctx.nextTag = tags[j+1] - } else { - ctx.nextTag = "" - } - - if h, ok := parser.handlers[ctx.tagName]; ok { - if err := h(&ctx); err != nil { - return nil, err - } - } else { - if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") { - col.Name = key[1 : len(key)-1] - } else { - col.Name = key - } - } - - if ctx.hasCacheTag { - hasCacheTag = true - } - if ctx.hasNoCacheTag { - hasNoCacheTag = true - } - } - - if col.SQLType.Name == "" { - col.SQLType = schemas.Type2SQLType(fieldType) - } - parser.dialect.SQLType(col) - if col.Length == 0 { - col.Length = col.SQLType.DefaultLength - } - if col.Length2 == 0 { - col.Length2 = col.SQLType.DefaultLength2 - } - if col.Name == "" { - col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name) - } - - if ctx.isUnique { - ctx.indexNames[col.Name] = schemas.UniqueType - } else if ctx.isIndex { - ctx.indexNames[col.Name] = schemas.IndexType - } - - for indexName, indexType := range ctx.indexNames { - addIndex(indexName, table, col, indexType) - } - } - } else { - var sqlType schemas.SQLType - if fieldValue.CanAddr() { - if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - sqlType = schemas.SQLType{Name: schemas.Text} - } - } - if _, ok := fieldValue.Interface().(convert.Conversion); ok { - sqlType = schemas.SQLType{Name: schemas.Text} - } else { - sqlType = schemas.Type2SQLType(fieldType) - } - col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name), - t.Field(i).Name, sqlType, sqlType.DefaultLength, - sqlType.DefaultLength2, true) - - if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { - idFieldColName = col.Name - } - } - if col.IsAutoIncrement { - col.Nullable = false + col, err := parser.parseField(table, field, fieldValue) + if err == ErrIgnoreField { + continue + } else if err != nil { + return nil, err } table.AddColumn(col) } // end for - if idFieldColName != "" && len(table.PrimaryKeys) == 0 { - col := table.GetColumn(idFieldColName) - col.IsPrimaryKey = true - col.IsAutoIncrement = true - col.Nullable = false - table.PrimaryKeys = append(table.PrimaryKeys, col.Name) - table.AutoIncrement = col.Name - } - - if hasCacheTag { - if parser.cacherMgr.GetDefaultCacher() != nil { // !nash! use engine's cacher if provided - //engine.logger.Info("enable cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher()) - } else { - //engine.logger.Info("enable LRU cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)) - } - } - if hasNoCacheTag { - //engine.logger.Info("disable cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, nil) - } - return table, nil } diff --git a/tags/parser_test.go b/tags/parser_test.go index 5add1e13..2259ca55 100644 --- a/tags/parser_test.go +++ b/tags/parser_test.go @@ -7,11 +7,13 @@ package tags import ( "reflect" "testing" + "time" - "github.com/stretchr/testify/assert" "xorm.io/xorm/caches" "xorm.io/xorm/dialects" "xorm.io/xorm/names" + + "github.com/stretchr/testify/assert" ) type ParseTableName1 struct{} @@ -80,7 +82,7 @@ func TestParseWithOtherIdentifier(t *testing.T) { parser := NewParser( "xorm", dialects.QueryDialect("mysql"), - names.GonicMapper{}, + names.SameMapper{}, names.SnakeMapper{}, caches.NewManager(), ) @@ -88,13 +90,136 @@ func TestParseWithOtherIdentifier(t *testing.T) { type StructWithDBTag struct { FieldFoo string `db:"foo"` } + parser.SetIdentifier("db") table, err := parser.Parse(reflect.ValueOf(new(StructWithDBTag))) assert.NoError(t, err) - assert.EqualValues(t, "struct_with_db_tag", table.Name) + assert.EqualValues(t, "StructWithDBTag", table.Name) assert.EqualValues(t, 1, len(table.Columns())) for _, col := range table.Columns() { assert.EqualValues(t, "foo", col.Name) } } + +func TestParseWithIgnore(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SameMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithIgnoreTag struct { + FieldFoo string `db:"-"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithIgnoreTag))) + assert.NoError(t, err) + assert.EqualValues(t, "StructWithIgnoreTag", table.Name) + assert.EqualValues(t, 0, len(table.Columns())) +} + +func TestParseWithAutoincrement(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithAutoIncrement struct { + ID int64 + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_auto_increment", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "id", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsAutoIncrement) + assert.True(t, table.Columns()[0].IsPrimaryKey) +} + +func TestParseWithAutoincrement2(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithAutoIncrement2 struct { + ID int64 `db:"pk autoincr"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement2))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_auto_increment2", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "id", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsAutoIncrement) + assert.True(t, table.Columns()[0].IsPrimaryKey) + assert.False(t, table.Columns()[0].Nullable) +} + +func TestParseWithNullable(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithNullable struct { + Name string `db:"notnull"` + FullName string `db:"null comment('column comment,字段注释')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithNullable))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_nullable", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "full_name", table.Columns()[1].Name) + assert.False(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.EqualValues(t, "column comment,字段注释", table.Columns()[1].Comment) +} + +func TestParseWithTimes(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithTimes struct { + Name string `db:"notnull"` + CreatedAt time.Time `db:"created"` + UpdatedAt time.Time `db:"updated"` + DeletedAt time.Time `db:"deleted"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithTimes))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_times", table.Name) + assert.EqualValues(t, 4, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "created_at", table.Columns()[1].Name) + assert.EqualValues(t, "updated_at", table.Columns()[2].Name) + assert.EqualValues(t, "deleted_at", table.Columns()[3].Name) + assert.False(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsCreated) + assert.True(t, table.Columns()[2].Nullable) + assert.True(t, table.Columns()[2].IsUpdated) + assert.True(t, table.Columns()[3].Nullable) + assert.True(t, table.Columns()[3].IsDeleted) +} diff --git a/tags/tag.go b/tags/tag.go index bb5b5838..7233c5ac 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -14,30 +14,74 @@ import ( "xorm.io/xorm/schemas" ) -func splitTag(tag string) (tags []string) { - tag = strings.TrimSpace(tag) - var hasQuote = false - var lastIdx = 0 - for i, t := range tag { - if t == '\'' { - hasQuote = !hasQuote - } else if t == ' ' { - if lastIdx < i && !hasQuote { - tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) - lastIdx = i + 1 +type tag struct { + name string + params []string +} + +func splitTag(tagStr string) ([]tag, error) { + tagStr = strings.TrimSpace(tagStr) + var ( + inQuote bool + inBigQuote bool + lastIdx int + curTag tag + paramStart int + tags []tag + ) + for i, t := range tagStr { + switch t { + case '\'': + inQuote = !inQuote + case ' ': + if !inQuote && !inBigQuote { + if lastIdx < i { + if curTag.name == "" { + curTag.name = tagStr[lastIdx:i] + } + tags = append(tags, curTag) + lastIdx = i + 1 + curTag = tag{} + } else if lastIdx == i { + lastIdx = i + 1 + } + } else if inBigQuote && !inQuote { + paramStart = i + 1 + } + case ',': + if !inQuote && !inBigQuote { + return nil, fmt.Errorf("comma[%d] of %s should be in quote or big quote", i, tagStr) + } + if !inQuote && inBigQuote { + curTag.params = append(curTag.params, strings.TrimSpace(tagStr[paramStart:i])) + paramStart = i + 1 + } + case '(': + inBigQuote = true + if !inQuote { + curTag.name = tagStr[lastIdx:i] + paramStart = i + 1 + } + case ')': + inBigQuote = false + if !inQuote { + curTag.params = append(curTag.params, tagStr[paramStart:i]) } } } - if lastIdx < len(tag) { - tags = append(tags, strings.TrimSpace(tag[lastIdx:])) + if lastIdx < len(tagStr) { + if curTag.name == "" { + curTag.name = tagStr[lastIdx:] + } + tags = append(tags, curTag) } - return + return tags, nil } // Context represents a context for xorm tag parse. type Context struct { - tagName string - params []string + tag + tagUname string preTag, nextTag string table *schemas.Table col *schemas.Column @@ -124,6 +168,7 @@ func NotNullTagHandler(ctx *Context) error { // AutoIncrTagHandler describes autoincr tag handler func AutoIncrTagHandler(ctx *Context) error { ctx.col.IsAutoIncrement = true + ctx.col.Nullable = false /* if len(ctx.params) > 0 { autoStartInt, err := strconv.Atoi(ctx.params[0]) @@ -225,41 +270,44 @@ func CommentTagHandler(ctx *Context) error { // SQLTypeTagHandler describes SQL Type tag handler func SQLTypeTagHandler(ctx *Context) error { - ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName} - if strings.EqualFold(ctx.tagName, "JSON") { + ctx.col.SQLType = schemas.SQLType{Name: ctx.tag.name} + if ctx.tagUname == "JSON" { ctx.col.IsJSON = true } - if len(ctx.params) > 0 { - if ctx.tagName == schemas.Enum { - ctx.col.EnumOptions = make(map[string]int) - for k, v := range ctx.params { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - ctx.col.EnumOptions[v] = k + if len(ctx.params) == 0 { + return nil + } + + switch ctx.tagUname { + case schemas.Enum: + ctx.col.EnumOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.EnumOptions[v] = k + } + case schemas.Set: + ctx.col.SetOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.SetOptions[v] = k + } + default: + var err error + if len(ctx.params) == 2 { + ctx.col.Length, err = strconv.Atoi(ctx.params[0]) + if err != nil { + return err } - } else if ctx.tagName == schemas.Set { - ctx.col.SetOptions = make(map[string]int) - for k, v := range ctx.params { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - ctx.col.SetOptions[v] = k + ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) + if err != nil { + return err } - } else { - var err error - if len(ctx.params) == 2 { - ctx.col.Length, err = strconv.Atoi(ctx.params[0]) - if err != nil { - return err - } - ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) - if err != nil { - return err - } - } else if len(ctx.params) == 1 { - ctx.col.Length, err = strconv.Atoi(ctx.params[0]) - if err != nil { - return err - } + } else if len(ctx.params) == 1 { + ctx.col.Length, err = strconv.Atoi(ctx.params[0]) + if err != nil { + return err } } } @@ -315,7 +363,7 @@ func ExtendsTagHandler(ctx *Context) error { default: //TODO: warning } - return nil + return ErrIgnoreField } // CacheTagHandler describes cache tag handler diff --git a/tags/tag_test.go b/tags/tag_test.go index 5775b40a..3ceeefd1 100644 --- a/tags/tag_test.go +++ b/tags/tag_test.go @@ -7,24 +7,83 @@ package tags import ( "testing" - "xorm.io/xorm/internal/utils" + "github.com/stretchr/testify/assert" ) func TestSplitTag(t *testing.T) { var cases = []struct { tag string - tags []string + tags []tag }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, - {"TEXT", []string{"TEXT"}}, - {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, - {"json binary", []string{"json", "binary"}}, + {"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ + { + name: "not", + }, + { + name: "null", + }, + { + name: "default", + }, + { + name: "'2000-01-01 00:00:00'", + }, + { + name: "TIMESTAMP", + }, + }, + }, + {"TEXT", []tag{ + { + name: "TEXT", + }, + }, + }, + {"default('2000-01-01 00:00:00')", []tag{ + { + name: "default", + params: []string{ + "'2000-01-01 00:00:00'", + }, + }, + }, + }, + {"json binary", []tag{ + { + name: "json", + }, + { + name: "binary", + }, + }, + }, + {"numeric(10, 2)", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + }, + }, + {"numeric(10, 2) notnull", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + { + name: "notnull", + }, + }, + }, } for _, kase := range cases { - tags := splitTag(kase.tag) - if !utils.SliceEq(tags, kase.tags) { - t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) - } + t.Run(kase.tag, func(t *testing.T) { + tags, err := splitTag(kase.tag) + assert.NoError(t, err) + assert.EqualValues(t, len(tags), len(kase.tags)) + for i := 0; i < len(tags); i++ { + assert.Equal(t, tags[i], kase.tags[i]) + } + }) } } -- 2.40.1 From 4a4f4300fed6ebfccc478f9bc6c7f259fb8c1032 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 28 Jun 2021 00:26:29 +0800 Subject: [PATCH 2/4] improve code --- tags/parser.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tags/parser.go b/tags/parser.go index 23d74aa9..481b7f8e 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -246,12 +246,12 @@ func (parser *Parser) parseField(table *schemas.Table, field reflect.StructField tag = field.Tag ormTagStr = strings.TrimSpace(tag.Get(parser.identifier)) ) - if ormTagStr == "" { - return parser.parseFieldWithNoTag(field, fieldValue) - } if ormTagStr == "-" { return nil, ErrIgnoreField } + if ormTagStr == "" { + return parser.parseFieldWithNoTag(field, fieldValue) + } tags, err := splitTag(ormTagStr) if err != nil { return nil, err @@ -259,6 +259,13 @@ func (parser *Parser) parseField(table *schemas.Table, field reflect.StructField return parser.parseFieldWithTags(table, field, fieldValue, tags) } +func isNotTitle(n string) bool { + for _, c := range n { + return unicode.IsLower(c) + } + return true +} + // Parse parses a struct as a table information func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { t := v.Type() @@ -275,14 +282,7 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { table.Name = names.GetTableName(parser.tableMapper, v) for i := 0; i < t.NumField(); i++ { - var isUnexportField bool - for _, c := range t.Field(i).Name { - if unicode.IsLower(c) { - isUnexportField = true - } - break - } - if isUnexportField { + if isNotTitle(t.Field(i).Name) { continue } -- 2.40.1 From 0fece7421c23870f6bded183478464c77ffbec39 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 28 Jun 2021 20:55:06 +0800 Subject: [PATCH 3/4] Fix extends --- .gitignore | 3 +- tags/parser.go | 3 - tags/parser_test.go | 327 ++++++++++++++++++++++++++++++++++++++++++++ tags/tag.go | 5 +- 4 files changed, 332 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index a3fbadd4..a183a295 100644 --- a/.gitignore +++ b/.gitignore @@ -36,4 +36,5 @@ test.db.sql *coverage.out test.db integrations/*.sql -integrations/test_sqlite* \ No newline at end of file +integrations/test_sqlite* +cover.out \ No newline at end of file diff --git a/tags/parser.go b/tags/parser.go index 481b7f8e..599e9e0e 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -189,9 +189,6 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, field reflect.Str if h, ok := parser.handlers[ctx.tagUname]; ok { if err := h(&ctx); err != nil { - if err == ErrIgnoreField { - continue - } return nil, err } } else { diff --git a/tags/parser_test.go b/tags/parser_test.go index 2259ca55..3578f78b 100644 --- a/tags/parser_test.go +++ b/tags/parser_test.go @@ -6,12 +6,14 @@ package tags import ( "reflect" + "strings" "testing" "time" "xorm.io/xorm/caches" "xorm.io/xorm/dialects" "xorm.io/xorm/names" + "xorm.io/xorm/schemas" "github.com/stretchr/testify/assert" ) @@ -223,3 +225,328 @@ func TestParseWithTimes(t *testing.T) { assert.True(t, table.Columns()[3].Nullable) assert.True(t, table.Columns()[3].IsDeleted) } + +func TestParseWithExtends(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithEmbed struct { + Name string + CreatedAt time.Time `db:"created"` + UpdatedAt time.Time `db:"updated"` + DeletedAt time.Time `db:"deleted"` + } + + type StructWithExtends struct { + SW StructWithEmbed `db:"extends"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithExtends))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_extends", table.Name) + assert.EqualValues(t, 4, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "created_at", table.Columns()[1].Name) + assert.EqualValues(t, "updated_at", table.Columns()[2].Name) + assert.EqualValues(t, "deleted_at", table.Columns()[3].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsCreated) + assert.True(t, table.Columns()[2].Nullable) + assert.True(t, table.Columns()[2].IsUpdated) + assert.True(t, table.Columns()[3].Nullable) + assert.True(t, table.Columns()[3].IsDeleted) +} + +func TestParseWithCache(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithCache struct { + Name string `db:"cache"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithCache))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_cache", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + cacher := parser.cacherMgr.GetCacher(table.Name) + assert.NotNil(t, cacher) +} + +func TestParseWithNoCache(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithNoCache struct { + Name string `db:"nocache"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithNoCache))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_no_cache", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + cacher := parser.cacherMgr.GetCacher(table.Name) + assert.Nil(t, cacher) +} + +func TestParseWithEnum(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithEnum struct { + Name string `db:"enum('alice', 'bob')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithEnum))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_enum", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.EqualValues(t, schemas.Enum, strings.ToUpper(table.Columns()[0].SQLType.Name)) + assert.EqualValues(t, map[string]int{ + "alice": 0, + "bob": 1, + }, table.Columns()[0].EnumOptions) +} + +func TestParseWithSet(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithSet struct { + Name string `db:"set('alice', 'bob')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithSet))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_set", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.EqualValues(t, schemas.Set, strings.ToUpper(table.Columns()[0].SQLType.Name)) + assert.EqualValues(t, map[string]int{ + "alice": 0, + "bob": 1, + }, table.Columns()[0].SetOptions) +} + +func TestParseWithIndex(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithIndex struct { + Name string `db:"index"` + Name2 string `db:"index(s)"` + Name3 string `db:"unique"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithIndex))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_index", table.Name) + assert.EqualValues(t, 3, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "name2", table.Columns()[1].Name) + assert.EqualValues(t, "name3", table.Columns()[2].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[2].Nullable) + assert.EqualValues(t, 1, len(table.Columns()[0].Indexes)) + assert.EqualValues(t, 1, len(table.Columns()[1].Indexes)) + assert.EqualValues(t, 1, len(table.Columns()[2].Indexes)) +} + +func TestParseWithVersion(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithVersion struct { + Name string + Version int `db:"version"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithVersion))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_version", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "version", table.Columns()[1].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsVersion) +} + +func TestParseWithLocale(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithLocale struct { + UTCLocale time.Time `db:"utc"` + LocalLocale time.Time `db:"local"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithLocale))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_locale", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "utc_locale", table.Columns()[0].Name) + assert.EqualValues(t, "local_locale", table.Columns()[1].Name) + assert.EqualValues(t, time.UTC, table.Columns()[0].TimeZone) + assert.EqualValues(t, time.Local, table.Columns()[1].TimeZone) +} + +func TestParseWithDefault(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithDefault struct { + Default1 time.Time `db:"default '1970-01-01 00:00:00'"` + Default2 time.Time `db:"default(CURRENT_TIMESTAMP)"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithDefault))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_default", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.EqualValues(t, "default2", table.Columns()[1].Name) + assert.EqualValues(t, "'1970-01-01 00:00:00'", table.Columns()[0].Default) + assert.EqualValues(t, "CURRENT_TIMESTAMP", table.Columns()[1].Default) +} + +func TestParseWithOnlyToDB(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "DB": true, + }, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithOnlyToDB struct { + Default1 time.Time `db:"->"` + Default2 time.Time `db:"<-"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithOnlyToDB))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_only_to_db", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.EqualValues(t, "default2", table.Columns()[1].Name) + assert.EqualValues(t, schemas.ONLYTODB, table.Columns()[0].MapType) + assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType) +} + +func TestParseWithJSON(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "JSON": true, + }, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithJSON struct { + Default1 []string `db:"json"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithJSON))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_json", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsJSON) +} + +func TestParseWithSQLType(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "SQL": true, + }, + names.GonicMapper{ + "UUID": true, + }, + caches.NewManager(), + ) + + type StructWithSQLType struct { + Col1 string `db:"varchar(32)"` + Col2 string `db:"char(32)"` + Int int64 `db:"bigint"` + DateTime time.Time `db:"datetime"` + UUID string `db:"uuid"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithSQLType))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_sql_type", table.Name) + assert.EqualValues(t, 5, len(table.Columns())) + assert.EqualValues(t, "col1", table.Columns()[0].Name) + assert.EqualValues(t, "col2", table.Columns()[1].Name) + assert.EqualValues(t, "int", table.Columns()[2].Name) + assert.EqualValues(t, "date_time", table.Columns()[3].Name) + assert.EqualValues(t, "uuid", table.Columns()[4].Name) + + assert.EqualValues(t, "varchar", table.Columns()[0].SQLType.Name) + assert.EqualValues(t, "char", table.Columns()[1].SQLType.Name) + assert.EqualValues(t, "bigint", table.Columns()[2].SQLType.Name) + assert.EqualValues(t, "datetime", table.Columns()[3].SQLType.Name) + assert.EqualValues(t, "uuid", table.Columns()[4].SQLType.Name) +} diff --git a/tags/tag.go b/tags/tag.go index 7233c5ac..d8d9bb46 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -120,6 +120,7 @@ var ( "CACHE": CacheTagHandler, "NOCACHE": NoCacheTagHandler, "COMMENT": CommentTagHandler, + "EXTENDS": ExtendsTagHandler, } ) @@ -270,7 +271,7 @@ func CommentTagHandler(ctx *Context) error { // SQLTypeTagHandler describes SQL Type tag handler func SQLTypeTagHandler(ctx *Context) error { - ctx.col.SQLType = schemas.SQLType{Name: ctx.tag.name} + ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname} if ctx.tagUname == "JSON" { ctx.col.IsJSON = true } @@ -341,7 +342,7 @@ func ExtendsTagHandler(ctx *Context) error { var tagPrefix = ctx.col.FieldName if len(ctx.params) > 0 { col.Nullable = isPtr - tagPrefix = ctx.params[0] + tagPrefix = strings.Trim(ctx.params[0], "'") if col.IsPrimaryKey { col.Name = ctx.col.FieldName col.IsPrimaryKey = false -- 2.40.1 From f4377059715d70c537ab471ed5351d83abbbf591 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 28 Jun 2021 21:41:26 +0800 Subject: [PATCH 4/4] Fix test --- tags/parser_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tags/parser_test.go b/tags/parser_test.go index 3578f78b..70c57692 100644 --- a/tags/parser_test.go +++ b/tags/parser_test.go @@ -544,9 +544,9 @@ func TestParseWithSQLType(t *testing.T) { assert.EqualValues(t, "date_time", table.Columns()[3].Name) assert.EqualValues(t, "uuid", table.Columns()[4].Name) - assert.EqualValues(t, "varchar", table.Columns()[0].SQLType.Name) - assert.EqualValues(t, "char", table.Columns()[1].SQLType.Name) - assert.EqualValues(t, "bigint", table.Columns()[2].SQLType.Name) - assert.EqualValues(t, "datetime", table.Columns()[3].SQLType.Name) - assert.EqualValues(t, "uuid", table.Columns()[4].SQLType.Name) + assert.EqualValues(t, "VARCHAR", table.Columns()[0].SQLType.Name) + assert.EqualValues(t, "CHAR", table.Columns()[1].SQLType.Name) + assert.EqualValues(t, "BIGINT", table.Columns()[2].SQLType.Name) + assert.EqualValues(t, "DATETIME", table.Columns()[3].SQLType.Name) + assert.EqualValues(t, "UUID", table.Columns()[4].SQLType.Name) } -- 2.40.1