Browse Source

Add SetSchema for engine (#876)

* add SetSchema for engine

* fix user

* fix postgres with schema

* fix test

* fix test

* fix test

* fix tablename

* refactor tableName

* fix schema support

* improve the interface of EngineInterface
tags/v0.6.5
Lunny Xiao GitHub 1 year ago
parent
commit
bd20c37bfb
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 373 additions and 592 deletions
  1. +4
    -1
      circle.yml
  2. +9
    -2
      dialect_postgres.go
  3. +20
    -67
      engine.go
  4. +109
    -0
      engine_table.go
  5. +2
    -0
      interface.go
  6. +3
    -3
      rows.go
  7. +2
    -1
      rows_test.go
  8. +0
    -9
      session.go
  9. +22
    -90
      session_cond_test.go
  10. +1
    -1
      session_delete.go
  11. +1
    -1
      session_exist.go
  12. +2
    -2
      session_exist_test.go
  13. +13
    -42
      session_find_test.go
  14. +1
    -1
      session_get.go
  15. +1
    -1
      session_get_test.go
  16. +1
    -1
      session_insert.go
  17. +2
    -1
      session_insert_test.go
  18. +19
    -4
      session_pk_test.go
  19. +5
    -5
      session_query_test.go
  20. +2
    -2
      session_raw_test.go
  21. +30
    -51
      session_schema.go
  22. +1
    -1
      session_stats_test.go
  23. +19
    -81
      session_tx_test.go
  24. +1
    -1
      session_update.go
  25. +27
    -109
      session_update_test.go
  26. +42
    -76
      statement.go
  27. +27
    -36
      tag_extends_test.go
  28. +3
    -2
      types_test.go
  29. +4
    -1
      xorm_test.go

+ 4
- 1
circle.yml View File

@@ -17,6 +17,7 @@ database:
- createdb -p 5432 -e -U postgres xorm_test1
- createdb -p 5432 -e -U postgres xorm_test2
- createdb -p 5432 -e -U postgres xorm_test3
- psql xorm_test postgres -c "create schema xorm"

test:
override:
@@ -30,7 +31,9 @@ test:
- go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -coverprofile=coverage4-1.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic
- gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt > coverage.txt
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic
- go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic
- gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt coverage5-1.txt coverage5-2.txt > coverage.txt
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./sqlite3.sh
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./mysql.sh
- cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh


+ 9
- 2
dialect_postgres.go View File

@@ -895,6 +895,7 @@ func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
}

args := []interface{}{db.Schema, tableName}
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args
}
@@ -912,6 +913,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
quote := db.Quote
idxName := index.Name

tableName = strings.Replace(tableName, `"`, "", -1)
tableName = strings.Replace(tableName, `.`, "_", -1)

if !strings.HasPrefix(idxName, "UQE_") &&
!strings.HasPrefix(idxName, "IDX_") {
if index.Type == core.UniqueType {
@@ -920,6 +924,9 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
}
}
if db.Uri.Schema != "" {
idxName = db.Uri.Schema + "." + idxName
}
return fmt.Sprintf("DROP INDEX %v", quote(idxName))
}

@@ -960,7 +967,7 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
var f string
if len(db.Schema) != 0 {
args = append(args, db.Schema)
f = "AND s.table_schema = $2"
f = " AND s.table_schema = $2"
}
s = fmt.Sprintf(s, f)

@@ -1085,11 +1092,11 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
db.LogSQL(s, args)
if len(db.Schema) != 0 {
args = append(args, db.Schema)
s = s + " AND schemaname=$2"
}
db.LogSQL(s, args)

rows, err := db.DB().Query(s, args...)
if err != nil {


+ 20
- 67
engine.go View File

@@ -536,46 +536,6 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
return nil
}

func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) {
v := rValue(beanOrTableName)
if v.Type().Kind() == reflect.String {
return beanOrTableName.(string), nil
} else if v.Type().Kind() == reflect.Struct {
return engine.tbName(v), nil
}
return "", errors.New("bean should be a struct or struct's point")
}

func (engine *Engine) tbSchemaName(v string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if engine.dialect.DBType() == core.POSTGRES &&
engine.dialect.URI().Schema != "" &&
engine.dialect.URI().Schema != postgresPublicSchema &&
strings.Index(v, ".") == -1 {
return engine.dialect.URI().Schema + "." + v
}
return v
}

func (engine *Engine) tbName(v reflect.Value) string {
if tb, ok := v.Interface().(TableName); ok {
return engine.tbSchemaName(tb.TableName())

}

if v.Type().Kind() == reflect.Ptr {
if tb, ok := reflect.Indirect(v).Interface().(TableName); ok {
return engine.tbSchemaName(tb.TableName())
}
} else if v.CanAddr() {
if tb, ok := v.Addr().Interface().(TableName); ok {
return engine.tbSchemaName(tb.TableName())
}
}
return engine.tbSchemaName(engine.TableMapper.Obj2Table(reflect.Indirect(v).Type().Name()))
}

// Cascade use cascade or not
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
session := engine.NewSession()
@@ -859,7 +819,7 @@ func (engine *Engine) TableInfo(bean interface{}) *Table {
if err != nil {
engine.logger.Error(err)
}
return &Table{tb, engine.tbName(v)}
return &Table{tb, engine.TableName(bean)}
}

func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) {
@@ -895,20 +855,8 @@ var (
func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
t := v.Type()
table := engine.newTable()
if tb, ok := v.Interface().(TableName); ok {
table.Name = tb.TableName()
} else {
if v.CanAddr() {
if tb, ok = v.Addr().Interface().(TableName); ok {
table.Name = tb.TableName()
}
}
if table.Name == "" {
table.Name = engine.TableMapper.Obj2Table(t.Name())
}
}

table.Type = t
table.Name = engine.tbNameForMap(v)

var idFieldColName string
var hasCacheTag, hasNoCacheTag bool
@@ -1186,7 +1134,7 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
if t.Kind() != reflect.Struct {
return errors.New("error params")
}
tableName := engine.tbName(v)
tableName := engine.TableName(bean)
table, err := engine.autoMapType(v)
if err != nil {
return err
@@ -1210,7 +1158,7 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
if t.Kind() != reflect.Struct {
return errors.New("error params")
}
tableName := engine.tbName(v)
tableName := engine.TableName(bean)
table, err := engine.autoMapType(v)
if err != nil {
return err
@@ -1237,13 +1185,13 @@ func (engine *Engine) Sync(beans ...interface{}) error {

for _, bean := range beans {
v := rValue(bean)
tableName := engine.tbName(v)
tableNameNoSchema := engine.tbNameNoSchema(v.Interface())
table, err := engine.autoMapType(v)
if err != nil {
return err
}

isExist, err := session.Table(bean).isTableExist(tableName)
isExist, err := session.Table(bean).isTableExist(tableNameNoSchema)
if err != nil {
return err
}
@@ -1269,12 +1217,12 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}
} else {
for _, col := range table.Columns() {
isExist, err := engine.dialect.IsColumnExist(tableName, col.Name)
isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name)
if err != nil {
return err
}
if !isExist {
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}
err = session.addColumn(col.Name)
@@ -1285,35 +1233,35 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}

for name, index := range table.Indexes {
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}
if index.Type == core.UniqueType {
isExist, err := session.isIndexExist2(tableName, index.Cols, true)
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true)
if err != nil {
return err
}
if !isExist {
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}

err = session.addUnique(tableName, name)
err = session.addUnique(tableNameNoSchema, name)
if err != nil {
return err
}
}
} else if index.Type == core.IndexType {
isExist, err := session.isIndexExist2(tableName, index.Cols, false)
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false)
if err != nil {
return err
}
if !isExist {
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}

err = session.addIndex(tableName, name)
err = session.addIndex(tableNameNoSchema, name)
if err != nil {
return err
}
@@ -1649,6 +1597,11 @@ func (engine *Engine) SetTZDatabase(tz *time.Location) {
engine.DatabaseTZ = tz
}

// SetSchema sets the schema of database
func (engine *Engine) SetSchema(schema string) {
engine.dialect.URI().Schema = schema
}

// Unscoped always disable struct tag "deleted"
func (engine *Engine) Unscoped() *Session {
session := engine.NewSession()


+ 109
- 0
engine_table.go View File

@@ -0,0 +1,109 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm

import (
"fmt"
"reflect"
"strings"

"github.com/go-xorm/core"
)

// TableNameWithSchema will automatically add schema prefix on table name
func (engine *Engine) tbNameWithSchema(v string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if engine.dialect.DBType() == core.POSTGRES &&
engine.dialect.URI().Schema != "" &&
engine.dialect.URI().Schema != postgresPublicSchema &&
strings.Index(v, ".") == -1 {
return engine.dialect.URI().Schema + "." + v
}
return v
}

// TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
tbName := engine.tbNameNoSchema(bean)
if len(includeSchema) > 0 && includeSchema[0] {
tbName = engine.tbNameWithSchema(tbName)
}

return tbName
}

// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}

return table.Name
}

func (engine *Engine) tbNameForMap(v reflect.Value) string {
t := v.Type()
if tb, ok := v.Interface().(TableName); ok {
return tb.TableName()
}
if v.CanAddr() {
if tb, ok := v.Addr().Interface().(TableName); ok {
return tb.TableName()
}
}
return engine.TableMapper.Obj2Table(t.Name())
}

func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
} else if len(t) == 1 {
return engine.Quote(t[0])
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case TableName:
table = f.(TableName).TableName()
default:
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
table = engine.tbNameForMap(v)
} else {
table = engine.Quote(fmt.Sprintf("%v", f))
}
}
}
if l > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(table),
engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
return engine.Quote(table)
}
case TableName:
return tablename.(TableName).TableName()
case string:
return tablename.(string)
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
return engine.tbNameForMap(v)
}
return engine.Quote(fmt.Sprintf("%v", tablename))
}
return ""
}

+ 2
- 0
interface.go View File

@@ -87,6 +87,7 @@ type EngineInterface interface {
SetDefaultCacher(core.Cacher)
SetLogLevel(core.LogLevel)
SetMapper(core.IMapper)
SetSchema(string)
SetTZDatabase(tz *time.Location)
SetTZLocation(tz *time.Location)
ShowSQL(show ...bool)
@@ -94,6 +95,7 @@ type EngineInterface interface {
Sync2(...interface{}) error
StoreEngine(storeEngine string) *Session
TableInfo(bean interface{}) *Table
TableName(interface{}, ...bool) string
UnMapType(reflect.Type)
}



+ 3
- 3
rows.go View File

@@ -32,7 +32,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
var args []interface{}
var err error

if err = rows.session.statement.setRefValue(rValue(bean)); err != nil {
if err = rows.session.statement.setRefBean(bean); err != nil {
return nil, err
}

@@ -94,8 +94,7 @@ func (rows *Rows) Scan(bean interface{}) error {
return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType)
}

dataStruct := rValue(bean)
if err := rows.session.statement.setRefValue(dataStruct); err != nil {
if err := rows.session.statement.setRefBean(bean); err != nil {
return err
}

@@ -104,6 +103,7 @@ func (rows *Rows) Scan(bean interface{}) error {
return err
}

dataStruct := rValue(bean)
_, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable)
if err != nil {
return err


+ 2
- 1
rows_test.go View File

@@ -54,7 +54,8 @@ func TestRows(t *testing.T) {
}
assert.EqualValues(t, 1, cnt)

rows2, err := testEngine.SQL("SELECT * FROM user_rows").Rows(new(UserRows))
var tbName = testEngine.Quote(testEngine.TableName(user, true))
rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows))
assert.NoError(t, err)
defer rows2.Close()



+ 0
- 9
session.go View File

@@ -828,15 +828,6 @@ func (session *Session) LastSQL() (string, []interface{}) {
return session.lastSQL, session.lastSQLArgs
}

// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}

return table.Name
}

// Unscoped always disable struct tag "deleted"
func (session *Session) Unscoped() *Session {
session.statement.Unscoped()


+ 22
- 90
session_cond_test.go View File

@@ -122,18 +122,11 @@ func TestIn(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 3, cnt)

department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`"
var usrs []Userinfo
err = testEngine.Limit(3).Find(&usrs)
if err != nil {
t.Error(err)
panic(err)
}

if len(usrs) != 3 {
err = errors.New("there are not 3 records")
t.Error(err)
panic(err)
}
err = testEngine.Where(department+" = ?", "dev").Limit(3).Find(&usrs)
assert.NoError(t, err)
assert.EqualValues(t, 3, len(usrs))

var ids []int64
var idsStr string
@@ -145,35 +138,20 @@ func TestIn(t *testing.T) {

users := make([]Userinfo, 0)
err = testEngine.In("(id)", ids[0], ids[1], ids[2]).Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)
if len(users) != 3 {
err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err)
panic(err)
}
assert.EqualValues(t, 3, len(users))

users = make([]Userinfo, 0)
err = testEngine.In("(id)", ids).Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)
if len(users) != 3 {
err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err)
panic(err)
}
assert.EqualValues(t, 3, len(users))

for _, user := range users {
if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] {
err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err)
panic(err)
assert.NoError(t, err)
}
}

@@ -183,87 +161,41 @@ func TestIn(t *testing.T) {
idsInterface = append(idsInterface, id)
}

department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`"
err = testEngine.Where(department+" = ?", "dev").In("(id)", idsInterface...).Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)

if len(users) != 3 {
err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err)
panic(err)
}
assert.EqualValues(t, 3, len(users))

for _, user := range users {
if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] {
err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err)
panic(err)
assert.NoError(t, err)
}
}

dev := testEngine.GetColumnMapper().Obj2Table("Dev")

err = testEngine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users)

if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)

cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev-"})
if err != nil {
t.Error(err)
panic(err)
}
if cnt != 1 {
err = errors.New("update records not 1")
t.Error(err)
panic(err)
}
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)

user := new(Userinfo)
has, err := testEngine.ID(ids[0]).Get(user)
if err != nil {
t.Error(err)
panic(err)
}
if !has {
err = errors.New("get record not 1")
t.Error(err)
panic(err)
}
if user.Departname != "dev-" {
err = errors.New("update not success")
t.Error(err)
panic(err)
}
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "dev-", user.Departname)

cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev"})
if err != nil {
t.Error(err)
panic(err)
}
if cnt != 1 {
err = errors.New("update records not 1")
t.Error(err)
panic(err)
}
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)

cnt, err = testEngine.In("(id)", ids[1]).Delete(&Userinfo{})
if err != nil {
t.Error(err)
panic(err)
}
if cnt != 1 {
err = errors.New("deleted records not 1")
t.Error(err)
panic(err)
}
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
}

func TestFindAndCount(t *testing.T) {


+ 1
- 1
session_delete.go View File

@@ -79,7 +79,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
defer session.Close()
}

if err := session.statement.setRefValue(rValue(bean)); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return 0, err
}



+ 1
- 1
session_exist.go View File

@@ -57,7 +57,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
}

if beanValue.Elem().Kind() == reflect.Struct {
if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
if err := session.statement.setRefBean(bean[0]); err != nil {
return false, err
}
}


+ 2
- 2
session_exist_test.go View File

@@ -54,11 +54,11 @@ func TestExistStruct(t *testing.T) {
assert.NoError(t, err)
assert.False(t, has)

has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist()
has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test1").Exist()
assert.NoError(t, err)
assert.True(t, has)

has, err = testEngine.SQL("select * from record_exist where name = ?", "test2").Exist()
has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test2").Exist()
assert.NoError(t, err)
assert.False(t, has)



+ 13
- 42
session_find_test.go View File

@@ -96,21 +96,15 @@ func TestFind(t *testing.T) {
users := make([]Userinfo, 0)

err := testEngine.Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
for _, user := range users {
fmt.Println(user)
}

users2 := make([]Userinfo, 0)
userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo")
err = testEngine.SQL("select * from " + testEngine.Quote(userinfo)).Find(&users2)
if err != nil {
t.Error(err)
panic(err)
}
var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true))
err = testEngine.SQL("select * from " + tbName).Find(&users2)
assert.NoError(t, err)
}

func TestFind2(t *testing.T) {
@@ -238,14 +232,8 @@ func TestDistinct(t *testing.T) {
users := make([]Userinfo, 0)
departname := testEngine.GetTableMapper().Obj2Table("Departname")
err = testEngine.Distinct(departname).Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
if len(users) != 1 {
t.Error(err)
panic(errors.New("should be one record"))
}
assert.NoError(t, err)
assert.EqualValues(t, 1, len(users))

fmt.Println(users)

@@ -255,11 +243,9 @@ func TestDistinct(t *testing.T) {

users2 := make([]Depart, 0)
err = testEngine.Distinct(departname).Table(new(Userinfo)).Find(&users2)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
if len(users2) != 1 {
fmt.Println(len(users2))
t.Error(err)
panic(errors.New("should be one record"))
}
@@ -272,18 +258,12 @@ func TestOrder(t *testing.T) {

users := make([]Userinfo, 0)
err := testEngine.OrderBy("id desc").Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)

users2 := make([]Userinfo, 0)
err = testEngine.Asc("id", "username").Desc("height").Find(&users2)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users2)
}

@@ -293,10 +273,7 @@ func TestHaving(t *testing.T) {

users := make([]Userinfo, 0)
err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)

/*users = make([]Userinfo, 0)
@@ -324,18 +301,12 @@ func TestOrderSameMapper(t *testing.T) {

users := make([]Userinfo, 0)
err := testEngine.OrderBy("(id) desc").Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)

users2 := make([]Userinfo, 0)
err = testEngine.Asc("(id)", "Username").Desc("Height").Find(&users2)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users2)
}



+ 1
- 1
session_get.go View File

@@ -31,7 +31,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
}

if beanValue.Elem().Kind() == reflect.Struct {
if err := session.statement.setRefValue(beanValue.Elem()); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return false, err
}
}


+ 1
- 1
session_get_test.go View File

@@ -84,7 +84,7 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))

var money2 float64
has, err = testEngine.SQL("SELECT money FROM get_var LIMIT 1").Get(&money2)
has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2))


+ 1
- 1
session_insert.go View File

@@ -298,7 +298,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
}

func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err := session.statement.setRefValue(rValue(bean)); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return 0, err
}
if len(session.statement.TableName()) <= 0 {


+ 2
- 1
session_insert_test.go View File

@@ -716,8 +716,9 @@ func (MyUserinfo2) TableName() string {
func TestInsertMulti4(t *testing.T) {
assert.NoError(t, prepareEngine())

testEngine.ShowSQL(true)
testEngine.ShowSQL(false)
assertSync(t, new(MyUserinfo2))
testEngine.ShowSQL(true)

users := []MyUserinfo2{
{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()},


+ 19
- 4
session_pk_test.go View File

@@ -1118,13 +1118,28 @@ func TestCompositePK(t *testing.T) {
}

assert.NoError(t, prepareEngine())
assertSync(t, new(TaskSolution))

tables1, err := testEngine.DBMetas()
assert.NoError(t, err)

assertSync(t, new(TaskSolution))
assert.NoError(t, testEngine.Sync2(new(TaskSolution)))
tables, err := testEngine.DBMetas()

tables2, err := testEngine.DBMetas()
assert.NoError(t, err)
assert.EqualValues(t, 1, len(tables))
pkCols := tables[0].PKColumns()
assert.EqualValues(t, 1+len(tables1), len(tables2))

var table *core.Table
for _, t := range tables2 {
if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") {
table = t
break
}
}

assert.NotEqual(t, nil, table)

pkCols := table.PKColumns()
assert.EqualValues(t, 2, len(pkCols))
assert.EqualValues(t, "uid", pkCols[0].Name)
assert.EqualValues(t, "tid", pkCols[1].Name)


+ 5
- 5
session_query_test.go View File

@@ -36,7 +36,7 @@ func TestQueryString(t *testing.T) {
_, err := testEngine.InsertOne(data)
assert.NoError(t, err)

records, err := testEngine.QueryString("select * from get_var2")
records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var2", true))
assert.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, 5, len(records[0]))
@@ -62,7 +62,7 @@ func TestQueryString2(t *testing.T) {
_, err := testEngine.Insert(data)
assert.NoError(t, err)

records, err := testEngine.QueryString("select * from get_var3")
records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var3", true))
assert.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, 2, len(records[0]))
@@ -127,7 +127,7 @@ func TestQueryInterface(t *testing.T) {
_, err := testEngine.InsertOne(data)
assert.NoError(t, err)

records, err := testEngine.QueryInterface("select * from get_var_interface")
records, err := testEngine.QueryInterface("select * from " + testEngine.TableName("get_var_interface", true))
assert.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, 5, len(records[0]))
@@ -181,7 +181,7 @@ func TestQueryNoParams(t *testing.T) {
assert.NoError(t, err)
assertResult(t, results)

results, err = testEngine.SQL("select * from query_no_params").Query()
results, err = testEngine.SQL("select * from " + testEngine.TableName("query_no_params", true)).Query()
assert.NoError(t, err)
assertResult(t, results)
}
@@ -226,7 +226,7 @@ func TestQueryWithBuilder(t *testing.T) {
assert.EqualValues(t, 3000, money)
}

results, err := testEngine.Query(builder.Select("*").From("query_with_builder"))
results, err := testEngine.Query(builder.Select("*").From(testEngine.TableName("query_with_builder", true)))
assert.NoError(t, err)
assertResult(t, results)
}

+ 2
- 2
session_raw_test.go View File

@@ -21,13 +21,13 @@ func TestExecAndQuery(t *testing.T) {

assert.NoError(t, testEngine.Sync2(new(UserinfoQuery)))

res, err := testEngine.Exec("INSERT INTO `userinfo_query` (uid, name) VALUES (?, ?)", 1, "user")
res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_query`", true)+" (uid, name) VALUES (?, ?)", 1, "user")
assert.NoError(t, err)
cnt, err := res.RowsAffected()
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)

results, err := testEngine.Query("select * from userinfo_query")
results, err := testEngine.Query("select * from " + testEngine.TableName("userinfo_query", true))
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
id, err := strconv.Atoi(string(results[0]["uid"]))


+ 30
- 51
session_schema.go View File

@@ -6,9 +6,7 @@ package xorm

import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"

"github.com/go-xorm/core"
@@ -34,8 +32,7 @@ func (session *Session) CreateTable(bean interface{}) error {
}

func (session *Session) createTable(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}

@@ -54,8 +51,7 @@ func (session *Session) CreateIndexes(bean interface{}) error {
}

func (session *Session) createIndexes(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}

@@ -78,8 +74,7 @@ func (session *Session) CreateUniques(bean interface{}) error {
}

func (session *Session) createUniques(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}

@@ -103,8 +98,7 @@ func (session *Session) DropIndexes(bean interface{}) error {
}

func (session *Session) dropIndexes(bean interface{}) error {
v := rValue(bean)
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return err
}

@@ -128,11 +122,7 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
}

func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName, err := session.engine.tableName(beanOrTableName)
if err != nil {
return err
}

tableName := session.engine.tbNameNoSchema(beanOrTableName)
var needDrop = true
if !session.engine.dialect.SupportDropIfExists() {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
@@ -144,8 +134,8 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
}

if needDrop {
sqlStr := session.engine.Dialect().DropTableSql(tableName)
_, err = session.exec(sqlStr)
sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true))
_, err := session.exec(sqlStr)
return err
}
return nil
@@ -157,10 +147,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
defer session.Close()
}

tableName, err := session.engine.tableName(beanOrTableName)
if err != nil {
return false, err
}
tableName := session.engine.tbNameNoSchema(beanOrTableName)

return session.isTableExist(tableName)
}
@@ -173,24 +160,15 @@ func (session *Session) isTableExist(tableName string) (bool, error) {

// IsTableEmpty if table have any records
func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
v := rValue(bean)
t := v.Type()

if t.Kind() == reflect.String {
if session.isAutoClose {
defer session.Close()
}
return session.isTableEmpty(bean.(string))
} else if t.Kind() == reflect.Struct {
rows, err := session.Count(bean)
return rows == 0, err
if session.isAutoClose {
defer session.Close()
}
return false, errors.New("bean should be a struct or struct's point")
return session.isTableEmpty(session.engine.tbNameNoSchema(bean))
}

func (session *Session) isTableEmpty(tableName string) (bool, error) {
var total int64
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName))
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true)))
err := session.queryRow(sqlStr).Scan(&total)
if err != nil {
if err == sql.ErrNoRows {
@@ -270,7 +248,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
return err
}
structTables = append(structTables, table)
var tbName = session.tbNameNoSchema(table)
tbName := session.tbNameNoSchema(table)
tbNameWithSchema := engine.TableName(tbName, true)

var oriTable *core.Table
for _, tb := range tables {
@@ -315,32 +294,32 @@ func (session *Session) Sync2(beans ...interface{}) error {
if engine.dialect.DBType() == core.MYSQL ||
engine.dialect.DBType() == core.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbName, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
} else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbName, col.Name, curType, expectedType)
tbNameWithSchema, col.Name, curType, expectedType)
}
} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
if engine.dialect.DBType() == core.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
}
}
} else {
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
tbName, col.Name, curType, expectedType)
tbNameWithSchema, col.Name, curType, expectedType)
}
}
} else if expectedType == core.Varchar {
if engine.dialect.DBType() == core.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
}
}
}
@@ -354,7 +333,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
} else {
session.statement.RefTable = table
session.statement.tableName = tbName
session.statement.tableName = tbNameWithSchema
err = session.addColumn(col.Name)
}
if err != nil {
@@ -377,7 +356,7 @@ func (session *Session) Sync2(beans ...interface{}) error {

if oriIndex != nil {
if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSql(tbName, oriIndex)
sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex)
_, err = session.exec(sql)
if err != nil {
return err
@@ -393,7 +372,7 @@ func (session *Session) Sync2(beans ...interface{}) error {

for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSql(tbName, index2)
sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2)
_, err = session.exec(sql)
if err != nil {
return err
@@ -404,12 +383,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name, index := range addedNames {
if index.Type == core.UniqueType {
session.statement.RefTable = table
session.statement.tableName = tbName
err = session.addUnique(tbName, name)
session.statement.tableName = tbNameWithSchema
err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == core.IndexType {
session.statement.RefTable = table
session.statement.tableName = tbName
err = session.addIndex(tbName, name)
session.statement.tableName = tbNameWithSchema
err = session.addIndex(tbNameWithSchema, name)
}
if err != nil {
return err
@@ -434,7 +413,7 @@ func (session *Session) Sync2(beans ...interface{}) error {

for _, colName := range table.ColumnsSeq() {
if oriTable.GetColumn(colName) == nil {
engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName)
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(table.Name, true), colName)
}
}
}


+ 1
- 1
session_stats_test.go View File

@@ -153,7 +153,7 @@ func TestSQLCount(t *testing.T) {

assertSync(t, new(UserinfoCount2), new(UserinfoBooks))

total, err := testEngine.SQL("SELECT count(id) FROM userinfo_count2").
total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableName("userinfo_count2", true)).
Count()
assert.NoError(t, err)
assert.EqualValues(t, 0, total)


+ 19
- 81
session_tx_test.go View File

@@ -32,45 +32,21 @@ func TestTransaction(t *testing.T) {
defer session.Close()

err := session.Begin()
if err != nil {
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)

user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)

user2 := Userinfo{Username: "yyy"}
_, err = session.Where("(id) = ?", 0).Update(&user2)
if err != nil {
session.Rollback()
fmt.Println(err)
//t.Error(err)
return
}
assert.NoError(t, err)

_, err = session.Delete(&user2)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)

err = session.Commit()
if err != nil {
t.Error(err)
panic(err)
return
}
// panic(err) !nashtsai! should remove this
assert.NoError(t, err)
}

func TestCombineTransaction(t *testing.T) {
@@ -91,38 +67,21 @@ func TestCombineTransaction(t *testing.T) {
defer session.Close()

err := session.Begin()
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
}
assert.NoError(t, err)

user2 := Userinfo{Username: "zzz"}
_, err = session.Where("id = ?", 0).Update(&user2)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
}
assert.NoError(t, err)

_, err = session.Exec("delete from userinfo where username = ?", user2.Username)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
}
_, err = session.Exec("delete from "+testEngine.TableName("userinfo", true)+" where username = ?", user2.Username)
assert.NoError(t, err)

err = session.Commit()
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
}

func TestCombineTransactionSameMapper(t *testing.T) {
@@ -148,45 +107,24 @@ func TestCombineTransactionSameMapper(t *testing.T) {

counter()
defer counter()

session := testEngine.NewSession()
defer session.Close()

err := session.Begin()
if err != nil {
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)

user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)

user2 := Userinfo{Username: "zzz"}
_, err = session.Where("(id) = ?", 0).Update(&user2)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)

_, err = session.Exec("delete from `Userinfo` where `Username` = ?", user2.Username)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
return
}
_, err = session.Exec("delete from "+testEngine.TableName("`Userinfo`", true)+" where `Username` = ?", user2.Username)
assert.NoError(t, err)

err = session.Commit()
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
}

+ 1
- 1
session_update.go View File

@@ -167,7 +167,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var isMap = t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct
if isStruct {
if err := session.statement.setRefValue(v); err != nil {
if err := session.statement.setRefBean(bean); err != nil {
return 0, err
}



+ 27
- 109
session_update_test.go View File

@@ -462,30 +462,18 @@ func TestUpdate1(t *testing.T) {

col1 := &UpdateAllCols{Ptr: &s}
err = testEngine.Sync(col1)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

_, err = testEngine.Insert(col1)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

col2 := &UpdateAllCols{col1.Id, true, "", nil}
_, err = testEngine.ID(col2.Id).AllCols().Update(col2)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

col3 := &UpdateAllCols{}
has, err = testEngine.ID(col2.Id).Get(col3)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

if !has {
err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))
@@ -759,7 +747,7 @@ func TestUpdateUpdated(t *testing.T) {
func TestUpdateSameMapper(t *testing.T) {
assert.NoError(t, prepareEngine())

oldMapper := testEngine.GetColumnMapper()
oldMapper := testEngine.GetTableMapper()
testEngine.UnMapType(rValue(new(Userinfo)).Type())
testEngine.UnMapType(rValue(new(Condi)).Type())
testEngine.UnMapType(rValue(new(Article)).Type())
@@ -786,81 +774,38 @@ func TestUpdateSameMapper(t *testing.T) {

var ori Userinfo
has, err := testEngine.Get(&ori)
if err != nil {
t.Error(err)
panic(err)
}
if !has {
t.Error(errors.New("not exist"))
panic(errors.New("not exist"))
}
assert.NoError(t, err)
assert.True(t, has)

// update by id
user := Userinfo{Username: "xxx", Height: 1.2}
cnt, err := testEngine.ID(ori.Uid).Update(&user)
if err != nil {
t.Error(err)
panic(err)
}
if cnt != 1 {
err = errors.New("update not returned 1")
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)

condi := Condi{"Username": "zzz", "Departname": ""}
cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi)
if err != nil {
t.Error(err)
panic(err)
}

if cnt != 1 {
err = errors.New("update not returned 1")
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)

cnt, err = testEngine.Update(&Userinfo{Username: "yyy"}, &user)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

total, err := testEngine.Count(&user)
if err != nil {
t.Error(err)
panic(err)
}

if cnt != total {
err = errors.New("insert not returned 1")
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)
assert.EqualValues(t, cnt, total)

err = testEngine.Sync(&Article{})
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

defer func() {
err = testEngine.DropTables(&Article{})
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
}()

a := &Article{0, "1", "2", "3", "4", "5", 2}
cnt, err = testEngine.Insert(a)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

if cnt != 1 {
err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt))
@@ -875,10 +820,7 @@ func TestUpdateSameMapper(t *testing.T) {
}

cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"})
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

if cnt != 1 {
err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt))
@@ -889,30 +831,18 @@ func TestUpdateSameMapper(t *testing.T) {

col1 := &UpdateAllCols{}
err = testEngine.Sync(col1)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

_, err = testEngine.Insert(col1)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

col2 := &UpdateAllCols{col1.Id, true, "", nil}
_, err = testEngine.ID(col2.Id).AllCols().Update(col2)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

col3 := &UpdateAllCols{}
has, err = testEngine.ID(col2.Id).Get(col3)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

if !has {
err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))
@@ -931,32 +861,20 @@ func TestUpdateSameMapper(t *testing.T) {
{
col1 := &UpdateMustCols{}
err = testEngine.Sync(col1)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

_, err = testEngine.Insert(col1)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

col2 := &UpdateMustCols{col1.Id, true, ""}
boolStr := testEngine.GetColumnMapper().Obj2Table("Bool")
stringStr := testEngine.GetColumnMapper().Obj2Table("String")
_, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

col3 := &UpdateMustCols{}
has, err := testEngine.ID(col2.Id).Get(col3)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

if !has {
err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))


+ 42
- 76
statement.go View File

@@ -221,26 +221,18 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
if err != nil {
return err
}
statement.tableName = statement.Engine.tbName(v)
statement.tableName = statement.Engine.TableName(v.Interface(), true)
return nil
}

// Table tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
v := rValue(tableNameOrBean)
t := v.Type()
if t.Kind() == reflect.String {
statement.AltTableName = tableNameOrBean.(string)
} else if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
statement.AltTableName = statement.Engine.tbName(v)
func (statement *Statement) setRefBean(bean interface{}) error {
var err error
statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
if err != nil {
return err
}
return statement
statement.tableName = statement.Engine.TableName(bean, true)
return nil
}

// Auto generating update columnes and values according a struct
@@ -743,6 +735,23 @@ func (statement *Statement) Asc(colNames ...string) *Statement {
return statement
}

// Table tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
v := rValue(tableNameOrBean)
t := v.Type()
if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
}

statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
return statement
}

// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
var buf bytes.Buffer
@@ -752,56 +761,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
fmt.Fprintf(&buf, "%v JOIN ", joinOP)
}

switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
} else if len(t) == 1 {
fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case TableName:
table = f.(TableName).TableName()
default:
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(&buf, statement.Engine.tbName(v))
} else {
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", f)))
}
}
}
if l > 1 {
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
fmt.Fprintf(&buf, statement.Engine.Quote(table))
}
case TableName:
fmt.Fprintf(&buf, tablename.(TableName).TableName())
case string:
fmt.Fprintf(&buf, tablename.(string))
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(&buf, statement.Engine.tbName(v))
} else {
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
}
}
tbName := statement.Engine.TableName(tablename, true)

fmt.Fprintf(&buf, " ON %v", condition)
fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...)
return statement
@@ -876,11 +838,13 @@ func (statement *Statement) genCreateTableSQL() string {
func (statement *Statement) genIndexSQL() []string {
var sqls []string
tbName := statement.TableName()
quote := statement.Engine.Quote
for idxName, index := range statement.RefTable.Indexes {
for _, index := range statement.RefTable.Indexes {
if index.Type == core.IndexType {
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)),
quote(tbName), quote(strings.Join(index.Cols, quote(","))))
sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
/*idxTBName := strings.Replace(tbName, ".", "_", -1)
idxTBName = strings.Replace(idxTBName, `"`, "", -1)
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/
sqls = append(sqls, sql)
}
}
@@ -906,16 +870,18 @@ func (statement *Statement) genUniqueSQL() []string {
func (statement *Statement) genDelIndexSQL() []string {
var sqls []string
tbName := statement.TableName()
idxPrefixName := strings.Replace(tbName, `"`, "", -1)
idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
for idxName, index := range statement.RefTable.Indexes {
var rIdxName string
if index.Type == core.UniqueType {
rIdxName = uniqueName(tbName, idxName)
rIdxName = uniqueName(idxPrefixName, idxName)
} else if index.Type == core.IndexType {
rIdxName = indexName(tbName, idxName)
rIdxName = indexName(idxPrefixName, idxName)
}
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName))
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
if statement.Engine.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName()))
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
}
sqls = append(sqls, sql)
}
@@ -966,7 +932,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
v := rValue(bean)
isStruct := v.Kind() == reflect.Struct
if isStruct {
statement.setRefValue(v)
statement.setRefBean(bean)
}

var columnStr = statement.ColumnStr
@@ -1018,7 +984,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
var condArgs []interface{}
var err error
if len(beans) > 0 {
statement.setRefValue(rValue(beans[0]))
statement.setRefBean(beans[0])
condSQL, condArgs, err = statement.genConds(beans[0])
} else {
condSQL, condArgs, err = builder.ToSQL(statement.cond)
@@ -1044,7 +1010,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
}

func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
statement.setRefValue(rValue(bean))
statement.setRefBean(bean)

var sumStrs = make([]string, 0, len(columns))
for _, colName := range columns {


+ 27
- 36
tag_extends_test.go View File

@@ -202,17 +202,14 @@ func TestExtends(t *testing.T) {

var info UserAndDetail
qt := testEngine.Quote
ui := testEngine.GetTableMapper().Obj2Table("Userinfo")
ud := testEngine.GetTableMapper().Obj2Table("Userdetail")
uiid := testEngine.GetTableMapper().Obj2Table("Id")
ui := testEngine.TableName(new(Userinfo), true)
ud := testEngine.TableName(&detail, true)
uiid := testEngine.GetColumnMapper().Obj2Table("Id")
udid := "detail_id"
sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s",
qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid))
b, err := testEngine.SQL(sql).NoCascade().Get(&info)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
if !b {
err = errors.New("should has lest one record")
t.Error(err)
@@ -341,19 +338,17 @@ func TestExtends2(t *testing.T) {
}

var mapper = testEngine.GetTableMapper().Obj2Table
userTableName := mapper("MessageUser")
typeTableName := mapper("MessageType")
msgTableName := mapper("Message")
var quote = testEngine.Quote
userTableName := quote(testEngine.TableName(mapper("MessageUser"), true))
typeTableName := quote(testEngine.TableName(mapper("MessageType"), true))
msgTableName := quote(testEngine.TableName(mapper("Message"), true))

list := make([]Message, 0)
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`").
Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`").
Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`").
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`").
Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
Find(&list)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

if len(list) != 1 {
err = errors.New(fmt.Sprintln("should have 1 message, got", len(list)))
@@ -406,25 +401,20 @@ func TestExtends3(t *testing.T) {
assert.NoError(t, err)
}
_, err = testEngine.Insert(&msg)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

var mapper = testEngine.GetTableMapper().Obj2Table
userTableName := mapper("MessageUser")
typeTableName := mapper("MessageType")
msgTableName := mapper("Message")
var quote = testEngine.Quote
userTableName := quote(testEngine.TableName(mapper("MessageUser"), true))
typeTableName := quote(testEngine.TableName(mapper("MessageType"), true))
msgTableName := quote(testEngine.TableName(mapper("Message"), true))

list := make([]MessageExtend3, 0)
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`").
Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`").
Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`").
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`").
Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
Find(&list)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)

if len(list) != 1 {
err = errors.New(fmt.Sprintln("should have 1 message, got", len(list)))
@@ -499,13 +489,14 @@ func TestExtends4(t *testing.T) {
}

var mapper = testEngine.GetTableMapper().Obj2Table
userTableName := mapper("MessageUser")
typeTableName := mapper("MessageType")
msgTableName := mapper("Message")
var quote = testEngine.Quote
userTableName := quote(testEngine.TableName(mapper("MessageUser"), true))
typeTableName := quote(testEngine.TableName(mapper("MessageType"), true))
msgTableName := quote(testEngine.TableName(mapper("Message"), true))

list := make([]MessageExtend4, 0)
err = testEngine.Table(msgTableName).Join("LEFT", userTableName, "`"+userTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`").
Join("LEFT", typeTableName, "`"+typeTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`").
err = testEngine.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`").
Find(&list)
if err != nil {
t.Error(err)


+ 3
- 2
types_test.go View File

@@ -301,10 +301,11 @@ type UserCus struct {
func TestCustomType2(t *testing.T) {
assert.NoError(t, prepareEngine())

err := testEngine.CreateTables(&UserCus{})
var uc UserCus
err := testEngine.CreateTables(&uc)
assert.NoError(t, err)

tableName := testEngine.GetTableMapper().Obj2Table("UserCus")
tableName := testEngine.TableName(&uc, true)
_, err = testEngine.Exec("delete from " + testEngine.Quote(tableName))
assert.NoError(t, err)



+ 4
- 1
xorm_test.go View File

@@ -27,6 +27,7 @@ var (
cache = flag.Bool("cache", false, "if enable cache")
cluster = flag.Bool("cluster", false, "if this is a cluster")
splitter = flag.String("splitter", ";", "the splitter on connstr for cluster")
schema = flag.String("schema", "", "specify the schema")
)

func createEngine(dbType, connStr string) error {
@@ -35,7 +36,6 @@ func createEngine(dbType, connStr string) error {

if !*cluster {
testEngine, err = NewEngine(dbType, connStr)

} else {
testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter))
}
@@ -43,6 +43,9 @@ func createEngine(dbType, connStr string) error {
return err
}

if *schema != "" {
testEngine.SetSchema(*schema)
}
testEngine.ShowSQL(*showSQL)
testEngine.SetLogLevel(core.LOG_DEBUG)
if *cache {


Loading…
Cancel
Save