refactor splitTag function #1960

Merged
lunny merged 4 commits from lunny/improve_parse_tag into master 2021-06-28 14:41:55 +00:00
4 changed files with 438 additions and 233 deletions
Showing only changes of commit 940afb2af3 - Show all commits

View File

@ -7,7 +7,6 @@ package tags
import (
"encoding/gob"
"errors"
"fmt"
"reflect"
"strings"
"sync"
@ -23,7 +22,7 @@ import (
var (
// ErrUnsupportedType represents an unsupported type error
ErrUnsupportedType = errors.New("Unsupported type")
ErrUnsupportedType = errors.New("unsupported type")
)
// Parser represents a parser for xorm tag
@ -125,6 +124,141 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index
}
}
var ErrIgnoreField = errors.New("field will be ignored")
func (parser *Parser) parseFieldWithNoTag(field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
var sqlType schemas.SQLType
if fieldValue.CanAddr() {
if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
sqlType = schemas.SQLType{Name: schemas.Text}
}
}
if _, ok := fieldValue.Interface().(convert.Conversion); ok {
sqlType = schemas.SQLType{Name: schemas.Text}
} else {
sqlType = schemas.Type2SQLType(field.Type)
}
col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name),
field.Name, sqlType, sqlType.DefaultLength,
sqlType.DefaultLength2, true)
if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
col.IsAutoIncrement = true
col.IsPrimaryKey = true
col.Nullable = false
}
return col, nil
}
func (parser *Parser) parseFieldWithTags(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) {
var col = &schemas.Column{
FieldName: field.Name,
Nullable: true,
IsPrimaryKey: false,
IsAutoIncrement: false,
MapType: schemas.TWOSIDES,
Indexes: make(map[string]int),
DefaultIsEmpty: true,
}
var ctx = Context{
table: table,
col: col,
fieldValue: fieldValue,
indexNames: make(map[string]int),
parser: parser,
}
for j, tag := range tags {
if ctx.ignoreNext {
ctx.ignoreNext = false
continue
}
ctx.tag = tag
ctx.tagUname = strings.ToUpper(tag.name)
if j > 0 {
ctx.preTag = strings.ToUpper(tags[j-1].name)
}
if j < len(tags)-1 {
ctx.nextTag = tags[j+1].name
} else {
ctx.nextTag = ""
}
if h, ok := parser.handlers[ctx.tagUname]; ok {
if err := h(&ctx); err != nil {
if err == ErrIgnoreField {
continue
}
return nil, err
}
} else {
if strings.HasPrefix(ctx.tag.name, "'") && strings.HasSuffix(ctx.tag.name, "'") {
col.Name = ctx.tag.name[1 : len(ctx.tag.name)-1]
} else {
col.Name = ctx.tag.name
}
}
if ctx.hasCacheTag {
if parser.cacherMgr.GetDefaultCacher() != nil {
parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher())
} else {
parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000))
}
}
if ctx.hasNoCacheTag {
parser.cacherMgr.SetCacher(table.Name, nil)
}
}
if col.SQLType.Name == "" {
col.SQLType = schemas.Type2SQLType(field.Type)
}
parser.dialect.SQLType(col)
if col.Length == 0 {
col.Length = col.SQLType.DefaultLength
}
if col.Length2 == 0 {
col.Length2 = col.SQLType.DefaultLength2
}
if col.Name == "" {
col.Name = parser.columnMapper.Obj2Table(field.Name)
}
if ctx.isUnique {
ctx.indexNames[col.Name] = schemas.UniqueType
} else if ctx.isIndex {
ctx.indexNames[col.Name] = schemas.IndexType
}
for indexName, indexType := range ctx.indexNames {
addIndex(indexName, table, col, indexType)
}
return col, nil
}
func (parser *Parser) parseField(table *schemas.Table, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
var (
tag = field.Tag
ormTagStr = strings.TrimSpace(tag.Get(parser.identifier))
)
if ormTagStr == "" {
return parser.parseFieldWithNoTag(field, fieldValue)
}
if ormTagStr == "-" {
return nil, ErrIgnoreField
}
tags, err := splitTag(ormTagStr)
if err != nil {
return nil, err
}
return parser.parseFieldWithTags(table, field, fieldValue, tags)
}
// Parse parses a struct as a table information
func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
t := v.Type()
@ -140,9 +274,6 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
table.Type = t
table.Name = names.GetTableName(parser.tableMapper, v)
var idFieldColName string
var hasCacheTag, hasNoCacheTag bool
for i := 0; i < t.NumField(); i++ {
var isUnexportField bool
for _, c := range t.Field(i).Name {
@ -155,178 +286,20 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
continue
}
tag := t.Field(i).Tag
ormTagStr := tag.Get(parser.identifier)
var col *schemas.Column
fieldValue := v.Field(i)
fieldType := fieldValue.Type()
var (
field = t.Field(i)
fieldValue = v.Field(i)
)
if ormTagStr != "" {
col = &schemas.Column{
FieldName: t.Field(i).Name,
Nullable: true,
IsPrimaryKey: false,
IsAutoIncrement: false,
MapType: schemas.TWOSIDES,
Indexes: make(map[string]int),
DefaultIsEmpty: true,
}
tags := splitTag(ormTagStr)
if len(tags) > 0 {
if tags[0] == "-" {
continue
}
var ctx = Context{
table: table,
col: col,
fieldValue: fieldValue,
indexNames: make(map[string]int),
parser: parser,
}
if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") {
pStart := strings.Index(tags[0], "(")
if pStart > -1 && strings.HasSuffix(tags[0], ")") {
var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool {
return r == '\'' || r == '"'
})
ctx.params = []string{tagPrefix}
}
if err := ExtendsTagHandler(&ctx); err != nil {
return nil, err
}
continue
}
for j, key := range tags {
if ctx.ignoreNext {
ctx.ignoreNext = false
continue
}
k := strings.ToUpper(key)
ctx.tagName = k
ctx.params = []string{}
pStart := strings.Index(k, "(")
if pStart == 0 {
return nil, errors.New("( could not be the first character")
}
if pStart > -1 {
if !strings.HasSuffix(k, ")") {
return nil, fmt.Errorf("field %s tag %s cannot match ) character", col.FieldName, key)
}
ctx.tagName = k[:pStart]
ctx.params = strings.Split(key[pStart+1:len(k)-1], ",")
}
if j > 0 {
ctx.preTag = strings.ToUpper(tags[j-1])
}
if j < len(tags)-1 {
ctx.nextTag = tags[j+1]
} else {
ctx.nextTag = ""
}
if h, ok := parser.handlers[ctx.tagName]; ok {
if err := h(&ctx); err != nil {
return nil, err
}
} else {
if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") {
col.Name = key[1 : len(key)-1]
} else {
col.Name = key
}
}
if ctx.hasCacheTag {
hasCacheTag = true
}
if ctx.hasNoCacheTag {
hasNoCacheTag = true
}
}
if col.SQLType.Name == "" {
col.SQLType = schemas.Type2SQLType(fieldType)
}
parser.dialect.SQLType(col)
if col.Length == 0 {
col.Length = col.SQLType.DefaultLength
}
if col.Length2 == 0 {
col.Length2 = col.SQLType.DefaultLength2
}
if col.Name == "" {
col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name)
}
if ctx.isUnique {
ctx.indexNames[col.Name] = schemas.UniqueType
} else if ctx.isIndex {
ctx.indexNames[col.Name] = schemas.IndexType
}
for indexName, indexType := range ctx.indexNames {
addIndex(indexName, table, col, indexType)
}
}
} else {
var sqlType schemas.SQLType
if fieldValue.CanAddr() {
if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
sqlType = schemas.SQLType{Name: schemas.Text}
}
}
if _, ok := fieldValue.Interface().(convert.Conversion); ok {
sqlType = schemas.SQLType{Name: schemas.Text}
} else {
sqlType = schemas.Type2SQLType(fieldType)
}
col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name),
t.Field(i).Name, sqlType, sqlType.DefaultLength,
sqlType.DefaultLength2, true)
if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
idFieldColName = col.Name
}
}
if col.IsAutoIncrement {
col.Nullable = false
col, err := parser.parseField(table, field, fieldValue)
if err == ErrIgnoreField {
continue
} else if err != nil {
return nil, err
}
table.AddColumn(col)
} // end for
if idFieldColName != "" && len(table.PrimaryKeys) == 0 {
col := table.GetColumn(idFieldColName)
col.IsPrimaryKey = true
col.IsAutoIncrement = true
col.Nullable = false
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
table.AutoIncrement = col.Name
}
if hasCacheTag {
if parser.cacherMgr.GetDefaultCacher() != nil { // !nash! use engine's cacher if provided
//engine.logger.Info("enable cache on table:", table.Name)
parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher())
} else {
//engine.logger.Info("enable LRU cache on table:", table.Name)
parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000))
}
}
if hasNoCacheTag {
//engine.logger.Info("disable cache on table:", table.Name)
parser.cacherMgr.SetCacher(table.Name, nil)
}
return table, nil
}

View File

@ -7,11 +7,13 @@ package tags
import (
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
"xorm.io/xorm/caches"
"xorm.io/xorm/dialects"
"xorm.io/xorm/names"
"github.com/stretchr/testify/assert"
)
type ParseTableName1 struct{}
@ -80,7 +82,7 @@ func TestParseWithOtherIdentifier(t *testing.T) {
parser := NewParser(
"xorm",
dialects.QueryDialect("mysql"),
names.GonicMapper{},
names.SameMapper{},
names.SnakeMapper{},
caches.NewManager(),
)
@ -88,13 +90,136 @@ func TestParseWithOtherIdentifier(t *testing.T) {
type StructWithDBTag struct {
FieldFoo string `db:"foo"`
}
parser.SetIdentifier("db")
table, err := parser.Parse(reflect.ValueOf(new(StructWithDBTag)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_db_tag", table.Name)
assert.EqualValues(t, "StructWithDBTag", table.Name)
assert.EqualValues(t, 1, len(table.Columns()))
for _, col := range table.Columns() {
assert.EqualValues(t, "foo", col.Name)
}
}
func TestParseWithIgnore(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SameMapper{},
names.SnakeMapper{},
caches.NewManager(),
)
type StructWithIgnoreTag struct {
FieldFoo string `db:"-"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithIgnoreTag)))
assert.NoError(t, err)
assert.EqualValues(t, "StructWithIgnoreTag", table.Name)
assert.EqualValues(t, 0, len(table.Columns()))
}
func TestParseWithAutoincrement(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithAutoIncrement struct {
ID int64
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_auto_increment", table.Name)
assert.EqualValues(t, 1, len(table.Columns()))
assert.EqualValues(t, "id", table.Columns()[0].Name)
assert.True(t, table.Columns()[0].IsAutoIncrement)
assert.True(t, table.Columns()[0].IsPrimaryKey)
}
func TestParseWithAutoincrement2(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithAutoIncrement2 struct {
ID int64 `db:"pk autoincr"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement2)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_auto_increment2", table.Name)
assert.EqualValues(t, 1, len(table.Columns()))
assert.EqualValues(t, "id", table.Columns()[0].Name)
assert.True(t, table.Columns()[0].IsAutoIncrement)
assert.True(t, table.Columns()[0].IsPrimaryKey)
assert.False(t, table.Columns()[0].Nullable)
}
func TestParseWithNullable(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithNullable struct {
Name string `db:"notnull"`
FullName string `db:"null comment('column comment,字段注释')"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithNullable)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_nullable", table.Name)
assert.EqualValues(t, 2, len(table.Columns()))
assert.EqualValues(t, "name", table.Columns()[0].Name)
assert.EqualValues(t, "full_name", table.Columns()[1].Name)
assert.False(t, table.Columns()[0].Nullable)
assert.True(t, table.Columns()[1].Nullable)
assert.EqualValues(t, "column comment,字段注释", table.Columns()[1].Comment)
}
func TestParseWithTimes(t *testing.T) {
parser := NewParser(
"db",
dialects.QueryDialect("mysql"),
names.SnakeMapper{},
names.GonicMapper{},
caches.NewManager(),
)
type StructWithTimes struct {
Name string `db:"notnull"`
CreatedAt time.Time `db:"created"`
UpdatedAt time.Time `db:"updated"`
DeletedAt time.Time `db:"deleted"`
}
table, err := parser.Parse(reflect.ValueOf(new(StructWithTimes)))
assert.NoError(t, err)
assert.EqualValues(t, "struct_with_times", table.Name)
assert.EqualValues(t, 4, len(table.Columns()))
assert.EqualValues(t, "name", table.Columns()[0].Name)
assert.EqualValues(t, "created_at", table.Columns()[1].Name)
assert.EqualValues(t, "updated_at", table.Columns()[2].Name)
assert.EqualValues(t, "deleted_at", table.Columns()[3].Name)
assert.False(t, table.Columns()[0].Nullable)
assert.True(t, table.Columns()[1].Nullable)
assert.True(t, table.Columns()[1].IsCreated)
assert.True(t, table.Columns()[2].Nullable)
assert.True(t, table.Columns()[2].IsUpdated)
assert.True(t, table.Columns()[3].Nullable)
assert.True(t, table.Columns()[3].IsDeleted)
}

View File

@ -14,30 +14,74 @@ import (
"xorm.io/xorm/schemas"
)
func splitTag(tag string) (tags []string) {
tag = strings.TrimSpace(tag)
var hasQuote = false
var lastIdx = 0
for i, t := range tag {
if t == '\'' {
hasQuote = !hasQuote
} else if t == ' ' {
if lastIdx < i && !hasQuote {
tags = append(tags, strings.TrimSpace(tag[lastIdx:i]))
lastIdx = i + 1
type tag struct {
name string
params []string
}
func splitTag(tagStr string) ([]tag, error) {
tagStr = strings.TrimSpace(tagStr)
var (
inQuote bool
inBigQuote bool
lastIdx int
curTag tag
paramStart int
tags []tag
)
for i, t := range tagStr {
switch t {
case '\'':
inQuote = !inQuote
case ' ':
if !inQuote && !inBigQuote {
if lastIdx < i {
if curTag.name == "" {
curTag.name = tagStr[lastIdx:i]
}
tags = append(tags, curTag)
lastIdx = i + 1
curTag = tag{}
} else if lastIdx == i {
lastIdx = i + 1
}
} else if inBigQuote && !inQuote {
paramStart = i + 1
}
case ',':
if !inQuote && !inBigQuote {
return nil, fmt.Errorf("comma[%d] of %s should be in quote or big quote", i, tagStr)
}
if !inQuote && inBigQuote {
curTag.params = append(curTag.params, strings.TrimSpace(tagStr[paramStart:i]))
paramStart = i + 1
}
case '(':
inBigQuote = true
if !inQuote {
curTag.name = tagStr[lastIdx:i]
paramStart = i + 1
}
case ')':
inBigQuote = false
if !inQuote {
curTag.params = append(curTag.params, tagStr[paramStart:i])
}
}
}
if lastIdx < len(tag) {
tags = append(tags, strings.TrimSpace(tag[lastIdx:]))
if lastIdx < len(tagStr) {
if curTag.name == "" {
curTag.name = tagStr[lastIdx:]
}
tags = append(tags, curTag)
}
return
return tags, nil
}
// Context represents a context for xorm tag parse.
type Context struct {
tagName string
params []string
tag
tagUname string
preTag, nextTag string
table *schemas.Table
col *schemas.Column
@ -124,6 +168,7 @@ func NotNullTagHandler(ctx *Context) error {
// AutoIncrTagHandler describes autoincr tag handler
func AutoIncrTagHandler(ctx *Context) error {
ctx.col.IsAutoIncrement = true
ctx.col.Nullable = false
/*
if len(ctx.params) > 0 {
autoStartInt, err := strconv.Atoi(ctx.params[0])
@ -225,41 +270,44 @@ func CommentTagHandler(ctx *Context) error {
// SQLTypeTagHandler describes SQL Type tag handler
func SQLTypeTagHandler(ctx *Context) error {
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName}
if strings.EqualFold(ctx.tagName, "JSON") {
ctx.col.SQLType = schemas.SQLType{Name: ctx.tag.name}
if ctx.tagUname == "JSON" {
ctx.col.IsJSON = true
}
if len(ctx.params) > 0 {
if ctx.tagName == schemas.Enum {
ctx.col.EnumOptions = make(map[string]int)
for k, v := range ctx.params {
v = strings.TrimSpace(v)
v = strings.Trim(v, "'")
ctx.col.EnumOptions[v] = k
if len(ctx.params) == 0 {
return nil
}
switch ctx.tagUname {
case schemas.Enum:
ctx.col.EnumOptions = make(map[string]int)
for k, v := range ctx.params {
v = strings.TrimSpace(v)
v = strings.Trim(v, "'")
ctx.col.EnumOptions[v] = k
}
case schemas.Set:
ctx.col.SetOptions = make(map[string]int)
for k, v := range ctx.params {
v = strings.TrimSpace(v)
v = strings.Trim(v, "'")
ctx.col.SetOptions[v] = k
}
default:
var err error
if len(ctx.params) == 2 {
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
if err != nil {
return err
}
} else if ctx.tagName == schemas.Set {
ctx.col.SetOptions = make(map[string]int)
for k, v := range ctx.params {
v = strings.TrimSpace(v)
v = strings.Trim(v, "'")
ctx.col.SetOptions[v] = k
ctx.col.Length2, err = strconv.Atoi(ctx.params[1])
if err != nil {
return err
}
} else {
var err error
if len(ctx.params) == 2 {
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
if err != nil {
return err
}
ctx.col.Length2, err = strconv.Atoi(ctx.params[1])
if err != nil {
return err
}
} else if len(ctx.params) == 1 {
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
if err != nil {
return err
}
} else if len(ctx.params) == 1 {
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
if err != nil {
return err
}
}
}
@ -315,7 +363,7 @@ func ExtendsTagHandler(ctx *Context) error {
default:
//TODO: warning
}
return nil
return ErrIgnoreField
}
// CacheTagHandler describes cache tag handler

View File

@ -7,24 +7,83 @@ package tags
import (
"testing"
"xorm.io/xorm/internal/utils"
"github.com/stretchr/testify/assert"
)
func TestSplitTag(t *testing.T) {
var cases = []struct {
tag string
tags []string
tags []tag
}{
{"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}},
{"TEXT", []string{"TEXT"}},
{"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}},
{"json binary", []string{"json", "binary"}},
{"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{
{
name: "not",
},
{
name: "null",
},
{
name: "default",
},
{
name: "'2000-01-01 00:00:00'",
},
{
name: "TIMESTAMP",
},
},
},
{"TEXT", []tag{
{
name: "TEXT",
},
},
},
{"default('2000-01-01 00:00:00')", []tag{
{
name: "default",
params: []string{
"'2000-01-01 00:00:00'",
},
},
},
},
{"json binary", []tag{
{
name: "json",
},
{
name: "binary",
},
},
},
{"numeric(10, 2)", []tag{
{
name: "numeric",
params: []string{"10", "2"},
},
},
},
{"numeric(10, 2) notnull", []tag{
{
name: "numeric",
params: []string{"10", "2"},
},
{
name: "notnull",
},
},
},
}
for _, kase := range cases {
tags := splitTag(kase.tag)
if !utils.SliceEq(tags, kase.tags) {
t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags)
}
t.Run(kase.tag, func(t *testing.T) {
tags, err := splitTag(kase.tag)
assert.NoError(t, err)
assert.EqualValues(t, len(tags), len(kase.tags))
for i := 0; i < len(tags); i++ {
assert.Equal(t, tags[i], kase.tags[i])
}
})
}
}