Fix table name #1590

Merged
lunny merged 2 commits from lunny/fix_tableName into master 2020-03-10 08:15:16 +00:00
2 changed files with 67 additions and 10 deletions

View File

@ -6,6 +6,7 @@ package names
import ( import (
"reflect" "reflect"
"sync"
) )
// TableName table name interface to define customerize table name // TableName table name interface to define customerize table name
@ -15,23 +16,40 @@ type TableName interface {
var ( var (
tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() tpTableName = reflect.TypeOf((*TableName)(nil)).Elem()
tvCache sync.Map
) )
func GetTableName(mapper Mapper, v reflect.Value) string { func GetTableName(mapper Mapper, v reflect.Value) string {
if t, ok := v.Interface().(TableName); ok {
return t.TableName()
}
if v.Type().Implements(tpTableName) { if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName() return v.Interface().(TableName).TableName()
} }
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
v = v.Elem() v = v.Elem()
if t, ok := v.Interface().(TableName); ok {
return t.TableName()
}
if v.Type().Implements(tpTableName) { if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName() return v.Interface().(TableName).TableName()
} }
} else if v.CanAddr() {
v1 := v.Addr()
if v1.Type().Implements(tpTableName) {
return v1.Interface().(TableName).TableName()
}
} else {
name, ok := tvCache.Load(v.Type())
if ok {
if name.(string) != "" {
return name.(string)
}
} else {
v2 := reflect.New(v.Type())
if v2.Type().Implements(tpTableName) {
tableName := v2.Interface().(TableName).TableName()
tvCache.Store(v.Type(), tableName)
return tableName
}
tvCache.Store(v.Type(), "")
}
} }
return mapper.Obj2Table(v.Type().Name()) return mapper.Obj2Table(v.Type().Name())

View File

@ -5,6 +5,7 @@
package names package names
import ( import (
"fmt"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -43,8 +44,10 @@ func (MyGetCustomTableImpletation) TableName() string {
type TestTableNameStruct struct{} type TestTableNameStruct struct{}
const getTestTableName = "my_test_table_name_struct"
func (t *TestTableNameStruct) TableName() string { func (t *TestTableNameStruct) TableName() string {
return "my_test_table_name_struct" return getTestTableName
} }
func TestGetTableName(t *testing.T) { func TestGetTableName(t *testing.T) {
@ -85,13 +88,18 @@ func TestGetTableName(t *testing.T) {
}, },
{ {
SnakeMapper{}, SnakeMapper{},
reflect.ValueOf(MyGetCustomTableImpletation{}), reflect.ValueOf(new(TestTableNameStruct)),
getCustomTableName, new(TestTableNameStruct).TableName(),
}, },
{ {
SnakeMapper{}, SnakeMapper{},
reflect.ValueOf(new(TestTableNameStruct)), reflect.ValueOf(new(TestTableNameStruct)),
new(TestTableNameStruct).TableName(), getTestTableName,
},
{
SnakeMapper{},
reflect.ValueOf(TestTableNameStruct{}),
getTestTableName,
}, },
} }
@ -99,3 +107,34 @@ func TestGetTableName(t *testing.T) {
assert.EqualValues(t, kase.expectedTableName, GetTableName(kase.mapper, kase.v)) assert.EqualValues(t, kase.expectedTableName, GetTableName(kase.mapper, kase.v))
} }
} }
type OAuth2Application struct {
}
// TableName sets the table name to `oauth2_application`
func (app *OAuth2Application) TableName() string {
return "oauth2_application"
}
func TestGonicMapperCustomTable(t *testing.T) {
assert.EqualValues(t, "oauth2_application",
GetTableName(LintGonicMapper, reflect.ValueOf(new(OAuth2Application))))
assert.EqualValues(t, "oauth2_application",
GetTableName(LintGonicMapper, reflect.ValueOf(OAuth2Application{})))
}
type MyTable struct {
Idx int
}
func (t *MyTable) TableName() string {
return fmt.Sprintf("mytable_%d", t.Idx)
}
func TestMyTable(t *testing.T) {
var table MyTable
for i := 0; i < 10; i++ {
table.Idx = i
assert.EqualValues(t, fmt.Sprintf("mytable_%d", i), GetTableName(SameMapper{}, reflect.ValueOf(&table)))
}
}