support mysql shadow database for full chain stress test in the production environment #2155

Closed
Thomas_An wants to merge 10 commits from master into master
19 changed files with 544 additions and 30 deletions

View File

@ -89,6 +89,8 @@ type Dialect interface {
Filters() []Filter
SetParams(params map[string]string)
IsShadow(ctx context.Context) bool
SetShadowable(s Shadowable)
}
// Base represents a basic dialect and all real dialects could embed this struct
@ -96,6 +98,7 @@ type Base struct {
dialect Dialect
uri *URI
quoter schemas.Quoter
shadow Shadowable
}
// Alias returned col itself
@ -254,6 +257,16 @@ func (db *Base) ForUpdateSQL(query string) string {
func (db *Base) SetParams(params map[string]string) {
}
func (db *Base) IsShadow(ctx context.Context) bool {
if db.shadow != nil {
return db.shadow.IsShadow(ctx)
}
return false
}
func (db *Base) SetShadowable(shadow Shadowable) {
db.shadow = shadow
}
var (
dialects = map[string]func() Dialect{}
)

23
dialects/shadow.go Normal file
View File

@ -0,0 +1,23 @@
package dialects
import "context"
type Shadowable interface {
IsShadow(ctx context.Context) bool
}
type TrueShadow struct{}
type FalseShadow struct{}
func NewTrueShadow() Shadowable {
return &TrueShadow{}
}
func NewFalseShadow() Shadowable {
return &FalseShadow{}
}
func (t *TrueShadow) IsShadow(ctx context.Context) bool {
return true
}
func (f *FalseShadow) IsShadow(ctx context.Context) bool {
return false
}

View File

@ -5,6 +5,7 @@
package dialects
import (
"context"
"fmt"
"reflect"
"strings"
@ -14,6 +15,8 @@ import (
"xorm.io/xorm/schemas"
)
const ShadowDBNamePrefix = "shadow_"
// TableNameWithSchema will add schema prefix on table name if possible
func TableNameWithSchema(dialect Dialect, tableName string) string {
// Add schema name as prefix of table name.
@ -24,6 +27,18 @@ func TableNameWithSchema(dialect Dialect, tableName string) string {
return tableName
}
// TableNameWithDBName will add database name prefix on table name if possible
func TableNameWithDBName(dialect Dialect, tableName string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if dialect.URI().DBName != "" &&
dialect.URI().DBType == schemas.MYSQL &&
strings.Index(tableName, ".") == -1 {
return fmt.Sprintf("%s.%s", dialect.URI().DBName, tableName)
}
return tableName
}
// TableNameNoSchema returns table name with given tableName
func TableNameNoSchema(dialect Dialect, mapper names.Mapper, tableName interface{}) string {
quote := dialect.Quoter().Quote
@ -84,10 +99,19 @@ func TableNameNoSchema(dialect Dialect, mapper names.Mapper, tableName interface
}
// FullTableName returns table name with quote and schema according parameter
func FullTableName(dialect Dialect, mapper names.Mapper, bean interface{}, includeSchema ...bool) string {
func FullTableName(ctx context.Context, dialect Dialect, mapper names.Mapper, bean interface{}, includeSchema ...bool) string {
tbName := TableNameNoSchema(dialect, mapper, bean)
if len(includeSchema) > 0 && includeSchema[0] && !utils.IsSubQuery(tbName) {
tbName = TableNameWithSchema(dialect, tbName)
}
if dialect.URI() != nil &&
dialect.URI().DBType == schemas.MYSQL &&
dialect.IsShadow(ctx) && !hasShadowPrefix(tbName) {
tbName = ShadowDBNamePrefix + TableNameWithDBName(dialect, tbName)
}
return tbName
}
func hasShadowPrefix(tableName string) bool {
return strings.HasPrefix(tableName, ShadowDBNamePrefix)
}

View File

@ -5,6 +5,7 @@
package dialects
import (
"context"
"testing"
"xorm.io/xorm/names"
@ -23,8 +24,14 @@ func (mcc *MCC) TableName() string {
}
func TestFullTableName(t *testing.T) {
dialect := QueryDialect("mysql")
assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, &MCC{}))
assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, "mcc"))
dialect, err := OpenDialect("mysql", "root:root@tcp(127.0.0.1:3306)/test?charset=utf8")
if err != nil {
panic("unknow dialect")
}
dialect.SetShadowable(NewTrueShadow())
assert.EqualValues(t, "shadow_test.mcc", FullTableName(context.Background(), dialect, names.SnakeMapper{}, &MCC{}))
assert.EqualValues(t, "shadow_test.mcc", FullTableName(context.Background(), dialect, names.SnakeMapper{}, "mcc"))
dialect.SetShadowable(NewFalseShadow())
assert.EqualValues(t, "mcc", FullTableName(context.Background(), dialect, names.SnakeMapper{}, &MCC{}))
assert.EqualValues(t, "mcc", FullTableName(context.Background(), dialect, names.SnakeMapper{}, "mcc"))
}

View File

@ -291,7 +291,11 @@ func (engine *Engine) NoCascade() *Session {
// MapCacher Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error {
engine.SetCacher(dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, true), cacher)
for _, v := range []dialects.Shadowable{dialects.NewTrueShadow(), dialects.NewFalseShadow()} {
engine.dialect.SetShadowable(v)
engine.SetCacher(dialects.FullTableName(context.Background(), engine.dialect, engine.GetTableMapper(), bean, true), cacher)
engine.SetCacher(dialects.FullTableName(context.Background(), engine.dialect, engine.GetTableMapper(), bean, true), cacher)
}
return nil
}
@ -1067,7 +1071,12 @@ func (engine *Engine) IsTableExist(beanOrTableName interface{}) (bool, error) {
// TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
return dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, includeSchema...)
return dialects.FullTableName(context.Background(), engine.dialect, engine.GetTableMapper(), bean, includeSchema...)
}
// ContextTableName returns table name with schema and database prefix if has
func (engine *Engine) ContextTableName(ctx context.Context, bean interface{}, includeSchema ...bool) string {
return dialects.FullTableName(ctx, engine.dialect, engine.GetTableMapper(), bean, includeSchema...)
}
// CreateIndexes create indexes
@ -1086,23 +1095,29 @@ func (engine *Engine) CreateUniques(bean interface{}) error {
// ClearCacheBean if enabled cache, clear the cache bean
func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)
cacher := engine.GetCacher(tableName)
if cacher != nil {
cacher.ClearIds(tableName)
cacher.DelBean(tableName, id)
for _, v := range []dialects.Shadowable{dialects.NewTrueShadow(), dialects.NewFalseShadow()} {
engine.dialect.SetShadowable(v)
tableName := dialects.FullTableName(context.Background(), engine.dialect, engine.GetTableMapper(), bean)
cacher := engine.GetCacher(tableName)
if cacher != nil {
cacher.ClearIds(tableName)
cacher.DelBean(tableName, id)
}
}
return nil
}
// ClearCache if enabled cache, clear some tables' cache
func (engine *Engine) ClearCache(beans ...interface{}) error {
for _, bean := range beans {
tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)
cacher := engine.GetCacher(tableName)
if cacher != nil {
cacher.ClearIds(tableName)
cacher.ClearBeans(tableName)
for _, v := range []dialects.Shadowable{dialects.NewTrueShadow(), dialects.NewFalseShadow()} {
engine.dialect.SetShadowable(v)
for _, bean := range beans {
tableName := dialects.FullTableName(context.Background(), engine.dialect, engine.GetTableMapper(), bean)
cacher := engine.GetCacher(tableName)
if cacher != nil {
cacher.ClearIds(tableName)
cacher.ClearBeans(tableName)
}
}
}
return nil
@ -1431,3 +1446,8 @@ func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interf
return result, nil
}
// SetShadow Set whether to use shadow database algorithm, should be called after modify the cache setting
func (engine *Engine) SetShadow(shadow dialects.Shadowable) {
engine.dialect.SetShadowable(shadow)
}

View File

@ -265,3 +265,16 @@ func (eg *EngineGroup) Rows(bean interface{}) (*Rows, error) {
sess.isAutoClose = true
return sess.Rows(bean)
}
// SetShadow Set whether to use shadow database algorithm, should be called after modify the cache setting
func (eg *EngineGroup) SetShadow(shadow dialects.Shadowable) {
eg.Engine.SetShadow(shadow)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetShadow(shadow)
}
}
// ContextTableName returns table name with schema and database prefix if has
func (engine *EngineGroup) ContextTableName(ctx context.Context, bean interface{}, includeSchema ...bool) string {
return dialects.FullTableName(ctx, engine.dialect, engine.GetTableMapper(), bean, includeSchema...)
}

2
go.mod
View File

@ -17,5 +17,5 @@ require (
github.com/syndtr/goleveldb v1.0.0
github.com/ziutek/mymysql v1.5.4
modernc.org/sqlite v1.14.2
xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978
xorm.io/builder v0.3.11
)

4
go.sum
View File

@ -659,5 +659,5 @@ modernc.org/z v1.2.19 h1:BGyRFWhDVn5LFS5OcX4Yd/MlpRTOc7hOPTdcIpCiUao=
modernc.org/z v1.2.19/go.mod h1:+ZpP0pc4zz97eukOzW3xagV/lS82IpPN9NGG5pNF9vY=
sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o=
sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU=
xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978 h1:bvLlAPW1ZMTWA32LuZMBEGHAUOcATZjzHcotf3SWweM=
xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE=
xorm.io/builder v0.3.11 h1:naLkJitGyYW7ZZdncsh/JW+HF4HshmvTHTyUyPwJS00=
xorm.io/builder v0.3.11/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE=

View File

@ -6,6 +6,7 @@ package integrations
import (
"testing"
"xorm.io/xorm/dialects"
"xorm.io/xorm"
"xorm.io/xorm/log"
@ -32,4 +33,5 @@ func TestEngineGroup(t *testing.T) {
eg.SetColumnMapper(master.GetColumnMapper())
eg.SetLogLevel(log.LOG_INFO)
eg.ShowSQL(true)
eg.SetShadow(dialects.NewFalseShadow())
}

View File

@ -5,8 +5,10 @@
package integrations
import (
"context"
"testing"
"time"
"xorm.io/xorm/dialects"
"xorm.io/xorm/caches"
"xorm.io/xorm/schemas"
@ -191,6 +193,41 @@ func TestCacheDelete(t *testing.T) {
testEngine.SetDefaultCacher(oldCacher)
}
func TestShadowCacheDelete(t *testing.T) {
if testEngine.Dialect().URI().DBType != schemas.MYSQL {
return
}
testEngine.SetShadow(dialects.NewFalseShadow())
oldCacher := testEngine.GetDefaultCacher()
cacher := caches.NewLRUCacher(caches.NewMemoryStore(), 1000)
testEngine.SetDefaultCacher(cacher)
type ShadowCacheDeleteStruct struct {
Id int64
}
assert.NoError(t, testEngine.Context(context.Background()).Sync(&ShadowCacheDeleteStruct{}))
err := testEngine.CreateTables(&ShadowCacheDeleteStruct{})
assert.NoError(t, err)
_, err = testEngine.Insert(&ShadowCacheDeleteStruct{})
assert.NoError(t, err)
aff, err := testEngine.Delete(&ShadowCacheDeleteStruct{
Id: 1,
})
assert.NoError(t, err)
assert.EqualValues(t, aff, 1)
aff, err = testEngine.Unscoped().Delete(&ShadowCacheDeleteStruct{
Id: 1,
})
assert.NoError(t, err)
assert.EqualValues(t, aff, 0)
testEngine.SetDefaultCacher(oldCacher)
}
func TestUnscopeDelete(t *testing.T) {
assert.NoError(t, PrepareEngine())

View File

@ -5,6 +5,7 @@
package integrations
import (
"context"
"database/sql"
"errors"
"fmt"
@ -22,6 +23,212 @@ import (
"github.com/stretchr/testify/assert"
)
func TestShadowGetVar(t *testing.T) {
if testEngine.Dialect().URI().DBType != schemas.MYSQL {
return
}
type ShadowGetVar struct {
Id int64 `xorm:"autoincr pk"`
Msg string `xorm:"varchar(255)"`
Age int
Money float32
Created time.Time `xorm:"created"`
}
testEngine.SetShadow(dialects.NewFalseShadow())
assert.NoError(t, testEngine.Context(context.Background()).Sync(new(ShadowGetVar)))
data := ShadowGetVar{
Msg: "hi",
Age: 28,
Money: 1.5,
}
_, err := testEngine.InsertOne(&data)
assert.NoError(t, err)
var msg string
has, err := testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("msg").Get(&msg)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, "hi", msg)
var age int
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("age").Get(&age)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, 28, age)
var ageMax int
has, err = testEngine.SQL("SELECT max(`age`) FROM "+testEngine.Quote(testEngine.ContextTableName(context.Background(), "shadow_get_var"))+" WHERE `id` = ?", data.Id).Get(&ageMax)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, 28, ageMax)
var age2 int64
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("age").
Where("`age` > ?", 20).
And("`age` < ?", 30).
Get(&age2)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age2)
var age3 int8
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("age").Get(&age3)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age3)
var age4 int16
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("age").
Where("`age` > ?", 20).
And("`age` < ?", 30).
Get(&age4)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age4)
var age5 int32
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("age").
Where("`age` > ?", 20).
And("`age` < ?", 30).
Get(&age5)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age5)
var age6 int
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("age").Get(&age6)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age6)
var age7 int64
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("age").
Where("`age` > ?", 20).
And("`age` < ?", 30).
Get(&age7)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age7)
var age8 int8
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("age").Get(&age8)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age8)
var age9 int16
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("age").
Where("`age` > ?", 20).
And("`age` < ?", 30).
Get(&age9)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age9)
var age10 int32
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("age").
Where("`age` > ?", 20).
And("`age` < ?", 30).
Get(&age10)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age10)
var id sql.NullInt64
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("id").Get(&id)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, true, id.Valid)
var msgNull sql.NullString
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("msg").Get(&msgNull)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, true, msgNull.Valid)
assert.EqualValues(t, data.Msg, msgNull.String)
var nullMoney sql.NullFloat64
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("money").Get(&nullMoney)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, true, nullMoney.Valid)
assert.EqualValues(t, data.Money, nullMoney.Float64)
var money float64
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Cols("money").Get(&money)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))
var money2 float64
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
has, err = testEngine.SQL("SELECT TOP 1 `money` FROM " + testEngine.Quote(testEngine.ContextTableName(context.Background(), "shadow_get_var"))).Get(&money2)
} else {
has, err = testEngine.SQL("SELECT `money` FROM " + testEngine.Quote(testEngine.ContextTableName(context.Background(), "shadow_get_var")) + " LIMIT 1").Get(&money2)
}
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2))
var money3 float64
has, err = testEngine.SQL("SELECT `money` FROM " + testEngine.Quote(testEngine.ContextTableName(context.Background(), "shadow_get_var")) + " WHERE `money` > 20").Get(&money3)
assert.NoError(t, err)
assert.Equal(t, false, has)
valuesString := make(map[string]string)
has, err = testEngine.Table("shadow_get_var").Get(&valuesString)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, 5, len(valuesString))
assert.Equal(t, "1", valuesString["id"])
assert.Equal(t, "hi", valuesString["msg"])
assert.Equal(t, "28", valuesString["age"])
assert.Equal(t, "1.5", valuesString["money"])
// for mymysql driver, interface{} will be []byte, so ignore it currently
if testEngine.DriverName() != "mymysql" {
valuesInter := make(map[string]interface{})
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Where("`id` = ?", 1).Select("*").Get(&valuesInter)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, 5, len(valuesInter))
assert.EqualValues(t, 1, valuesInter["id"])
assert.Equal(t, "hi", fmt.Sprintf("%s", valuesInter["msg"]))
assert.EqualValues(t, 28, valuesInter["age"])
assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"]))
}
valuesSliceString := make([]string, 5)
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Get(&valuesSliceString)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, "1", valuesSliceString[0])
assert.Equal(t, "hi", valuesSliceString[1])
assert.Equal(t, "28", valuesSliceString[2])
assert.Equal(t, "1.5", valuesSliceString[3])
valuesSliceInter := make([]interface{}, 5)
has, err = testEngine.Table(testEngine.ContextTableName(context.Background(), "shadow_get_var")).Get(&valuesSliceInter)
assert.NoError(t, err)
assert.Equal(t, true, has)
v1, err := convert.AsInt64(valuesSliceInter[0])
assert.NoError(t, err)
assert.EqualValues(t, 1, v1)
assert.Equal(t, "hi", fmt.Sprintf("%s", valuesSliceInter[1]))
v3, err := convert.AsInt64(valuesSliceInter[2])
assert.NoError(t, err)
assert.EqualValues(t, 28, v3)
v4, err := convert.AsFloat64(valuesSliceInter[3])
assert.NoError(t, err)
assert.Equal(t, "1.5", fmt.Sprintf("%v", v4))
}
func TestGetVar(t *testing.T) {
assert.NoError(t, PrepareEngine())

View File

@ -5,13 +5,14 @@
package integrations
import (
"context"
"fmt"
"github.com/stretchr/testify/assert"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"xorm.io/xorm"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/statements"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/names"
@ -300,6 +301,16 @@ type Article struct {
Status int8 `xorm:"TINYINT(4)"`
}
type ShadowArticle struct {
Id int32 `xorm:"pk INT autoincr"`
Name string `xorm:"VARCHAR(45)"`
Img string `xorm:"VARCHAR(100)"`
Aside string `xorm:"VARCHAR(200)"`
Desc string `xorm:"VARCHAR(200)"`
Content string `xorm:"TEXT"`
Status int8 `xorm:"TINYINT(4)"`
}
func TestUpdateMap2(t *testing.T) {
assert.NoError(t, PrepareEngine())
assertSync(t, new(UpdateMustCols))
@ -1470,3 +1481,149 @@ func TestNilFromDB(t *testing.T) {
assert.NotNil(t, tt4.Field1)
assert.NotNil(t, tt4.Field1.cb)
}
type ShadowUserinfo struct {
Uid int64 `xorm:"id pk not null autoincr"`
Username string `xorm:"unique"`
Departname string
Alias string `xorm:"-"`
Created time.Time
Detail Userdetail `xorm:"detail_id int(11)"`
Height float64
Avatar []byte
IsMan bool
}
func TestShadowMysqlUpdate1(t *testing.T) {
if testEngine.Dialect().URI().DBType != schemas.MYSQL {
return
}
testEngine.SetShadow(dialects.NewFalseShadow())
assert.NoError(t, testEngine.Context(context.Background()).Sync(&ShadowUserinfo{}))
_, err := testEngine.Insert(&ShadowUserinfo{
Username: "user1",
})
assert.NoError(t, err)
var ori ShadowUserinfo
has, err := testEngine.Get(&ori)
assert.NoError(t, err)
assert.True(t, has)
// update by id
user := ShadowUserinfo{Username: "xxx", Height: 1.2}
cnt, err := testEngine.ID(ori.Uid).Update(&user)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
condi := Condi{"username": "zzz", "departname": ""}
cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
cnt, err = testEngine.Update(&ShadowUserinfo{Username: "yyy"}, &user)
assert.NoError(t, err)
total, err := testEngine.Count(&user)
assert.NoError(t, err)
assert.EqualValues(t, cnt, total)
// nullable update
{
user := &ShadowUserinfo{Username: "not null data", Height: 180.5}
_, err := testEngine.Insert(user)
assert.NoError(t, err)
userID := user.Uid
has, err := testEngine.ID(userID).
And("`username` = ?", user.Username).
And("`height` = ?", user.Height).
And("`departname` = ?", "").
And("`detail_id` = ?", 0).
And("`is_man` = ?", false).
Get(&ShadowUserinfo{})
assert.NoError(t, err)
assert.True(t, has, "cannot insert properly")
updatedUser := &ShadowUserinfo{Username: "null data"}
cnt, err = testEngine.ID(userID).
Nullable("height", "departname", "is_man", "created").
Update(updatedUser)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt, "update not returned 1")
has, err = testEngine.ID(userID).
And("`username` = ?", updatedUser.Username).
And("`height` IS NULL").
And("`departname` IS NULL").
And("`is_man` IS NULL").
And("`created` IS NULL").
And("`detail_id` = ?", 0).
Get(&ShadowUserinfo{})
assert.NoError(t, err)
assert.True(t, has, "cannot update with null properly")
cnt, err = testEngine.ID(userID).Delete(&ShadowUserinfo{})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt, "delete not returned 1")
}
err = testEngine.StoreEngine("Innodb").Sync(&ShadowArticle{})
assert.NoError(t, err)
defer func() {
err = testEngine.DropTables(&ShadowArticle{})
assert.NoError(t, err)
}()
a := &ShadowArticle{0, "1", "2", "3", "4", "5", 2}
cnt, err = testEngine.Insert(a)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt, fmt.Sprintf("insert not returned 1 but %d", cnt))
assert.Greater(t, a.Id, int32(0), "insert returned id is 0")
cnt, err = testEngine.ID(a.Id).Update(&ShadowArticle{Name: "6"})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var s = "test"
col1 := &UpdateAllCols{Ptr: &s}
err = testEngine.Sync(col1)
assert.NoError(t, err)
_, err = testEngine.Insert(col1)
assert.NoError(t, err)
col2 := &UpdateAllCols{col1.Id, true, "", nil}
_, err = testEngine.ID(col2.Id).AllCols().Update(col2)
assert.NoError(t, err)
col3 := &UpdateAllCols{}
has, err = testEngine.ID(col2.Id).Get(col3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, *col2, *col3)
{
col1 := &UpdateMustCols{}
err = testEngine.Sync(col1)
assert.NoError(t, err)
_, err = testEngine.Insert(col1)
assert.NoError(t, err)
col2 := &UpdateMustCols{col1.Id, true, ""}
boolStr := testEngine.GetColumnMapper().Obj2Table("Bool")
stringStr := testEngine.GetColumnMapper().Obj2Table("String")
_, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2)
assert.NoError(t, err)
col3 := &UpdateMustCols{}
has, err := testEngine.ID(col2.Id).Get(col3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, *col2, *col3)
}
}

View File

@ -123,8 +123,10 @@ type EngineInterface interface {
StoreEngine(storeEngine string) *Session
TableInfo(bean interface{}) (*schemas.Table, error)
TableName(interface{}, ...bool) string
ContextTableName(context.Context, interface{}, ...bool) string
UnMapType(reflect.Type)
EnableSessionID(bool)
SetShadow(shadow dialects.Shadowable)
}
var (

View File

@ -51,7 +51,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition))
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
tbName := dialects.FullTableName(statement.ctx, statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
if !utils.IsSubQuery(tbName) {
var buf strings.Builder
_ = statement.dialect.Quoter().QuoteTo(&buf, tbName)

View File

@ -5,6 +5,7 @@
package statements
import (
"context"
"database/sql/driver"
"errors"
"fmt"
@ -36,6 +37,7 @@ var (
// Statement save all the sql info for executing SQL
type Statement struct {
ctx context.Context
RefTable *schemas.Table
dialect dialects.Dialect
defaultTimeZone *time.Location
@ -82,8 +84,9 @@ type Statement struct {
}
// NewStatement creates a new statement
func NewStatement(dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZone *time.Location) *Statement {
func NewStatement(ctx context.Context, dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZone *time.Location) *Statement {
statement := &Statement{
ctx: ctx,
dialect: dialect,
tagParser: tagParser,
defaultTimeZone: defaultTimeZone,
@ -186,7 +189,8 @@ func (statement *Statement) SetRefValue(v reflect.Value) error {
if err != nil {
return err
}
statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), v, true)
statement.tableName = dialects.FullTableName(statement.ctx, statement.dialect,
statement.tagParser.GetTableMapper(), v, true)
return nil
}
@ -201,7 +205,8 @@ func (statement *Statement) SetRefBean(bean interface{}) error {
if err != nil {
return err
}
statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), bean, true)
statement.tableName = dialects.FullTableName(statement.ctx, statement.dialect,
statement.tagParser.GetTableMapper(), bean, true)
return nil
}
@ -280,7 +285,8 @@ func (statement *Statement) SetTable(tableNameOrBean interface{}) error {
}
}
statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tableNameOrBean, true)
statement.AltTableName = dialects.FullTableName(statement.ctx, statement.dialect, statement.tagParser.GetTableMapper(),
tableNameOrBean, true)
return nil
}

View File

@ -5,6 +5,7 @@
package statements
import (
"context"
"os"
"reflect"
"strings"
@ -171,7 +172,7 @@ func (TestType) TableName() string {
}
func createTestStatement() (*Statement, error) {
statement := NewStatement(dialect, tagParser, time.Local)
statement := NewStatement(context.Background(), dialect, tagParser, time.Local)
if err := statement.SetRefValue(reflect.ValueOf(TestType{})); err != nil {
return nil, err
}

View File

@ -113,6 +113,7 @@ func newSession(engine *Engine) *Session {
engine: engine,
tx: nil,
statement: statements.NewStatement(
ctx,
engine.dialect,
engine.tagParser,
engine.DatabaseTZ,

View File

@ -392,6 +392,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
statement := session.statement
session.statement = statements.NewStatement(
session.ctx,
session.engine.dialect,
session.engine.tagParser,
session.engine.DatabaseTZ,

View File

@ -280,7 +280,7 @@ func (session *Session) Sync(beans ...interface{}) error {
if len(session.statement.AltTableName) > 0 {
tbName = session.statement.AltTableName
} else {
tbName = engine.TableName(bean)
tbName = engine.ContextTableName(session.ctx, bean)
}
tbNameWithSchema := engine.tbNameWithSchema(tbName)