Merge core package back into the main repository and split into serval sub packages. #1543

Merged
lunny merged 4 commits from lunny/merge_core2 into master 2020-02-24 08:53:22 +00:00
87 changed files with 4457 additions and 3199 deletions

2
.gitignore vendored
View File

@ -32,3 +32,5 @@ xorm.test
test.db.sql test.db.sql
.idea/ .idea/
*coverage.out

View File

@ -8,6 +8,8 @@ import (
"testing" "testing"
"time" "time"
"xorm.io/xorm/caches"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -21,7 +23,7 @@ func TestCacheFind(t *testing.T) {
} }
oldCacher := testEngine.GetDefaultCacher() oldCacher := testEngine.GetDefaultCacher()
cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)
testEngine.SetDefaultCacher(cacher) testEngine.SetDefaultCacher(cacher)
assert.NoError(t, testEngine.Sync2(new(MailBox))) assert.NoError(t, testEngine.Sync2(new(MailBox)))
@ -96,7 +98,7 @@ func TestCacheFind2(t *testing.T) {
} }
oldCacher := testEngine.GetDefaultCacher() oldCacher := testEngine.GetDefaultCacher()
cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)
testEngine.SetDefaultCacher(cacher) testEngine.SetDefaultCacher(cacher)
assert.NoError(t, testEngine.Sync2(new(MailBox2))) assert.NoError(t, testEngine.Sync2(new(MailBox2)))
@ -147,7 +149,7 @@ func TestCacheGet(t *testing.T) {
} }
oldCacher := testEngine.GetDefaultCacher() oldCacher := testEngine.GetDefaultCacher()
cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)
testEngine.SetDefaultCacher(cacher) testEngine.SetDefaultCacher(cacher)
assert.NoError(t, testEngine.Sync2(new(MailBox3))) assert.NoError(t, testEngine.Sync2(new(MailBox3)))

99
caches/cache.go Normal file
View File

@ -0,0 +1,99 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package caches
import (
"bytes"
"encoding/gob"
"errors"
"fmt"
"strings"
"time"
"xorm.io/xorm/schemas"
)
const (
// CacheExpired is default cache expired time
CacheExpired = 60 * time.Minute
// CacheMaxMemory is not use now
CacheMaxMemory = 256
// CacheGcInterval represents interval time to clear all expired nodes
CacheGcInterval = 10 * time.Minute
// CacheGcMaxRemoved represents max nodes removed when gc
CacheGcMaxRemoved = 20
)
// list all the errors
var (
ErrCacheMiss = errors.New("xorm/cache: key not found")
ErrNotStored = errors.New("xorm/cache: not stored")
// ErrNotExist record does not exist error
ErrNotExist = errors.New("Record does not exist")
)
// CacheStore is a interface to store cache
type CacheStore interface {
// key is primary key or composite primary key
// value is struct's pointer
// key format : <tablename>-p-<pk1>-<pk2>...
Put(key string, value interface{}) error
Get(key string) (interface{}, error)
Del(key string) error
}
// Cacher is an interface to provide cache
// id format : u-<pk1>-<pk2>...
type Cacher interface {
GetIds(tableName, sql string) interface{}
GetBean(tableName string, id string) interface{}
PutIds(tableName, sql string, ids interface{})
PutBean(tableName string, id string, obj interface{})
DelIds(tableName, sql string)
DelBean(tableName string, id string)
ClearIds(tableName string)
ClearBeans(tableName string)
}
func encodeIds(ids []schemas.PK) (string, error) {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
err := enc.Encode(ids)
return buf.String(), err
}
func decodeIds(s string) ([]schemas.PK, error) {
pks := make([]schemas.PK, 0)
dec := gob.NewDecoder(strings.NewReader(s))
err := dec.Decode(&pks)
return pks, err
}
// GetCacheSql returns cacher PKs via SQL
func GetCacheSql(m Cacher, tableName, sql string, args interface{}) ([]schemas.PK, error) {
bytes := m.GetIds(tableName, GenSqlKey(sql, args))
if bytes == nil {
return nil, errors.New("Not Exist")
}
return decodeIds(bytes.(string))
}
// PutCacheSql puts cacher SQL and PKs
func PutCacheSql(m Cacher, ids []schemas.PK, tableName, sql string, args interface{}) error {
bytes, err := encodeIds(ids)
if err != nil {
return err
}
m.PutIds(tableName, GenSqlKey(sql, args), bytes)
return nil
}
// GenSqlKey generates cache key
func GenSqlKey(sql string, args interface{}) string {
return fmt.Sprintf("%v-%v", sql, args)
}

View File

@ -2,15 +2,13 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package caches
import ( import (
"container/list" "container/list"
"fmt" "fmt"
"sync" "sync"
"time" "time"
"xorm.io/core"
) )
// LRUCacher implments cache object facilities // LRUCacher implments cache object facilities
@ -19,7 +17,7 @@ type LRUCacher struct {
sqlList *list.List sqlList *list.List
idIndex map[string]map[string]*list.Element idIndex map[string]map[string]*list.Element
sqlIndex map[string]map[string]*list.Element sqlIndex map[string]map[string]*list.Element
store core.CacheStore store CacheStore
mutex sync.Mutex mutex sync.Mutex
MaxElementSize int MaxElementSize int
Expired time.Duration Expired time.Duration
@ -27,15 +25,15 @@ type LRUCacher struct {
} }
// NewLRUCacher creates a cacher // NewLRUCacher creates a cacher
func NewLRUCacher(store core.CacheStore, maxElementSize int) *LRUCacher { func NewLRUCacher(store CacheStore, maxElementSize int) *LRUCacher {
return NewLRUCacher2(store, 3600*time.Second, maxElementSize) return NewLRUCacher2(store, 3600*time.Second, maxElementSize)
} }
// NewLRUCacher2 creates a cache include different params // NewLRUCacher2 creates a cache include different params
func NewLRUCacher2(store core.CacheStore, expired time.Duration, maxElementSize int) *LRUCacher { func NewLRUCacher2(store CacheStore, expired time.Duration, maxElementSize int) *LRUCacher {
cacher := &LRUCacher{store: store, idList: list.New(), cacher := &LRUCacher{store: store, idList: list.New(),
sqlList: list.New(), Expired: expired, sqlList: list.New(), Expired: expired,
GcInterval: core.CacheGcInterval, MaxElementSize: maxElementSize, GcInterval: CacheGcInterval, MaxElementSize: maxElementSize,
sqlIndex: make(map[string]map[string]*list.Element), sqlIndex: make(map[string]map[string]*list.Element),
idIndex: make(map[string]map[string]*list.Element), idIndex: make(map[string]map[string]*list.Element),
} }
@ -57,7 +55,7 @@ func (m *LRUCacher) GC() {
defer m.mutex.Unlock() defer m.mutex.Unlock()
var removedNum int var removedNum int
for e := m.idList.Front(); e != nil; { for e := m.idList.Front(); e != nil; {
if removedNum <= core.CacheGcMaxRemoved && if removedNum <= CacheGcMaxRemoved &&
time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired { time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired {
removedNum++ removedNum++
next := e.Next() next := e.Next()
@ -71,7 +69,7 @@ func (m *LRUCacher) GC() {
removedNum = 0 removedNum = 0
for e := m.sqlList.Front(); e != nil; { for e := m.sqlList.Front(); e != nil; {
if removedNum <= core.CacheGcMaxRemoved && if removedNum <= CacheGcMaxRemoved &&
time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired { time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired {
removedNum++ removedNum++
next := e.Next() next := e.Next()
@ -268,11 +266,11 @@ type sqlNode struct {
} }
func genSQLKey(sql string, args interface{}) string { func genSQLKey(sql string, args interface{}) string {
return fmt.Sprintf("%v-%v", sql, args) return fmt.Sprintf("%s-%v", sql, args)
} }
func genID(prefix string, id string) string { func genID(prefix string, id string) string {
return fmt.Sprintf("%v-%v", prefix, id) return fmt.Sprintf("%s-%s", prefix, id)
} }
func newIDNode(tbName string, id string) *idNode { func newIDNode(tbName string, id string) *idNode {

View File

@ -2,13 +2,13 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package caches
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
func TestLRUCache(t *testing.T) { func TestLRUCache(t *testing.T) {
@ -20,7 +20,7 @@ func TestLRUCache(t *testing.T) {
cacher := NewLRUCacher(store, 10000) cacher := NewLRUCacher(store, 10000)
tableName := "cache_object1" tableName := "cache_object1"
pks := []core.PK{ pks := []schemas.PK{
{1}, {1},
{2}, {2},
} }

View File

@ -2,15 +2,13 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package caches
import ( import (
"sync" "sync"
"xorm.io/core"
) )
var _ core.CacheStore = NewMemoryStore() var _ CacheStore = NewMemoryStore()
// MemoryStore represents in-memory store // MemoryStore represents in-memory store
type MemoryStore struct { type MemoryStore struct {

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package caches
import ( import (
"testing" "testing"

View File

@ -346,3 +346,10 @@ func asBool(bs []byte) (bool, error) {
} }
return strconv.ParseBool(string(bs)) return strconv.ParseBool(string(bs))
} }
// Conversion is an interface. A type implements Conversion will according
// the custom method to fill into database and retrieve from database.
type Conversion interface {
FromDB([]byte) error
ToDB() ([]byte, error)
}

229
core/db.go Normal file
View File

@ -0,0 +1,229 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"sync"
"xorm.io/xorm/names"
)
var (
// DefaultCacheSize sets the default cache size
DefaultCacheSize = 200
)
func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
vv := reflect.ValueOf(mp)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return "", []interface{}{}, ErrNoMapPointer
}
args := make([]interface{}, 0, len(vv.Elem().MapKeys()))
var err error
query = re.ReplaceAllStringFunc(query, func(src string) string {
v := vv.Elem().MapIndex(reflect.ValueOf(src[1:]))
if !v.IsValid() {
err = fmt.Errorf("map key %s is missing", src[1:])
} else {
args = append(args, v.Interface())
}
return "?"
})
return query, args, err
}
func StructToSlice(query string, st interface{}) (string, []interface{}, error) {
vv := reflect.ValueOf(st)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return "", []interface{}{}, ErrNoStructPointer
}
args := make([]interface{}, 0)
var err error
query = re.ReplaceAllStringFunc(query, func(src string) string {
fv := vv.Elem().FieldByName(src[1:]).Interface()
if v, ok := fv.(driver.Valuer); ok {
var value driver.Value
value, err = v.Value()
if err != nil {
return "?"
}
args = append(args, value)
} else {
args = append(args, fv)
}
return "?"
})
if err != nil {
return "", []interface{}{}, err
}
return query, args, nil
}
type cacheStruct struct {
value reflect.Value
idx int
}
// DB is a wrap of sql.DB with extra contents
type DB struct {
*sql.DB
Mapper names.Mapper
reflectCache map[reflect.Type]*cacheStruct
reflectCacheMutex sync.RWMutex
}
// Open opens a database
func Open(driverName, dataSourceName string) (*DB, error) {
db, err := sql.Open(driverName, dataSourceName)
if err != nil {
return nil, err
}
return &DB{
DB: db,
Mapper: names.NewCacheMapper(&names.SnakeMapper{}),
reflectCache: make(map[reflect.Type]*cacheStruct),
}, nil
}
// FromDB creates a DB from a sql.DB
func FromDB(db *sql.DB) *DB {
return &DB{
DB: db,
Mapper: names.NewCacheMapper(&names.SnakeMapper{}),
reflectCache: make(map[reflect.Type]*cacheStruct),
}
}
func (db *DB) reflectNew(typ reflect.Type) reflect.Value {
db.reflectCacheMutex.Lock()
defer db.reflectCacheMutex.Unlock()
cs, ok := db.reflectCache[typ]
if !ok || cs.idx+1 > DefaultCacheSize-1 {
cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0}
db.reflectCache[typ] = cs
} else {
cs.idx = cs.idx + 1
}
return cs.value.Index(cs.idx).Addr()
}
// QueryContext overwrites sql.DB.QueryContext
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
rows, err := db.DB.QueryContext(ctx, query, args...)
if err != nil {
if rows != nil {
rows.Close()
}
return nil, err
}
return &Rows{rows, db}, nil
}
// Query overwrites sql.DB.Query
func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
return db.QueryContext(context.Background(), query, args...)
}
// QueryMapContext executes query with parameters via map and context
func (db *DB) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) {
query, args, err := MapToSlice(query, mp)
if err != nil {
return nil, err
}
return db.QueryContext(ctx, query, args...)
}
// QueryMap executes query with parameters via map
func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
return db.QueryMapContext(context.Background(), query, mp)
}
func (db *DB) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) {
query, args, err := StructToSlice(query, st)
if err != nil {
return nil, err
}
return db.QueryContext(ctx, query, args...)
}
func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
return db.QueryStructContext(context.Background(), query, st)
}
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return &Row{nil, err}
}
return &Row{rows, nil}
}
func (db *DB) QueryRow(query string, args ...interface{}) *Row {
return db.QueryRowContext(context.Background(), query, args...)
}
func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row {
query, args, err := MapToSlice(query, mp)
if err != nil {
return &Row{nil, err}
}
return db.QueryRowContext(ctx, query, args...)
}
func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
return db.QueryRowMapContext(context.Background(), query, mp)
}
func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row {
query, args, err := StructToSlice(query, st)
if err != nil {
return &Row{nil, err}
}
return db.QueryRowContext(ctx, query, args...)
}
func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
return db.QueryRowStructContext(context.Background(), query, st)
}
var (
re = regexp.MustCompile(`[?](\w+)`)
)
// ExecMapContext exec map with context.Context
// insert into (name) values (?)
// insert into (name) values (?name)
func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) {
query, args, err := MapToSlice(query, mp)
if err != nil {
return nil, err
}
return db.DB.ExecContext(ctx, query, args...)
}
func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) {
return db.ExecMapContext(context.Background(), query, mp)
}
func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) {
query, args, err := StructToSlice(query, st)
if err != nil {
return nil, err
}
return db.DB.ExecContext(ctx, query, args...)
}
func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
return db.ExecStructContext(context.Background(), query, st)
}

684
core/db_test.go Normal file
View File

@ -0,0 +1,684 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import (
"errors"
"flag"
"os"
"testing"
"time"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"xorm.io/xorm/names"
)
var (
dbtype = flag.String("dbtype", "mysql", "database type")
dbConn = flag.String("dbConn", "root:@/core_test?charset=utf8", "database connect string")
createTableSql string
)
func TestMain(m *testing.M) {
flag.Parse()
switch *dbtype {
case "sqlite3":
createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " +
"`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);"
case "mysql":
fallthrough
default:
createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " +
"`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);"
}
exitCode := m.Run()
os.Exit(exitCode)
}
func testOpen() (*DB, error) {
switch *dbtype {
case "sqlite3":
os.Remove("./test.db")
return Open("sqlite3", "./test.db")
case "mysql":
return Open("mysql", *dbConn)
default:
panic("no db type")
}
}
func BenchmarkOriQuery(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
for rows.Next() {
var Id int64
var Name, Title, Alias, NickName string
var Age float32
var Created NullTime
err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName, &Created)
if err != nil {
b.Error(err)
}
//fmt.Println(Id, Name, Title, Age, Alias, NickName)
}
rows.Close()
}
}
type User struct {
Id int64
Name string
Title string
Age float32
Alias string
NickName string
Created NullTime
}
func BenchmarkStructQuery(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
for rows.Next() {
var user User
err = rows.ScanStructByIndex(&user)
if err != nil {
b.Error(err)
}
if user.Name != "xlw" {
b.Log(user)
b.Error(errors.New("name should be xlw"))
}
}
rows.Close()
}
}
func BenchmarkStruct2Query(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
db.Mapper = names.NewCacheMapper(&names.SnakeMapper{})
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
for rows.Next() {
var user User
err = rows.ScanStructByName(&user)
if err != nil {
b.Error(err)
}
if user.Name != "xlw" {
b.Log(user)
b.Error(errors.New("name should be xlw"))
}
}
rows.Close()
}
}
func BenchmarkSliceInterfaceQuery(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
cols, err := rows.Columns()
if err != nil {
b.Error(err)
}
for rows.Next() {
slice := make([]interface{}, len(cols))
err = rows.ScanSlice(&slice)
if err != nil {
b.Error(err)
}
b.Log(slice)
switch slice[1].(type) {
case *string:
if *slice[1].(*string) != "xlw" {
b.Error(errors.New("name should be xlw"))
}
case []byte:
if string(slice[1].([]byte)) != "xlw" {
b.Error(errors.New("name should be xlw"))
}
}
}
rows.Close()
}
}
/*func BenchmarkSliceBytesQuery(b *testing.B) {
b.StopTimer()
os.Remove("./test.db")
db, err := Open("sqlite3", "./test.db")
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
cols, err := rows.Columns()
if err != nil {
b.Error(err)
}
for rows.Next() {
slice := make([][]byte, len(cols))
err = rows.ScanSlice(&slice)
if err != nil {
b.Error(err)
}
if string(slice[1]) != "xlw" {
fmt.Println(slice)
b.Error(errors.New("name should be xlw"))
}
}
rows.Close()
}
}
*/
func BenchmarkSliceStringQuery(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name, created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
cols, err := rows.Columns()
if err != nil {
b.Error(err)
}
for rows.Next() {
slice := make([]*string, len(cols))
err = rows.ScanSlice(&slice)
if err != nil {
b.Error(err)
}
if (*slice[1]) != "xlw" {
b.Log(slice)
b.Error(errors.New("name should be xlw"))
}
}
rows.Close()
}
}
func BenchmarkMapInterfaceQuery(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
for rows.Next() {
m := make(map[string]interface{})
err = rows.ScanMap(&m)
if err != nil {
b.Error(err)
}
switch m["name"].(type) {
case string:
if m["name"].(string) != "xlw" {
b.Log(m)
b.Error(errors.New("name should be xlw"))
}
case []byte:
if string(m["name"].([]byte)) != "xlw" {
b.Log(m)
b.Error(errors.New("name should be xlw"))
}
}
}
rows.Close()
}
}
/*func BenchmarkMapBytesQuery(b *testing.B) {
b.StopTimer()
os.Remove("./test.db")
db, err := Open("sqlite3", "./test.db")
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
for rows.Next() {
m := make(map[string][]byte)
err = rows.ScanMap(&m)
if err != nil {
b.Error(err)
}
if string(m["name"]) != "xlw" {
fmt.Println(m)
b.Error(errors.New("name should be xlw"))
}
}
rows.Close()
}
}
*/
/*
func BenchmarkMapStringQuery(b *testing.B) {
b.StopTimer()
os.Remove("./test.db")
db, err := Open("sqlite3", "./test.db")
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
for rows.Next() {
m := make(map[string]string)
err = rows.ScanMap(&m)
if err != nil {
b.Error(err)
}
if m["name"] != "xlw" {
fmt.Println(m)
b.Error(errors.New("name should be xlw"))
}
}
rows.Close()
}
}*/
func BenchmarkExec(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
b.StartTimer()
for i := 0; i < b.N; i++ {
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
}
func BenchmarkExecMap(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
b.StartTimer()
mp := map[string]interface{}{
"name": "xlw",
"title": "tester",
"age": 1.2,
"alias": "lunny",
"nick_name": "lunny xiao",
"created": time.Now(),
}
for i := 0; i < b.N; i++ {
_, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name, created) "+
"values (?name,?title,?age,?alias,?nick_name,?created)",
&mp)
if err != nil {
b.Error(err)
}
}
}
func TestExecMap(t *testing.T) {
db, err := testOpen()
if err != nil {
t.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
t.Error(err)
}
mp := map[string]interface{}{
"name": "xlw",
"title": "tester",
"age": 1.2,
"alias": "lunny",
"nick_name": "lunny xiao",
"created": time.Now(),
}
_, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name,created) "+
"values (?name,?title,?age,?alias,?nick_name,?created)",
&mp)
if err != nil {
t.Error(err)
}
rows, err := db.Query("select * from user")
if err != nil {
t.Error(err)
}
for rows.Next() {
var user User
err = rows.ScanStructByName(&user)
if err != nil {
t.Error(err)
}
t.Log("--", user)
}
}
func TestExecStruct(t *testing.T) {
db, err := testOpen()
if err != nil {
t.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
t.Error(err)
}
user := User{Name: "xlw",
Title: "tester",
Age: 1.2,
Alias: "lunny",
NickName: "lunny xiao",
Created: NullTime(time.Now()),
}
_, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+
"values (?Name,?Title,?Age,?Alias,?NickName,?Created)",
&user)
if err != nil {
t.Error(err)
}
rows, err := db.QueryStruct("select * from user where `name` = ?Name", &user)
if err != nil {
t.Error(err)
}
for rows.Next() {
var user User
err = rows.ScanStructByName(&user)
if err != nil {
t.Error(err)
}
t.Log("1--", user)
}
}
func BenchmarkExecStruct(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
b.StartTimer()
user := User{Name: "xlw",
Title: "tester",
Age: 1.2,
Alias: "lunny",
NickName: "lunny xiao",
Created: NullTime(time.Now()),
}
for i := 0; i < b.N; i++ {
_, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+
"values (?Name,?Title,?Age,?Alias,?NickName,?Created)",
&user)
if err != nil {
b.Error(err)
}
}
}

14
core/error.go Normal file
View File

@ -0,0 +1,14 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import "errors"
var (
// ErrNoMapPointer represents error when no map pointer
ErrNoMapPointer = errors.New("mp should be a map's pointer")
// ErrNoStructPointer represents error when no struct pointer
ErrNoStructPointer = errors.New("mp should be a struct's pointer")
)

338
core/rows.go Normal file
View File

@ -0,0 +1,338 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import (
"database/sql"
"errors"
"reflect"
"sync"
)
type Rows struct {
*sql.Rows
db *DB
}
func (rs *Rows) ToMapString() ([]map[string]string, error) {
cols, err := rs.Columns()
if err != nil {
return nil, err
}
var results = make([]map[string]string, 0, 10)
for rs.Next() {
var record = make(map[string]string, len(cols))
err = rs.ScanMap(&record)
if err != nil {
return nil, err
}
results = append(results, record)
}
return results, nil
}
// scan data to a struct's pointer according field index
func (rs *Rows) ScanStructByIndex(dest ...interface{}) error {
if len(dest) == 0 {
return errors.New("at least one struct")
}
vvvs := make([]reflect.Value, len(dest))
for i, s := range dest {
vv := reflect.ValueOf(s)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return errors.New("dest should be a struct's pointer")
}
vvvs[i] = vv.Elem()
}
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
var i = 0
for _, vvv := range vvvs {
for j := 0; j < vvv.NumField(); j++ {
newDest[i] = vvv.Field(j).Addr().Interface()
i = i + 1
}
}
return rs.Rows.Scan(newDest...)
}
var (
fieldCache = make(map[reflect.Type]map[string]int)
fieldCacheMutex sync.RWMutex
)
func fieldByName(v reflect.Value, name string) reflect.Value {
t := v.Type()
fieldCacheMutex.RLock()
cache, ok := fieldCache[t]
fieldCacheMutex.RUnlock()
if !ok {
cache = make(map[string]int)
for i := 0; i < v.NumField(); i++ {
cache[t.Field(i).Name] = i
}
fieldCacheMutex.Lock()
fieldCache[t] = cache
fieldCacheMutex.Unlock()
}
if i, ok := cache[name]; ok {
return v.Field(i)
}
return reflect.Zero(t)
}
// scan data to a struct's pointer according field name
func (rs *Rows) ScanStructByName(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return errors.New("dest should be a struct's pointer")
}
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
var v EmptyScanner
for j, name := range cols {
f := fieldByName(vv.Elem(), rs.db.Mapper.Table2Obj(name))
if f.IsValid() {
newDest[j] = f.Addr().Interface()
} else {
newDest[j] = &v
}
}
return rs.Rows.Scan(newDest...)
}
// scan data to a slice's pointer, slice's length should equal to columns' number
func (rs *Rows) ScanSlice(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice {
return errors.New("dest should be a slice's pointer")
}
vvv := vv.Elem()
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
for j := 0; j < len(cols); j++ {
if j >= vvv.Len() {
newDest[j] = reflect.New(vvv.Type().Elem()).Interface()
} else {
newDest[j] = vvv.Index(j).Addr().Interface()
}
}
err = rs.Rows.Scan(newDest...)
if err != nil {
return err
}
srcLen := vvv.Len()
for i := srcLen; i < len(cols); i++ {
vvv = reflect.Append(vvv, reflect.ValueOf(newDest[i]).Elem())
}
return nil
}
// scan data to a map's pointer
func (rs *Rows) ScanMap(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return errors.New("dest should be a map's pointer")
}
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
vvv := vv.Elem()
for i := range cols {
newDest[i] = rs.db.reflectNew(vvv.Type().Elem()).Interface()
}
err = rs.Rows.Scan(newDest...)
if err != nil {
return err
}
for i, name := range cols {
vname := reflect.ValueOf(name)
vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
}
return nil
}
type Row struct {
rows *Rows
// One of these two will be non-nil:
err error // deferred error for easy chaining
}
// ErrorRow return an error row
func ErrorRow(err error) *Row {
return &Row{
err: err,
}
}
// NewRow from rows
func NewRow(rows *Rows, err error) *Row {
return &Row{rows, err}
}
func (row *Row) Columns() ([]string, error) {
if row.err != nil {
return nil, row.err
}
return row.rows.Columns()
}
func (row *Row) Scan(dest ...interface{}) error {
if row.err != nil {
return row.err
}
defer row.rows.Close()
for _, dp := range dest {
if _, ok := dp.(*sql.RawBytes); ok {
return errors.New("sql: RawBytes isn't allowed on Row.Scan")
}
}
if !row.rows.Next() {
if err := row.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
err := row.rows.Scan(dest...)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return row.rows.Close()
}
func (row *Row) ScanStructByName(dest interface{}) error {
if row.err != nil {
return row.err
}
defer row.rows.Close()
if !row.rows.Next() {
if err := row.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
err := row.rows.ScanStructByName(dest)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return row.rows.Close()
}
func (row *Row) ScanStructByIndex(dest interface{}) error {
if row.err != nil {
return row.err
}
defer row.rows.Close()
if !row.rows.Next() {
if err := row.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
err := row.rows.ScanStructByIndex(dest)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return row.rows.Close()
}
// scan data to a slice's pointer, slice's length should equal to columns' number
func (row *Row) ScanSlice(dest interface{}) error {
if row.err != nil {
return row.err
}
defer row.rows.Close()
if !row.rows.Next() {
if err := row.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
err := row.rows.ScanSlice(dest)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return row.rows.Close()
}
// scan data to a map's pointer
func (row *Row) ScanMap(dest interface{}) error {
if row.err != nil {
return row.err
}
defer row.rows.Close()
if !row.rows.Next() {
if err := row.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
err := row.rows.ScanMap(dest)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return row.rows.Close()
}
func (row *Row) ToMapString() (map[string]string, error) {
cols, err := row.Columns()
if err != nil {
return nil, err
}
var record = make(map[string]string, len(cols))
err = row.ScanMap(&record)
if err != nil {
return nil, err
}
return record, nil
}

66
core/scan.go Normal file
View File

@ -0,0 +1,66 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import (
"database/sql/driver"
"fmt"
"time"
)
type NullTime time.Time
var (
_ driver.Valuer = NullTime{}
)
func (ns *NullTime) Scan(value interface{}) error {
if value == nil {
return nil
}
return convertTime(ns, value)
}
// Value implements the driver Valuer interface.
func (ns NullTime) Value() (driver.Value, error) {
if (time.Time)(ns).IsZero() {
return nil, nil
}
return (time.Time)(ns).Format("2006-01-02 15:04:05"), nil
}
func convertTime(dest *NullTime, src interface{}) error {
// Common cases, without reflect.
switch s := src.(type) {
case string:
t, err := time.Parse("2006-01-02 15:04:05", s)
if err != nil {
return err
}
*dest = NullTime(t)
return nil
case []uint8:
t, err := time.Parse("2006-01-02 15:04:05", string(s))
if err != nil {
return err
}
*dest = NullTime(t)
return nil
case time.Time:
*dest = NullTime(s)
return nil
case nil:
default:
return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest)
}
return nil
}
type EmptyScanner struct {
}
func (EmptyScanner) Scan(src interface{}) error {
return nil
}

166
core/stmt.go Normal file
View File

@ -0,0 +1,166 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import (
"context"
"database/sql"
"errors"
"reflect"
)
// Stmt reprents a stmt objects
type Stmt struct {
*sql.Stmt
db *DB
names map[string]int
}
func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
names := make(map[string]int)
var i int
query = re.ReplaceAllStringFunc(query, func(src string) string {
names[src[1:]] = i
i += 1
return "?"
})
stmt, err := db.DB.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &Stmt{stmt, db, names}, nil
}
func (db *DB) Prepare(query string) (*Stmt, error) {
return db.PrepareContext(context.Background(), query)
}
func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, error) {
vv := reflect.ValueOf(mp)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return nil, errors.New("mp should be a map's pointer")
}
args := make([]interface{}, len(s.names))
for k, i := range s.names {
args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
}
return s.Stmt.ExecContext(ctx, args...)
}
func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
return s.ExecMapContext(context.Background(), mp)
}
func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Result, error) {
vv := reflect.ValueOf(st)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return nil, errors.New("mp should be a map's pointer")
}
args := make([]interface{}, len(s.names))
for k, i := range s.names {
args[i] = vv.Elem().FieldByName(k).Interface()
}
return s.Stmt.ExecContext(ctx, args...)
}
func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
return s.ExecStructContext(context.Background(), st)
}
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
rows, err := s.Stmt.QueryContext(ctx, args...)
if err != nil {
return nil, err
}
return &Rows{rows, s.db}, nil
}
func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return s.QueryContext(context.Background(), args...)
}
func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, error) {
vv := reflect.ValueOf(mp)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return nil, errors.New("mp should be a map's pointer")
}
args := make([]interface{}, len(s.names))
for k, i := range s.names {
args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
}
return s.QueryContext(ctx, args...)
}
func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
return s.QueryMapContext(context.Background(), mp)
}
func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, error) {
vv := reflect.ValueOf(st)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return nil, errors.New("mp should be a map's pointer")
}
args := make([]interface{}, len(s.names))
for k, i := range s.names {
args[i] = vv.Elem().FieldByName(k).Interface()
}
return s.Query(args...)
}
func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
return s.QueryStructContext(context.Background(), st)
}
func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
rows, err := s.QueryContext(ctx, args...)
return &Row{rows, err}
}
func (s *Stmt) QueryRow(args ...interface{}) *Row {
return s.QueryRowContext(context.Background(), args...)
}
func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row {
vv := reflect.ValueOf(mp)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return &Row{nil, errors.New("mp should be a map's pointer")}
}
args := make([]interface{}, len(s.names))
for k, i := range s.names {
args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
}
return s.QueryRowContext(ctx, args...)
}
func (s *Stmt) QueryRowMap(mp interface{}) *Row {
return s.QueryRowMapContext(context.Background(), mp)
}
func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row {
vv := reflect.ValueOf(st)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return &Row{nil, errors.New("st should be a struct's pointer")}
}
args := make([]interface{}, len(s.names))
for k, i := range s.names {
args[i] = vv.Elem().FieldByName(k).Interface()
}
return s.QueryRowContext(ctx, args...)
}
func (s *Stmt) QueryRowStruct(st interface{}) *Row {
return s.QueryRowStructContext(context.Background(), st)
}

153
core/tx.go Normal file
View File

@ -0,0 +1,153 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import (
"context"
"database/sql"
)
type Tx struct {
*sql.Tx
db *DB
}
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
tx, err := db.DB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{tx, db}, nil
}
func (db *DB) Begin() (*Tx, error) {
tx, err := db.DB.Begin()
if err != nil {
return nil, err
}
return &Tx{tx, db}, nil
}
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
names := make(map[string]int)
var i int
query = re.ReplaceAllStringFunc(query, func(src string) string {
names[src[1:]] = i
i += 1
return "?"
})
stmt, err := tx.Tx.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &Stmt{stmt, tx.db, names}, nil
}
func (tx *Tx) Prepare(query string) (*Stmt, error) {
return tx.PrepareContext(context.Background(), query)
}
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
stmt.Stmt = tx.Tx.StmtContext(ctx, stmt.Stmt)
return stmt
}
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
return tx.StmtContext(context.Background(), stmt)
}
func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) {
query, args, err := MapToSlice(query, mp)
if err != nil {
return nil, err
}
return tx.Tx.ExecContext(ctx, query, args...)
}
func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) {
return tx.ExecMapContext(context.Background(), query, mp)
}
func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) {
query, args, err := StructToSlice(query, st)
if err != nil {
return nil, err
}
return tx.Tx.ExecContext(ctx, query, args...)
}
func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
return tx.ExecStructContext(context.Background(), query, st)
}
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
rows, err := tx.Tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &Rows{rows, tx.db}, nil
}
func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
return tx.QueryContext(context.Background(), query, args...)
}
func (tx *Tx) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) {
query, args, err := MapToSlice(query, mp)
if err != nil {
return nil, err
}
return tx.QueryContext(ctx, query, args...)
}
func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) {
return tx.QueryMapContext(context.Background(), query, mp)
}
func (tx *Tx) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) {
query, args, err := StructToSlice(query, st)
if err != nil {
return nil, err
}
return tx.QueryContext(ctx, query, args...)
}
func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) {
return tx.QueryStructContext(context.Background(), query, st)
}
func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := tx.QueryContext(ctx, query, args...)
return &Row{rows, err}
}
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
return tx.QueryRowContext(context.Background(), query, args...)
}
func (tx *Tx) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row {
query, args, err := MapToSlice(query, mp)
if err != nil {
return &Row{nil, err}
}
return tx.QueryRowContext(ctx, query, args...)
}
func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
return tx.QueryRowMapContext(context.Background(), query, mp)
}
func (tx *Tx) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row {
query, args, err := StructToSlice(query, st)
if err != nil {
return &Row{nil, err}
}
return tx.QueryRowContext(ctx, query, args...)
}
func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row {
return tx.QueryRowStructContext(context.Background(), query, st)
}

410
dialects/dialect.go Normal file
View File

@ -0,0 +1,410 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dialects
import (
"fmt"
"strings"
"time"
"xorm.io/xorm/core"
"xorm.io/xorm/log"
"xorm.io/xorm/schemas"
)
type DBType string
type URI struct {
DBType DBType
Proto string
Host string
Port string
DBName string
User string
Passwd string
Charset string
Laddr string
Raddr string
Timeout time.Duration
Schema string
}
// a dialect is a driver's wrapper
type Dialect interface {
SetLogger(logger log.Logger)
Init(*core.DB, *URI, string, string) error
URI() *URI
DB() *core.DB
DBType() DBType
SQLType(*schemas.Column) string
FormatBytes(b []byte) string
DriverName() string
DataSourceName() string
IsReserved(string) bool
Quote(string) string
AndStr() string
OrStr() string
EqStr() string
RollBackStr() string
AutoIncrStr() string
SupportInsertMany() bool
SupportEngine() bool
SupportCharset() bool
SupportDropIfExists() bool
IndexOnTable() bool
ShowCreateNull() bool
IndexCheckSQL(tableName, idxName string) (string, []interface{})
TableCheckSQL(tableName string) (string, []interface{})
IsColumnExist(tableName string, colName string) (bool, error)
CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string
DropTableSQL(tableName string) string
CreateIndexSQL(tableName string, index *schemas.Index) string
DropIndexSQL(tableName string, index *schemas.Index) string
ModifyColumnSQL(tableName string, col *schemas.Column) string
ForUpdateSQL(query string) string
// CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error
// MustDropTable(tableName string) error
GetColumns(tableName string) ([]string, map[string]*schemas.Column, error)
GetTables() ([]*schemas.Table, error)
GetIndexes(tableName string) (map[string]*schemas.Index, error)
Filters() []Filter
SetParams(params map[string]string)
}
func OpenDialect(dialect Dialect) (*core.DB, error) {
return core.Open(dialect.DriverName(), dialect.DataSourceName())
}
// Base represents a basic dialect and all real dialects could embed this struct
type Base struct {
db *core.DB
dialect Dialect
driverName string
dataSourceName string
logger log.Logger
uri *URI
}
// String generate column description string according dialect
func String(d Dialect, col *schemas.Column) string {
sql := d.Quote(col.Name) + " "
sql += d.SQLType(col) + " "
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
if col.IsAutoIncrement {
sql += d.AutoIncrStr() + " "
}
}
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
if d.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
}
return sql
}
// StringNoPk generate column description string according dialect without primary keys
func StringNoPk(d Dialect, col *schemas.Column) string {
sql := d.Quote(col.Name) + " "
sql += d.SQLType(col) + " "
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
if d.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
}
return sql
}
func (b *Base) DB() *core.DB {
return b.db
}
func (b *Base) SetLogger(logger log.Logger) {
b.logger = logger
}
func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI, drivername, dataSourceName string) error {
b.db, b.dialect, b.uri = db, dialect, uri
b.driverName, b.dataSourceName = drivername, dataSourceName
return nil
}
func (b *Base) URI() *URI {
return b.uri
}
func (b *Base) DBType() DBType {
return b.uri.DBType
}
func (b *Base) FormatBytes(bs []byte) string {
return fmt.Sprintf("0x%x", bs)
}
func (b *Base) DriverName() string {
return b.driverName
}
func (b *Base) ShowCreateNull() bool {
return true
}
func (b *Base) DataSourceName() string {
return b.dataSourceName
}
func (b *Base) AndStr() string {
return "AND"
}
func (b *Base) OrStr() string {
return "OR"
}
func (b *Base) EqStr() string {
return "="
}
func (db *Base) RollBackStr() string {
return "ROLL BACK"
}
func (db *Base) SupportDropIfExists() bool {
return true
}
func (db *Base) DropTableSQL(tableName string) string {
quote := db.dialect.Quote
return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName))
}
func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) {
db.LogSQL(query, args)
rows, err := db.DB().Query(query, args...)
if err != nil {
return false, err
}
defer rows.Close()
if rows.Next() {
return true, nil
}
return false, nil
}
func (db *Base) IsColumnExist(tableName, colName string) (bool, error) {
query := fmt.Sprintf(
"SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?",
db.dialect.Quote("COLUMN_NAME"),
db.dialect.Quote("INFORMATION_SCHEMA"),
db.dialect.Quote("COLUMNS"),
db.dialect.Quote("TABLE_SCHEMA"),
db.dialect.Quote("TABLE_NAME"),
db.dialect.Quote("COLUMN_NAME"),
)
return db.HasRecords(query, db.uri.DBName, tableName, colName)
}
/*
func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error {
sql, args := db.dialect.TableCheckSQL(tableName)
rows, err := db.DB().Query(sql, args...)
if db.Logger != nil {
db.Logger.Info("[sql]", sql, args)
}
if err != nil {
return err
}
defer rows.Close()
if rows.Next() {
return nil
}
sql = db.dialect.CreateTableSQL(table, tableName, storeEngine, charset)
_, err = db.DB().Exec(sql)
if db.Logger != nil {
db.Logger.Info("[sql]", sql)
}
return err
}*/
func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
quote := db.dialect.Quote
var unique string
var idxName string
if index.Type == schemas.UniqueType {
unique = " UNIQUE"
}
idxName = index.XName(tableName)
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
quote(idxName), quote(tableName),
quote(strings.Join(index.Cols, quote(","))))
}
func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
quote := db.dialect.Quote
var name string
if index.IsRegular {
name = index.XName(tableName)
} else {
name = index.Name
}
return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName))
}
func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, StringNoPk(db.dialect, col))
}
func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}
sql += b.dialect.Quote(tableName)
sql += " ("
if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += String(b.dialect, col)
} else {
sql += StringNoPk(b.dialect, col)
}
sql = strings.TrimSpace(sql)
if b.DriverName() == schemas.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ", "
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += b.dialect.Quote(strings.Join(pkList, b.dialect.Quote(",")))
sql += " ), "
}
sql = sql[:len(sql)-2]
}
sql += ")"
if b.dialect.SupportEngine() && storeEngine != "" {
sql += " ENGINE=" + storeEngine
}
if b.dialect.SupportCharset() {
if len(charset) == 0 {
charset = b.dialect.URI().Charset
}
if len(charset) > 0 {
sql += " DEFAULT CHARSET " + charset
}
}
return sql
}
func (b *Base) ForUpdateSQL(query string) string {
return query + " FOR UPDATE"
}
func (b *Base) LogSQL(sql string, args []interface{}) {
if b.logger != nil && b.logger.IsShowSQL() {
if len(args) > 0 {
b.logger.Infof("[SQL] %v %v", sql, args)
} else {
b.logger.Infof("[SQL] %v", sql)
}
}
}
func (b *Base) SetParams(params map[string]string) {
}
var (
dialects = map[string]func() Dialect{}
)
// RegisterDialect register database dialect
func RegisterDialect(dbName DBType, dialectFunc func() Dialect) {
if dialectFunc == nil {
panic("core: Register dialect is nil")
}
dialects[strings.ToLower(string(dbName))] = dialectFunc // !nashtsai! allow override dialect
}
// QueryDialect query if registered database dialect
func QueryDialect(dbName DBType) Dialect {
if d, ok := dialects[strings.ToLower(string(dbName))]; ok {
return d()
}
return nil
}
func regDrvsNDialects() bool {
providedDrvsNDialects := map[string]struct {
dbType DBType
getDriver func() Driver
getDialect func() Dialect
}{
"mssql": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }},
"odbc": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access
"mysql": {"mysql", func() Driver { return &mysqlDriver{} }, func() Dialect { return &mysql{} }},
"mymysql": {"mysql", func() Driver { return &mymysqlDriver{} }, func() Dialect { return &mysql{} }},
"postgres": {"postgres", func() Driver { return &pqDriver{} }, func() Dialect { return &postgres{} }},
"pgx": {"postgres", func() Driver { return &pqDriverPgx{} }, func() Dialect { return &postgres{} }},
"sqlite3": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }},
"oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }},
"goracle": {"oracle", func() Driver { return &goracleDriver{} }, func() Dialect { return &oracle{} }},
}
for driverName, v := range providedDrvsNDialects {
if driver := QueryDriver(driverName); driver == nil {
RegisterDriver(driverName, v.getDriver())
RegisterDialect(v.dbType, v.getDialect)
}
}
return true
}
func init() {
regDrvsNDialects()
}

31
dialects/driver.go Normal file
View File

@ -0,0 +1,31 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dialects
type Driver interface {
Parse(string, string) (*URI, error)
}
var (
drivers = map[string]Driver{}
)
func RegisterDriver(driverName string, driver Driver) {
if driver == nil {
panic("core: Register driver is nil")
}
if _, dup := drivers[driverName]; dup {
panic("core: Register called twice for driver " + driverName)
}
drivers[driverName] = driver
}
func QueryDriver(driverName string) Driver {
return drivers[driverName]
}
func RegisteredDriverSize() int {
return len(drivers)
}

95
dialects/filter.go Normal file
View File

@ -0,0 +1,95 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dialects
import (
"fmt"
"strings"
"xorm.io/xorm/schemas"
)
// Filter is an interface to filter SQL
type Filter interface {
Do(sql string, dialect Dialect, table *schemas.Table) string
}
// QuoteFilter filter SQL replace ` to database's own quote character
type QuoteFilter struct {
}
func (s *QuoteFilter) Do(sql string, dialect Dialect, table *schemas.Table) string {
dummy := dialect.Quote("")
if len(dummy) != 2 {
return sql
}
prefix, suffix := dummy[0], dummy[1]
raw := []byte(sql)
for i, cnt := 0, 0; i < len(raw); i = i + 1 {
if raw[i] == '`' {
if cnt%2 == 0 {
raw[i] = prefix
} else {
raw[i] = suffix
}
cnt++
}
}
return string(raw)
}
// IdFilter filter SQL replace (id) to primary key column name
type IdFilter struct {
}
type Quoter struct {
dialect Dialect
}
func NewQuoter(dialect Dialect) *Quoter {
return &Quoter{dialect}
}
func (q *Quoter) Quote(content string) string {
return q.dialect.Quote(content)
}
func (i *IdFilter) Do(sql string, dialect Dialect, table *schemas.Table) string {
quoter := NewQuoter(dialect)
if table != nil && len(table.PrimaryKeys) == 1 {
sql = strings.Replace(sql, " `(id)` ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1)
sql = strings.Replace(sql, " "+quoter.Quote("(id)")+" ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1)
return strings.Replace(sql, " (id) ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1)
}
return sql
}
// SeqFilter filter SQL replace ?, ? ... to $1, $2 ...
type SeqFilter struct {
Prefix string
Start int
}
func convertQuestionMark(sql, prefix string, start int) string {
var buf strings.Builder
var beginSingleQuote bool
var index = start
for _, c := range sql {
if !beginSingleQuote && c == '?' {
buf.WriteString(fmt.Sprintf("%s%v", prefix, index))
index++
} else {
if c == '\'' {
beginSingleQuote = !beginSingleQuote
}
buf.WriteRune(c)
}
}
return buf.String()
}
func (s *SeqFilter) Do(sql string, dialect Dialect, table *schemas.Table) string {
return convertQuestionMark(sql, s.Prefix, s.Start)
}

39
dialects/filter_test.go Normal file
View File

@ -0,0 +1,39 @@
package dialects
import (
"testing"
"github.com/stretchr/testify/assert"
)
type quoterOnly struct {
Dialect
}
func (q *quoterOnly) Quote(item string) string {
return "[" + item + "]"
}
func TestQuoteFilter_Do(t *testing.T) {
f := QuoteFilter{}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
res := f.Do(sql, new(quoterOnly), nil)
assert.EqualValues(t,
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
res,
)
}
func TestSeqFilter(t *testing.T) {
var kases = map[string]string{
"SELECT * FROM TABLE1 WHERE a=? AND b=?": "SELECT * FROM TABLE1 WHERE a=$1 AND b=$2",
"SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=? AND b=?": "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=$1 AND b=$2",
"select '1''?' from issue": "select '1''?' from issue",
"select '1\\??' from issue": "select '1\\??' from issue",
"select '1\\\\',? from issue": "select '1\\\\',$1 from issue",
"select '1\\''?',? from issue": "select '1\\''?',$1 from issue",
}
for sql, result := range kases {
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
}
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package dialects
import ( import (
"errors" "errors"
@ -11,7 +11,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"xorm.io/core" "xorm.io/xorm/core"
"xorm.io/xorm/schemas"
) )
var ( var (
@ -205,64 +206,64 @@ var (
) )
type mssql struct { type mssql struct {
core.Base Base
} }
func (db *mssql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { func (db *mssql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName) return db.Base.Init(d, db, uri, drivername, dataSourceName)
} }
func (db *mssql) SqlType(c *core.Column) string { func (db *mssql) SQLType(c *schemas.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case core.Bool: case schemas.Bool:
res = core.Bit res = schemas.Bit
if strings.EqualFold(c.Default, "true") { if strings.EqualFold(c.Default, "true") {
c.Default = "1" c.Default = "1"
} else if strings.EqualFold(c.Default, "false") { } else if strings.EqualFold(c.Default, "false") {
c.Default = "0" c.Default = "0"
} }
case core.Serial: case schemas.Serial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = core.Int res = schemas.Int
case core.BigSerial: case schemas.BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = core.BigInt res = schemas.BigInt
case core.Bytea, core.Blob, core.Binary, core.TinyBlob, core.MediumBlob, core.LongBlob: case schemas.Bytea, schemas.Blob, schemas.Binary, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob:
res = core.VarBinary res = schemas.VarBinary
if c.Length == 0 { if c.Length == 0 {
c.Length = 50 c.Length = 50
} }
case core.TimeStamp: case schemas.TimeStamp:
res = core.DateTime res = schemas.DateTime
case core.TimeStampz: case schemas.TimeStampz:
res = "DATETIMEOFFSET" res = "DATETIMEOFFSET"
c.Length = 7 c.Length = 7
case core.MediumInt: case schemas.MediumInt:
res = core.Int res = schemas.Int
case core.Text, core.MediumText, core.TinyText, core.LongText, core.Json: case schemas.Text, schemas.MediumText, schemas.TinyText, schemas.LongText, schemas.Json:
res = core.Varchar + "(MAX)" res = schemas.Varchar + "(MAX)"
case core.Double: case schemas.Double:
res = core.Real res = schemas.Real
case core.Uuid: case schemas.Uuid:
res = core.Varchar res = schemas.Varchar
c.Length = 40 c.Length = 40
case core.TinyInt: case schemas.TinyInt:
res = core.TinyInt res = schemas.TinyInt
c.Length = 0 c.Length = 0
case core.BigInt: case schemas.BigInt:
res = core.BigInt res = schemas.BigInt
c.Length = 0 c.Length = 0
default: default:
res = t res = t
} }
if res == core.Int { if res == schemas.Int {
return core.Int return schemas.Int
} }
hasLen1 := (c.Length > 0) hasLen1 := (c.Length > 0)
@ -297,7 +298,7 @@ func (db *mssql) AutoIncrStr() string {
return "IDENTITY" return "IDENTITY"
} }
func (db *mssql) DropTableSql(tableName string) string { func (db *mssql) DropTableSQL(tableName string) string {
return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+
"object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+ "object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+
"DROP TABLE \"%s\"", tableName, tableName) "DROP TABLE \"%s\"", tableName, tableName)
@ -311,7 +312,7 @@ func (db *mssql) IndexOnTable() bool {
return true return true
} }
func (db *mssql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{idxName} args := []interface{}{idxName}
sql := "select name from sysindexes where id=object_id('" + tableName + "') and name=?" sql := "select name from sysindexes where id=object_id('" + tableName + "') and name=?"
return sql, args return sql, args
@ -329,13 +330,13 @@ func (db *mssql) IsColumnExist(tableName, colName string) (bool, error) {
return db.HasRecords(query, tableName, colName) return db.HasRecords(query, tableName, colName)
} }
func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) { func (db *mssql) TableCheckSQL(tableName string) (string, []interface{}) {
args := []interface{}{} args := []interface{}{}
sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1" sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1"
return sql, args return sql, args
} }
func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { func (db *mssql) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{} args := []interface{}{}
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable,
"default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END),
@ -357,7 +358,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
} }
defer rows.Close() defer rows.Close()
cols := make(map[string]*core.Column) cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
var name, ctype, vdefault string var name, ctype, vdefault string
@ -368,7 +369,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
return nil, nil, err return nil, nil, err
} }
col := new(core.Column) col := new(schemas.Column)
col.Indexes = make(map[string]int) col.Indexes = make(map[string]int)
col.Name = strings.Trim(name, "` ") col.Name = strings.Trim(name, "` ")
col.Nullable = nullable col.Nullable = nullable
@ -387,14 +388,14 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
} }
switch ct { switch ct {
case "DATETIMEOFFSET": case "DATETIMEOFFSET":
col.SQLType = core.SQLType{Name: core.TimeStampz, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
case "NVARCHAR": case "NVARCHAR":
col.SQLType = core.SQLType{Name: core.NVarchar, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.NVarchar, DefaultLength: 0, DefaultLength2: 0}
case "IMAGE": case "IMAGE":
col.SQLType = core.SQLType{Name: core.VarBinary, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.VarBinary, DefaultLength: 0, DefaultLength2: 0}
default: default:
if _, ok := core.SqlTypes[ct]; ok { if _, ok := schemas.SqlTypes[ct]; ok {
col.SQLType = core.SQLType{Name: ct, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: ct, DefaultLength: 0, DefaultLength2: 0}
} else { } else {
return nil, nil, fmt.Errorf("Unknown colType %v for %v - %v", ct, tableName, col.Name) return nil, nil, fmt.Errorf("Unknown colType %v for %v - %v", ct, tableName, col.Name)
} }
@ -406,7 +407,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *mssql) GetTables() ([]*core.Table, error) { func (db *mssql) GetTables() ([]*schemas.Table, error) {
args := []interface{}{} args := []interface{}{}
s := `select name from sysobjects where xtype ='U'` s := `select name from sysobjects where xtype ='U'`
db.LogSQL(s, args) db.LogSQL(s, args)
@ -417,9 +418,9 @@ func (db *mssql) GetTables() ([]*core.Table, error) {
} }
defer rows.Close() defer rows.Close()
tables := make([]*core.Table, 0) tables := make([]*schemas.Table, 0)
for rows.Next() { for rows.Next() {
table := core.NewEmptyTable() table := schemas.NewEmptyTable()
var name string var name string
err = rows.Scan(&name) err = rows.Scan(&name)
if err != nil { if err != nil {
@ -431,7 +432,7 @@ func (db *mssql) GetTables() ([]*core.Table, error) {
return tables, nil return tables, nil
} }
func (db *mssql) GetIndexes(tableName string) (map[string]*core.Index, error) { func (db *mssql) GetIndexes(tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := `SELECT s := `SELECT
IXS.NAME AS [INDEX_NAME], IXS.NAME AS [INDEX_NAME],
@ -452,7 +453,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
} }
defer rows.Close() defer rows.Close()
indexes := make(map[string]*core.Index, 0) indexes := make(map[string]*schemas.Index, 0)
for rows.Next() { for rows.Next() {
var indexType int var indexType int
var indexName, colName, isUnique string var indexName, colName, isUnique string
@ -468,9 +469,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
} }
if i { if i {
indexType = core.UniqueType indexType = schemas.UniqueType
} else { } else {
indexType = core.IndexType indexType = schemas.IndexType
} }
colName = strings.Trim(colName, "` ") colName = strings.Trim(colName, "` ")
@ -480,10 +481,10 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
isRegular = true isRegular = true
} }
var index *core.Index var index *schemas.Index
var ok bool var ok bool
if index, ok = indexes[indexName]; !ok { if index, ok = indexes[indexName]; !ok {
index = new(core.Index) index = new(schemas.Index)
index.Type = indexType index.Type = indexType
index.Name = indexName index.Name = indexName
index.IsRegular = isRegular index.IsRegular = isRegular
@ -494,7 +495,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
return indexes, nil return indexes, nil
} }
func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { func (db *mssql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string var sql string
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
@ -509,9 +510,9 @@ func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, chars
for _, colName := range table.ColumnsSeq() { for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 { if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(db) sql += String(db, col)
} else { } else {
sql += col.StringNoPk(db) sql += StringNoPk(db, col)
} }
sql = strings.TrimSpace(sql) sql = strings.TrimSpace(sql)
sql += ", " sql += ", "
@ -528,18 +529,18 @@ func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, chars
return sql return sql
} }
func (db *mssql) ForUpdateSql(query string) string { func (db *mssql) ForUpdateSQL(query string) string {
return query return query
} }
func (db *mssql) Filters() []core.Filter { func (db *mssql) Filters() []Filter {
return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}} return []Filter{&IdFilter{}, &QuoteFilter{}}
} }
type odbcDriver struct { type odbcDriver struct {
} }
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
var dbName string var dbName string
if strings.HasPrefix(dataSourceName, "sqlserver://") { if strings.HasPrefix(dataSourceName, "sqlserver://") {
@ -563,5 +564,5 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error)
if dbName == "" { if dbName == "" {
return nil, errors.New("no db name provided") return nil, errors.New("no db name provided")
} }
return &core.Uri{DbName: dbName, DbType: core.MSSQL}, nil return &URI{DBName: dbName, DBType: schemas.MSSQL}, nil
} }

View File

@ -2,13 +2,11 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package dialects
import ( import (
"reflect" "reflect"
"testing" "testing"
"xorm.io/core"
) )
func TestParseMSSQL(t *testing.T) { func TestParseMSSQL(t *testing.T) {
@ -21,15 +19,15 @@ func TestParseMSSQL(t *testing.T) {
{"server=localhost;user id=sa;password=yourStrong(!)Password;database=db", "db", true}, {"server=localhost;user id=sa;password=yourStrong(!)Password;database=db", "db", true},
} }
driver := core.QueryDriver("mssql") driver := QueryDriver("mssql")
for _, test := range tests { for _, test := range tests {
uri, err := driver.Parse("mssql", test.in) uri, err := driver.Parse("mssql", test.in)
if err != nil && test.valid { if err != nil && test.valid {
t.Errorf("%q got unexpected error: %s", test.in, err) t.Errorf("%q got unexpected error: %s", test.in, err)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected)
} }
} }
} }

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package dialects
import ( import (
"crypto/tls" "crypto/tls"
@ -13,7 +13,8 @@ import (
"strings" "strings"
"time" "time"
"xorm.io/core" "xorm.io/xorm/core"
"xorm.io/xorm/schemas"
) )
var ( var (
@ -162,7 +163,7 @@ var (
) )
type mysql struct { type mysql struct {
core.Base Base
net string net string
addr string addr string
params map[string]string params map[string]string
@ -175,7 +176,7 @@ type mysql struct {
rowFormat string rowFormat string
} }
func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { func (db *mysql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName) return db.Base.Init(d, db, uri, drivername, dataSourceName)
} }
@ -199,29 +200,29 @@ func (db *mysql) SetParams(params map[string]string) {
} }
} }
func (db *mysql) SqlType(c *core.Column) string { func (db *mysql) SQLType(c *schemas.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case core.Bool: case schemas.Bool:
res = core.TinyInt res = schemas.TinyInt
c.Length = 1 c.Length = 1
case core.Serial: case schemas.Serial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = core.Int res = schemas.Int
case core.BigSerial: case schemas.BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = core.BigInt res = schemas.BigInt
case core.Bytea: case schemas.Bytea:
res = core.Blob res = schemas.Blob
case core.TimeStampz: case schemas.TimeStampz:
res = core.Char res = schemas.Char
c.Length = 64 c.Length = 64
case core.Enum: // mysql enum case schemas.Enum: // mysql enum
res = core.Enum res = schemas.Enum
res += "(" res += "("
opts := "" opts := ""
for v := range c.EnumOptions { for v := range c.EnumOptions {
@ -229,8 +230,8 @@ func (db *mysql) SqlType(c *core.Column) string {
} }
res += strings.TrimLeft(opts, ",") res += strings.TrimLeft(opts, ",")
res += ")" res += ")"
case core.Set: // mysql set case schemas.Set: // mysql set
res = core.Set res = schemas.Set
res += "(" res += "("
opts := "" opts := ""
for v := range c.SetOptions { for v := range c.SetOptions {
@ -238,13 +239,13 @@ func (db *mysql) SqlType(c *core.Column) string {
} }
res += strings.TrimLeft(opts, ",") res += strings.TrimLeft(opts, ",")
res += ")" res += ")"
case core.NVarchar: case schemas.NVarchar:
res = core.Varchar res = schemas.Varchar
case core.Uuid: case schemas.Uuid:
res = core.Varchar res = schemas.Varchar
c.Length = 40 c.Length = 40
case core.Json: case schemas.Json:
res = core.Text res = schemas.Text
default: default:
res = t res = t
} }
@ -252,7 +253,7 @@ func (db *mysql) SqlType(c *core.Column) string {
hasLen1 := (c.Length > 0) hasLen1 := (c.Length > 0)
hasLen2 := (c.Length2 > 0) hasLen2 := (c.Length2 > 0)
if res == core.BigInt && !hasLen1 && !hasLen2 { if res == schemas.BigInt && !hasLen1 && !hasLen2 {
c.Length = 20 c.Length = 20
hasLen1 = true hasLen1 = true
} }
@ -294,8 +295,8 @@ func (db *mysql) IndexOnTable() bool {
return true return true
} }
func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{db.DbName, tableName, idxName} args := []interface{}{db.uri.DBName, tableName, idxName}
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`" sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?" sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?"
return sql, args return sql, args
@ -307,14 +308,14 @@ func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}
return sql, args return sql, args
}*/ }*/
func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) { func (db *mysql) TableCheckSQL(tableName string) (string, []interface{}) {
args := []interface{}{db.DbName, tableName} args := []interface{}{db.uri.DBName, tableName}
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
return sql, args return sql, args
} }
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { func (db *mysql) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{db.DbName, tableName} args := []interface{}{db.uri.DBName, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" " `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
db.LogSQL(s, args) db.LogSQL(s, args)
@ -325,10 +326,10 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
} }
defer rows.Close() defer rows.Close()
cols := make(map[string]*core.Column) cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
col := new(core.Column) col := new(schemas.Column)
col.Indexes = make(map[string]int) col.Indexes = make(map[string]int)
var columnName, isNullable, colType, colKey, extra, comment string var columnName, isNullable, colType, colKey, extra, comment string
@ -356,7 +357,7 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
var len1, len2 int var len1, len2 int
if len(cts) == 2 { if len(cts) == 2 {
idx := strings.Index(cts[1], ")") idx := strings.Index(cts[1], ")")
if colType == core.Enum && cts[1][0] == '\'' { // enum if colType == schemas.Enum && cts[1][0] == '\'' { // enum
options := strings.Split(cts[1][0:idx], ",") options := strings.Split(cts[1][0:idx], ",")
col.EnumOptions = make(map[string]int) col.EnumOptions = make(map[string]int)
for k, v := range options { for k, v := range options {
@ -364,7 +365,7 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
v = strings.Trim(v, "'") v = strings.Trim(v, "'")
col.EnumOptions[v] = k col.EnumOptions[v] = k
} }
} else if colType == core.Set && cts[1][0] == '\'' { } else if colType == schemas.Set && cts[1][0] == '\'' {
options := strings.Split(cts[1][0:idx], ",") options := strings.Split(cts[1][0:idx], ",")
col.SetOptions = make(map[string]int) col.SetOptions = make(map[string]int)
for k, v := range options { for k, v := range options {
@ -394,8 +395,8 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
} }
col.Length = len1 col.Length = len1
col.Length2 = len2 col.Length2 = len2
if _, ok := core.SqlTypes[colType]; ok { if _, ok := schemas.SqlTypes[colType]; ok {
col.SQLType = core.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2} col.SQLType = schemas.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2}
} else { } else {
return nil, nil, fmt.Errorf("Unknown colType %v", colType) return nil, nil, fmt.Errorf("Unknown colType %v", colType)
} }
@ -424,8 +425,8 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *mysql) GetTables() ([]*core.Table, error) { func (db *mysql) GetTables() ([]*schemas.Table, error) {
args := []interface{}{db.DbName} args := []interface{}{db.uri.DBName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " +
"`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')"
db.LogSQL(s, args) db.LogSQL(s, args)
@ -436,9 +437,9 @@ func (db *mysql) GetTables() ([]*core.Table, error) {
} }
defer rows.Close() defer rows.Close()
tables := make([]*core.Table, 0) tables := make([]*schemas.Table, 0)
for rows.Next() { for rows.Next() {
table := core.NewEmptyTable() table := schemas.NewEmptyTable()
var name, engine, tableRows, comment string var name, engine, tableRows, comment string
var autoIncr *string var autoIncr *string
err = rows.Scan(&name, &engine, &tableRows, &autoIncr, &comment) err = rows.Scan(&name, &engine, &tableRows, &autoIncr, &comment)
@ -454,8 +455,8 @@ func (db *mysql) GetTables() ([]*core.Table, error) {
return tables, nil return tables, nil
} }
func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { func (db *mysql) GetIndexes(tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{db.DbName, tableName} args := []interface{}{db.uri.DBName, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
db.LogSQL(s, args) db.LogSQL(s, args)
@ -465,7 +466,7 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
} }
defer rows.Close() defer rows.Close()
indexes := make(map[string]*core.Index, 0) indexes := make(map[string]*schemas.Index, 0)
for rows.Next() { for rows.Next() {
var indexType int var indexType int
var indexName, colName, nonUnique string var indexName, colName, nonUnique string
@ -479,9 +480,9 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
} }
if "YES" == nonUnique || nonUnique == "1" { if "YES" == nonUnique || nonUnique == "1" {
indexType = core.IndexType indexType = schemas.IndexType
} else { } else {
indexType = core.UniqueType indexType = schemas.UniqueType
} }
colName = strings.Trim(colName, "` ") colName = strings.Trim(colName, "` ")
@ -491,10 +492,10 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
isRegular = true isRegular = true
} }
var index *core.Index var index *schemas.Index
var ok bool var ok bool
if index, ok = indexes[indexName]; !ok { if index, ok = indexes[indexName]; !ok {
index = new(core.Index) index = new(schemas.Index)
index.IsRegular = isRegular index.IsRegular = isRegular
index.Type = indexType index.Type = indexType
index.Name = indexName index.Name = indexName
@ -505,7 +506,7 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
return indexes, nil return indexes, nil
} }
func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string var sql string
sql = "CREATE TABLE IF NOT EXISTS " sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" { if tableName == "" {
@ -521,9 +522,9 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars
for _, colName := range table.ColumnsSeq() { for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 { if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(db) sql += String(db, col)
} else { } else {
sql += col.StringNoPk(db) sql += StringNoPk(db, col)
} }
sql = strings.TrimSpace(sql) sql = strings.TrimSpace(sql)
if len(col.Comment) > 0 { if len(col.Comment) > 0 {
@ -559,15 +560,15 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars
return sql return sql
} }
func (db *mysql) Filters() []core.Filter { func (db *mysql) Filters() []Filter {
return []core.Filter{&core.IdFilter{}} return []Filter{&IdFilter{}}
} }
type mymysqlDriver struct { type mymysqlDriver struct {
} }
func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &core.Uri{DbType: core.MYSQL} uri := &URI{DBType: schemas.MYSQL}
pd := strings.SplitN(dataSourceName, "*", 2) pd := strings.SplitN(dataSourceName, "*", 2)
if len(pd) == 2 { if len(pd) == 2 {
@ -576,9 +577,9 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, err
if len(p) != 2 { if len(p) != 2 {
return nil, errors.New("Wrong protocol part of URI") return nil, errors.New("Wrong protocol part of URI")
} }
db.Proto = p[0] uri.Proto = p[0]
options := strings.Split(p[1], ",") options := strings.Split(p[1], ",")
db.Raddr = options[0] uri.Raddr = options[0]
for _, o := range options[1:] { for _, o := range options[1:] {
kv := strings.SplitN(o, "=", 2) kv := strings.SplitN(o, "=", 2)
var k, v string var k, v string
@ -589,13 +590,13 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, err
} }
switch k { switch k {
case "laddr": case "laddr":
db.Laddr = v uri.Laddr = v
case "timeout": case "timeout":
to, err := time.ParseDuration(v) to, err := time.ParseDuration(v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
db.Timeout = to uri.Timeout = to
default: default:
return nil, errors.New("Unknown option: " + k) return nil, errors.New("Unknown option: " + k)
} }
@ -608,17 +609,17 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, err
if len(dup) != 3 { if len(dup) != 3 {
return nil, errors.New("Wrong database part of URI") return nil, errors.New("Wrong database part of URI")
} }
db.DbName = dup[0] uri.DBName = dup[0]
db.User = dup[1] uri.User = dup[1]
db.Passwd = dup[2] uri.Passwd = dup[2]
return db, nil return uri, nil
} }
type mysqlDriver struct { type mysqlDriver struct {
} }
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
dsnPattern := regexp.MustCompile( dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@] `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]] `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
@ -628,12 +629,12 @@ func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error
// tlsConfigRegister := make(map[string]*tls.Config) // tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames() names := dsnPattern.SubexpNames()
uri := &core.Uri{DbType: core.MYSQL} uri := &URI{DBType: schemas.MYSQL}
for i, match := range matches { for i, match := range matches {
switch names[i] { switch names[i] {
case "dbname": case "dbname":
uri.DbName = match uri.DBName = match
case "params": case "params":
if len(match) > 0 { if len(match) > 0 {
kvs := strings.Split(match, "&") kvs := strings.Split(match, "&")

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package dialects
import ( import (
"errors" "errors"
@ -11,7 +11,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"xorm.io/core" "xorm.io/xorm/core"
"xorm.io/xorm/schemas"
) )
var ( var (
@ -499,29 +500,29 @@ var (
) )
type oracle struct { type oracle struct {
core.Base Base
} }
func (db *oracle) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { func (db *oracle) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName) return db.Base.Init(d, db, uri, drivername, dataSourceName)
} }
func (db *oracle) SqlType(c *core.Column) string { func (db *oracle) SQLType(c *schemas.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool, core.Serial, core.BigSerial: case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt, schemas.Bool, schemas.Serial, schemas.BigSerial:
res = "NUMBER" res = "NUMBER"
case core.Binary, core.VarBinary, core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob, core.Bytea: case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
return core.Blob return schemas.Blob
case core.Time, core.DateTime, core.TimeStamp: case schemas.Time, schemas.DateTime, schemas.TimeStamp:
res = core.TimeStamp res = schemas.TimeStamp
case core.TimeStampz: case schemas.TimeStampz:
res = "TIMESTAMP WITH TIME ZONE" res = "TIMESTAMP WITH TIME ZONE"
case core.Float, core.Double, core.Numeric, core.Decimal: case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal:
res = "NUMBER" res = "NUMBER"
case core.Text, core.MediumText, core.LongText, core.Json: case schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json:
res = "CLOB" res = "CLOB"
case core.Char, core.Varchar, core.TinyText: case schemas.Char, schemas.Varchar, schemas.TinyText:
res = "VARCHAR2" res = "VARCHAR2"
default: default:
res = t res = t
@ -571,11 +572,11 @@ func (db *oracle) IndexOnTable() bool {
return false return false
} }
func (db *oracle) DropTableSql(tableName string) string { func (db *oracle) DropTableSQL(tableName string) string {
return fmt.Sprintf("DROP TABLE `%s`", tableName) return fmt.Sprintf("DROP TABLE `%s`", tableName)
} }
func (db *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string var sql string
sql = "CREATE TABLE " sql = "CREATE TABLE "
if tableName == "" { if tableName == "" {
@ -591,7 +592,7 @@ func (db *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, char
/*if col.IsPrimaryKey && len(pkList) == 1 { /*if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect) sql += col.String(b.dialect)
} else {*/ } else {*/
sql += col.StringNoPk(db) sql += StringNoPk(db, col)
// } // }
sql = strings.TrimSpace(sql) sql = strings.TrimSpace(sql)
sql += ", " sql += ", "
@ -618,19 +619,19 @@ func (db *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, char
return sql return sql
} }
func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{tableName, idxName} args := []interface{}{tableName, idxName}
return `SELECT INDEX_NAME FROM USER_INDEXES ` + return `SELECT INDEX_NAME FROM USER_INDEXES ` +
`WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args `WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args
} }
func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) { func (db *oracle) TableCheckSQL(tableName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
return `SELECT table_name FROM user_tables WHERE table_name = :1`, args return `SELECT table_name FROM user_tables WHERE table_name = :1`, args
} }
func (db *oracle) MustDropTable(tableName string) error { func (db *oracle) MustDropTable(tableName string) error {
sql, args := db.TableCheckSql(tableName) sql, args := db.TableCheckSQL(tableName)
db.LogSQL(sql, args) db.LogSQL(sql, args)
rows, err := db.DB().Query(sql, args...) rows, err := db.DB().Query(sql, args...)
@ -674,7 +675,7 @@ func (db *oracle) IsColumnExist(tableName, colName string) (bool, error) {
return false, nil return false, nil
} }
func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { func (db *oracle) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
@ -686,10 +687,10 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum
} }
defer rows.Close() defer rows.Close()
cols := make(map[string]*core.Column) cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
col := new(core.Column) col := new(schemas.Column)
col.Indexes = make(map[string]int) col.Indexes = make(map[string]int)
var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string
@ -731,30 +732,30 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum
switch dt { switch dt {
case "VARCHAR2": case "VARCHAR2":
col.SQLType = core.SQLType{Name: core.Varchar, DefaultLength: len1, DefaultLength2: len2} col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: len1, DefaultLength2: len2}
case "NVARCHAR2": case "NVARCHAR2":
col.SQLType = core.SQLType{Name: core.NVarchar, DefaultLength: len1, DefaultLength2: len2} col.SQLType = schemas.SQLType{Name: schemas.NVarchar, DefaultLength: len1, DefaultLength2: len2}
case "TIMESTAMP WITH TIME ZONE": case "TIMESTAMP WITH TIME ZONE":
col.SQLType = core.SQLType{Name: core.TimeStampz, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
case "NUMBER": case "NUMBER":
col.SQLType = core.SQLType{Name: core.Double, DefaultLength: len1, DefaultLength2: len2} col.SQLType = schemas.SQLType{Name: schemas.Double, DefaultLength: len1, DefaultLength2: len2}
case "LONG", "LONG RAW": case "LONG", "LONG RAW":
col.SQLType = core.SQLType{Name: core.Text, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.Text, DefaultLength: 0, DefaultLength2: 0}
case "RAW": case "RAW":
col.SQLType = core.SQLType{Name: core.Binary, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.Binary, DefaultLength: 0, DefaultLength2: 0}
case "ROWID": case "ROWID":
col.SQLType = core.SQLType{Name: core.Varchar, DefaultLength: 18, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: 18, DefaultLength2: 0}
case "AQ$_SUBSCRIBERS": case "AQ$_SUBSCRIBERS":
ignore = true ignore = true
default: default:
col.SQLType = core.SQLType{Name: strings.ToUpper(dt), DefaultLength: len1, DefaultLength2: len2} col.SQLType = schemas.SQLType{Name: strings.ToUpper(dt), DefaultLength: len1, DefaultLength2: len2}
} }
if ignore { if ignore {
continue continue
} }
if _, ok := core.SqlTypes[col.SQLType.Name]; !ok { if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, fmt.Errorf("Unknown colType %v %v", *dataType, col.SQLType) return nil, nil, fmt.Errorf("Unknown colType %v %v", *dataType, col.SQLType)
} }
@ -772,7 +773,7 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *oracle) GetTables() ([]*core.Table, error) { func (db *oracle) GetTables() ([]*schemas.Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT table_name FROM user_tables" s := "SELECT table_name FROM user_tables"
db.LogSQL(s, args) db.LogSQL(s, args)
@ -783,9 +784,9 @@ func (db *oracle) GetTables() ([]*core.Table, error) {
} }
defer rows.Close() defer rows.Close()
tables := make([]*core.Table, 0) tables := make([]*schemas.Table, 0)
for rows.Next() { for rows.Next() {
table := core.NewEmptyTable() table := schemas.NewEmptyTable()
err = rows.Scan(&table.Name) err = rows.Scan(&table.Name)
if err != nil { if err != nil {
return nil, err return nil, err
@ -796,7 +797,7 @@ func (db *oracle) GetTables() ([]*core.Table, error) {
return tables, nil return tables, nil
} }
func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) { func (db *oracle) GetIndexes(tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " +
"WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1" "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1"
@ -808,7 +809,7 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) {
} }
defer rows.Close() defer rows.Close()
indexes := make(map[string]*core.Index, 0) indexes := make(map[string]*schemas.Index, 0)
for rows.Next() { for rows.Next() {
var indexType int var indexType int
var indexName, colName, uniqueness string var indexName, colName, uniqueness string
@ -827,15 +828,15 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) {
} }
if uniqueness == "UNIQUE" { if uniqueness == "UNIQUE" {
indexType = core.UniqueType indexType = schemas.UniqueType
} else { } else {
indexType = core.IndexType indexType = schemas.IndexType
} }
var index *core.Index var index *schemas.Index
var ok bool var ok bool
if index, ok = indexes[indexName]; !ok { if index, ok = indexes[indexName]; !ok {
index = new(core.Index) index = new(schemas.Index)
index.Type = indexType index.Type = indexType
index.Name = indexName index.Name = indexName
index.IsRegular = isRegular index.IsRegular = isRegular
@ -846,15 +847,15 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) {
return indexes, nil return indexes, nil
} }
func (db *oracle) Filters() []core.Filter { func (db *oracle) Filters() []Filter {
return []core.Filter{&core.QuoteFilter{}, &core.SeqFilter{Prefix: ":", Start: 1}, &core.IdFilter{}} return []Filter{&QuoteFilter{}, &SeqFilter{Prefix: ":", Start: 1}, &IdFilter{}}
} }
type goracleDriver struct { type goracleDriver struct {
} }
func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &core.Uri{DbType: core.ORACLE} db := &URI{DBType: schemas.ORACLE}
dsnPattern := regexp.MustCompile( dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@] `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]] `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
@ -867,10 +868,10 @@ func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*core.Uri, e
for i, match := range matches { for i, match := range matches {
switch names[i] { switch names[i] {
case "dbname": case "dbname":
db.DbName = match db.DBName = match
} }
} }
if db.DbName == "" { if db.DBName == "" {
return nil, errors.New("dbname is empty") return nil, errors.New("dbname is empty")
} }
return db, nil return db, nil
@ -881,8 +882,8 @@ type oci8Driver struct {
// dataSourceName=user/password@ipv4:port/dbname // dataSourceName=user/password@ipv4:port/dbname
// dataSourceName=user/password@[ipv6]:port/dbname // dataSourceName=user/password@[ipv6]:port/dbname
func (p *oci8Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (p *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &core.Uri{DbType: core.ORACLE} db := &URI{DBType: schemas.ORACLE}
dsnPattern := regexp.MustCompile( dsnPattern := regexp.MustCompile(
`^(?P<user>.*)\/(?P<password>.*)@` + // user:password@ `^(?P<user>.*)\/(?P<password>.*)@` + // user:password@
`(?P<net>.*)` + // ip:port `(?P<net>.*)` + // ip:port
@ -892,10 +893,10 @@ func (p *oci8Driver) Parse(driverName, dataSourceName string) (*core.Uri, error)
for i, match := range matches { for i, match := range matches {
switch names[i] { switch names[i] {
case "dbname": case "dbname":
db.DbName = match db.DBName = match
} }
} }
if db.DbName == "" { if db.DBName == "" {
return nil, errors.New("dbname is empty") return nil, errors.New("dbname is empty")
} }
return db, nil return db, nil

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package dialects
import ( import (
"errors" "errors"
@ -11,7 +11,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"xorm.io/core" "xorm.io/xorm/core"
"xorm.io/xorm/schemas"
) )
// from http://www.postgresql.org/docs/current/static/sql-keywords-appendix.html // from http://www.postgresql.org/docs/current/static/sql-keywords-appendix.html
@ -769,67 +770,67 @@ var (
DefaultPostgresSchema = "public" DefaultPostgresSchema = "public"
) )
const postgresPublicSchema = "public" const PostgresPublicSchema = "public"
type postgres struct { type postgres struct {
core.Base Base
} }
func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { func (db *postgres) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
err := db.Base.Init(d, db, uri, drivername, dataSourceName) err := db.Base.Init(d, db, uri, drivername, dataSourceName)
if err != nil { if err != nil {
return err return err
} }
if db.Schema == "" { if db.uri.Schema == "" {
db.Schema = DefaultPostgresSchema db.uri.Schema = DefaultPostgresSchema
} }
return nil return nil
} }
func (db *postgres) SqlType(c *core.Column) string { func (db *postgres) SQLType(c *schemas.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case core.TinyInt: case schemas.TinyInt:
res = core.SmallInt res = schemas.SmallInt
return res return res
case core.Bit: case schemas.Bit:
res = core.Boolean res = schemas.Boolean
return res return res
case core.MediumInt, core.Int, core.Integer: case schemas.MediumInt, schemas.Int, schemas.Integer:
if c.IsAutoIncrement { if c.IsAutoIncrement {
return core.Serial return schemas.Serial
} }
return core.Integer return schemas.Integer
case core.BigInt: case schemas.BigInt:
if c.IsAutoIncrement { if c.IsAutoIncrement {
return core.BigSerial return schemas.BigSerial
} }
return core.BigInt return schemas.BigInt
case core.Serial, core.BigSerial: case schemas.Serial, schemas.BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.Nullable = false c.Nullable = false
res = t res = t
case core.Binary, core.VarBinary: case schemas.Binary, schemas.VarBinary:
return core.Bytea return schemas.Bytea
case core.DateTime: case schemas.DateTime:
res = core.TimeStamp res = schemas.TimeStamp
case core.TimeStampz: case schemas.TimeStampz:
return "timestamp with time zone" return "timestamp with time zone"
case core.Float: case schemas.Float:
res = core.Real res = schemas.Real
case core.TinyText, core.MediumText, core.LongText: case schemas.TinyText, schemas.MediumText, schemas.LongText:
res = core.Text res = schemas.Text
case core.NVarchar: case schemas.NVarchar:
res = core.Varchar res = schemas.Varchar
case core.Uuid: case schemas.Uuid:
return core.Uuid return schemas.Uuid
case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob: case schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob:
return core.Bytea return schemas.Bytea
case core.Double: case schemas.Double:
return "DOUBLE PRECISION" return "DOUBLE PRECISION"
default: default:
if c.IsAutoIncrement { if c.IsAutoIncrement {
return core.Serial return schemas.Serial
} }
res = t res = t
} }
@ -879,37 +880,37 @@ func (db *postgres) IndexOnTable() bool {
return false return false
} }
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
if len(db.Schema) == 0 { if len(db.uri.Schema) == 0 {
args := []interface{}{tableName, idxName} args := []interface{}{tableName, idxName}
return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args
} }
args := []interface{}{db.Schema, tableName, idxName} args := []interface{}{db.uri.Schema, tableName, idxName}
return `SELECT indexname FROM pg_indexes ` + return `SELECT indexname FROM pg_indexes ` +
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
} }
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { func (db *postgres) TableCheckSQL(tableName string) (string, []interface{}) {
if len(db.Schema) == 0 { if len(db.uri.Schema) == 0 {
args := []interface{}{tableName} args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
} }
args := []interface{}{db.Schema, tableName} args := []interface{}{db.uri.Schema, tableName}
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args
} }
func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string {
if len(db.Schema) == 0 || strings.Contains(tableName, ".") { if len(db.uri.Schema) == 0 || strings.Contains(tableName, ".") {
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
tableName, col.Name, db.SqlType(col)) tableName, col.Name, db.SQLType(col))
} }
return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s",
db.Schema, tableName, col.Name, db.SqlType(col)) db.uri.Schema, tableName, col.Name, db.SQLType(col))
} }
func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string {
quote := db.Quote quote := db.Quote
idxName := index.Name idxName := index.Name
@ -918,23 +919,23 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
if !strings.HasPrefix(idxName, "UQE_") && if !strings.HasPrefix(idxName, "UQE_") &&
!strings.HasPrefix(idxName, "IDX_") { !strings.HasPrefix(idxName, "IDX_") {
if index.Type == core.UniqueType { if index.Type == schemas.UniqueType {
idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name) idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
} else { } else {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
} }
} }
if db.Uri.Schema != "" { if db.uri.Schema != "" {
idxName = db.Uri.Schema + "." + idxName idxName = db.uri.Schema + "." + idxName
} }
return fmt.Sprintf("DROP INDEX %v", quote(idxName)) return fmt.Sprintf("DROP INDEX %v", quote(idxName))
} }
func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
args := []interface{}{db.Schema, tableName, colName} args := []interface{}{db.uri.Schema, tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
" AND column_name = $3" " AND column_name = $3"
if len(db.Schema) == 0 { if len(db.uri.Schema) == 0 {
args = []interface{}{tableName, colName} args = []interface{}{tableName, colName}
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2" " AND column_name = $2"
@ -950,7 +951,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
return rows.Next(), nil return rows.Next(), nil
} }
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { func (db *postgres) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
@ -965,8 +966,8 @@ FROM pg_attribute f
WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
var f string var f string
if len(db.Schema) != 0 { if len(db.uri.Schema) != 0 {
args = append(args, db.Schema) args = append(args, db.uri.Schema)
f = " AND s.table_schema = $2" f = " AND s.table_schema = $2"
} }
s = fmt.Sprintf(s, f) s = fmt.Sprintf(s, f)
@ -979,11 +980,11 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
} }
defer rows.Close() defer rows.Close()
cols := make(map[string]*core.Column) cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
col := new(core.Column) col := new(schemas.Column)
col.Indexes = make(map[string]int) col.Indexes = make(map[string]int)
var colName, isNullable, dataType string var colName, isNullable, dataType string
@ -1023,23 +1024,23 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
switch dataType { switch dataType {
case "character varying", "character": case "character varying", "character":
col.SQLType = core.SQLType{Name: core.Varchar, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: 0, DefaultLength2: 0}
case "timestamp without time zone": case "timestamp without time zone":
col.SQLType = core.SQLType{Name: core.DateTime, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.DateTime, DefaultLength: 0, DefaultLength2: 0}
case "timestamp with time zone": case "timestamp with time zone":
col.SQLType = core.SQLType{Name: core.TimeStampz, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
case "double precision": case "double precision":
col.SQLType = core.SQLType{Name: core.Double, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.Double, DefaultLength: 0, DefaultLength2: 0}
case "boolean": case "boolean":
col.SQLType = core.SQLType{Name: core.Bool, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.Bool, DefaultLength: 0, DefaultLength2: 0}
case "time without time zone": case "time without time zone":
col.SQLType = core.SQLType{Name: core.Time, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.Time, DefaultLength: 0, DefaultLength2: 0}
case "oid": case "oid":
col.SQLType = core.SQLType{Name: core.BigInt, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.BigInt, DefaultLength: 0, DefaultLength2: 0}
default: default:
col.SQLType = core.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0}
} }
if _, ok := core.SqlTypes[col.SQLType.Name]; !ok { if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, fmt.Errorf("Unknown colType: %v", dataType) return nil, nil, fmt.Errorf("Unknown colType: %v", dataType)
} }
@ -1065,11 +1066,11 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *postgres) GetTables() ([]*core.Table, error) { func (db *postgres) GetTables() ([]*schemas.Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT tablename FROM pg_tables" s := "SELECT tablename FROM pg_tables"
if len(db.Schema) != 0 { if len(db.uri.Schema) != 0 {
args = append(args, db.Schema) args = append(args, db.uri.Schema)
s = s + " WHERE schemaname = $1" s = s + " WHERE schemaname = $1"
} }
@ -1081,9 +1082,9 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
} }
defer rows.Close() defer rows.Close()
tables := make([]*core.Table, 0) tables := make([]*schemas.Table, 0)
for rows.Next() { for rows.Next() {
table := core.NewEmptyTable() table := schemas.NewEmptyTable()
var name string var name string
err = rows.Scan(&name) err = rows.Scan(&name)
if err != nil { if err != nil {
@ -1106,11 +1107,11 @@ func getIndexColName(indexdef string) []string {
return colNames return colNames
} }
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
if len(db.Schema) != 0 { if len(db.uri.Schema) != 0 {
args = append(args, db.Schema) args = append(args, db.uri.Schema)
s = s + " AND schemaname=$2" s = s + " AND schemaname=$2"
} }
db.LogSQL(s, args) db.LogSQL(s, args)
@ -1121,7 +1122,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error)
} }
defer rows.Close() defer rows.Close()
indexes := make(map[string]*core.Index, 0) indexes := make(map[string]*schemas.Index, 0)
for rows.Next() { for rows.Next() {
var indexType int var indexType int
var indexName, indexdef string var indexName, indexdef string
@ -1135,9 +1136,9 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error)
continue continue
} }
if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") { if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") {
indexType = core.UniqueType indexType = schemas.UniqueType
} else { } else {
indexType = core.IndexType indexType = schemas.IndexType
} }
colNames = getIndexColName(indexdef) colNames = getIndexColName(indexdef)
var isRegular bool var isRegular bool
@ -1149,7 +1150,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error)
} }
} }
index := &core.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} index := &schemas.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
for _, colName := range colNames { for _, colName := range colNames {
index.Cols = append(index.Cols, strings.Trim(colName, `" `)) index.Cols = append(index.Cols, strings.Trim(colName, `" `))
} }
@ -1159,8 +1160,8 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error)
return indexes, nil return indexes, nil
} }
func (db *postgres) Filters() []core.Filter { func (db *postgres) Filters() []Filter {
return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}, &core.SeqFilter{Prefix: "$", Start: 1}} return []Filter{&IdFilter{}, &QuoteFilter{}, &SeqFilter{Prefix: "$", Start: 1}}
} }
type pqDriver struct { type pqDriver struct {
@ -1214,12 +1215,12 @@ func parseOpts(name string, o values) error {
return nil return nil
} }
func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &core.Uri{DbType: core.POSTGRES} db := &URI{DBType: schemas.POSTGRES}
var err error var err error
if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
db.DbName, err = parseURL(dataSourceName) db.DBName, err = parseURL(dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1230,10 +1231,10 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
return nil, err return nil, err
} }
db.DbName = o.Get("dbname") db.DBName = o.Get("dbname")
} }
if db.DbName == "" { if db.DBName == "" {
return nil, errors.New("dbname is empty") return nil, errors.New("dbname is empty")
} }
@ -1244,7 +1245,7 @@ type pqDriverPgx struct {
pqDriver pqDriver
} }
func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*URI, error) {
// Remove the leading characters for driver to work // Remove the leading characters for driver to work
if len(dataSourceName) >= 9 && dataSourceName[0] == 0 { if len(dataSourceName) >= 9 && dataSourceName[0] == 0 {
dataSourceName = dataSourceName[9:] dataSourceName = dataSourceName[9:]

View File

@ -1,11 +1,10 @@
package xorm package dialects
import ( import (
"reflect" "reflect"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core"
) )
func TestParsePostgres(t *testing.T) { func TestParsePostgres(t *testing.T) {
@ -27,15 +26,15 @@ func TestParsePostgres(t *testing.T) {
{"dbname=db =disable", "db", false}, {"dbname=db =disable", "db", false},
} }
driver := core.QueryDriver("postgres") driver := QueryDriver("postgres")
for _, test := range tests { for _, test := range tests {
uri, err := driver.Parse("postgres", test.in) uri, err := driver.Parse("postgres", test.in)
if err != nil && test.valid { if err != nil && test.valid {
t.Errorf("%q got unexpected error: %s", test.in, err) t.Errorf("%q got unexpected error: %s", test.in, err)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected)
} }
} }
} }
@ -59,23 +58,23 @@ func TestParsePgx(t *testing.T) {
{"dbname=db =disable", "db", false}, {"dbname=db =disable", "db", false},
} }
driver := core.QueryDriver("pgx") driver := QueryDriver("pgx")
for _, test := range tests { for _, test := range tests {
uri, err := driver.Parse("pgx", test.in) uri, err := driver.Parse("pgx", test.in)
if err != nil && test.valid { if err != nil && test.valid {
t.Errorf("%q got unexpected error: %s", test.in, err) t.Errorf("%q got unexpected error: %s", test.in, err)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected)
} }
// Register DriverConfig // Register DriverConfig
uri, err = driver.Parse("pgx", test.in) uri, err = driver.Parse("pgx", test.in)
if err != nil && test.valid { if err != nil && test.valid {
t.Errorf("%q got unexpected error: %s", test.in, err) t.Errorf("%q got unexpected error: %s", test.in, err)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected)
} }
} }

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package dialects
import ( import (
"database/sql" "database/sql"
@ -11,7 +11,8 @@ import (
"regexp" "regexp"
"strings" "strings"
"xorm.io/core" "xorm.io/xorm/core"
"xorm.io/xorm/schemas"
) )
var ( var (
@ -144,42 +145,42 @@ var (
) )
type sqlite3 struct { type sqlite3 struct {
core.Base Base
} }
func (db *sqlite3) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { func (db *sqlite3) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName) return db.Base.Init(d, db, uri, drivername, dataSourceName)
} }
func (db *sqlite3) SqlType(c *core.Column) string { func (db *sqlite3) SQLType(c *schemas.Column) string {
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case core.Bool: case schemas.Bool:
if c.Default == "true" { if c.Default == "true" {
c.Default = "1" c.Default = "1"
} else if c.Default == "false" { } else if c.Default == "false" {
c.Default = "0" c.Default = "0"
} }
return core.Integer return schemas.Integer
case core.Date, core.DateTime, core.TimeStamp, core.Time: case schemas.Date, schemas.DateTime, schemas.TimeStamp, schemas.Time:
return core.DateTime return schemas.DateTime
case core.TimeStampz: case schemas.TimeStampz:
return core.Text return schemas.Text
case core.Char, core.Varchar, core.NVarchar, core.TinyText, case schemas.Char, schemas.Varchar, schemas.NVarchar, schemas.TinyText,
core.Text, core.MediumText, core.LongText, core.Json: schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json:
return core.Text return schemas.Text
case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt: case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt:
return core.Integer return schemas.Integer
case core.Float, core.Double, core.Real: case schemas.Float, schemas.Double, schemas.Real:
return core.Real return schemas.Real
case core.Decimal, core.Numeric: case schemas.Decimal, schemas.Numeric:
return core.Numeric return schemas.Numeric
case core.TinyBlob, core.Blob, core.MediumBlob, core.LongBlob, core.Bytea, core.Binary, core.VarBinary: case schemas.TinyBlob, schemas.Blob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea, schemas.Binary, schemas.VarBinary:
return core.Blob return schemas.Blob
case core.Serial, core.BigSerial: case schemas.Serial, schemas.BigSerial:
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.Nullable = false c.Nullable = false
return core.Integer return schemas.Integer
default: default:
return t return t
} }
@ -218,24 +219,24 @@ func (db *sqlite3) IndexOnTable() bool {
return false return false
} }
func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{idxName} args := []interface{}{idxName}
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
} }
func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) { func (db *sqlite3) TableCheckSQL(tableName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
} }
func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string { func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string {
// var unique string // var unique string
quote := db.Quote quote := db.Quote
idxName := index.Name idxName := index.Name
if !strings.HasPrefix(idxName, "UQE_") && if !strings.HasPrefix(idxName, "UQE_") &&
!strings.HasPrefix(idxName, "IDX_") { !strings.HasPrefix(idxName, "IDX_") {
if index.Type == core.UniqueType { if index.Type == schemas.UniqueType {
idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name) idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
} else { } else {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
@ -244,7 +245,7 @@ func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string {
return fmt.Sprintf("DROP INDEX %v", quote(idxName)) return fmt.Sprintf("DROP INDEX %v", quote(idxName))
} }
func (db *sqlite3) ForUpdateSql(query string) string { func (db *sqlite3) ForUpdateSQL(query string) string {
return query return query
} }
@ -298,9 +299,9 @@ func splitColStr(colStr string) []string {
return results return results
} }
func parseString(colStr string) (*core.Column, error) { func parseString(colStr string) (*schemas.Column, error) {
fields := splitColStr(colStr) fields := splitColStr(colStr)
col := new(core.Column) col := new(schemas.Column)
col.Indexes = make(map[string]int) col.Indexes = make(map[string]int)
col.Nullable = true col.Nullable = true
col.DefaultIsEmpty = true col.DefaultIsEmpty = true
@ -310,7 +311,7 @@ func parseString(colStr string) (*core.Column, error) {
col.Name = strings.Trim(strings.Trim(field, "`[] "), `"`) col.Name = strings.Trim(strings.Trim(field, "`[] "), `"`)
continue continue
} else if idx == 1 { } else if idx == 1 {
col.SQLType = core.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0}
continue continue
} }
switch field { switch field {
@ -332,7 +333,7 @@ func parseString(colStr string) (*core.Column, error) {
return col, nil return col, nil
} }
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
db.LogSQL(s, args) db.LogSQL(s, args)
@ -359,7 +360,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu
nEnd := strings.LastIndex(name, ")") nEnd := strings.LastIndex(name, ")")
reg := regexp.MustCompile(`[^\(,\)]*(\([^\(]*\))?`) reg := regexp.MustCompile(`[^\(,\)]*(\([^\(]*\))?`)
colCreates := reg.FindAllString(name[nStart+1:nEnd], -1) colCreates := reg.FindAllString(name[nStart+1:nEnd], -1)
cols := make(map[string]*core.Column) cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, colStr := range colCreates { for _, colStr := range colCreates {
@ -389,7 +390,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *sqlite3) GetTables() ([]*core.Table, error) { func (db *sqlite3) GetTables() ([]*schemas.Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT name FROM sqlite_master WHERE type='table'" s := "SELECT name FROM sqlite_master WHERE type='table'"
db.LogSQL(s, args) db.LogSQL(s, args)
@ -400,9 +401,9 @@ func (db *sqlite3) GetTables() ([]*core.Table, error) {
} }
defer rows.Close() defer rows.Close()
tables := make([]*core.Table, 0) tables := make([]*schemas.Table, 0)
for rows.Next() { for rows.Next() {
table := core.NewEmptyTable() table := schemas.NewEmptyTable()
err = rows.Scan(&table.Name) err = rows.Scan(&table.Name)
if err != nil { if err != nil {
return nil, err return nil, err
@ -415,7 +416,7 @@ func (db *sqlite3) GetTables() ([]*core.Table, error) {
return tables, nil return tables, nil
} }
func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) { func (db *sqlite3) GetIndexes(tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
db.LogSQL(s, args) db.LogSQL(s, args)
@ -426,7 +427,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
} }
defer rows.Close() defer rows.Close()
indexes := make(map[string]*core.Index, 0) indexes := make(map[string]*schemas.Index, 0)
for rows.Next() { for rows.Next() {
var tmpSQL sql.NullString var tmpSQL sql.NullString
err = rows.Scan(&tmpSQL) err = rows.Scan(&tmpSQL)
@ -439,7 +440,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
} }
sql := tmpSQL.String sql := tmpSQL.String
index := new(core.Index) index := new(schemas.Index)
nNStart := strings.Index(sql, "INDEX") nNStart := strings.Index(sql, "INDEX")
nNEnd := strings.Index(sql, "ON") nNEnd := strings.Index(sql, "ON")
if nNStart == -1 || nNEnd == -1 { if nNStart == -1 || nNEnd == -1 {
@ -456,9 +457,9 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
} }
if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") {
index.Type = core.UniqueType index.Type = schemas.UniqueType
} else { } else {
index.Type = core.IndexType index.Type = schemas.IndexType
} }
nStart := strings.Index(sql, "(") nStart := strings.Index(sql, "(")
@ -476,17 +477,17 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
return indexes, nil return indexes, nil
} }
func (db *sqlite3) Filters() []core.Filter { func (db *sqlite3) Filters() []Filter {
return []core.Filter{&core.IdFilter{}} return []Filter{&IdFilter{}}
} }
type sqlite3Driver struct { type sqlite3Driver struct {
} }
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) {
if strings.Contains(dataSourceName, "?") { if strings.Contains(dataSourceName, "?") {
dataSourceName = dataSourceName[:strings.Index(dataSourceName, "?")] dataSourceName = dataSourceName[:strings.Index(dataSourceName, "?")]
} }
return &core.Uri{DbType: core.SQLITE, DbName: dataSourceName}, nil return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil
} }

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package dialects
import ( import (
"testing" "testing"

2
doc.go
View File

@ -126,7 +126,7 @@ Attention: the above 8 methods should be the last chainable method.
engine.ID(1).Get(&user) // for single primary key engine.ID(1).Get(&user) // for single primary key
// SELECT * FROM user WHERE id = 1 // SELECT * FROM user WHERE id = 1
engine.ID(core.PK{1, 2}).Get(&user) // for composite primary keys engine.ID(schemas.PK{1, 2}).Get(&user) // for composite primary keys
// SELECT * FROM user WHERE id1 = 1 AND id2 = 2 // SELECT * FROM user WHERE id1 = 1 AND id2 = 2
engine.In("id", 1, 2, 3).Find(&users) engine.In("id", 1, 2, 3).Find(&users)
// SELECT * FROM user WHERE id IN (1, 2, 3) // SELECT * FROM user WHERE id IN (1, 2, 3)

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 238 KiB

175
engine.go
View File

@ -21,27 +21,32 @@ import (
"time" "time"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/core" "xorm.io/xorm/caches"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
) )
// Engine is the major struct of xorm, it means a database manager. // Engine is the major struct of xorm, it means a database manager.
// Commonly, an application only need one engine // Commonly, an application only need one engine
type Engine struct { type Engine struct {
db *core.DB db *core.DB
dialect core.Dialect dialect dialects.Dialect
ColumnMapper core.IMapper ColumnMapper names.Mapper
TableMapper core.IMapper TableMapper names.Mapper
TagIdentifier string TagIdentifier string
Tables map[reflect.Type]*core.Table Tables map[reflect.Type]*schemas.Table
mutex *sync.RWMutex mutex *sync.RWMutex
Cacher core.Cacher Cacher caches.Cacher
showSQL bool showSQL bool
showExecTime bool showExecTime bool
logger core.ILogger logger log.Logger
TZLocation *time.Location // The timezone of the application TZLocation *time.Location // The timezone of the application
DatabaseTZ *time.Location // The timezone of the database DatabaseTZ *time.Location // The timezone of the database
@ -51,24 +56,24 @@ type Engine struct {
engineGroup *EngineGroup engineGroup *EngineGroup
cachers map[string]core.Cacher cachers map[string]caches.Cacher
cacherLock sync.RWMutex cacherLock sync.RWMutex
defaultContext context.Context defaultContext context.Context
} }
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { func (engine *Engine) setCacher(tableName string, cacher caches.Cacher) {
engine.cacherLock.Lock() engine.cacherLock.Lock()
engine.cachers[tableName] = cacher engine.cachers[tableName] = cacher
engine.cacherLock.Unlock() engine.cacherLock.Unlock()
} }
func (engine *Engine) SetCacher(tableName string, cacher core.Cacher) { func (engine *Engine) SetCacher(tableName string, cacher caches.Cacher) {
engine.setCacher(tableName, cacher) engine.setCacher(tableName, cacher)
} }
func (engine *Engine) getCacher(tableName string) core.Cacher { func (engine *Engine) getCacher(tableName string) caches.Cacher {
var cacher core.Cacher var cacher caches.Cacher
var ok bool var ok bool
engine.cacherLock.RLock() engine.cacherLock.RLock()
cacher, ok = engine.cachers[tableName] cacher, ok = engine.cachers[tableName]
@ -79,7 +84,7 @@ func (engine *Engine) getCacher(tableName string) core.Cacher {
return cacher return cacher
} }
func (engine *Engine) GetCacher(tableName string) core.Cacher { func (engine *Engine) GetCacher(tableName string) caches.Cacher {
return engine.getCacher(tableName) return engine.getCacher(tableName)
} }
@ -91,13 +96,13 @@ func (engine *Engine) BufferSize(size int) *Session {
} }
// CondDeleted returns the conditions whether a record is soft deleted. // CondDeleted returns the conditions whether a record is soft deleted.
func (engine *Engine) CondDeleted(col *core.Column) builder.Cond { func (engine *Engine) CondDeleted(col *schemas.Column) builder.Cond {
var cond = builder.NewCond() var cond = builder.NewCond()
if col.SQLType.IsNumeric() { if col.SQLType.IsNumeric() {
cond = builder.Eq{col.Name: 0} cond = builder.Eq{col.Name: 0}
} else { } else {
// FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value. // FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value.
if engine.dialect.DBType() != core.MSSQL { if engine.dialect.DBType() != schemas.MSSQL {
cond = builder.Eq{col.Name: zeroTime1} cond = builder.Eq{col.Name: zeroTime1}
} }
} }
@ -129,19 +134,19 @@ func (engine *Engine) ShowExecTime(show ...bool) {
} }
// Logger return the logger interface // Logger return the logger interface
func (engine *Engine) Logger() core.ILogger { func (engine *Engine) Logger() log.Logger {
return engine.logger return engine.logger
} }
// SetLogger set the new logger // SetLogger set the new logger
func (engine *Engine) SetLogger(logger core.ILogger) { func (engine *Engine) SetLogger(logger log.Logger) {
engine.logger = logger engine.logger = logger
engine.showSQL = logger.IsShowSQL() engine.showSQL = logger.IsShowSQL()
engine.dialect.SetLogger(logger) engine.dialect.SetLogger(logger)
} }
// SetLogLevel sets the logger level // SetLogLevel sets the logger level
func (engine *Engine) SetLogLevel(level core.LogLevel) { func (engine *Engine) SetLogLevel(level log.LogLevel) {
engine.logger.SetLevel(level) engine.logger.SetLevel(level)
} }
@ -163,18 +168,18 @@ func (engine *Engine) DataSourceName() string {
} }
// SetMapper set the name mapping rules // SetMapper set the name mapping rules
func (engine *Engine) SetMapper(mapper core.IMapper) { func (engine *Engine) SetMapper(mapper names.Mapper) {
engine.SetTableMapper(mapper) engine.SetTableMapper(mapper)
engine.SetColumnMapper(mapper) engine.SetColumnMapper(mapper)
} }
// SetTableMapper set the table name mapping rule // SetTableMapper set the table name mapping rule
func (engine *Engine) SetTableMapper(mapper core.IMapper) { func (engine *Engine) SetTableMapper(mapper names.Mapper) {
engine.TableMapper = mapper engine.TableMapper = mapper
} }
// SetColumnMapper set the column name mapping rule // SetColumnMapper set the column name mapping rule
func (engine *Engine) SetColumnMapper(mapper core.IMapper) { func (engine *Engine) SetColumnMapper(mapper names.Mapper) {
engine.ColumnMapper = mapper engine.ColumnMapper = mapper
} }
@ -268,13 +273,13 @@ func (engine *Engine) quote(sql string) string {
// SqlType will be deprecated, please use SQLType instead // SqlType will be deprecated, please use SQLType instead
// //
// Deprecated: use SQLType instead // Deprecated: use SQLType instead
func (engine *Engine) SqlType(c *core.Column) string { func (engine *Engine) SqlType(c *schemas.Column) string {
return engine.SQLType(c) return engine.SQLType(c)
} }
// SQLType A simple wrapper to dialect's core.SqlType method // SQLType A simple wrapper to dialect's core.SqlType method
func (engine *Engine) SQLType(c *core.Column) string { func (engine *Engine) SQLType(c *schemas.Column) string {
return engine.dialect.SqlType(c) return engine.dialect.SQLType(c)
} }
// AutoIncrStr Database's autoincrement statement // AutoIncrStr Database's autoincrement statement
@ -298,12 +303,12 @@ func (engine *Engine) SetMaxIdleConns(conns int) {
} }
// SetDefaultCacher set the default cacher. Xorm's default not enable cacher. // SetDefaultCacher set the default cacher. Xorm's default not enable cacher.
func (engine *Engine) SetDefaultCacher(cacher core.Cacher) { func (engine *Engine) SetDefaultCacher(cacher caches.Cacher) {
engine.Cacher = cacher engine.Cacher = cacher
} }
// GetDefaultCacher returns the default cacher // GetDefaultCacher returns the default cacher
func (engine *Engine) GetDefaultCacher() core.Cacher { func (engine *Engine) GetDefaultCacher() caches.Cacher {
return engine.Cacher return engine.Cacher
} }
@ -323,14 +328,14 @@ func (engine *Engine) NoCascade() *Session {
} }
// MapCacher Set a table use a special cacher // MapCacher Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error { func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error {
engine.setCacher(engine.TableName(bean, true), cacher) engine.setCacher(engine.TableName(bean, true), cacher)
return nil return nil
} }
// NewDB provides an interface to operate database directly // NewDB provides an interface to operate database directly
func (engine *Engine) NewDB() (*core.DB, error) { func (engine *Engine) NewDB() (*core.DB, error) {
return core.OpenDialect(engine.dialect) return dialects.OpenDialect(engine.dialect)
} }
// DB return the wrapper of sql.DB // DB return the wrapper of sql.DB
@ -339,7 +344,7 @@ func (engine *Engine) DB() *core.DB {
} }
// Dialect return database dialect // Dialect return database dialect
func (engine *Engine) Dialect() core.Dialect { func (engine *Engine) Dialect() dialects.Dialect {
return engine.dialect return engine.dialect
} }
@ -409,7 +414,7 @@ func (engine *Engine) NoAutoCondition(no ...bool) *Session {
return session.NoAutoCondition(no...) return session.NoAutoCondition(no...)
} }
func (engine *Engine) loadTableInfo(table *core.Table) error { func (engine *Engine) loadTableInfo(table *schemas.Table) error {
colSeq, cols, err := engine.dialect.GetColumns(table.Name) colSeq, cols, err := engine.dialect.GetColumns(table.Name)
if err != nil { if err != nil {
return err return err
@ -436,7 +441,7 @@ func (engine *Engine) loadTableInfo(table *core.Table) error {
} }
// DBMetas Retrieve all tables, columns, indexes' informations from database. // DBMetas Retrieve all tables, columns, indexes' informations from database.
func (engine *Engine) DBMetas() ([]*core.Table, error) { func (engine *Engine) DBMetas() ([]*schemas.Table, error) {
tables, err := engine.dialect.GetTables() tables, err := engine.dialect.GetTables()
if err != nil { if err != nil {
return nil, err return nil, err
@ -451,7 +456,7 @@ func (engine *Engine) DBMetas() ([]*core.Table, error) {
} }
// DumpAllToFile dump database all table structs and data to a file // DumpAllToFile dump database all table structs and data to a file
func (engine *Engine) DumpAllToFile(fp string, tp ...core.DbType) error { func (engine *Engine) DumpAllToFile(fp string, tp ...dialects.DBType) error {
f, err := os.Create(fp) f, err := os.Create(fp)
if err != nil { if err != nil {
return err return err
@ -461,7 +466,7 @@ func (engine *Engine) DumpAllToFile(fp string, tp ...core.DbType) error {
} }
// DumpAll dump database all table structs and data to w // DumpAll dump database all table structs and data to w
func (engine *Engine) DumpAll(w io.Writer, tp ...core.DbType) error { func (engine *Engine) DumpAll(w io.Writer, tp ...dialects.DBType) error {
tables, err := engine.DBMetas() tables, err := engine.DBMetas()
if err != nil { if err != nil {
return err return err
@ -470,7 +475,7 @@ func (engine *Engine) DumpAll(w io.Writer, tp ...core.DbType) error {
} }
// DumpTablesToFile dump specified tables to SQL file. // DumpTablesToFile dump specified tables to SQL file.
func (engine *Engine) DumpTablesToFile(tables []*core.Table, fp string, tp ...core.DbType) error { func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ...dialects.DBType) error {
f, err := os.Create(fp) f, err := os.Create(fp)
if err != nil { if err != nil {
return err return err
@ -480,19 +485,19 @@ func (engine *Engine) DumpTablesToFile(tables []*core.Table, fp string, tp ...co
} }
// DumpTables dump specify tables to io.Writer // DumpTables dump specify tables to io.Writer
func (engine *Engine) DumpTables(tables []*core.Table, w io.Writer, tp ...core.DbType) error { func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...dialects.DBType) error {
return engine.dumpTables(tables, w, tp...) return engine.dumpTables(tables, w, tp...)
} }
// dumpTables dump database all table structs and data to w with specify db type // dumpTables dump database all table structs and data to w with specify db type
func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.DbType) error { func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dialects.DBType) error {
var dialect core.Dialect var dialect dialects.Dialect
var distDBName string var distDBName string
if len(tp) == 0 { if len(tp) == 0 {
dialect = engine.dialect dialect = engine.dialect
distDBName = string(engine.dialect.DBType()) distDBName = string(engine.dialect.DBType())
} else { } else {
dialect = core.QueryDialect(tp[0]) dialect = dialects.QueryDialect(tp[0])
if dialect == nil { if dialect == nil {
return errors.New("Unsupported database type") return errors.New("Unsupported database type")
} }
@ -513,12 +518,12 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
return err return err
} }
} }
_, err = io.WriteString(w, dialect.CreateTableSql(table, "", table.StoreEngine, "")+";\n") _, err = io.WriteString(w, dialect.CreateTableSQL(table, "", table.StoreEngine, "")+";\n")
if err != nil { if err != nil {
return err return err
} }
for _, index := range table.Indexes { for _, index := range table.Indexes {
_, err = io.WriteString(w, dialect.CreateIndexSql(table.Name, index)+";\n") _, err = io.WriteString(w, dialect.CreateIndexSQL(table.Name, index)+";\n")
if err != nil { if err != nil {
return err return err
} }
@ -571,19 +576,19 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
} else if col.SQLType.IsNumeric() { } else if col.SQLType.IsNumeric() {
switch reflect.TypeOf(d).Kind() { switch reflect.TypeOf(d).Kind() {
case reflect.Slice: case reflect.Slice:
if col.SQLType.Name == core.Bool { if col.SQLType.Name == schemas.Bool {
temp += fmt.Sprintf(", %v", strconv.FormatBool(d.([]byte)[0] != byte('0'))) temp += fmt.Sprintf(", %v", strconv.FormatBool(d.([]byte)[0] != byte('0')))
} else { } else {
temp += fmt.Sprintf(", %s", string(d.([]byte))) temp += fmt.Sprintf(", %s", string(d.([]byte)))
} }
case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int: case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int:
if col.SQLType.Name == core.Bool { if col.SQLType.Name == schemas.Bool {
temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Int() > 0)) temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Int() > 0))
} else { } else {
temp += fmt.Sprintf(", %v", d) temp += fmt.Sprintf(", %v", d)
} }
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if col.SQLType.Name == core.Bool { if col.SQLType.Name == schemas.Bool {
temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Uint() > 0)) temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Uint() > 0))
} else { } else {
temp += fmt.Sprintf(", %v", d) temp += fmt.Sprintf(", %v", d)
@ -611,7 +616,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
} }
// FIXME: Hack for postgres // FIXME: Hack for postgres
if string(dialect.DBType()) == core.POSTGRES && table.AutoIncrColumn() != nil { if string(dialect.DBType()) == schemas.POSTGRES && table.AutoIncrColumn() != nil {
_, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quote(table.Name)+"), 1), false);\n") _, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quote(table.Name)+"), 1), false);\n")
if err != nil { if err != nil {
return err return err
@ -856,7 +861,7 @@ func (engine *Engine) UnMapType(t reflect.Type) {
delete(engine.Tables, t) delete(engine.Tables, t)
} }
func (engine *Engine) autoMapType(v reflect.Value) (*core.Table, error) { func (engine *Engine) autoMapType(v reflect.Value) (*schemas.Table, error) {
t := v.Type() t := v.Type()
engine.mutex.Lock() engine.mutex.Lock()
defer engine.mutex.Unlock() defer engine.mutex.Unlock()
@ -888,7 +893,7 @@ func (engine *Engine) GobRegister(v interface{}) *Engine {
// Table table struct // Table table struct
type Table struct { type Table struct {
*core.Table *schemas.Table
Name string Name string
} }
@ -907,12 +912,12 @@ func (engine *Engine) TableInfo(bean interface{}) *Table {
return &Table{tb, engine.TableName(bean)} return &Table{tb, engine.TableName(bean)}
} }
func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) { func addIndex(indexName string, table *schemas.Table, col *schemas.Column, indexType int) {
if index, ok := table.Indexes[indexName]; ok { if index, ok := table.Indexes[indexName]; ok {
index.AddColumn(col.Name) index.AddColumn(col.Name)
col.Indexes[index.Name] = indexType col.Indexes[index.Name] = indexType
} else { } else {
index := core.NewIndex(indexName, indexType) index := schemas.NewIndex(indexName, indexType)
index.AddColumn(col.Name) index.AddColumn(col.Name)
table.AddIndex(index) table.AddIndex(index)
col.Indexes[index.Name] = indexType col.Indexes[index.Name] = indexType
@ -928,11 +933,11 @@ var (
tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() tpTableName = reflect.TypeOf((*TableName)(nil)).Elem()
) )
func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { func (engine *Engine) mapType(v reflect.Value) (*schemas.Table, error) {
t := v.Type() t := v.Type()
table := core.NewEmptyTable() table := schemas.NewEmptyTable()
table.Type = t table.Type = t
table.Name = getTableName(engine.TableMapper, v) table.Name = names.GetTableName(engine.TableMapper, v)
var idFieldColName string var idFieldColName string
var hasCacheTag, hasNoCacheTag bool var hasCacheTag, hasNoCacheTag bool
@ -941,17 +946,17 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
tag := t.Field(i).Tag tag := t.Field(i).Tag
ormTagStr := tag.Get(engine.TagIdentifier) ormTagStr := tag.Get(engine.TagIdentifier)
var col *core.Column var col *schemas.Column
fieldValue := v.Field(i) fieldValue := v.Field(i)
fieldType := fieldValue.Type() fieldType := fieldValue.Type()
if ormTagStr != "" { if ormTagStr != "" {
col = &core.Column{ col = &schemas.Column{
FieldName: t.Field(i).Name, FieldName: t.Field(i).Name,
Nullable: true, Nullable: true,
IsPrimaryKey: false, IsPrimaryKey: false,
IsAutoIncrement: false, IsAutoIncrement: false,
MapType: core.TWOSIDES, MapType: schemas.TWOSIDES,
Indexes: make(map[string]int), Indexes: make(map[string]int),
DefaultIsEmpty: true, DefaultIsEmpty: true,
} }
@ -1039,9 +1044,9 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
} }
if col.SQLType.Name == "" { if col.SQLType.Name == "" {
col.SQLType = core.Type2SQLType(fieldType) col.SQLType = schemas.Type2SQLType(fieldType)
} }
engine.dialect.SqlType(col) engine.dialect.SQLType(col)
if col.Length == 0 { if col.Length == 0 {
col.Length = col.SQLType.DefaultLength col.Length = col.SQLType.DefaultLength
} }
@ -1053,9 +1058,9 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
} }
if ctx.isUnique { if ctx.isUnique {
ctx.indexNames[col.Name] = core.UniqueType ctx.indexNames[col.Name] = schemas.UniqueType
} else if ctx.isIndex { } else if ctx.isIndex {
ctx.indexNames[col.Name] = core.IndexType ctx.indexNames[col.Name] = schemas.IndexType
} }
for indexName, indexType := range ctx.indexNames { for indexName, indexType := range ctx.indexNames {
@ -1063,18 +1068,18 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
} }
} }
} else { } else {
var sqlType core.SQLType var sqlType schemas.SQLType
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
if _, ok := fieldValue.Addr().Interface().(core.Conversion); ok { if _, ok := fieldValue.Addr().Interface().(Conversion); ok {
sqlType = core.SQLType{Name: core.Text} sqlType = schemas.SQLType{Name: schemas.Text}
} }
} }
if _, ok := fieldValue.Interface().(core.Conversion); ok { if _, ok := fieldValue.Interface().(Conversion); ok {
sqlType = core.SQLType{Name: core.Text} sqlType = schemas.SQLType{Name: schemas.Text}
} else { } else {
sqlType = core.Type2SQLType(fieldType) sqlType = schemas.Type2SQLType(fieldType)
} }
col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name), col = schemas.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name),
t.Field(i).Name, sqlType, sqlType.DefaultLength, t.Field(i).Name, sqlType, sqlType.DefaultLength,
sqlType.DefaultLength2, true) sqlType.DefaultLength2, true)
@ -1105,7 +1110,7 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
engine.setCacher(table.Name, engine.Cacher) engine.setCacher(table.Name, engine.Cacher)
} else { } else {
engine.logger.Info("enable LRU cache on table:", table.Name) engine.logger.Info("enable LRU cache on table:", table.Name)
engine.setCacher(table.Name, NewLRUCacher2(NewMemoryStore(), time.Hour, 10000)) engine.setCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000))
} }
} }
if hasNoCacheTag { if hasNoCacheTag {
@ -1133,24 +1138,24 @@ func (engine *Engine) IsTableExist(beanOrTableName interface{}) (bool, error) {
// IdOf get id from one struct // IdOf get id from one struct
// //
// Deprecated: use IDOf instead. // Deprecated: use IDOf instead.
func (engine *Engine) IdOf(bean interface{}) core.PK { func (engine *Engine) IdOf(bean interface{}) schemas.PK {
return engine.IDOf(bean) return engine.IDOf(bean)
} }
// IDOf get id from one struct // IDOf get id from one struct
func (engine *Engine) IDOf(bean interface{}) core.PK { func (engine *Engine) IDOf(bean interface{}) schemas.PK {
return engine.IdOfV(reflect.ValueOf(bean)) return engine.IdOfV(reflect.ValueOf(bean))
} }
// IdOfV get id from one value of struct // IdOfV get id from one value of struct
// //
// Deprecated: use IDOfV instead. // Deprecated: use IDOfV instead.
func (engine *Engine) IdOfV(rv reflect.Value) core.PK { func (engine *Engine) IdOfV(rv reflect.Value) schemas.PK {
return engine.IDOfV(rv) return engine.IDOfV(rv)
} }
// IDOfV get id from one value of struct // IDOfV get id from one value of struct
func (engine *Engine) IDOfV(rv reflect.Value) core.PK { func (engine *Engine) IDOfV(rv reflect.Value) schemas.PK {
pk, err := engine.idOfV(rv) pk, err := engine.idOfV(rv)
if err != nil { if err != nil {
engine.logger.Error(err) engine.logger.Error(err)
@ -1159,7 +1164,7 @@ func (engine *Engine) IDOfV(rv reflect.Value) core.PK {
return pk return pk
} }
func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) { func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) {
v := reflect.Indirect(rv) v := reflect.Indirect(rv)
table, err := engine.autoMapType(v) table, err := engine.autoMapType(v)
if err != nil { if err != nil {
@ -1202,10 +1207,10 @@ func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) {
return nil, err return nil, err
} }
} }
return core.PK(pk), nil return schemas.PK(pk), nil
} }
func (engine *Engine) idTypeAssertion(col *core.Column, sid string) (interface{}, error) { func (engine *Engine) idTypeAssertion(col *schemas.Column, sid string) (interface{}, error) {
if col.SQLType.IsNumeric() { if col.SQLType.IsNumeric() {
n, err := strconv.ParseInt(sid, 10, 64) n, err := strconv.ParseInt(sid, 10, 64)
if err != nil { if err != nil {
@ -1317,7 +1322,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.setRefBean(bean); err != nil {
return err return err
} }
if index.Type == core.UniqueType { if index.Type == schemas.UniqueType {
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true) isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true)
if err != nil { if err != nil {
return err return err
@ -1332,7 +1337,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err return err
} }
} }
} else if index.Type == core.IndexType { } else if index.Type == schemas.IndexType {
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false) isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false)
if err != nil { if err != nil {
return err return err
@ -1601,7 +1606,7 @@ func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) {
} }
// nowTime return current time // nowTime return current time
func (engine *Engine) nowTime(col *core.Column) (interface{}, time.Time) { func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time) {
t := time.Now() t := time.Now()
var tz = engine.DatabaseTZ var tz = engine.DatabaseTZ
if !col.DisableTimeZone && col.TimeZone != nil { if !col.DisableTimeZone && col.TimeZone != nil {
@ -1610,7 +1615,7 @@ func (engine *Engine) nowTime(col *core.Column) (interface{}, time.Time) {
return engine.formatTime(col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation) return engine.formatTime(col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation)
} }
func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{}) { func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interface{}) {
if t.IsZero() { if t.IsZero() {
if col.Nullable { if col.Nullable {
return nil return nil
@ -1627,20 +1632,20 @@ func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{
// formatTime format time as column type // formatTime format time as column type
func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) { func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) {
switch sqlTypeName { switch sqlTypeName {
case core.Time: case schemas.Time:
s := t.Format("2006-01-02 15:04:05") // time.RFC3339 s := t.Format("2006-01-02 15:04:05") // time.RFC3339
v = s[11:19] v = s[11:19]
case core.Date: case schemas.Date:
v = t.Format("2006-01-02") v = t.Format("2006-01-02")
case core.DateTime, core.TimeStamp, core.Varchar: // !DarthPestilane! format time when sqlTypeName is core.Varchar. case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar.
v = t.Format("2006-01-02 15:04:05") v = t.Format("2006-01-02 15:04:05")
case core.TimeStampz: case schemas.TimeStampz:
if engine.dialect.DBType() == core.MSSQL { if engine.dialect.DBType() == schemas.MSSQL {
v = t.Format("2006-01-02T15:04:05.9999999Z07:00") v = t.Format("2006-01-02T15:04:05.9999999Z07:00")
} else { } else {
v = t.Format(time.RFC3339Nano) v = t.Format(time.RFC3339Nano)
} }
case core.BigInt, core.Int: case schemas.BigInt, schemas.Int:
v = t.Unix() v = t.Unix()
default: default:
v = t v = t
@ -1649,12 +1654,12 @@ func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}
} }
// GetColumnMapper returns the column name mapper // GetColumnMapper returns the column name mapper
func (engine *Engine) GetColumnMapper() core.IMapper { func (engine *Engine) GetColumnMapper() names.Mapper {
return engine.ColumnMapper return engine.ColumnMapper
} }
// GetTableMapper returns the table name mapper // GetTableMapper returns the table name mapper
func (engine *Engine) GetTableMapper() core.IMapper { func (engine *Engine) GetTableMapper() names.Mapper {
return engine.TableMapper return engine.TableMapper
} }

View File

@ -12,10 +12,10 @@ import (
"time" "time"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
func (engine *Engine) buildConds(table *core.Table, bean interface{}, func (engine *Engine) buildConds(table *schemas.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool, includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) {
@ -31,7 +31,7 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{},
continue continue
} }
if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) { if engine.dialect.DBType() == schemas.MSSQL && (col.SQLType.Name == schemas.Text || col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) {
continue continue
} }
if col.SQLType.IsJson() { if col.SQLType.IsJson() {
@ -130,13 +130,13 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{},
t := int64(fieldValue.Uint()) t := int64(fieldValue.Uint())
val = reflect.ValueOf(&t).Interface() val = reflect.ValueOf(&t).Interface()
case reflect.Struct: case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) { if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(core.TimeType).Interface().(time.Time) t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue continue
} }
val = engine.formatColTime(col, t) val = engine.formatColTime(col, t)
} else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { } else if _, ok := reflect.New(fieldType).Interface().(Conversion); ok {
continue continue
} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = valNul.Value() val, _ = valNul.Value()

View File

@ -8,7 +8,9 @@ import (
"context" "context"
"time" "time"
"xorm.io/core" "xorm.io/xorm/caches"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
) )
// EngineGroup defines an engine group // EngineGroup defines an engine group
@ -109,7 +111,7 @@ func (eg *EngineGroup) Ping() error {
} }
// SetColumnMapper set the column name mapping rule // SetColumnMapper set the column name mapping rule
func (eg *EngineGroup) SetColumnMapper(mapper core.IMapper) { func (eg *EngineGroup) SetColumnMapper(mapper names.Mapper) {
eg.Engine.ColumnMapper = mapper eg.Engine.ColumnMapper = mapper
for i := 0; i < len(eg.slaves); i++ { for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].ColumnMapper = mapper eg.slaves[i].ColumnMapper = mapper
@ -125,7 +127,7 @@ func (eg *EngineGroup) SetConnMaxLifetime(d time.Duration) {
} }
// SetDefaultCacher set the default cacher // SetDefaultCacher set the default cacher
func (eg *EngineGroup) SetDefaultCacher(cacher core.Cacher) { func (eg *EngineGroup) SetDefaultCacher(cacher caches.Cacher) {
eg.Engine.SetDefaultCacher(cacher) eg.Engine.SetDefaultCacher(cacher)
for i := 0; i < len(eg.slaves); i++ { for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetDefaultCacher(cacher) eg.slaves[i].SetDefaultCacher(cacher)
@ -133,7 +135,7 @@ func (eg *EngineGroup) SetDefaultCacher(cacher core.Cacher) {
} }
// SetLogger set the new logger // SetLogger set the new logger
func (eg *EngineGroup) SetLogger(logger core.ILogger) { func (eg *EngineGroup) SetLogger(logger log.Logger) {
eg.Engine.SetLogger(logger) eg.Engine.SetLogger(logger)
for i := 0; i < len(eg.slaves); i++ { for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetLogger(logger) eg.slaves[i].SetLogger(logger)
@ -141,7 +143,7 @@ func (eg *EngineGroup) SetLogger(logger core.ILogger) {
} }
// SetLogLevel sets the logger level // SetLogLevel sets the logger level
func (eg *EngineGroup) SetLogLevel(level core.LogLevel) { func (eg *EngineGroup) SetLogLevel(level log.LogLevel) {
eg.Engine.SetLogLevel(level) eg.Engine.SetLogLevel(level)
for i := 0; i < len(eg.slaves); i++ { for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetLogLevel(level) eg.slaves[i].SetLogLevel(level)
@ -149,7 +151,7 @@ func (eg *EngineGroup) SetLogLevel(level core.LogLevel) {
} }
// SetMapper set the name mapping rules // SetMapper set the name mapping rules
func (eg *EngineGroup) SetMapper(mapper core.IMapper) { func (eg *EngineGroup) SetMapper(mapper names.Mapper) {
eg.Engine.SetMapper(mapper) eg.Engine.SetMapper(mapper)
for i := 0; i < len(eg.slaves); i++ { for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetMapper(mapper) eg.slaves[i].SetMapper(mapper)
@ -179,7 +181,7 @@ func (eg *EngineGroup) SetPolicy(policy GroupPolicy) *EngineGroup {
} }
// SetTableMapper set the table name mapping rule // SetTableMapper set the table name mapping rule
func (eg *EngineGroup) SetTableMapper(mapper core.IMapper) { func (eg *EngineGroup) SetTableMapper(mapper names.Mapper) {
eg.Engine.TableMapper = mapper eg.Engine.TableMapper = mapper
for i := 0; i < len(eg.slaves); i++ { for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].TableMapper = mapper eg.slaves[i].TableMapper = mapper

View File

@ -9,16 +9,18 @@ import (
"reflect" "reflect"
"strings" "strings"
"xorm.io/core" "xorm.io/xorm/dialects"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
) )
// tbNameWithSchema will automatically add schema prefix on table name // tbNameWithSchema will automatically add schema prefix on table name
func (engine *Engine) tbNameWithSchema(v string) string { func (engine *Engine) tbNameWithSchema(v string) string {
// Add schema name as prefix of table name. // Add schema name as prefix of table name.
// Only for postgres database. // Only for postgres database.
if engine.dialect.DBType() == core.POSTGRES && if engine.dialect.DBType() == schemas.POSTGRES &&
engine.dialect.URI().Schema != "" && engine.dialect.URI().Schema != "" &&
engine.dialect.URI().Schema != postgresPublicSchema && engine.dialect.URI().Schema != dialects.PostgresPublicSchema &&
strings.Index(v, ".") == -1 { strings.Index(v, ".") == -1 {
return engine.dialect.URI().Schema + "." + v return engine.dialect.URI().Schema + "." + v
} }
@ -44,7 +46,7 @@ func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string
} }
// tbName get some table's table name // tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string { func (session *Session) tbNameNoSchema(table *schemas.Table) string {
if len(session.statement.AltTableName) > 0 { if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName return session.statement.AltTableName
} }
@ -76,7 +78,7 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
v := rValue(f) v := rValue(f)
t := v.Type() t := v.Type()
if t.Kind() == reflect.Struct { if t.Kind() == reflect.Struct {
table = getTableName(engine.TableMapper, v) table = names.GetTableName(engine.TableMapper, v)
} else { } else {
table = engine.Quote(fmt.Sprintf("%v", f)) table = engine.Quote(fmt.Sprintf("%v", f))
} }
@ -94,12 +96,12 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
return tablename.(string) return tablename.(string)
case reflect.Value: case reflect.Value:
v := tablename.(reflect.Value) v := tablename.(reflect.Value)
return getTableName(engine.TableMapper, v) return names.GetTableName(engine.TableMapper, v)
default: default:
v := rValue(tablename) v := rValue(tablename)
t := v.Type() t := v.Type()
if t.Kind() == reflect.Struct { if t.Kind() == reflect.Struct {
return getTableName(engine.TableMapper, v) return names.GetTableName(engine.TableMapper, v)
} }
return engine.Quote(fmt.Sprintf("%v", tablename)) return engine.Quote(fmt.Sprintf("%v", tablename))
} }

View File

@ -6,6 +6,7 @@ import (
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"xorm.io/xorm" "xorm.io/xorm"
"xorm.io/xorm/caches"
) )
// User describes a user // User describes a user
@ -15,7 +16,7 @@ type User struct {
} }
func main() { func main() {
f := "cache.db" f := "caches.db"
os.Remove(f) os.Remove(f)
Orm, err := xorm.NewEngine("sqlite3", f) Orm, err := xorm.NewEngine("sqlite3", f)
@ -24,7 +25,7 @@ func main() {
return return
} }
Orm.ShowSQL(true) Orm.ShowSQL(true)
cacher := xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000) cacher := caches.NewLRUCacher(caches.NewMemoryStore(), 1000)
Orm.SetDefaultCacher(cacher) Orm.SetDefaultCacher(cacher)
err = Orm.CreateTables(&User{}) err = Orm.CreateTables(&User{})

View File

@ -8,6 +8,7 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"xorm.io/xorm" "xorm.io/xorm"
"xorm.io/xorm/caches"
) )
// User describes a user // User describes a user
@ -87,7 +88,7 @@ func main() {
return return
} }
engine.ShowSQL(true) engine.ShowSQL(true)
cacher := xorm.NewLRUCacher2(xorm.NewMemoryStore(), time.Hour, 1000) cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 1000)
engine.SetDefaultCacher(cacher) engine.SetDefaultCacher(cacher)
fmt.Println(engine) fmt.Println(engine)
test(engine) test(engine)
@ -97,7 +98,7 @@ func main() {
fmt.Println("-----start mysql go routines-----") fmt.Println("-----start mysql go routines-----")
engine, err = mysqlEngine() engine, err = mysqlEngine()
engine.ShowSQL(true) engine.ShowSQL(true)
cacher = xorm.NewLRUCacher2(xorm.NewMemoryStore(), time.Hour, 1000) cacher = caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 1000)
engine.SetDefaultCacher(cacher) engine.SetDefaultCacher(cacher)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)

2
go.mod
View File

@ -10,6 +10,6 @@ require (
github.com/mattn/go-sqlite3 v1.10.0 github.com/mattn/go-sqlite3 v1.10.0
github.com/stretchr/testify v1.4.0 github.com/stretchr/testify v1.4.0
github.com/ziutek/mymysql v1.5.4 github.com/ziutek/mymysql v1.5.4
google.golang.org/appengine v1.6.0 // indirect
xorm.io/builder v0.3.6 xorm.io/builder v0.3.6
xorm.io/core v0.7.2
) )

2
go.sum
View File

@ -145,5 +145,3 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
xorm.io/builder v0.3.6 h1:ha28mQ2M+TFx96Hxo+iq6tQgnkC9IZkM6D8w9sKHHF8= xorm.io/builder v0.3.6 h1:ha28mQ2M+TFx96Hxo+iq6tQgnkC9IZkM6D8w9sKHHF8=
xorm.io/builder v0.3.6/go.mod h1:LEFAPISnRzG+zxaxj2vPicRwz67BdhFreKg8yv8/TgU= xorm.io/builder v0.3.6/go.mod h1:LEFAPISnRzG+zxaxj2vPicRwz67BdhFreKg8yv8/TgU=
xorm.io/core v0.7.2 h1:mEO22A2Z7a3fPaZMk6gKL/jMD80iiyNwRrX5HOv3XLw=
xorm.io/core v0.7.2/go.mod h1:jJfd0UAEzZ4t87nbQYtVjmqpIODugN6PD2D9E+dJvdM=

View File

@ -12,7 +12,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
// str2PK convert string value to primary key value according to tp // str2PK convert string value to primary key value according to tp
@ -95,26 +95,6 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) {
return v.Interface(), nil return v.Interface(), nil
} }
func splitTag(tag string) (tags []string) {
tag = strings.TrimSpace(tag)
var hasQuote = false
var lastIdx = 0
for i, t := range tag {
if t == '\'' {
hasQuote = !hasQuote
} else if t == ' ' {
if lastIdx < i && !hasQuote {
tags = append(tags, strings.TrimSpace(tag[lastIdx:i]))
lastIdx = i + 1
}
}
}
if lastIdx < len(tag) {
tags = append(tags, strings.TrimSpace(tag[lastIdx:]))
}
return
}
type zeroable interface { type zeroable interface {
IsZero() bool IsZero() bool
} }
@ -249,7 +229,7 @@ func int64ToInt(id int64, tp reflect.Type) interface{} {
return int64ToIntValue(id, tp).Interface() return int64ToIntValue(id, tp).Interface()
} }
func isPKZero(pk core.PK) bool { func isPKZero(pk schemas.PK) bool {
for _, k := range pk { for _, k := range pk {
if isZero(k) { if isZero(k) {
return true return true

View File

@ -10,7 +10,11 @@ import (
"reflect" "reflect"
"time" "time"
"xorm.io/core" "xorm.io/xorm/caches"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
) )
// Interface defines the interface which Engine, EngineGroup and Session will implementate. // Interface defines the interface which Engine, EngineGroup and Session will implementate.
@ -76,31 +80,31 @@ type EngineInterface interface {
ClearCache(...interface{}) error ClearCache(...interface{}) error
Context(context.Context) *Session Context(context.Context) *Session
CreateTables(...interface{}) error CreateTables(...interface{}) error
DBMetas() ([]*core.Table, error) DBMetas() ([]*schemas.Table, error)
Dialect() core.Dialect Dialect() dialects.Dialect
DropTables(...interface{}) error DropTables(...interface{}) error
DumpAllToFile(fp string, tp ...core.DbType) error DumpAllToFile(fp string, tp ...dialects.DBType) error
GetCacher(string) core.Cacher GetCacher(string) caches.Cacher
GetColumnMapper() core.IMapper GetColumnMapper() names.Mapper
GetDefaultCacher() core.Cacher GetDefaultCacher() caches.Cacher
GetTableMapper() core.IMapper GetTableMapper() names.Mapper
GetTZDatabase() *time.Location GetTZDatabase() *time.Location
GetTZLocation() *time.Location GetTZLocation() *time.Location
MapCacher(interface{}, core.Cacher) error MapCacher(interface{}, caches.Cacher) error
NewSession() *Session NewSession() *Session
NoAutoTime() *Session NoAutoTime() *Session
Quote(string) string Quote(string) string
SetCacher(string, core.Cacher) SetCacher(string, caches.Cacher)
SetConnMaxLifetime(time.Duration) SetConnMaxLifetime(time.Duration)
SetColumnMapper(core.IMapper) SetColumnMapper(names.Mapper)
SetDefaultCacher(core.Cacher) SetDefaultCacher(caches.Cacher)
SetLogger(logger core.ILogger) SetLogger(logger log.Logger)
SetLogLevel(core.LogLevel) SetLogLevel(log.LogLevel)
SetMapper(core.IMapper) SetMapper(names.Mapper)
SetMaxOpenConns(int) SetMaxOpenConns(int)
SetMaxIdleConns(int) SetMaxIdleConns(int)
SetSchema(string) SetSchema(string)
SetTableMapper(core.IMapper) SetTableMapper(names.Mapper)
SetTZDatabase(tz *time.Location) SetTZDatabase(tz *time.Location)
SetTZLocation(tz *time.Location) SetTZLocation(tz *time.Location)
ShowExecTime(...bool) ShowExecTime(...bool)

View File

@ -2,26 +2,56 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package log
import ( import (
"fmt" "fmt"
"io" "io"
"log" "log"
)
"xorm.io/core" // LogLevel defines a log level
type LogLevel int
// enumerate all LogLevels
const (
// !nashtsai! following level also match syslog.Priority value
LOG_DEBUG LogLevel = iota
LOG_INFO
LOG_WARNING
LOG_ERR
LOG_OFF
LOG_UNKNOWN
) )
// default log options // default log options
const ( const (
DEFAULT_LOG_PREFIX = "[xorm]" DEFAULT_LOG_PREFIX = "[xorm]"
DEFAULT_LOG_FLAG = log.Ldate | log.Lmicroseconds DEFAULT_LOG_FLAG = log.Ldate | log.Lmicroseconds
DEFAULT_LOG_LEVEL = core.LOG_DEBUG DEFAULT_LOG_LEVEL = LOG_DEBUG
) )
var _ core.ILogger = DiscardLogger{} // Logger is a logger interface
type Logger interface {
Debug(v ...interface{})
Debugf(format string, v ...interface{})
Error(v ...interface{})
Errorf(format string, v ...interface{})
Info(v ...interface{})
Infof(format string, v ...interface{})
Warn(v ...interface{})
Warnf(format string, v ...interface{})
// DiscardLogger don't log implementation for core.ILogger Level() LogLevel
SetLevel(l LogLevel)
ShowSQL(show ...bool)
IsShowSQL() bool
}
var _ Logger = DiscardLogger{}
// DiscardLogger don't log implementation for ILogger
type DiscardLogger struct{} type DiscardLogger struct{}
// Debug empty implementation // Debug empty implementation
@ -49,12 +79,12 @@ func (DiscardLogger) Warn(v ...interface{}) {}
func (DiscardLogger) Warnf(format string, v ...interface{}) {} func (DiscardLogger) Warnf(format string, v ...interface{}) {}
// Level empty implementation // Level empty implementation
func (DiscardLogger) Level() core.LogLevel { func (DiscardLogger) Level() LogLevel {
return core.LOG_UNKNOWN return LOG_UNKNOWN
} }
// SetLevel empty implementation // SetLevel empty implementation
func (DiscardLogger) SetLevel(l core.LogLevel) {} func (DiscardLogger) SetLevel(l LogLevel) {}
// ShowSQL empty implementation // ShowSQL empty implementation
func (DiscardLogger) ShowSQL(show ...bool) {} func (DiscardLogger) ShowSQL(show ...bool) {}
@ -64,17 +94,17 @@ func (DiscardLogger) IsShowSQL() bool {
return false return false
} }
// SimpleLogger is the default implment of core.ILogger // SimpleLogger is the default implment of ILogger
type SimpleLogger struct { type SimpleLogger struct {
DEBUG *log.Logger DEBUG *log.Logger
ERR *log.Logger ERR *log.Logger
INFO *log.Logger INFO *log.Logger
WARN *log.Logger WARN *log.Logger
level core.LogLevel level LogLevel
showSQL bool showSQL bool
} }
var _ core.ILogger = &SimpleLogger{} var _ Logger = &SimpleLogger{}
// NewSimpleLogger use a special io.Writer as logger output // NewSimpleLogger use a special io.Writer as logger output
func NewSimpleLogger(out io.Writer) *SimpleLogger { func NewSimpleLogger(out io.Writer) *SimpleLogger {
@ -87,7 +117,7 @@ func NewSimpleLogger2(out io.Writer, prefix string, flag int) *SimpleLogger {
} }
// NewSimpleLogger3 let you customrize your logger prefix and flag and logLevel // NewSimpleLogger3 let you customrize your logger prefix and flag and logLevel
func NewSimpleLogger3(out io.Writer, prefix string, flag int, l core.LogLevel) *SimpleLogger { func NewSimpleLogger3(out io.Writer, prefix string, flag int, l LogLevel) *SimpleLogger {
return &SimpleLogger{ return &SimpleLogger{
DEBUG: log.New(out, fmt.Sprintf("%s [debug] ", prefix), flag), DEBUG: log.New(out, fmt.Sprintf("%s [debug] ", prefix), flag),
ERR: log.New(out, fmt.Sprintf("%s [error] ", prefix), flag), ERR: log.New(out, fmt.Sprintf("%s [error] ", prefix), flag),
@ -97,82 +127,82 @@ func NewSimpleLogger3(out io.Writer, prefix string, flag int, l core.LogLevel) *
} }
} }
// Error implement core.ILogger // Error implement ILogger
func (s *SimpleLogger) Error(v ...interface{}) { func (s *SimpleLogger) Error(v ...interface{}) {
if s.level <= core.LOG_ERR { if s.level <= LOG_ERR {
s.ERR.Output(2, fmt.Sprint(v...)) s.ERR.Output(2, fmt.Sprint(v...))
} }
return return
} }
// Errorf implement core.ILogger // Errorf implement ILogger
func (s *SimpleLogger) Errorf(format string, v ...interface{}) { func (s *SimpleLogger) Errorf(format string, v ...interface{}) {
if s.level <= core.LOG_ERR { if s.level <= LOG_ERR {
s.ERR.Output(2, fmt.Sprintf(format, v...)) s.ERR.Output(2, fmt.Sprintf(format, v...))
} }
return return
} }
// Debug implement core.ILogger // Debug implement ILogger
func (s *SimpleLogger) Debug(v ...interface{}) { func (s *SimpleLogger) Debug(v ...interface{}) {
if s.level <= core.LOG_DEBUG { if s.level <= LOG_DEBUG {
s.DEBUG.Output(2, fmt.Sprint(v...)) s.DEBUG.Output(2, fmt.Sprint(v...))
} }
return return
} }
// Debugf implement core.ILogger // Debugf implement ILogger
func (s *SimpleLogger) Debugf(format string, v ...interface{}) { func (s *SimpleLogger) Debugf(format string, v ...interface{}) {
if s.level <= core.LOG_DEBUG { if s.level <= LOG_DEBUG {
s.DEBUG.Output(2, fmt.Sprintf(format, v...)) s.DEBUG.Output(2, fmt.Sprintf(format, v...))
} }
return return
} }
// Info implement core.ILogger // Info implement ILogger
func (s *SimpleLogger) Info(v ...interface{}) { func (s *SimpleLogger) Info(v ...interface{}) {
if s.level <= core.LOG_INFO { if s.level <= LOG_INFO {
s.INFO.Output(2, fmt.Sprint(v...)) s.INFO.Output(2, fmt.Sprint(v...))
} }
return return
} }
// Infof implement core.ILogger // Infof implement ILogger
func (s *SimpleLogger) Infof(format string, v ...interface{}) { func (s *SimpleLogger) Infof(format string, v ...interface{}) {
if s.level <= core.LOG_INFO { if s.level <= LOG_INFO {
s.INFO.Output(2, fmt.Sprintf(format, v...)) s.INFO.Output(2, fmt.Sprintf(format, v...))
} }
return return
} }
// Warn implement core.ILogger // Warn implement ILogger
func (s *SimpleLogger) Warn(v ...interface{}) { func (s *SimpleLogger) Warn(v ...interface{}) {
if s.level <= core.LOG_WARNING { if s.level <= LOG_WARNING {
s.WARN.Output(2, fmt.Sprint(v...)) s.WARN.Output(2, fmt.Sprint(v...))
} }
return return
} }
// Warnf implement core.ILogger // Warnf implement ILogger
func (s *SimpleLogger) Warnf(format string, v ...interface{}) { func (s *SimpleLogger) Warnf(format string, v ...interface{}) {
if s.level <= core.LOG_WARNING { if s.level <= LOG_WARNING {
s.WARN.Output(2, fmt.Sprintf(format, v...)) s.WARN.Output(2, fmt.Sprintf(format, v...))
} }
return return
} }
// Level implement core.ILogger // Level implement ILogger
func (s *SimpleLogger) Level() core.LogLevel { func (s *SimpleLogger) Level() LogLevel {
return s.level return s.level
} }
// SetLevel implement core.ILogger // SetLevel implement ILogger
func (s *SimpleLogger) SetLevel(l core.LogLevel) { func (s *SimpleLogger) SetLevel(l LogLevel) {
s.level = l s.level = l
return return
} }
// ShowSQL implement core.ILogger // ShowSQL implement ILogger
func (s *SimpleLogger) ShowSQL(show ...bool) { func (s *SimpleLogger) ShowSQL(show ...bool) {
if len(show) == 0 { if len(show) == 0 {
s.showSQL = true s.showSQL = true
@ -181,7 +211,7 @@ func (s *SimpleLogger) ShowSQL(show ...bool) {
s.showSQL = show[0] s.showSQL = show[0]
} }
// IsShowSQL implement core.ILogger // IsShowSQL implement ILogger
func (s *SimpleLogger) IsShowSQL() bool { func (s *SimpleLogger) IsShowSQL() bool {
return s.showSQL return s.showSQL
} }

View File

@ -4,16 +4,14 @@
// +build !windows,!nacl,!plan9 // +build !windows,!nacl,!plan9
package xorm package log
import ( import (
"fmt" "fmt"
"log/syslog" "log/syslog"
"xorm.io/core"
) )
var _ core.ILogger = &SyslogLogger{} var _ Logger = &SyslogLogger{}
// SyslogLogger will be depricated // SyslogLogger will be depricated
type SyslogLogger struct { type SyslogLogger struct {
@ -21,7 +19,7 @@ type SyslogLogger struct {
showSQL bool showSQL bool
} }
// NewSyslogLogger implements core.ILogger // NewSyslogLogger implements Logger
func NewSyslogLogger(w *syslog.Writer) *SyslogLogger { func NewSyslogLogger(w *syslog.Writer) *SyslogLogger {
return &SyslogLogger{w: w} return &SyslogLogger{w: w}
} }
@ -67,12 +65,12 @@ func (s *SyslogLogger) Warnf(format string, v ...interface{}) {
} }
// Level shows log level // Level shows log level
func (s *SyslogLogger) Level() core.LogLevel { func (s *SyslogLogger) Level() LogLevel {
return core.LOG_UNKNOWN return LOG_UNKNOWN
} }
// SetLevel always return error, as current log/syslog package doesn't allow to set priority level after syslog.Writer created // SetLevel always return error, as current log/syslog package doesn't allow to set priority level after syslog.Writer created
func (s *SyslogLogger) SetLevel(l core.LogLevel) {} func (s *SyslogLogger) SetLevel(l LogLevel) {}
// ShowSQL set if logging SQL // ShowSQL set if logging SQL
func (s *SyslogLogger) ShowSQL(show ...bool) { func (s *SyslogLogger) ShowSQL(show ...bool) {

View File

@ -13,7 +13,7 @@ type MigrateFunc func(*xorm.Engine) error
// RollbackFunc is the func signature for rollbacking. // RollbackFunc is the func signature for rollbacking.
type RollbackFunc func(*xorm.Engine) error type RollbackFunc func(*xorm.Engine) error
// InitSchemaFunc is the func signature for initializing the schema. // InitSchemaFunc is the func signature for initializing the schemas.
type InitSchemaFunc func(*xorm.Engine) error type InitSchemaFunc func(*xorm.Engine) error
// Options define options for all migrations. // Options define options for all migrations.
@ -34,7 +34,7 @@ type Migration struct {
Rollback RollbackFunc Rollback RollbackFunc
} }
// Migrate represents a collection of all migrations of a database schema. // Migrate represents a collection of all migrations of a database schemas.
type Migrate struct { type Migrate struct {
db *xorm.Engine db *xorm.Engine
options *Options options *Options

258
names/mapper.go Normal file
View File

@ -0,0 +1,258 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package names
import (
"strings"
"sync"
)
// Mapper represents a name convertation between struct's fields name and table's column name
type Mapper interface {
Obj2Table(string) string
Table2Obj(string) string
}
type CacheMapper struct {
oriMapper Mapper
obj2tableCache map[string]string
obj2tableMutex sync.RWMutex
table2objCache map[string]string
table2objMutex sync.RWMutex
}
func NewCacheMapper(mapper Mapper) *CacheMapper {
return &CacheMapper{oriMapper: mapper, obj2tableCache: make(map[string]string),
table2objCache: make(map[string]string),
}
}
func (m *CacheMapper) Obj2Table(o string) string {
m.obj2tableMutex.RLock()
t, ok := m.obj2tableCache[o]
m.obj2tableMutex.RUnlock()
if ok {
return t
}
t = m.oriMapper.Obj2Table(o)
m.obj2tableMutex.Lock()
m.obj2tableCache[o] = t
m.obj2tableMutex.Unlock()
return t
}
func (m *CacheMapper) Table2Obj(t string) string {
m.table2objMutex.RLock()
o, ok := m.table2objCache[t]
m.table2objMutex.RUnlock()
if ok {
return o
}
o = m.oriMapper.Table2Obj(t)
m.table2objMutex.Lock()
m.table2objCache[t] = o
m.table2objMutex.Unlock()
return o
}
// SameMapper implements IMapper and provides same name between struct and
// database table
type SameMapper struct {
}
func (m SameMapper) Obj2Table(o string) string {
return o
}
func (m SameMapper) Table2Obj(t string) string {
return t
}
// SnakeMapper implements IMapper and provides name transaltion between
// struct and database table
type SnakeMapper struct {
}
func snakeCasedName(name string) string {
newstr := make([]rune, 0)
for idx, chr := range name {
if isUpper := 'A' <= chr && chr <= 'Z'; isUpper {
if idx > 0 {
newstr = append(newstr, '_')
}
chr -= ('A' - 'a')
}
newstr = append(newstr, chr)
}
return string(newstr)
}
func (mapper SnakeMapper) Obj2Table(name string) string {
return snakeCasedName(name)
}
func titleCasedName(name string) string {
newstr := make([]rune, 0)
upNextChar := true
name = strings.ToLower(name)
for _, chr := range name {
switch {
case upNextChar:
upNextChar = false
if 'a' <= chr && chr <= 'z' {
chr -= ('a' - 'A')
}
case chr == '_':
upNextChar = true
continue
}
newstr = append(newstr, chr)
}
return string(newstr)
}
func (mapper SnakeMapper) Table2Obj(name string) string {
return titleCasedName(name)
}
// GonicMapper implements IMapper. It will consider initialisms when mapping names.
// E.g. id -> ID, user -> User and to table names: UserID -> user_id, MyUID -> my_uid
type GonicMapper map[string]bool
func isASCIIUpper(r rune) bool {
return 'A' <= r && r <= 'Z'
}
func toASCIIUpper(r rune) rune {
if 'a' <= r && r <= 'z' {
r -= ('a' - 'A')
}
return r
}
func gonicCasedName(name string) string {
newstr := make([]rune, 0, len(name)+3)
for idx, chr := range name {
if isASCIIUpper(chr) && idx > 0 {
if !isASCIIUpper(newstr[len(newstr)-1]) {
newstr = append(newstr, '_')
}
}
if !isASCIIUpper(chr) && idx > 1 {
l := len(newstr)
if isASCIIUpper(newstr[l-1]) && isASCIIUpper(newstr[l-2]) {
newstr = append(newstr, newstr[l-1])
newstr[l-1] = '_'
}
}
newstr = append(newstr, chr)
}
return strings.ToLower(string(newstr))
}
func (mapper GonicMapper) Obj2Table(name string) string {
return gonicCasedName(name)
}
func (mapper GonicMapper) Table2Obj(name string) string {
newstr := make([]rune, 0)
name = strings.ToLower(name)
parts := strings.Split(name, "_")
for _, p := range parts {
_, isInitialism := mapper[strings.ToUpper(p)]
for i, r := range p {
if i == 0 || isInitialism {
r = toASCIIUpper(r)
}
newstr = append(newstr, r)
}
}
return string(newstr)
}
// LintGonicMapper is A GonicMapper that contains a list of common initialisms taken from golang/lint
var LintGonicMapper = GonicMapper{
"API": true,
"ASCII": true,
"CPU": true,
"CSS": true,
"DNS": true,
"EOF": true,
"GUID": true,
"HTML": true,
"HTTP": true,
"HTTPS": true,
"ID": true,
"IP": true,
"JSON": true,
"LHS": true,
"QPS": true,
"RAM": true,
"RHS": true,
"RPC": true,
"SLA": true,
"SMTP": true,
"SSH": true,
"TLS": true,
"TTL": true,
"UI": true,
"UID": true,
"UUID": true,
"URI": true,
"URL": true,
"UTF8": true,
"VM": true,
"XML": true,
"XSRF": true,
"XSS": true,
}
// PrefixMapper provides prefix table name support
type PrefixMapper struct {
Mapper Mapper
Prefix string
}
func (mapper PrefixMapper) Obj2Table(name string) string {
return mapper.Prefix + mapper.Mapper.Obj2Table(name)
}
func (mapper PrefixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):])
}
func NewPrefixMapper(mapper Mapper, prefix string) PrefixMapper {
return PrefixMapper{mapper, prefix}
}
// SuffixMapper provides suffix table name support
type SuffixMapper struct {
Mapper Mapper
Suffix string
}
func (mapper SuffixMapper) Obj2Table(name string) string {
return mapper.Mapper.Obj2Table(name) + mapper.Suffix
}
func (mapper SuffixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)])
}
func NewSuffixMapper(mapper Mapper, suffix string) SuffixMapper {
return SuffixMapper{mapper, suffix}
}

49
names/mapper_test.go Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package names
import (
"testing"
)
func TestGonicMapperFromObj(t *testing.T) {
testCases := map[string]string{
"HTTPLib": "http_lib",
"id": "id",
"ID": "id",
"IDa": "i_da",
"iDa": "i_da",
"IDAa": "id_aa",
"aID": "a_id",
"aaID": "aa_id",
"aaaID": "aaa_id",
"MyREalFunkYLONgNAME": "my_r_eal_funk_ylo_ng_name",
}
for in, expected := range testCases {
out := gonicCasedName(in)
if out != expected {
t.Errorf("Given %s, expected %s but got %s", in, expected, out)
}
}
}
func TestGonicMapperToObj(t *testing.T) {
testCases := map[string]string{
"http_lib": "HTTPLib",
"id": "ID",
"ida": "Ida",
"id_aa": "IDAa",
"aa_id": "AaID",
"my_r_eal_funk_ylo_ng_name": "MyREalFunkYloNgName",
}
for in, expected := range testCases {
out := LintGonicMapper.Table2Obj(in)
if out != expected {
t.Errorf("Given %s, expected %s but got %s", in, expected, out)
}
}
}

View File

@ -2,15 +2,22 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package names
import ( import (
"reflect" "reflect"
"xorm.io/core"
) )
func getTableName(mapper core.IMapper, v reflect.Value) string { // TableName table name interface to define customerize table name
type TableName interface {
TableName() string
}
var (
tpTableName = reflect.TypeOf((*TableName)(nil)).Elem()
)
func GetTableName(mapper Mapper, v reflect.Value) string {
if t, ok := v.Interface().(TableName); ok { if t, ok := v.Interface().(TableName); ok {
return t.TableName() return t.TableName()
} }

View File

@ -2,17 +2,45 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package names
import ( import (
"reflect" "reflect"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core"
) )
type Userinfo struct {
Uid int64 `xorm:"id pk not null autoincr"`
Username string `xorm:"unique"`
Departname string
Alias string `xorm:"-"`
Created time.Time
Detail Userdetail `xorm:"detail_id int(11)"`
Height float64
Avatar []byte
IsMan bool
}
type Userdetail struct {
Id int64
Intro string `xorm:"text"`
Profile string `xorm:"varchar(2000)"`
}
type MyGetCustomTableImpletation struct {
Id int64 `json:"id"`
Name string `json:"name"`
}
const getCustomTableName = "GetCustomTableInterface"
func (MyGetCustomTableImpletation) TableName() string {
return getCustomTableName
}
type TestTableNameStruct struct{} type TestTableNameStruct struct{}
func (t *TestTableNameStruct) TableName() string { func (t *TestTableNameStruct) TableName() string {
@ -21,53 +49,53 @@ func (t *TestTableNameStruct) TableName() string {
func TestGetTableName(t *testing.T) { func TestGetTableName(t *testing.T) {
var kases = []struct { var kases = []struct {
mapper core.IMapper mapper Mapper
v reflect.Value v reflect.Value
expectedTableName string expectedTableName string
}{ }{
{ {
core.SnakeMapper{}, SnakeMapper{},
reflect.ValueOf(new(Userinfo)), reflect.ValueOf(new(Userinfo)),
"userinfo", "userinfo",
}, },
{ {
core.SnakeMapper{}, SnakeMapper{},
reflect.ValueOf(Userinfo{}), reflect.ValueOf(Userinfo{}),
"userinfo", "userinfo",
}, },
{ {
core.SameMapper{}, SameMapper{},
reflect.ValueOf(new(Userinfo)), reflect.ValueOf(new(Userinfo)),
"Userinfo", "Userinfo",
}, },
{ {
core.SameMapper{}, SameMapper{},
reflect.ValueOf(Userinfo{}), reflect.ValueOf(Userinfo{}),
"Userinfo", "Userinfo",
}, },
{ {
core.SnakeMapper{}, SnakeMapper{},
reflect.ValueOf(new(MyGetCustomTableImpletation)), reflect.ValueOf(new(MyGetCustomTableImpletation)),
getCustomTableName, getCustomTableName,
}, },
{ {
core.SnakeMapper{}, SnakeMapper{},
reflect.ValueOf(MyGetCustomTableImpletation{}), reflect.ValueOf(MyGetCustomTableImpletation{}),
getCustomTableName, getCustomTableName,
}, },
{ {
core.SnakeMapper{}, SnakeMapper{},
reflect.ValueOf(MyGetCustomTableImpletation{}), reflect.ValueOf(MyGetCustomTableImpletation{}),
getCustomTableName, getCustomTableName,
}, },
{ {
core.SnakeMapper{}, SnakeMapper{},
reflect.ValueOf(new(TestTableNameStruct)), reflect.ValueOf(new(TestTableNameStruct)),
new(TestTableNameStruct).TableName(), new(TestTableNameStruct).TableName(),
}, },
} }
for _, kase := range kases { for _, kase := range kases {
assert.EqualValues(t, kase.expectedTableName, getTableName(kase.mapper, kase.v)) assert.EqualValues(t, kase.expectedTableName, GetTableName(kase.mapper, kase.v))
} }
} }

View File

@ -9,7 +9,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"xorm.io/core" "xorm.io/xorm/core"
) )
// Rows rows wrapper a rows to // Rows rows wrapper a rows to

117
schemas/column.go Normal file
View File

@ -0,0 +1,117 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"fmt"
"reflect"
"strings"
"time"
)
const (
TWOSIDES = iota + 1
ONLYTODB
ONLYFROMDB
)
// Column defines database column
type Column struct {
Name string
TableName string
FieldName string
SQLType SQLType
IsJSON bool
Length int
Length2 int
Nullable bool
Default string
Indexes map[string]int
IsPrimaryKey bool
IsAutoIncrement bool
MapType int
IsCreated bool
IsUpdated bool
IsDeleted bool
IsCascade bool
IsVersion bool
DefaultIsEmpty bool // false means column has no default set, but not default value is empty
EnumOptions map[string]int
SetOptions map[string]int
DisableTimeZone bool
TimeZone *time.Location // column specified time zone
Comment string
}
// NewColumn creates a new column
func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable bool) *Column {
return &Column{
Name: name,
TableName: "",
FieldName: fieldName,
SQLType: sqlType,
Length: len1,
Length2: len2,
Nullable: nullable,
Default: "",
Indexes: make(map[string]int),
IsPrimaryKey: false,
IsAutoIncrement: false,
MapType: TWOSIDES,
IsCreated: false,
IsUpdated: false,
IsDeleted: false,
IsCascade: false,
IsVersion: false,
DefaultIsEmpty: true, // default should be no default
EnumOptions: make(map[string]int),
Comment: "",
}
}
// ValueOf returns column's filed of struct's value
func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) {
dataStruct := reflect.Indirect(reflect.ValueOf(bean))
return col.ValueOfV(&dataStruct)
}
// ValueOfV returns column's filed of struct's value accept reflevt value
func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
var fieldValue reflect.Value
fieldPath := strings.Split(col.FieldName, ".")
if dataStruct.Type().Kind() == reflect.Map {
keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1])
fieldValue = dataStruct.MapIndex(keyValue)
return &fieldValue, nil
} else if dataStruct.Type().Kind() == reflect.Interface {
structValue := reflect.ValueOf(dataStruct.Interface())
dataStruct = &structValue
}
level := len(fieldPath)
fieldValue = dataStruct.FieldByName(fieldPath[0])
for i := 0; i < level-1; i++ {
if !fieldValue.IsValid() {
break
}
if fieldValue.Kind() == reflect.Struct {
fieldValue = fieldValue.FieldByName(fieldPath[i+1])
} else if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
}
fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1])
} else {
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
}
}
if !fieldValue.IsValid() {
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
}
return &fieldValue, nil
}

72
schemas/index.go Normal file
View File

@ -0,0 +1,72 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"fmt"
"strings"
)
// enumerate all index types
const (
IndexType = iota + 1
UniqueType
)
// Index represents a database index
type Index struct {
IsRegular bool
Name string
Type int
Cols []string
}
func (index *Index) XName(tableName string) string {
if !strings.HasPrefix(index.Name, "UQE_") &&
!strings.HasPrefix(index.Name, "IDX_") {
tableParts := strings.Split(strings.Replace(tableName, `"`, "", -1), ".")
tableName = tableParts[len(tableParts)-1]
if index.Type == UniqueType {
return fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
}
return fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
}
return index.Name
}
// AddColumn add columns which will be composite index
func (index *Index) AddColumn(cols ...string) {
for _, col := range cols {
index.Cols = append(index.Cols, col)
}
}
func (index *Index) Equal(dst *Index) bool {
if index.Type != dst.Type {
return false
}
if len(index.Cols) != len(dst.Cols) {
return false
}
for i := 0; i < len(index.Cols); i++ {
var found bool
for j := 0; j < len(dst.Cols); j++ {
if index.Cols[i] == dst.Cols[j] {
found = true
break
}
}
if !found {
return false
}
}
return true
}
// NewIndex new an index object
func NewIndex(name string, indexType int) *Index {
return &Index{true, name, indexType, make([]string, 0)}
}

30
schemas/pk.go Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"bytes"
"encoding/gob"
)
type PK []interface{}
func NewPK(pks ...interface{}) *PK {
p := PK(pks)
return &p
}
func (p *PK) ToString() (string, error) {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
err := enc.Encode(*p)
return buf.String(), err
}
func (p *PK) FromString(content string) error {
dec := gob.NewDecoder(bytes.NewBufferString(content))
err := dec.Decode(p)
return err
}

36
schemas/pk_test.go Normal file
View File

@ -0,0 +1,36 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"reflect"
"testing"
)
func TestPK(t *testing.T) {
p := NewPK(1, 3, "string")
str, err := p.ToString()
if err != nil {
t.Error(err)
}
t.Log(str)
s := &PK{}
err = s.FromString(str)
if err != nil {
t.Error(err)
}
t.Log(s)
if len(*p) != len(*s) {
t.Fatal("p", *p, "should be equal", *s)
}
for i, ori := range *p {
if ori != (*s)[i] {
t.Fatal("ori", ori, reflect.ValueOf(ori), "should be equal", (*s)[i], reflect.ValueOf((*s)[i]))
}
}
}

156
schemas/table.go Normal file
View File

@ -0,0 +1,156 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"reflect"
"strings"
//"xorm.io/xorm/cache"
)
// Table represents a database table
type Table struct {
Name string
Type reflect.Type
columnsSeq []string
columnsMap map[string][]*Column
columns []*Column
Indexes map[string]*Index
PrimaryKeys []string
AutoIncrement string
Created map[string]bool
Updated string
Deleted string
Version string
//Cacher caches.Cacher
StoreEngine string
Charset string
Comment string
}
func (table *Table) Columns() []*Column {
return table.columns
}
func (table *Table) ColumnsSeq() []string {
return table.columnsSeq
}
func NewEmptyTable() *Table {
return NewTable("", nil)
}
// NewTable creates a new Table object
func NewTable(name string, t reflect.Type) *Table {
return &Table{Name: name, Type: t,
columnsSeq: make([]string, 0),
columns: make([]*Column, 0),
columnsMap: make(map[string][]*Column),
Indexes: make(map[string]*Index),
Created: make(map[string]bool),
PrimaryKeys: make([]string, 0),
}
}
func (table *Table) columnsByName(name string) []*Column {
n := len(name)
for k := range table.columnsMap {
if len(k) != n {
continue
}
if strings.EqualFold(k, name) {
return table.columnsMap[k]
}
}
return nil
}
func (table *Table) GetColumn(name string) *Column {
cols := table.columnsByName(name)
if cols != nil {
return cols[0]
}
return nil
}
func (table *Table) GetColumnIdx(name string, idx int) *Column {
cols := table.columnsByName(name)
if cols != nil && idx < len(cols) {
return cols[idx]
}
return nil
}
// PKColumns reprents all primary key columns
func (table *Table) PKColumns() []*Column {
columns := make([]*Column, len(table.PrimaryKeys))
for i, name := range table.PrimaryKeys {
columns[i] = table.GetColumn(name)
}
return columns
}
func (table *Table) ColumnType(name string) reflect.Type {
t, _ := table.Type.FieldByName(name)
return t.Type
}
func (table *Table) AutoIncrColumn() *Column {
return table.GetColumn(table.AutoIncrement)
}
func (table *Table) VersionColumn() *Column {
return table.GetColumn(table.Version)
}
func (table *Table) UpdatedColumn() *Column {
return table.GetColumn(table.Updated)
}
func (table *Table) DeletedColumn() *Column {
return table.GetColumn(table.Deleted)
}
// AddColumn adds a column to table
func (table *Table) AddColumn(col *Column) {
table.columnsSeq = append(table.columnsSeq, col.Name)
table.columns = append(table.columns, col)
colName := strings.ToLower(col.Name)
if c, ok := table.columnsMap[colName]; ok {
table.columnsMap[colName] = append(c, col)
} else {
table.columnsMap[colName] = []*Column{col}
}
if col.IsPrimaryKey {
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
}
if col.IsAutoIncrement {
table.AutoIncrement = col.Name
}
if col.IsCreated {
table.Created[col.Name] = true
}
if col.IsUpdated {
table.Updated = col.Name
}
if col.IsDeleted {
table.Deleted = col.Name
}
if col.IsVersion {
table.Version = col.Name
}
}
// AddIndex adds an index or an unique to table
func (table *Table) AddIndex(index *Index) {
table.Indexes[index.Name] = index
}

111
schemas/table_test.go Normal file
View File

@ -0,0 +1,111 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"strings"
"testing"
)
var testsGetColumn = []struct {
name string
idx int
}{
{"Id", 0},
{"Deleted", 0},
{"Caption", 0},
{"Code_1", 0},
{"Code_2", 0},
{"Code_3", 0},
{"Parent_Id", 0},
{"Latitude", 0},
{"Longitude", 0},
}
var table *Table
func init() {
table = NewEmptyTable()
var name string
for i := 0; i < len(testsGetColumn); i++ {
// as in Table.AddColumn func
name = strings.ToLower(testsGetColumn[i].name)
table.columnsMap[name] = append(table.columnsMap[name], &Column{})
}
}
func TestGetColumn(t *testing.T) {
for _, test := range testsGetColumn {
if table.GetColumn(test.name) == nil {
t.Error("Column not found!")
}
}
}
func TestGetColumnIdx(t *testing.T) {
for _, test := range testsGetColumn {
if table.GetColumnIdx(test.name, test.idx) == nil {
t.Errorf("Column %s with idx %d not found!", test.name, test.idx)
}
}
}
func BenchmarkGetColumnWithToLower(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, test := range testsGetColumn {
if _, ok := table.columnsMap[strings.ToLower(test.name)]; !ok {
b.Errorf("Column not found:%s", test.name)
}
}
}
}
func BenchmarkGetColumnIdxWithToLower(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, test := range testsGetColumn {
if c, ok := table.columnsMap[strings.ToLower(test.name)]; ok {
if test.idx < len(c) {
continue
} else {
b.Errorf("Bad idx in: %s, %d", test.name, test.idx)
}
} else {
b.Errorf("Column not found: %s, %d", test.name, test.idx)
}
}
}
}
func BenchmarkGetColumn(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, test := range testsGetColumn {
if table.GetColumn(test.name) == nil {
b.Errorf("Column not found:%s", test.name)
}
}
}
}
func BenchmarkGetColumnIdx(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, test := range testsGetColumn {
if table.GetColumnIdx(test.name, test.idx) == nil {
b.Errorf("Column not found:%s, %d", test.name, test.idx)
}
}
}
}

325
schemas/type.go Normal file
View File

@ -0,0 +1,325 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"reflect"
"sort"
"strings"
"time"
)
const (
POSTGRES = "postgres"
SQLITE = "sqlite3"
MYSQL = "mysql"
MSSQL = "mssql"
ORACLE = "oracle"
)
// SQLType represents SQL types
type SQLType struct {
Name string
DefaultLength int
DefaultLength2 int
}
const (
UNKNOW_TYPE = iota
TEXT_TYPE
BLOB_TYPE
TIME_TYPE
NUMERIC_TYPE
)
func (s *SQLType) IsType(st int) bool {
if t, ok := SqlTypes[s.Name]; ok && t == st {
return true
}
return false
}
func (s *SQLType) IsText() bool {
return s.IsType(TEXT_TYPE)
}
func (s *SQLType) IsBlob() bool {
return s.IsType(BLOB_TYPE)
}
func (s *SQLType) IsTime() bool {
return s.IsType(TIME_TYPE)
}
func (s *SQLType) IsNumeric() bool {
return s.IsType(NUMERIC_TYPE)
}
func (s *SQLType) IsJson() bool {
return s.Name == Json || s.Name == Jsonb
}
var (
Bit = "BIT"
TinyInt = "TINYINT"
SmallInt = "SMALLINT"
MediumInt = "MEDIUMINT"
Int = "INT"
Integer = "INTEGER"
BigInt = "BIGINT"
Enum = "ENUM"
Set = "SET"
Char = "CHAR"
Varchar = "VARCHAR"
NChar = "NCHAR"
NVarchar = "NVARCHAR"
TinyText = "TINYTEXT"
Text = "TEXT"
NText = "NTEXT"
Clob = "CLOB"
MediumText = "MEDIUMTEXT"
LongText = "LONGTEXT"
Uuid = "UUID"
UniqueIdentifier = "UNIQUEIDENTIFIER"
SysName = "SYSNAME"
Date = "DATE"
DateTime = "DATETIME"
SmallDateTime = "SMALLDATETIME"
Time = "TIME"
TimeStamp = "TIMESTAMP"
TimeStampz = "TIMESTAMPZ"
Year = "YEAR"
Decimal = "DECIMAL"
Numeric = "NUMERIC"
Money = "MONEY"
SmallMoney = "SMALLMONEY"
Real = "REAL"
Float = "FLOAT"
Double = "DOUBLE"
Binary = "BINARY"
VarBinary = "VARBINARY"
TinyBlob = "TINYBLOB"
Blob = "BLOB"
MediumBlob = "MEDIUMBLOB"
LongBlob = "LONGBLOB"
Bytea = "BYTEA"
Bool = "BOOL"
Boolean = "BOOLEAN"
Serial = "SERIAL"
BigSerial = "BIGSERIAL"
Json = "JSON"
Jsonb = "JSONB"
SqlTypes = map[string]int{
Bit: NUMERIC_TYPE,
TinyInt: NUMERIC_TYPE,
SmallInt: NUMERIC_TYPE,
MediumInt: NUMERIC_TYPE,
Int: NUMERIC_TYPE,
Integer: NUMERIC_TYPE,
BigInt: NUMERIC_TYPE,
Enum: TEXT_TYPE,
Set: TEXT_TYPE,
Json: TEXT_TYPE,
Jsonb: TEXT_TYPE,
Char: TEXT_TYPE,
NChar: TEXT_TYPE,
Varchar: TEXT_TYPE,
NVarchar: TEXT_TYPE,
TinyText: TEXT_TYPE,
Text: TEXT_TYPE,
NText: TEXT_TYPE,
MediumText: TEXT_TYPE,
LongText: TEXT_TYPE,
Uuid: TEXT_TYPE,
Clob: TEXT_TYPE,
SysName: TEXT_TYPE,
Date: TIME_TYPE,
DateTime: TIME_TYPE,
Time: TIME_TYPE,
TimeStamp: TIME_TYPE,
TimeStampz: TIME_TYPE,
SmallDateTime: TIME_TYPE,
Year: TIME_TYPE,
Decimal: NUMERIC_TYPE,
Numeric: NUMERIC_TYPE,
Real: NUMERIC_TYPE,
Float: NUMERIC_TYPE,
Double: NUMERIC_TYPE,
Money: NUMERIC_TYPE,
SmallMoney: NUMERIC_TYPE,
Binary: BLOB_TYPE,
VarBinary: BLOB_TYPE,
TinyBlob: BLOB_TYPE,
Blob: BLOB_TYPE,
MediumBlob: BLOB_TYPE,
LongBlob: BLOB_TYPE,
Bytea: BLOB_TYPE,
UniqueIdentifier: BLOB_TYPE,
Bool: NUMERIC_TYPE,
Serial: NUMERIC_TYPE,
BigSerial: NUMERIC_TYPE,
}
intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"}
uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"}
)
// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison
var (
c_EMPTY_STRING string
c_BOOL_DEFAULT bool
c_BYTE_DEFAULT byte
c_COMPLEX64_DEFAULT complex64
c_COMPLEX128_DEFAULT complex128
c_FLOAT32_DEFAULT float32
c_FLOAT64_DEFAULT float64
c_INT64_DEFAULT int64
c_UINT64_DEFAULT uint64
c_INT32_DEFAULT int32
c_UINT32_DEFAULT uint32
c_INT16_DEFAULT int16
c_UINT16_DEFAULT uint16
c_INT8_DEFAULT int8
c_UINT8_DEFAULT uint8
c_INT_DEFAULT int
c_UINT_DEFAULT uint
c_TIME_DEFAULT time.Time
)
var (
IntType = reflect.TypeOf(c_INT_DEFAULT)
Int8Type = reflect.TypeOf(c_INT8_DEFAULT)
Int16Type = reflect.TypeOf(c_INT16_DEFAULT)
Int32Type = reflect.TypeOf(c_INT32_DEFAULT)
Int64Type = reflect.TypeOf(c_INT64_DEFAULT)
UintType = reflect.TypeOf(c_UINT_DEFAULT)
Uint8Type = reflect.TypeOf(c_UINT8_DEFAULT)
Uint16Type = reflect.TypeOf(c_UINT16_DEFAULT)
Uint32Type = reflect.TypeOf(c_UINT32_DEFAULT)
Uint64Type = reflect.TypeOf(c_UINT64_DEFAULT)
Float32Type = reflect.TypeOf(c_FLOAT32_DEFAULT)
Float64Type = reflect.TypeOf(c_FLOAT64_DEFAULT)
Complex64Type = reflect.TypeOf(c_COMPLEX64_DEFAULT)
Complex128Type = reflect.TypeOf(c_COMPLEX128_DEFAULT)
StringType = reflect.TypeOf(c_EMPTY_STRING)
BoolType = reflect.TypeOf(c_BOOL_DEFAULT)
ByteType = reflect.TypeOf(c_BYTE_DEFAULT)
BytesType = reflect.SliceOf(ByteType)
TimeType = reflect.TypeOf(c_TIME_DEFAULT)
)
var (
PtrIntType = reflect.PtrTo(IntType)
PtrInt8Type = reflect.PtrTo(Int8Type)
PtrInt16Type = reflect.PtrTo(Int16Type)
PtrInt32Type = reflect.PtrTo(Int32Type)
PtrInt64Type = reflect.PtrTo(Int64Type)
PtrUintType = reflect.PtrTo(UintType)
PtrUint8Type = reflect.PtrTo(Uint8Type)
PtrUint16Type = reflect.PtrTo(Uint16Type)
PtrUint32Type = reflect.PtrTo(Uint32Type)
PtrUint64Type = reflect.PtrTo(Uint64Type)
PtrFloat32Type = reflect.PtrTo(Float32Type)
PtrFloat64Type = reflect.PtrTo(Float64Type)
PtrComplex64Type = reflect.PtrTo(Complex64Type)
PtrComplex128Type = reflect.PtrTo(Complex128Type)
PtrStringType = reflect.PtrTo(StringType)
PtrBoolType = reflect.PtrTo(BoolType)
PtrByteType = reflect.PtrTo(ByteType)
PtrTimeType = reflect.PtrTo(TimeType)
)
// Type2SQLType generate SQLType acorrding Go's type
func Type2SQLType(t reflect.Type) (st SQLType) {
switch k := t.Kind(); k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
st = SQLType{Int, 0, 0}
case reflect.Int64, reflect.Uint64:
st = SQLType{BigInt, 0, 0}
case reflect.Float32:
st = SQLType{Float, 0, 0}
case reflect.Float64:
st = SQLType{Double, 0, 0}
case reflect.Complex64, reflect.Complex128:
st = SQLType{Varchar, 64, 0}
case reflect.Array, reflect.Slice, reflect.Map:
if t.Elem() == reflect.TypeOf(c_BYTE_DEFAULT) {
st = SQLType{Blob, 0, 0}
} else {
st = SQLType{Text, 0, 0}
}
case reflect.Bool:
st = SQLType{Bool, 0, 0}
case reflect.String:
st = SQLType{Varchar, 255, 0}
case reflect.Struct:
if t.ConvertibleTo(TimeType) {
st = SQLType{DateTime, 0, 0}
} else {
// TODO need to handle association struct
st = SQLType{Text, 0, 0}
}
case reflect.Ptr:
st = Type2SQLType(t.Elem())
default:
st = SQLType{Text, 0, 0}
}
return
}
// default sql type change to go types
func SQLType2Type(st SQLType) reflect.Type {
name := strings.ToUpper(st.Name)
switch name {
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial:
return reflect.TypeOf(1)
case BigInt, BigSerial:
return reflect.TypeOf(int64(1))
case Float, Real:
return reflect.TypeOf(float32(1))
case Double:
return reflect.TypeOf(float64(1))
case Char, NChar, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName:
return reflect.TypeOf("")
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier:
return reflect.TypeOf([]byte{})
case Bool:
return reflect.TypeOf(true)
case DateTime, Date, Time, TimeStamp, TimeStampz, SmallDateTime, Year:
return reflect.TypeOf(c_TIME_DEFAULT)
case Decimal, Numeric, Money, SmallMoney:
return reflect.TypeOf("")
default:
return reflect.TypeOf("")
}
}

View File

@ -14,7 +14,8 @@ import (
"strings" "strings"
"time" "time"
"xorm.io/core" "xorm.io/xorm/core"
"xorm.io/xorm/schemas"
) )
type sessionType int type sessionType int
@ -306,8 +307,8 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt,
return return
} }
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) (*reflect.Value, error) { func (session *Session) getField(dataStruct *reflect.Value, key string, table *schemas.Table, idx int) (*reflect.Value, error) {
var col *core.Column var col *schemas.Column
if col = table.GetColumnIdx(key, idx); col == nil { if col = table.GetColumnIdx(key, idx); col == nil {
return nil, ErrFieldIsNotExist{key, table.Name} return nil, ErrFieldIsNotExist{key, table.Name}
} }
@ -328,8 +329,8 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *c
type Cell *interface{} type Cell *interface{}
func (session *Session) rows2Beans(rows *core.Rows, fields []string, func (session *Session) rows2Beans(rows *core.Rows, fields []string,
table *core.Table, newElemFunc func([]string) reflect.Value, table *schemas.Table, newElemFunc func([]string) reflect.Value,
sliceValueSetFunc func(*reflect.Value, core.PK) error) error { sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error {
for rows.Next() { for rows.Next() {
var newValue = newElemFunc(fields) var newValue = newElemFunc(fields)
bean := newValue.Interface() bean := newValue.Interface()
@ -377,7 +378,7 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa
return scanResults, nil return scanResults, nil
} }
func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) { func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) {
defer func() { defer func() {
if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet {
for ii, key := range fields { for ii, key := range fields {
@ -421,7 +422,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
} }
var tempMap = make(map[string]int) var tempMap = make(map[string]int)
var pk core.PK var pk schemas.PK
for ii, key := range fields { for ii, key := range fields {
var idx int var idx int
var ok bool var ok bool
@ -451,7 +452,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
} }
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
if data, err := value2Bytes(&rawValue); err == nil { if data, err := value2Bytes(&rawValue); err == nil {
if err := structConvert.FromDB(data); err != nil { if err := structConvert.FromDB(data); err != nil {
return nil, err return nil, err
@ -463,12 +464,12 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
} }
} }
if _, ok := fieldValue.Interface().(core.Conversion); ok { if _, ok := fieldValue.Interface().(Conversion); ok {
if data, err := value2Bytes(&rawValue); err == nil { if data, err := value2Bytes(&rawValue); err == nil {
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem())) fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
} }
fieldValue.Interface().(core.Conversion).FromDB(data) fieldValue.Interface().(Conversion).FromDB(data)
} else { } else {
return nil, err return nil, err
} }
@ -488,7 +489,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
var bs []byte var bs []byte
if rawValueType.Kind() == reflect.String { if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String()) bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(core.BytesType) { } else if rawValueType.ConvertibleTo(schemas.BytesType) {
bs = vv.Bytes() bs = vv.Bytes()
} else { } else {
return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind())
@ -525,7 +526,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
var bs []byte var bs []byte
if rawValueType.Kind() == reflect.String { if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String()) bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(core.BytesType) { } else if rawValueType.ConvertibleTo(schemas.BytesType) {
bs = vv.Bytes() bs = vv.Bytes()
} }
@ -607,16 +608,16 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
fieldValue.SetUint(uint64(vv.Int())) fieldValue.SetUint(uint64(vv.Int()))
} }
case reflect.Struct: case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) { if fieldType.ConvertibleTo(schemas.TimeType) {
dbTZ := session.engine.DatabaseTZ dbTZ := session.engine.DatabaseTZ
if col.TimeZone != nil { if col.TimeZone != nil {
dbTZ = col.TimeZone dbTZ = col.TimeZone
} }
if rawValueType == core.TimeType { if rawValueType == schemas.TimeType {
hasAssigned = true hasAssigned = true
t := vv.Convert(core.TimeType).Interface().(time.Time) t := vv.Convert(schemas.TimeType).Interface().(time.Time)
z, _ := t.Zone() z, _ := t.Zone()
// set new location if database don't save timezone or give an incorrect timezone // set new location if database don't save timezone or give an incorrect timezone
@ -628,8 +629,8 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
t = t.In(session.engine.TZLocation) t = t.In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else if rawValueType == core.IntType || rawValueType == core.Int64Type || } else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type ||
rawValueType == core.Int32Type { rawValueType == schemas.Int32Type {
hasAssigned = true hasAssigned = true
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
@ -696,7 +697,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
if len(table.PrimaryKeys) != 1 { if len(table.PrimaryKeys) != 1 {
return nil, errors.New("unsupported non or composited primary key cascade") return nil, errors.New("unsupported non or composited primary key cascade")
} }
var pk = make(core.PK, len(table.PrimaryKeys)) var pk = make(schemas.PK, len(table.PrimaryKeys))
pk[0], err = asKind(vv, rawValueType) pk[0], err = asKind(vv, rawValueType)
if err != nil { if err != nil {
return nil, err return nil, err
@ -722,97 +723,97 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
// !nashtsai! TODO merge duplicated codes above // !nashtsai! TODO merge duplicated codes above
switch fieldType { switch fieldType {
// following types case matching ptr's native type, therefore assign ptr directly // following types case matching ptr's native type, therefore assign ptr directly
case core.PtrStringType: case schemas.PtrStringType:
if rawValueType.Kind() == reflect.String { if rawValueType.Kind() == reflect.String {
x := vv.String() x := vv.String()
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.PtrBoolType: case schemas.PtrBoolType:
if rawValueType.Kind() == reflect.Bool { if rawValueType.Kind() == reflect.Bool {
x := vv.Bool() x := vv.Bool()
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.PtrTimeType: case schemas.PtrTimeType:
if rawValueType == core.PtrTimeType { if rawValueType == schemas.PtrTimeType {
hasAssigned = true hasAssigned = true
var x = rawValue.Interface().(time.Time) var x = rawValue.Interface().(time.Time)
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.PtrFloat64Type: case schemas.PtrFloat64Type:
if rawValueType.Kind() == reflect.Float64 { if rawValueType.Kind() == reflect.Float64 {
x := vv.Float() x := vv.Float()
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.PtrUint64Type: case schemas.PtrUint64Type:
if rawValueType.Kind() == reflect.Int64 { if rawValueType.Kind() == reflect.Int64 {
var x = uint64(vv.Int()) var x = uint64(vv.Int())
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.PtrInt64Type: case schemas.PtrInt64Type:
if rawValueType.Kind() == reflect.Int64 { if rawValueType.Kind() == reflect.Int64 {
x := vv.Int() x := vv.Int()
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.PtrFloat32Type: case schemas.PtrFloat32Type:
if rawValueType.Kind() == reflect.Float64 { if rawValueType.Kind() == reflect.Float64 {
var x = float32(vv.Float()) var x = float32(vv.Float())
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.PtrIntType: case schemas.PtrIntType:
if rawValueType.Kind() == reflect.Int64 { if rawValueType.Kind() == reflect.Int64 {
var x = int(vv.Int()) var x = int(vv.Int())
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.PtrInt32Type: case schemas.PtrInt32Type:
if rawValueType.Kind() == reflect.Int64 { if rawValueType.Kind() == reflect.Int64 {
var x = int32(vv.Int()) var x = int32(vv.Int())
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.PtrInt8Type: case schemas.PtrInt8Type:
if rawValueType.Kind() == reflect.Int64 { if rawValueType.Kind() == reflect.Int64 {
var x = int8(vv.Int()) var x = int8(vv.Int())
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.PtrInt16Type: case schemas.PtrInt16Type:
if rawValueType.Kind() == reflect.Int64 { if rawValueType.Kind() == reflect.Int64 {
var x = int16(vv.Int()) var x = int16(vv.Int())
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.PtrUintType: case schemas.PtrUintType:
if rawValueType.Kind() == reflect.Int64 { if rawValueType.Kind() == reflect.Int64 {
var x = uint(vv.Int()) var x = uint(vv.Int())
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.PtrUint32Type: case schemas.PtrUint32Type:
if rawValueType.Kind() == reflect.Int64 { if rawValueType.Kind() == reflect.Int64 {
var x = uint32(vv.Int()) var x = uint32(vv.Int())
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.Uint8Type: case schemas.Uint8Type:
if rawValueType.Kind() == reflect.Int64 { if rawValueType.Kind() == reflect.Int64 {
var x = uint8(vv.Int()) var x = uint8(vv.Int())
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.Uint16Type: case schemas.Uint16Type:
if rawValueType.Kind() == reflect.Int64 { if rawValueType.Kind() == reflect.Int64 {
var x = uint16(vv.Int()) var x = uint16(vv.Int())
hasAssigned = true hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
case core.Complex64Type: case schemas.Complex64Type:
var x complex64 var x complex64
if len([]byte(vv.String())) > 0 { if len([]byte(vv.String())) > 0 {
err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x)
@ -822,7 +823,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
} }
hasAssigned = true hasAssigned = true
case core.Complex128Type: case schemas.Complex128Type:
var x complex128 var x complex128
if len([]byte(vv.String())) > 0 { if len([]byte(vv.String())) > 0 {
err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x)

View File

@ -9,10 +9,10 @@ import (
"strings" "strings"
"time" "time"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
func setColumnInt(bean interface{}, col *core.Column, t int64) { func setColumnInt(bean interface{}, col *schemas.Column, t int64) {
v, err := col.ValueOf(bean) v, err := col.ValueOf(bean)
if err != nil { if err != nil {
return return
@ -27,7 +27,7 @@ func setColumnInt(bean interface{}, col *core.Column, t int64) {
} }
} }
func setColumnTime(bean interface{}, col *core.Column, t time.Time) { func setColumnTime(bean interface{}, col *schemas.Column, t time.Time) {
v, err := col.ValueOf(bean) v, err := col.ValueOf(bean)
if err != nil { if err != nil {
return return
@ -44,7 +44,7 @@ func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
} }
} }
func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) { func getFlagForColumn(m map[string]bool, col *schemas.Column) (val bool, has bool) {
if len(m) == 0 { if len(m) == 0 {
return false, false return false, false
} }

View File

@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
func TestSetExpr(t *testing.T) { func TestSetExpr(t *testing.T) {
@ -45,7 +45,7 @@ func TestSetExpr(t *testing.T) {
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var not = "NOT" var not = "NOT"
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
not = "~" not = "~"
} }
cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr)) cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr))

View File

@ -14,10 +14,10 @@ import (
"strings" "strings"
"time" "time"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
func (session *Session) str2Time(col *core.Column, data string) (outTime time.Time, outErr error) { func (session *Session) str2Time(col *schemas.Column, data string) (outTime time.Time, outErr error) {
sdata := strings.TrimSpace(data) sdata := strings.TrimSpace(data)
var x time.Time var x time.Time
var err error var err error
@ -54,14 +54,14 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc)
//session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) //session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if col.SQLType.Name == core.Time { } else if col.SQLType.Name == schemas.Time {
if strings.Contains(sdata, " ") { if strings.Contains(sdata, " ") {
ssd := strings.Split(sdata, " ") ssd := strings.Split(sdata, " ")
sdata = ssd[1] sdata = ssd[1]
} }
sdata = strings.TrimSpace(sdata) sdata = strings.TrimSpace(sdata)
if session.engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 { if session.engine.dialect.DBType() == schemas.MYSQL && len(sdata) > 8 {
sdata = sdata[len(sdata)-8:] sdata = sdata[len(sdata)-8:]
} }
@ -80,7 +80,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
return return
} }
func (session *Session) byte2Time(col *core.Column, data []byte) (outTime time.Time, outErr error) { func (session *Session) byte2Time(col *schemas.Column, data []byte) (outTime time.Time, outErr error) {
return session.str2Time(col, string(data)) return session.str2Time(col, string(data))
} }
@ -89,12 +89,12 @@ var (
) )
// convert a db data([]byte) to a field value // convert a db data([]byte) to a field value
func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, data []byte) error { func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Value, data []byte) error {
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
return structConvert.FromDB(data) return structConvert.FromDB(data)
} }
if structConvert, ok := fieldValue.Interface().(core.Conversion); ok { if structConvert, ok := fieldValue.Interface().(Conversion); ok {
return structConvert.FromDB(data) return structConvert.FromDB(data)
} }
@ -157,8 +157,8 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var x int64 var x int64
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit && if col.SQLType.Name == schemas.Bit &&
session.engine.dialect.DBType() == core.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API session.engine.dialect.DBType() == schemas.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API
if len(data) == 1 { if len(data) == 1 {
x = int64(data[0]) x = int64(data[0])
} else { } else {
@ -199,7 +199,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
return fmt.Errorf("sql.Scan(%v) failed: %s ", data, err.Error()) return fmt.Errorf("sql.Scan(%v) failed: %s ", data, err.Error())
} }
} else { } else {
if fieldType.ConvertibleTo(core.TimeType) { if fieldType.ConvertibleTo(schemas.TimeType) {
x, err := session.byte2Time(col, data) x, err := session.byte2Time(col, data)
if err != nil { if err != nil {
return err return err
@ -217,7 +217,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
return errors.New("unsupported composited primary key cascade") return errors.New("unsupported composited primary key cascade")
} }
var pk = make(core.PK, len(table.PrimaryKeys)) var pk = make(schemas.PK, len(table.PrimaryKeys))
rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
pk[0], err = str2PK(string(data), rawValueType) pk[0], err = str2PK(string(data), rawValueType)
if err != nil { if err != nil {
@ -247,11 +247,11 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
//typeStr := fieldType.String() //typeStr := fieldType.String()
switch fieldType.Elem().Kind() { switch fieldType.Elem().Kind() {
// case "*string": // case "*string":
case core.StringType.Kind(): case schemas.StringType.Kind():
x := string(data) x := string(data)
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*bool": // case "*bool":
case core.BoolType.Kind(): case schemas.BoolType.Kind():
d := string(data) d := string(data)
v, err := strconv.ParseBool(d) v, err := strconv.ParseBool(d)
if err != nil { if err != nil {
@ -259,7 +259,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
} }
fieldValue.Set(reflect.ValueOf(&v).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&v).Convert(fieldType))
// case "*complex64": // case "*complex64":
case core.Complex64Type.Kind(): case schemas.Complex64Type.Kind():
var x complex64 var x complex64
if len(data) > 0 { if len(data) > 0 {
err := DefaultJSONHandler.Unmarshal(data, &x) err := DefaultJSONHandler.Unmarshal(data, &x)
@ -270,7 +270,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
} }
// case "*complex128": // case "*complex128":
case core.Complex128Type.Kind(): case schemas.Complex128Type.Kind():
var x complex128 var x complex128
if len(data) > 0 { if len(data) > 0 {
err := DefaultJSONHandler.Unmarshal(data, &x) err := DefaultJSONHandler.Unmarshal(data, &x)
@ -281,14 +281,14 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
} }
// case "*float64": // case "*float64":
case core.Float64Type.Kind(): case schemas.Float64Type.Kind():
x, err := strconv.ParseFloat(string(data), 64) x, err := strconv.ParseFloat(string(data), 64)
if err != nil { if err != nil {
return fmt.Errorf("arg %v as float64: %s", key, err.Error()) return fmt.Errorf("arg %v as float64: %s", key, err.Error())
} }
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*float32": // case "*float32":
case core.Float32Type.Kind(): case schemas.Float32Type.Kind():
var x float32 var x float32
x1, err := strconv.ParseFloat(string(data), 32) x1, err := strconv.ParseFloat(string(data), 32)
if err != nil { if err != nil {
@ -297,7 +297,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
x = float32(x1) x = float32(x1)
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*uint64": // case "*uint64":
case core.Uint64Type.Kind(): case schemas.Uint64Type.Kind():
var x uint64 var x uint64
x, err := strconv.ParseUint(string(data), 10, 64) x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil { if err != nil {
@ -305,7 +305,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
} }
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*uint": // case "*uint":
case core.UintType.Kind(): case schemas.UintType.Kind():
var x uint var x uint
x1, err := strconv.ParseUint(string(data), 10, 64) x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil { if err != nil {
@ -314,7 +314,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
x = uint(x1) x = uint(x1)
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*uint32": // case "*uint32":
case core.Uint32Type.Kind(): case schemas.Uint32Type.Kind():
var x uint32 var x uint32
x1, err := strconv.ParseUint(string(data), 10, 64) x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil { if err != nil {
@ -323,7 +323,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
x = uint32(x1) x = uint32(x1)
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*uint8": // case "*uint8":
case core.Uint8Type.Kind(): case schemas.Uint8Type.Kind():
var x uint8 var x uint8
x1, err := strconv.ParseUint(string(data), 10, 64) x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil { if err != nil {
@ -332,7 +332,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
x = uint8(x1) x = uint8(x1)
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*uint16": // case "*uint16":
case core.Uint16Type.Kind(): case schemas.Uint16Type.Kind():
var x uint16 var x uint16
x1, err := strconv.ParseUint(string(data), 10, 64) x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil { if err != nil {
@ -341,12 +341,12 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
x = uint16(x1) x = uint16(x1)
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*int64": // case "*int64":
case core.Int64Type.Kind(): case schemas.Int64Type.Kind():
sdata := string(data) sdata := string(data)
var x int64 var x int64
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit && if col.SQLType.Name == schemas.Bit &&
strings.Contains(session.engine.DriverName(), "mysql") { strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 { if len(data) == 1 {
x = int64(data[0]) x = int64(data[0])
@ -365,13 +365,13 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
} }
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*int": // case "*int":
case core.IntType.Kind(): case schemas.IntType.Kind():
sdata := string(data) sdata := string(data)
var x int var x int
var x1 int64 var x1 int64
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit && if col.SQLType.Name == schemas.Bit &&
strings.Contains(session.engine.DriverName(), "mysql") { strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 { if len(data) == 1 {
x = int(data[0]) x = int(data[0])
@ -393,14 +393,14 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
} }
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*int32": // case "*int32":
case core.Int32Type.Kind(): case schemas.Int32Type.Kind():
sdata := string(data) sdata := string(data)
var x int32 var x int32
var x1 int64 var x1 int64
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit && if col.SQLType.Name == schemas.Bit &&
session.engine.dialect.DBType() == core.MYSQL { session.engine.dialect.DBType() == schemas.MYSQL {
if len(data) == 1 { if len(data) == 1 {
x = int32(data[0]) x = int32(data[0])
} else { } else {
@ -421,13 +421,13 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
} }
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*int8": // case "*int8":
case core.Int8Type.Kind(): case schemas.Int8Type.Kind():
sdata := string(data) sdata := string(data)
var x int8 var x int8
var x1 int64 var x1 int64
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit && if col.SQLType.Name == schemas.Bit &&
strings.Contains(session.engine.DriverName(), "mysql") { strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 { if len(data) == 1 {
x = int8(data[0]) x = int8(data[0])
@ -449,13 +449,13 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
} }
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*int16": // case "*int16":
case core.Int16Type.Kind(): case schemas.Int16Type.Kind():
sdata := string(data) sdata := string(data)
var x int16 var x int16
var x1 int64 var x1 int64
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit && if col.SQLType.Name == schemas.Bit &&
strings.Contains(session.engine.DriverName(), "mysql") { strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 { if len(data) == 1 {
x = int16(data[0]) x = int16(data[0])
@ -480,7 +480,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
case reflect.Struct: case reflect.Struct:
switch fieldType { switch fieldType {
// case "*.time.Time": // case "*.time.Time":
case core.PtrTimeType: case schemas.PtrTimeType:
x, err := session.byte2Time(col, data) x, err := session.byte2Time(col, data)
if err != nil { if err != nil {
return err return err
@ -499,7 +499,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
return errors.New("unsupported composited primary key cascade") return errors.New("unsupported composited primary key cascade")
} }
var pk = make(core.PK, len(table.PrimaryKeys)) var pk = make(schemas.PK, len(table.PrimaryKeys))
rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
pk[0], err = str2PK(string(data), rawValueType) pk[0], err = str2PK(string(data), rawValueType)
if err != nil { if err != nil {
@ -536,9 +536,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
} }
// convert a field value of a struct to interface for put into db // convert a field value of a struct to interface for put into db
func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Value) (interface{}, error) { func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.Value) (interface{}, error) {
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
if fieldConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { if fieldConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
data, err := fieldConvert.ToDB() data, err := fieldConvert.ToDB()
if err != nil { if err != nil {
return 0, err return 0, err
@ -550,7 +550,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
} }
} }
if fieldConvert, ok := fieldValue.Interface().(core.Conversion); ok { if fieldConvert, ok := fieldValue.Interface().(Conversion); ok {
data, err := fieldConvert.ToDB() data, err := fieldConvert.ToDB()
if err != nil { if err != nil {
return 0, err return 0, err
@ -583,8 +583,8 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
case reflect.String: case reflect.String:
return fieldValue.String(), nil return fieldValue.String(), nil
case reflect.Struct: case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) { if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(core.TimeType).Interface().(time.Time) t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
tf := session.engine.formatColTime(col, t) tf := session.engine.formatColTime(col, t)
return tf, nil return tf, nil
} else if fieldType.ConvertibleTo(nullFloatType) { } else if fieldType.ConvertibleTo(nullFloatType) {

View File

@ -9,10 +9,11 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"xorm.io/core" "xorm.io/xorm/caches"
"xorm.io/xorm/schemas"
) )
func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, args ...interface{}) error { func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error {
if table == nil || if table == nil ||
session.tx != nil { session.tx != nil {
return ErrCacheFailed return ErrCacheFailed
@ -29,17 +30,17 @@ func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string,
cacher := session.engine.getCacher(tableName) cacher := session.engine.getCacher(tableName)
pkColumns := table.PKColumns() pkColumns := table.PKColumns()
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) ids, err := caches.GetCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
resultsSlice, err := session.queryBytes(newsql, args...) resultsSlice, err := session.queryBytes(newsql, args...)
if err != nil { if err != nil {
return err return err
} }
ids = make([]core.PK, 0) ids = make([]schemas.PK, 0)
if len(resultsSlice) > 0 { if len(resultsSlice) > 0 {
for _, data := range resultsSlice { for _, data := range resultsSlice {
var id int64 var id int64
var pk core.PK = make([]interface{}, 0) var pk schemas.PK = make([]interface{}, 0)
for _, col := range pkColumns { for _, col := range pkColumns {
if v, ok := data[col.Name]; !ok { if v, ok := data[col.Name]; !ok {
return errors.New("no id") return errors.New("no id")
@ -127,14 +128,14 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
if len(orderSQL) > 0 { if len(orderSQL) > 0 {
switch session.engine.dialect.DBType() { switch session.engine.dialect.DBType() {
case core.POSTGRES: case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 { if len(condSQL) > 0 {
deleteSQL += " AND " + inSQL deleteSQL += " AND " + inSQL
} else { } else {
deleteSQL += " WHERE " + inSQL deleteSQL += " WHERE " + inSQL
} }
case core.SQLITE: case schemas.SQLITE:
inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 { if len(condSQL) > 0 {
deleteSQL += " AND " + inSQL deleteSQL += " AND " + inSQL
@ -142,7 +143,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
deleteSQL += " WHERE " + inSQL deleteSQL += " WHERE " + inSQL
} }
// TODO: how to handle delete limit on mssql? // TODO: how to handle delete limit on mssql?
case core.MSSQL: case schemas.MSSQL:
return 0, ErrNotImplemented return 0, ErrNotImplemented
default: default:
deleteSQL += orderSQL deleteSQL += orderSQL
@ -156,7 +157,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
copy(argsForCache, condArgs) copy(argsForCache, condArgs)
argsForCache = append(condArgs, argsForCache...) argsForCache = append(condArgs, argsForCache...)
} else { } else {
// !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for cache. // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches.
copy(argsForCache, condArgs) copy(argsForCache, condArgs)
argsForCache = append(condArgs, argsForCache...) argsForCache = append(condArgs, argsForCache...)
@ -168,14 +169,14 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
if len(orderSQL) > 0 { if len(orderSQL) > 0 {
switch session.engine.dialect.DBType() { switch session.engine.dialect.DBType() {
case core.POSTGRES: case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 { if len(condSQL) > 0 {
realSQL += " AND " + inSQL realSQL += " AND " + inSQL
} else { } else {
realSQL += " WHERE " + inSQL realSQL += " WHERE " + inSQL
} }
case core.SQLITE: case schemas.SQLITE:
inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 { if len(condSQL) > 0 {
realSQL += " AND " + inSQL realSQL += " AND " + inSQL
@ -183,7 +184,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
realSQL += " WHERE " + inSQL realSQL += " WHERE " + inSQL
} }
// TODO: how to handle delete limit on mssql? // TODO: how to handle delete limit on mssql?
case core.MSSQL: case schemas.MSSQL:
return 0, ErrNotImplemented return 0, ErrNotImplemented
default: default:
realSQL += orderSQL realSQL += orderSQL

View File

@ -9,7 +9,8 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core" "xorm.io/xorm/caches"
"xorm.io/xorm/schemas"
) )
func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
@ -26,7 +27,7 @@ func TestDelete(t *testing.T) {
defer session.Close() defer session.Close()
var err error var err error
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Begin() err = session.Begin()
assert.NoError(t, err) assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT userinfo_delete ON") _, err = session.Exec("SET IDENTITY_INSERT userinfo_delete ON")
@ -38,7 +39,7 @@ func TestDelete(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Commit() err = session.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }
@ -159,7 +160,7 @@ func TestCacheDelete(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
oldCacher := testEngine.GetDefaultCacher() oldCacher := testEngine.GetDefaultCacher()
cacher := NewLRUCacher(NewMemoryStore(), 1000) cacher := caches.NewLRUCacher(caches.NewMemoryStore(), 1000)
testEngine.SetDefaultCacher(cacher) testEngine.SetDefaultCacher(cacher)
type CacheDeleteStruct struct { type CacheDeleteStruct struct {

View File

@ -10,7 +10,7 @@ import (
"reflect" "reflect"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
// Exist returns true if the record exist otherwise return false // Exist returns true if the record exist otherwise return false
@ -45,18 +45,18 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
return false, err return false, err
} }
if session.engine.dialect.DBType() == core.MSSQL { if session.engine.dialect.DBType() == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL) sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL)
} else if session.engine.dialect.DBType() == core.ORACLE { } else if session.engine.dialect.DBType() == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL) sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL)
} else { } else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL) sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL)
} }
args = condArgs args = condArgs
} else { } else {
if session.engine.dialect.DBType() == core.MSSQL { if session.engine.dialect.DBType() == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr) sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr)
} else if session.engine.dialect.DBType() == core.ORACLE { } else if session.engine.dialect.DBType() == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr) sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr)
} else { } else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr) sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr)

View File

@ -11,7 +11,8 @@ import (
"strings" "strings"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/core" "xorm.io/xorm/caches"
"xorm.io/xorm/schemas"
) )
const ( const (
@ -197,7 +198,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
return session.noCacheFind(table, sliceValue, sqlStr, args...) return session.noCacheFind(table, sliceValue, sqlStr, args...)
} }
func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error {
rows, err := session.queryRows(sqlStr, args...) rows, err := session.queryRows(sqlStr, args...)
if err != nil { if err != nil {
return err return err
@ -236,10 +237,10 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
return reflect.New(elemType) return reflect.New(elemType)
} }
var containerValueSetFunc func(*reflect.Value, core.PK) error var containerValueSetFunc func(*reflect.Value, schemas.PK) error
if containerValue.Kind() == reflect.Slice { if containerValue.Kind() == reflect.Slice {
containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error { containerValueSetFunc = func(newValue *reflect.Value, pk schemas.PK) error {
if isPointer { if isPointer {
containerValue.Set(reflect.Append(containerValue, newValue.Elem().Addr())) containerValue.Set(reflect.Append(containerValue, newValue.Elem().Addr()))
} else { } else {
@ -256,7 +257,7 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
return errors.New("don't support multiple primary key's map has non-slice key type") return errors.New("don't support multiple primary key's map has non-slice key type")
} }
containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error { containerValueSetFunc = func(newValue *reflect.Value, pk schemas.PK) error {
keyValue := reflect.New(keyType) keyValue := reflect.New(keyType)
err := convertPKToValue(table, keyValue.Interface(), pk) err := convertPKToValue(table, keyValue.Interface(), pk)
if err != nil { if err != nil {
@ -310,7 +311,7 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
return nil return nil
} }
func convertPKToValue(table *core.Table, dst interface{}, pk core.PK) error { func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error {
cols := table.PKColumns() cols := table.PKColumns()
if len(cols) == 1 { if len(cols) == 1 {
return convertAssign(dst, pk[0]) return convertAssign(dst, pk[0])
@ -343,7 +344,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
} }
table := session.statement.RefTable table := session.statement.RefTable
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) ids, err := caches.GetCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
rows, err := session.queryRows(newsql, args...) rows, err := session.queryRows(newsql, args...)
if err != nil { if err != nil {
@ -352,7 +353,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
defer rows.Close() defer rows.Close()
var i int var i int
ids = make([]core.PK, 0) ids = make([]schemas.PK, 0)
for rows.Next() { for rows.Next() {
i++ i++
if i > 500 { if i > 500 {
@ -364,7 +365,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
if err != nil { if err != nil {
return err return err
} }
var pk core.PK = make([]interface{}, len(table.PrimaryKeys)) var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() { for i, col := range table.PKColumns() {
pk[i], err = session.engine.idTypeAssertion(col, res[i]) pk[i], err = session.engine.idTypeAssertion(col, res[i])
if err != nil { if err != nil {
@ -376,7 +377,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
} }
session.engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, sqlStr, newsql, args) session.engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, sqlStr, newsql, args)
err = core.PutCacheSql(cacher, ids, tableName, newsql, args) err = caches.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil { if err != nil {
return err return err
} }
@ -387,7 +388,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
ididxes := make(map[string]int) ididxes := make(map[string]int)
var ides []core.PK var ides []schemas.PK
var temps = make([]interface{}, len(ids)) var temps = make([]interface{}, len(ids))
for idx, id := range ids { for idx, id := range ids {
@ -502,7 +503,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
} }
} else { } else {
if keyType.Kind() != reflect.Slice { if keyType.Kind() != reflect.Slice {
return errors.New("table have multiple primary keys, key is not core.PK or slice") return errors.New("table have multiple primary keys, key is not schemas.PK or slice")
} }
ikey = key ikey = key
} }

View File

@ -11,7 +11,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core" "xorm.io/xorm/names"
) )
func TestJoinLimit(t *testing.T) { func TestJoinLimit(t *testing.T) {
@ -300,7 +300,7 @@ func TestOrderSameMapper(t *testing.T) {
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(rValue(new(Userinfo)).Type())
mapper := testEngine.GetTableMapper() mapper := testEngine.GetTableMapper()
testEngine.SetMapper(core.SameMapper{}) testEngine.SetMapper(names.SameMapper{})
defer func() { defer func() {
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(rValue(new(Userinfo)).Type())
@ -325,7 +325,7 @@ func TestHavingSameMapper(t *testing.T) {
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(rValue(new(Userinfo)).Type())
mapper := testEngine.GetTableMapper() mapper := testEngine.GetTableMapper()
testEngine.SetMapper(core.SameMapper{}) testEngine.SetMapper(names.SameMapper{})
defer func() { defer func() {
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(rValue(new(Userinfo)).Type())
testEngine.SetMapper(mapper) testEngine.SetMapper(mapper)

View File

@ -11,7 +11,8 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"xorm.io/core" "xorm.io/xorm/caches"
"xorm.io/xorm/schemas"
) )
// Get retrieve one record from database, bean's non-empty fields // Get retrieve one record from database, bean's non-empty fields
@ -99,7 +100,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
return true, nil return true, nil
} }
func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
rows, err := session.queryRows(sqlStr, args...) rows, err := session.queryRows(sqlStr, args...)
if err != nil { if err != nil {
return false, err return false, err
@ -283,7 +284,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
session.engine.logger.Debug("[cacheGet] find sql:", newsql, args) session.engine.logger.Debug("[cacheGet] find sql:", newsql, args)
table := session.statement.RefTable table := session.statement.RefTable
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) ids, err := caches.GetCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
var res = make([]string, len(table.PrimaryKeys)) var res = make([]string, len(table.PrimaryKeys))
rows, err := session.NoCache().queryRows(newsql, args...) rows, err := session.NoCache().queryRows(newsql, args...)
@ -301,7 +302,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return false, ErrCacheFailed return false, ErrCacheFailed
} }
var pk core.PK = make([]interface{}, len(table.PrimaryKeys)) var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() { for i, col := range table.PKColumns() {
if col.SQLType.IsText() { if col.SQLType.IsText() {
pk[i] = res[i] pk[i] = res[i]
@ -316,9 +317,9 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
} }
} }
ids = []core.PK{pk} ids = []schemas.PK{pk}
session.engine.logger.Debug("[cacheGet] cache ids:", newsql, ids) session.engine.logger.Debug("[cacheGet] cache ids:", newsql, ids)
err = core.PutCacheSql(cacher, ids, tableName, newsql, args) err = caches.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@ -11,7 +11,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
func TestGetVar(t *testing.T) { func TestGetVar(t *testing.T) {
@ -153,7 +153,7 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money)) assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))
var money2 float64 var money2 float64
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
has, err = testEngine.SQL("SELECT TOP 1 money FROM " + testEngine.TableName("get_var", true)).Get(&money2) has, err = testEngine.SQL("SELECT TOP 1 money FROM " + testEngine.TableName("get_var", true)).Get(&money2)
} else { } else {
has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2) has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2)
@ -233,7 +233,7 @@ func TestGetStruct(t *testing.T) {
defer session.Close() defer session.Close()
var err error var err error
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Begin() err = session.Begin()
assert.NoError(t, err) assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT userinfo_get ON") _, err = session.Exec("SET IDENTITY_INSERT userinfo_get ON")
@ -242,7 +242,7 @@ func TestGetStruct(t *testing.T) {
cnt, err := session.Insert(&UserinfoGet{Uid: 2}) cnt, err := session.Insert(&UserinfoGet{Uid: 2})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Commit() err = session.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@ -13,7 +13,7 @@ import (
"strings" "strings"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
// ErrNoElementsOnSlice represents an error there is no element when insert // ErrNoElementsOnSlice represents an error there is no element when insert
@ -127,7 +127,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
var colNames []string var colNames []string
var colMultiPlaces []string var colMultiPlaces []string
var args []interface{} var args []interface{}
var cols []*core.Column var cols []*schemas.Column
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
v := sliceValue.Index(i) v := sliceValue.Index(i)
@ -156,7 +156,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsAutoIncrement && isZero(fieldValue.Interface()) { if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
continue continue
} }
if col.MapType == core.ONLYFROMDB { if col.MapType == schemas.ONLYFROMDB {
continue continue
} }
if col.IsDeleted { if col.IsDeleted {
@ -207,7 +207,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsAutoIncrement && isZero(fieldValue.Interface()) { if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
continue continue
} }
if col.MapType == core.ONLYFROMDB { if col.MapType == schemas.ONLYFROMDB {
continue continue
} }
if col.IsDeleted { if col.IsDeleted {
@ -251,7 +251,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
cleanupProcessorsClosures(&session.beforeClosures) cleanupProcessorsClosures(&session.beforeClosures)
var sql string var sql string
if session.engine.dialect.DBType() == core.ORACLE { if session.engine.dialect.DBType() == schemas.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v) VALUES (", temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
session.engine.Quote(tableName), session.engine.Quote(tableName),
quoteColumns(colNames, session.engine.Quote, ",")) quoteColumns(colNames, session.engine.Quote, ","))
@ -358,7 +358,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
var tableName = session.statement.TableName() var tableName = session.statement.TableName()
var output string var output string
if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 { if session.engine.dialect.DBType() == schemas.MSSQL && len(table.AutoIncrement) > 0 {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
} }
@ -368,7 +368,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} }
if len(colPlaces) <= 0 { if len(colPlaces) <= 0 {
if session.engine.dialect.DBType() == core.MYSQL { if session.engine.dialect.DBType() == schemas.MYSQL {
if _, err := buf.WriteString(" VALUES ()"); err != nil { if _, err := buf.WriteString(" VALUES ()"); err != nil {
return 0, err return 0, err
} }
@ -430,7 +430,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} }
} }
if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES { if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == schemas.POSTGRES {
if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil { if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil {
return 0, err return 0, err
} }
@ -469,7 +469,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// for postgres, many of them didn't implement lastInsertId, so we should // for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself. // implemented it ourself.
if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 { if session.engine.dialect.DBType() == schemas.ORACLE && len(table.AutoIncrement) > 0 {
res, err := session.queryBytes("select seq_atable.currval from dual", args...) res, err := session.queryBytes("select seq_atable.currval from dual", args...)
if err != nil { if err != nil {
return 0, err return 0, err
@ -510,7 +510,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue.Set(int64ToIntValue(id, aiValue.Type())) aiValue.Set(int64ToIntValue(id, aiValue.Type()))
return 1, nil return 1, nil
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.DBType() == core.POSTGRES || session.engine.dialect.DBType() == core.MSSQL) { } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.DBType() == schemas.POSTGRES || session.engine.dialect.DBType() == schemas.MSSQL) {
res, err := session.queryBytes(sqlStr, args...) res, err := session.queryBytes(sqlStr, args...)
if err != nil { if err != nil {
@ -626,7 +626,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
args := make([]interface{}, 0, len(table.ColumnsSeq())) args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() { for _, col := range table.Columns() {
if col.MapType == core.ONLYFROMDB { if col.MapType == schemas.ONLYFROMDB {
continue continue
} }

View File

@ -11,7 +11,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
type IntId struct { type IntId struct {
@ -726,7 +726,7 @@ func TestCompositeKey(t *testing.T) {
} }
var compositeKeyVal CompositeKey var compositeKeyVal CompositeKey
has, err := testEngine.ID(core.PK{11, 22}).Get(&compositeKeyVal) has, err := testEngine.ID(schemas.PK{11, 22}).Get(&compositeKeyVal)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -735,7 +735,7 @@ func TestCompositeKey(t *testing.T) {
var compositeKeyVal2 CompositeKey var compositeKeyVal2 CompositeKey
// test passing PK ptr, this test seem failed withCache // test passing PK ptr, this test seem failed withCache
has, err = testEngine.ID(&core.PK{11, 22}).Get(&compositeKeyVal2) has, err = testEngine.ID(&schemas.PK{11, 22}).Get(&compositeKeyVal2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -772,14 +772,14 @@ func TestCompositeKey(t *testing.T) {
assert.EqualValues(t, compositeKeyVal, cps[0], "should be equeal") assert.EqualValues(t, compositeKeyVal, cps[0], "should be equeal")
compositeKeyVal = CompositeKey{UpdateStr: "test1"} compositeKeyVal = CompositeKey{UpdateStr: "test1"}
cnt, err = testEngine.ID(core.PK{11, 22}).Update(&compositeKeyVal) cnt, err = testEngine.ID(schemas.PK{11, 22}).Update(&compositeKeyVal)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
t.Error(errors.New("can't update CompositeKey{11, 22}")) t.Error(errors.New("can't update CompositeKey{11, 22}"))
} }
cnt, err = testEngine.ID(core.PK{11, 22}).Delete(&CompositeKey{}) cnt, err = testEngine.ID(schemas.PK{11, 22}).Delete(&CompositeKey{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
@ -823,7 +823,7 @@ func TestCompositeKey2(t *testing.T) {
} }
var user User var user User
has, err := testEngine.ID(core.PK{"11", 22}).Get(&user) has, err := testEngine.ID(schemas.PK{"11", 22}).Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -831,7 +831,7 @@ func TestCompositeKey2(t *testing.T) {
} }
// test passing PK ptr, this test seem failed withCache // test passing PK ptr, this test seem failed withCache
has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user) has, err = testEngine.ID(&schemas.PK{"11", 22}).Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -839,14 +839,14 @@ func TestCompositeKey2(t *testing.T) {
} }
user = User{NickName: "test1"} user = User{NickName: "test1"}
cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user) cnt, err = testEngine.ID(schemas.PK{"11", 22}).Update(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
t.Error(errors.New("can't update User{11, 22}")) t.Error(errors.New("can't update User{11, 22}"))
} }
cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&User{}) cnt, err = testEngine.ID(schemas.PK{"11", 22}).Delete(&User{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
@ -891,7 +891,7 @@ func TestCompositeKey3(t *testing.T) {
} }
var user UserPK2 var user UserPK2
has, err := testEngine.ID(core.PK{"11", 22}).Get(&user) has, err := testEngine.ID(schemas.PK{"11", 22}).Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -899,7 +899,7 @@ func TestCompositeKey3(t *testing.T) {
} }
// test passing PK ptr, this test seem failed withCache // test passing PK ptr, this test seem failed withCache
has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user) has, err = testEngine.ID(&schemas.PK{"11", 22}).Get(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if !has { } else if !has {
@ -907,14 +907,14 @@ func TestCompositeKey3(t *testing.T) {
} }
user = UserPK2{NickName: "test1"} user = UserPK2{NickName: "test1"}
cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user) cnt, err = testEngine.ID(schemas.PK{"11", 22}).Update(&user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
t.Error(errors.New("can't update User{11, 22}")) t.Error(errors.New("can't update User{11, 22}"))
} }
cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&UserPK2{}) cnt, err = testEngine.ID(schemas.PK{"11", 22}).Delete(&UserPK2{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else if cnt != 1 { } else if cnt != 1 {
@ -1130,7 +1130,7 @@ func TestCompositePK(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1+len(tables1), len(tables2)) assert.EqualValues(t, 1+len(tables1), len(tables2))
var table *core.Table var table *schemas.Table
for _, t := range tables2 { for _, t := range tables2 {
if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") { if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") {
table = t table = t

View File

@ -12,7 +12,8 @@ import (
"time" "time"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/core" "xorm.io/xorm/core"
"xorm.io/xorm/schemas"
) )
func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) { func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) {
@ -116,8 +117,8 @@ func value2String(rawValue *reflect.Value) (str string, err error) {
} }
// time type // time type
case reflect.Struct: case reflect.Struct:
if aa.ConvertibleTo(core.TimeType) { if aa.ConvertibleTo(schemas.TimeType) {
str = vv.Convert(core.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
} else { } else {
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
} }

View File

@ -11,7 +11,7 @@ import (
"time" "time"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/core" "xorm.io/xorm/schemas"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -207,7 +207,7 @@ func TestQueryStringNoParam(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0]["id"]) assert.EqualValues(t, "1", records[0]["id"])
if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL {
assert.EqualValues(t, "false", records[0]["msg"]) assert.EqualValues(t, "false", records[0]["msg"])
} else { } else {
assert.EqualValues(t, "0", records[0]["msg"]) assert.EqualValues(t, "0", records[0]["msg"])
@ -217,7 +217,7 @@ func TestQueryStringNoParam(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0]["id"]) assert.EqualValues(t, "1", records[0]["id"])
if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL {
assert.EqualValues(t, "false", records[0]["msg"]) assert.EqualValues(t, "false", records[0]["msg"])
} else { } else {
assert.EqualValues(t, "0", records[0]["msg"]) assert.EqualValues(t, "0", records[0]["msg"])
@ -244,7 +244,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0][0]) assert.EqualValues(t, "1", records[0][0])
if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL {
assert.EqualValues(t, "false", records[0][1]) assert.EqualValues(t, "false", records[0][1])
} else { } else {
assert.EqualValues(t, "0", records[0][1]) assert.EqualValues(t, "0", records[0][1])
@ -254,7 +254,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0][0]) assert.EqualValues(t, "1", records[0][0])
if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL {
assert.EqualValues(t, "false", records[0][1]) assert.EqualValues(t, "false", records[0][1])
} else { } else {
assert.EqualValues(t, "0", records[0][1]) assert.EqualValues(t, "0", records[0][1])

View File

@ -10,7 +10,7 @@ import (
"time" "time"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/core" "xorm.io/xorm/core"
) )
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {

View File

@ -9,7 +9,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
// Ping test if database is ok // Ping test if database is ok
@ -125,7 +125,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName := session.engine.TableName(beanOrTableName) tableName := session.engine.TableName(beanOrTableName)
var needDrop = true var needDrop = true
if !session.engine.dialect.SupportDropIfExists() { if !session.engine.dialect.SupportDropIfExists() {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName) sqlStr, args := session.engine.dialect.TableCheckSQL(tableName)
results, err := session.queryBytes(sqlStr, args...) results, err := session.queryBytes(sqlStr, args...)
if err != nil { if err != nil {
return err return err
@ -134,7 +134,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
} }
if needDrop { if needDrop {
sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true)) sqlStr := session.engine.Dialect().DropTableSQL(session.engine.TableName(tableName, true))
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
return err return err
} }
@ -153,7 +153,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
} }
func (session *Session) isTableExist(tableName string) (bool, error) { func (session *Session) isTableExist(tableName string) (bool, error) {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName) sqlStr, args := session.engine.dialect.TableCheckSQL(tableName)
results, err := session.queryBytes(sqlStr, args...) results, err := session.queryBytes(sqlStr, args...)
return len(results) > 0, err return len(results) > 0, err
} }
@ -190,9 +190,9 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo
for _, index := range indexes { for _, index := range indexes {
if sliceEq(index.Cols, cols) { if sliceEq(index.Cols, cols) {
if unique { if unique {
return index.Type == core.UniqueType, nil return index.Type == schemas.UniqueType, nil
} }
return index.Type == core.IndexType, nil return index.Type == schemas.IndexType, nil
} }
} }
return false, nil return false, nil
@ -207,14 +207,14 @@ func (session *Session) addColumn(colName string) error {
func (session *Session) addIndex(tableName, idxName string) error { func (session *Session) addIndex(tableName, idxName string) error {
index := session.statement.RefTable.Indexes[idxName] index := session.statement.RefTable.Indexes[idxName]
sqlStr := session.engine.dialect.CreateIndexSql(tableName, index) sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index)
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
return err return err
} }
func (session *Session) addUnique(tableName, uqeName string) error { func (session *Session) addUnique(tableName, uqeName string) error {
index := session.statement.RefTable.Indexes[uqeName] index := session.statement.RefTable.Indexes[uqeName]
sqlStr := session.engine.dialect.CreateIndexSql(tableName, index) sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index)
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
return err return err
} }
@ -253,7 +253,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
} }
tbNameWithSchema := engine.tbNameWithSchema(tbName) tbNameWithSchema := engine.tbNameWithSchema(tbName)
var oriTable *core.Table var oriTable *schemas.Table
for _, tb := range tables { for _, tb := range tables {
if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) { if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
oriTable = tb oriTable = tb
@ -287,7 +287,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
// check columns // check columns
for _, col := range table.Columns() { for _, col := range table.Columns() {
var oriCol *core.Column var oriCol *schemas.Column
for _, col2 := range oriTable.Columns() { for _, col2 := range oriTable.Columns() {
if strings.EqualFold(col.Name, col2.Name) { if strings.EqualFold(col.Name, col2.Name) {
oriCol = col2 oriCol = col2
@ -306,27 +306,27 @@ func (session *Session) Sync2(beans ...interface{}) error {
} }
err = nil err = nil
expectedType := engine.dialect.SqlType(col) expectedType := engine.dialect.SQLType(col)
curType := engine.dialect.SqlType(oriCol) curType := engine.dialect.SQLType(oriCol)
if expectedType != curType { if expectedType != curType {
if expectedType == core.Text && if expectedType == schemas.Text &&
strings.HasPrefix(curType, core.Varchar) { strings.HasPrefix(curType, schemas.Varchar) {
// currently only support mysql & postgres // currently only support mysql & postgres
if engine.dialect.DBType() == core.MYSQL || if engine.dialect.DBType() == schemas.MYSQL ||
engine.dialect.DBType() == core.POSTGRES { engine.dialect.DBType() == schemas.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n", engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbNameWithSchema, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} else { } else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbNameWithSchema, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
} }
} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) { } else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) {
if engine.dialect.DBType() == core.MYSQL { if engine.dialect.DBType() == schemas.MYSQL {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length) tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} }
} }
} else { } else {
@ -335,12 +335,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
tbNameWithSchema, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
} }
} }
} else if expectedType == core.Varchar { } else if expectedType == schemas.Varchar {
if engine.dialect.DBType() == core.MYSQL { if engine.dialect.DBType() == schemas.MYSQL {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length) tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} }
} }
} }
@ -348,7 +348,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if col.Default != oriCol.Default { if col.Default != oriCol.Default {
switch { switch {
case col.IsAutoIncrement: // For autoincrement column, don't check default case col.IsAutoIncrement: // For autoincrement column, don't check default
case (col.SQLType.Name == core.Bool || col.SQLType.Name == core.Boolean) && case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) &&
((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") || ((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") ||
(strings.EqualFold(col.Default, "false") && oriCol.Default == "0")): (strings.EqualFold(col.Default, "false") && oriCol.Default == "0")):
default: default:
@ -367,10 +367,10 @@ func (session *Session) Sync2(beans ...interface{}) error {
} }
var foundIndexNames = make(map[string]bool) var foundIndexNames = make(map[string]bool)
var addedNames = make(map[string]*core.Index) var addedNames = make(map[string]*schemas.Index)
for name, index := range table.Indexes { for name, index := range table.Indexes {
var oriIndex *core.Index var oriIndex *schemas.Index
for name2, index2 := range oriTable.Indexes { for name2, index2 := range oriTable.Indexes {
if index.Equal(index2) { if index.Equal(index2) {
oriIndex = index2 oriIndex = index2
@ -381,7 +381,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriIndex != nil { if oriIndex != nil {
if oriIndex.Type != index.Type { if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex) sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex)
_, err = session.exec(sql) _, err = session.exec(sql)
if err != nil { if err != nil {
return err return err
@ -397,7 +397,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name2, index2 := range oriTable.Indexes { for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok { if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2) sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2)
_, err = session.exec(sql) _, err = session.exec(sql)
if err != nil { if err != nil {
return err return err
@ -406,11 +406,11 @@ func (session *Session) Sync2(beans ...interface{}) error {
} }
for name, index := range addedNames { for name, index := range addedNames {
if index.Type == core.UniqueType { if index.Type == schemas.UniqueType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema session.statement.tableName = tbNameWithSchema
err = session.addUnique(tbNameWithSchema, name) err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == core.IndexType { } else if index.Type == schemas.IndexType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema session.statement.tableName = tbNameWithSchema
err = session.addIndex(tbNameWithSchema, name) err = session.addIndex(tbNameWithSchema, name)

View File

@ -213,7 +213,7 @@ func TestCustomTableName(t *testing.T) {
func TestDump(t *testing.T) { func TestDump(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
fp := testEngine.Dialect().URI().DbName + ".sql" fp := testEngine.Dialect().URI().DBName + ".sql"
os.Remove(fp) os.Remove(fp)
assert.NoError(t, testEngine.DumpAllToFile(fp)) assert.NoError(t, testEngine.DumpAllToFile(fp))
} }

View File

@ -10,7 +10,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core" "xorm.io/xorm/names"
) )
func TestTransaction(t *testing.T) { func TestTransaction(t *testing.T) {
@ -89,7 +89,7 @@ func TestCombineTransactionSameMapper(t *testing.T) {
oldMapper := testEngine.GetColumnMapper() oldMapper := testEngine.GetColumnMapper()
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(rValue(new(Userinfo)).Type())
testEngine.SetMapper(core.SameMapper{}) testEngine.SetMapper(names.SameMapper{})
defer func() { defer func() {
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(rValue(new(Userinfo)).Type())
testEngine.SetMapper(oldMapper) testEngine.SetMapper(oldMapper)

View File

@ -12,10 +12,11 @@ import (
"strings" "strings"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/core" "xorm.io/xorm/caches"
"xorm.io/xorm/schemas"
) )
func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, args ...interface{}) error { func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error {
if table == nil || if table == nil ||
session.tx != nil { session.tx != nil {
return ErrCacheFailed return ErrCacheFailed
@ -42,7 +43,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
cacher := session.engine.getCacher(tableName) cacher := session.engine.getCacher(tableName)
session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) ids, err := caches.GetCacheSql(cacher, tableName, newsql, args[nStart:])
if err != nil { if err != nil {
rows, err := session.NoCache().queryRows(newsql, args[nStart:]...) rows, err := session.NoCache().queryRows(newsql, args[nStart:]...)
if err != nil { if err != nil {
@ -50,14 +51,14 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
} }
defer rows.Close() defer rows.Close()
ids = make([]core.PK, 0) ids = make([]schemas.PK, 0)
for rows.Next() { for rows.Next() {
var res = make([]string, len(table.PrimaryKeys)) var res = make([]string, len(table.PrimaryKeys))
err = rows.ScanSlice(&res) err = rows.ScanSlice(&res)
if err != nil { if err != nil {
return err return err
} }
var pk core.PK = make([]interface{}, len(table.PrimaryKeys)) var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() { for i, col := range table.PKColumns() {
if col.SQLType.IsNumeric() { if col.SQLType.IsNumeric() {
n, err := strconv.ParseInt(res[i], 10, 64) n, err := strconv.ParseInt(res[i], 10, 64)
@ -339,9 +340,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var top string var top string
if st.LimitN != nil { if st.LimitN != nil {
limitValue := *st.LimitN limitValue := *st.LimitN
if st.Engine.dialect.DBType() == core.MYSQL { if st.Engine.dialect.DBType() == schemas.MYSQL {
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
} else if st.Engine.dialect.DBType() == core.SQLITE { } else if st.Engine.dialect.DBType() == schemas.SQLITE {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), tempCondSQL), condArgs...))
@ -352,7 +353,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 { if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL condSQL = "WHERE " + condSQL
} }
} else if st.Engine.dialect.DBType() == core.POSTGRES { } else if st.Engine.dialect.DBType() == schemas.POSTGRES {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), tempCondSQL), condArgs...))
@ -364,8 +365,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 { if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL condSQL = "WHERE " + condSQL
} }
} else if st.Engine.dialect.DBType() == core.MSSQL { } else if st.Engine.dialect.DBType() == schemas.MSSQL {
if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL && if st.OrderStr != "" && st.Engine.dialect.DBType() == schemas.MSSQL &&
table != nil && len(table.PrimaryKeys) == 1 { table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
@ -392,7 +393,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var fromSQL string var fromSQL string
if session.statement.TableAlias != "" { if session.statement.TableAlias != "" {
switch session.engine.dialect.DBType() { switch session.engine.dialect.DBType() {
case core.MSSQL: case schemas.MSSQL:
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias) fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias)
tableAlias = session.statement.TableAlias tableAlias = session.statement.TableAlias
default: default:
@ -467,7 +468,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
continue continue
} }
} }
if col.MapType == core.ONLYFROMDB { if col.MapType == schemas.ONLYFROMDB {
continue continue
} }

View File

@ -12,7 +12,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core" "xorm.io/xorm/names"
) )
func TestUpdateMap(t *testing.T) { func TestUpdateMap(t *testing.T) {
@ -691,7 +691,7 @@ func TestUpdateSameMapper(t *testing.T) {
testEngine.UnMapType(rValue(new(UpdateAllCols)).Type()) testEngine.UnMapType(rValue(new(UpdateAllCols)).Type())
testEngine.UnMapType(rValue(new(UpdateMustCols)).Type()) testEngine.UnMapType(rValue(new(UpdateMustCols)).Type())
testEngine.UnMapType(rValue(new(UpdateIncr)).Type()) testEngine.UnMapType(rValue(new(UpdateIncr)).Type())
testEngine.SetMapper(core.SameMapper{}) testEngine.SetMapper(names.SameMapper{})
defer func() { defer func() {
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(rValue(new(Userinfo)).Type())
testEngine.UnMapType(rValue(new(Condi)).Type()) testEngine.UnMapType(rValue(new(Condi)).Type())
@ -1419,7 +1419,7 @@ func TestUpdateMap3(t *testing.T) {
testEngine.SetColumnMapper(oldMapper) testEngine.SetColumnMapper(oldMapper)
}() }()
mapper := core.NewPrefixMapper(core.SnakeMapper{}, "F") mapper := names.NewPrefixMapper(names.SnakeMapper{}, "F")
testEngine.SetColumnMapper(mapper) testEngine.SetColumnMapper(mapper)
assertSync(t, new(UpdateMapUser)) assertSync(t, new(UpdateMapUser))

View File

@ -12,16 +12,17 @@ import (
"time" "time"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/core" "xorm.io/xorm/dialects"
"xorm.io/xorm/schemas"
) )
// Statement save all the sql info for executing SQL // Statement save all the sql info for executing SQL
type Statement struct { type Statement struct {
RefTable *core.Table RefTable *schemas.Table
Engine *Engine Engine *Engine
Start int Start int
LimitN *int LimitN *int
idParam *core.PK idParam *schemas.PK
OrderStr string OrderStr string
JoinStr string JoinStr string
joinArgs []interface{} joinArgs []interface{}
@ -266,7 +267,7 @@ func (statement *Statement) buildUpdates(bean interface{},
continue continue
} }
if col.MapType == core.ONLYFROMDB { if col.MapType == schemas.ONLYFROMDB {
continue continue
} }
@ -314,7 +315,7 @@ func (statement *Statement) buildUpdates(bean interface{},
var val interface{} var val interface{}
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
data, err := structConvert.ToDB() data, err := structConvert.ToDB()
if err != nil { if err != nil {
engine.logger.Error(err) engine.logger.Error(err)
@ -325,7 +326,7 @@ func (statement *Statement) buildUpdates(bean interface{},
} }
} }
if structConvert, ok := fieldValue.Interface().(core.Conversion); ok { if structConvert, ok := fieldValue.Interface().(Conversion); ok {
data, err := structConvert.ToDB() data, err := structConvert.ToDB()
if err != nil { if err != nil {
engine.logger.Error(err) engine.logger.Error(err)
@ -388,8 +389,8 @@ func (statement *Statement) buildUpdates(bean interface{},
t := int64(fieldValue.Uint()) t := int64(fieldValue.Uint())
val = reflect.ValueOf(&t).Interface() val = reflect.ValueOf(&t).Interface()
case reflect.Struct: case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) { if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(core.TimeType).Interface().(time.Time) t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue continue
} }
@ -496,7 +497,7 @@ func (statement *Statement) needTableName() bool {
return len(statement.JoinStr) > 0 return len(statement.JoinStr) > 0
} }
func (statement *Statement) colName(col *core.Column, tableName string) string { func (statement *Statement) colName(col *schemas.Column, tableName string) string {
if statement.needTableName() { if statement.needTableName() {
var nm = tableName var nm = tableName
if len(statement.TableAlias) > 0 { if len(statement.TableAlias) > 0 {
@ -523,12 +524,12 @@ func (statement *Statement) ID(id interface{}) *Statement {
switch idType { switch idType {
case ptrPkType: case ptrPkType:
if pkPtr, ok := (id).(*core.PK); ok { if pkPtr, ok := (id).(*schemas.PK); ok {
statement.idParam = pkPtr statement.idParam = pkPtr
return statement return statement
} }
case pkType: case pkType:
if pk, ok := (id).(core.PK); ok { if pk, ok := (id).(schemas.PK); ok {
statement.idParam = &pk statement.idParam = &pk
return statement return statement
} }
@ -536,11 +537,11 @@ func (statement *Statement) ID(id interface{}) *Statement {
switch idType.Kind() { switch idType.Kind() {
case reflect.String: case reflect.String:
statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()} statement.idParam = &schemas.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
return statement return statement
} }
statement.idParam = &core.PK{id} statement.idParam = &schemas.PK{id}
return statement return statement
} }
@ -807,7 +808,7 @@ func (statement *Statement) genColumnStr() string {
continue continue
} }
if col.MapType == core.ONLYTODB { if col.MapType == schemas.ONLYTODB {
continue continue
} }
@ -832,7 +833,7 @@ func (statement *Statement) genColumnStr() string {
} }
func (statement *Statement) genCreateTableSQL() string { func (statement *Statement) genCreateTableSQL() string {
return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(), return statement.Engine.dialect.CreateTableSQL(statement.RefTable, statement.TableName(),
statement.StoreEngine, statement.Charset) statement.StoreEngine, statement.Charset)
} }
@ -840,8 +841,8 @@ func (statement *Statement) genIndexSQL() []string {
var sqls []string var sqls []string
tbName := statement.TableName() tbName := statement.TableName()
for _, index := range statement.RefTable.Indexes { for _, index := range statement.RefTable.Indexes {
if index.Type == core.IndexType { if index.Type == schemas.IndexType {
sql := statement.Engine.dialect.CreateIndexSql(tbName, index) sql := statement.Engine.dialect.CreateIndexSQL(tbName, index)
/*idxTBName := strings.Replace(tbName, ".", "_", -1) /*idxTBName := strings.Replace(tbName, ".", "_", -1)
idxTBName = strings.Replace(idxTBName, `"`, "", -1) idxTBName = strings.Replace(idxTBName, `"`, "", -1)
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)), sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
@ -860,8 +861,8 @@ func (statement *Statement) genUniqueSQL() []string {
var sqls []string var sqls []string
tbName := statement.TableName() tbName := statement.TableName()
for _, index := range statement.RefTable.Indexes { for _, index := range statement.RefTable.Indexes {
if index.Type == core.UniqueType { if index.Type == schemas.UniqueType {
sql := statement.Engine.dialect.CreateIndexSql(tbName, index) sql := statement.Engine.dialect.CreateIndexSQL(tbName, index)
sqls = append(sqls, sql) sqls = append(sqls, sql)
} }
} }
@ -871,13 +872,17 @@ func (statement *Statement) genUniqueSQL() []string {
func (statement *Statement) genDelIndexSQL() []string { func (statement *Statement) genDelIndexSQL() []string {
var sqls []string var sqls []string
tbName := statement.TableName() tbName := statement.TableName()
idx := strings.Index(tbName, ".")
if idx > -1 {
tbName = tbName[idx+1:]
}
idxPrefixName := strings.Replace(tbName, `"`, "", -1) idxPrefixName := strings.Replace(tbName, `"`, "", -1)
idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1) idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
for idxName, index := range statement.RefTable.Indexes { for idxName, index := range statement.RefTable.Indexes {
var rIdxName string var rIdxName string
if index.Type == core.UniqueType { if index.Type == schemas.UniqueType {
rIdxName = uniqueName(idxPrefixName, idxName) rIdxName = uniqueName(idxPrefixName, idxName)
} else if index.Type == core.IndexType { } else if index.Type == schemas.IndexType {
rIdxName = indexName(idxPrefixName, idxName) rIdxName = indexName(idxPrefixName, idxName)
} }
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true))) sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
@ -889,18 +894,18 @@ func (statement *Statement) genDelIndexSQL() []string {
return sqls return sqls
} }
func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) { func (statement *Statement) genAddColumnStr(col *schemas.Column) (string, []interface{}) {
quote := statement.Engine.Quote quote := statement.Engine.Quote
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()), sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
col.String(statement.Engine.dialect)) dialects.String(statement.Engine.dialect, col))
if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 { if statement.Engine.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'" sql += " COMMENT '" + col.Comment + "'"
} }
sql += ";" sql += ";"
return sql, []interface{}{} return sql, []interface{}{}
} }
func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { func (statement *Statement) buildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
} }
@ -1054,14 +1059,14 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
whereStr = " WHERE " + condSQL whereStr = " WHERE " + condSQL
} }
if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") { if dialect.DBType() == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
fromStr += statement.TableName() fromStr += statement.TableName()
} else { } else {
fromStr += quote(statement.TableName()) fromStr += quote(statement.TableName())
} }
if statement.TableAlias != "" { if statement.TableAlias != "" {
if dialect.DBType() == core.ORACLE { if dialect.DBType() == schemas.ORACLE {
fromStr += " " + quote(statement.TableAlias) fromStr += " " + quote(statement.TableAlias)
} else { } else {
fromStr += " AS " + quote(statement.TableAlias) fromStr += " AS " + quote(statement.TableAlias)
@ -1072,7 +1077,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
} }
pLimitN := statement.LimitN pLimitN := statement.LimitN
if dialect.DBType() == core.MSSQL { if dialect.DBType() == schemas.MSSQL {
if pLimitN != nil { if pLimitN != nil {
LimitNValue := *pLimitN LimitNValue := *pLimitN
top = fmt.Sprintf("TOP %d ", LimitNValue) top = fmt.Sprintf("TOP %d ", LimitNValue)
@ -1134,7 +1139,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
} }
if needLimit { if needLimit {
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { if dialect.DBType() != schemas.MSSQL && dialect.DBType() != schemas.ORACLE {
if statement.Start > 0 { if statement.Start > 0 {
if pLimitN != nil { if pLimitN != nil {
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start) fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
@ -1144,7 +1149,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
} else if pLimitN != nil { } else if pLimitN != nil {
fmt.Fprint(&buf, " LIMIT ", *pLimitN) fmt.Fprint(&buf, " LIMIT ", *pLimitN)
} }
} else if dialect.DBType() == core.ORACLE { } else if dialect.DBType() == schemas.ORACLE {
if statement.Start != 0 || pLimitN != nil { if statement.Start != 0 || pLimitN != nil {
oldString := buf.String() oldString := buf.String()
buf.Reset() buf.Reset()
@ -1158,7 +1163,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
} }
} }
if statement.IsForUpdate { if statement.IsForUpdate {
return dialect.ForUpdateSql(buf.String()), nil return dialect.ForUpdateSQL(buf.String()), nil
} }
return buf.String(), nil return buf.String(), nil
@ -1183,7 +1188,7 @@ func (statement *Statement) processIDParam() error {
return nil return nil
} }
func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string { func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName bool) string {
var colnames = make([]string, len(cols)) var colnames = make([]string, len(cols))
for i, col := range cols { for i, col := range cols {
if includeTableName { if includeTableName {
@ -1211,7 +1216,7 @@ func (statement *Statement) convertIDSQL(sqlStr string) string {
var top string var top string
pLimitN := statement.LimitN pLimitN := statement.LimitN
if pLimitN != nil && statement.Engine.dialect.DBType() == core.MSSQL { if pLimitN != nil && statement.Engine.dialect.DBType() == schemas.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN) top = fmt.Sprintf("TOP %d ", *pLimitN)
} }
@ -1240,9 +1245,9 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
// TODO: for postgres only, if any other database? // TODO: for postgres only, if any other database?
var paraStr string var paraStr string
if statement.Engine.dialect.DBType() == core.POSTGRES { if statement.Engine.dialect.DBType() == schemas.POSTGRES {
paraStr = "$" paraStr = "$"
} else if statement.Engine.dialect.DBType() == core.MSSQL { } else if statement.Engine.dialect.DBType() == schemas.MSSQL {
paraStr = ":" paraStr = ":"
} }

View File

@ -11,7 +11,7 @@ import (
"time" "time"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
func quoteNeeded(a interface{}) bool { func quoteNeeded(a interface{}) bool {
@ -80,7 +80,7 @@ const insertSelectPlaceHolder = true
func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error { func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error {
switch argv := arg.(type) { switch argv := arg.(type) {
case bool: case bool:
if statement.Engine.dialect.DBType() == core.MSSQL { if statement.Engine.dialect.DBType() == schemas.MSSQL {
if argv { if argv {
if _, err := w.WriteString("1"); err != nil { if _, err := w.WriteString("1"); err != nil {
return err return err
@ -119,7 +119,7 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er
w.Append(arg) w.Append(arg)
} else { } else {
var convertFunc = convertStringSingleQuote var convertFunc = convertStringSingleQuote
if statement.Engine.dialect.DBType() == core.MYSQL { if statement.Engine.dialect.DBType() == schemas.MYSQL {
convertFunc = convertString convertFunc = convertString
} }
if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil { if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil {

View File

@ -10,7 +10,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
var colStrTests = []struct { var colStrTests = []struct {
@ -42,7 +42,7 @@ func TestColumnsStringGeneration(t *testing.T) {
columns := statement.RefTable.Columns() columns := statement.RefTable.Columns()
if testCase.onlyToDBColumnNdx >= 0 { if testCase.onlyToDBColumnNdx >= 0 {
columns[testCase.onlyToDBColumnNdx].MapType = core.ONLYTODB columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB
} }
actual := statement.genColumnStr() actual := statement.genColumnStr()
@ -51,7 +51,7 @@ func TestColumnsStringGeneration(t *testing.T) {
t.Errorf("[test #%d] Unexpected columns string:\nwant:\t%s\nhave:\t%s", ndx, testCase.expected, actual) t.Errorf("[test #%d] Unexpected columns string:\nwant:\t%s\nhave:\t%s", ndx, testCase.expected, actual)
} }
if testCase.onlyToDBColumnNdx >= 0 { if testCase.onlyToDBColumnNdx >= 0 {
columns[testCase.onlyToDBColumnNdx].MapType = core.TWOSIDES columns[testCase.onlyToDBColumnNdx].MapType = schemas.TWOSIDES
} }
} }
} }
@ -69,7 +69,7 @@ func BenchmarkColumnsStringGeneration(b *testing.B) {
if testCase.onlyToDBColumnNdx >= 0 { if testCase.onlyToDBColumnNdx >= 0 {
columns := statement.RefTable.Columns() columns := statement.RefTable.Columns()
columns[testCase.onlyToDBColumnNdx].MapType = core.ONLYTODB // !nemec784! Column must be skipped columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped
} }
b.StartTimer() b.StartTimer()
@ -88,7 +88,7 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
b.StopTimer() b.StopTimer()
mapCols := make(map[string]bool) mapCols := make(map[string]bool)
cols := []*core.Column{ cols := []*schemas.Column{
{Name: `ID`}, {Name: `ID`},
{Name: `IsDeleted`}, {Name: `IsDeleted`},
{Name: `Caption`}, {Name: `Caption`},
@ -122,7 +122,7 @@ func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) {
b.StopTimer() b.StopTimer()
mapCols := make(map[string]bool) mapCols := make(map[string]bool)
cols := []*core.Column{ cols := []*schemas.Column{
{Name: `ID`}, {Name: `ID`},
{Name: `IsDeleted`}, {Name: `IsDeleted`},
{Name: `Caption`}, {Name: `Caption`},

43
tag.go
View File

@ -11,15 +11,36 @@ import (
"strings" "strings"
"time" "time"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
func splitTag(tag string) (tags []string) {
tag = strings.TrimSpace(tag)
var hasQuote = false
var lastIdx = 0
for i, t := range tag {
if t == '\'' {
hasQuote = !hasQuote
} else if t == ' ' {
if lastIdx < i && !hasQuote {
tags = append(tags, strings.TrimSpace(tag[lastIdx:i]))
lastIdx = i + 1
}
}
}
if lastIdx < len(tag) {
tags = append(tags, strings.TrimSpace(tag[lastIdx:]))
}
return
}
// tagContext represents a context for xorm tag parse.
type tagContext struct { type tagContext struct {
tagName string tagName string
params []string params []string
preTag, nextTag string preTag, nextTag string
table *core.Table table *schemas.Table
col *core.Column col *schemas.Column
fieldValue reflect.Value fieldValue reflect.Value
isIndex bool isIndex bool
isUnique bool isUnique bool
@ -59,7 +80,7 @@ var (
) )
func init() { func init() {
for k := range core.SqlTypes { for k := range schemas.SqlTypes {
defaultTagHandlers[k] = SQLTypeTagHandler defaultTagHandlers[k] = SQLTypeTagHandler
} }
} }
@ -71,13 +92,13 @@ func IgnoreTagHandler(ctx *tagContext) error {
// OnlyFromDBTagHandler describes mapping direction tag handler // OnlyFromDBTagHandler describes mapping direction tag handler
func OnlyFromDBTagHandler(ctx *tagContext) error { func OnlyFromDBTagHandler(ctx *tagContext) error {
ctx.col.MapType = core.ONLYFROMDB ctx.col.MapType = schemas.ONLYFROMDB
return nil return nil
} }
// OnlyToDBTagHandler describes mapping direction tag handler // OnlyToDBTagHandler describes mapping direction tag handler
func OnlyToDBTagHandler(ctx *tagContext) error { func OnlyToDBTagHandler(ctx *tagContext) error {
ctx.col.MapType = core.ONLYTODB ctx.col.MapType = schemas.ONLYTODB
return nil return nil
} }
@ -177,7 +198,7 @@ func DeletedTagHandler(ctx *tagContext) error {
// IndexTagHandler describes index tag handler // IndexTagHandler describes index tag handler
func IndexTagHandler(ctx *tagContext) error { func IndexTagHandler(ctx *tagContext) error {
if len(ctx.params) > 0 { if len(ctx.params) > 0 {
ctx.indexNames[ctx.params[0]] = core.IndexType ctx.indexNames[ctx.params[0]] = schemas.IndexType
} else { } else {
ctx.isIndex = true ctx.isIndex = true
} }
@ -187,7 +208,7 @@ func IndexTagHandler(ctx *tagContext) error {
// UniqueTagHandler describes unique tag handler // UniqueTagHandler describes unique tag handler
func UniqueTagHandler(ctx *tagContext) error { func UniqueTagHandler(ctx *tagContext) error {
if len(ctx.params) > 0 { if len(ctx.params) > 0 {
ctx.indexNames[ctx.params[0]] = core.UniqueType ctx.indexNames[ctx.params[0]] = schemas.UniqueType
} else { } else {
ctx.isUnique = true ctx.isUnique = true
} }
@ -204,16 +225,16 @@ func CommentTagHandler(ctx *tagContext) error {
// SQLTypeTagHandler describes SQL Type tag handler // SQLTypeTagHandler describes SQL Type tag handler
func SQLTypeTagHandler(ctx *tagContext) error { func SQLTypeTagHandler(ctx *tagContext) error {
ctx.col.SQLType = core.SQLType{Name: ctx.tagName} ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName}
if len(ctx.params) > 0 { if len(ctx.params) > 0 {
if ctx.tagName == core.Enum { if ctx.tagName == schemas.Enum {
ctx.col.EnumOptions = make(map[string]int) ctx.col.EnumOptions = make(map[string]int)
for k, v := range ctx.params { for k, v := range ctx.params {
v = strings.TrimSpace(v) v = strings.TrimSpace(v)
v = strings.Trim(v, "'") v = strings.Trim(v, "'")
ctx.col.EnumOptions[v] = k ctx.col.EnumOptions[v] = k
} }
} else if ctx.tagName == core.Set { } else if ctx.tagName == schemas.Set {
ctx.col.SetOptions = make(map[string]int) ctx.col.SetOptions = make(map[string]int)
for k, v := range ctx.params { for k, v := range ctx.params {
v = strings.TrimSpace(v) v = strings.TrimSpace(v)

View File

@ -11,7 +11,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
type tempUser struct { type tempUser struct {
@ -269,7 +269,7 @@ func TestExtends2(t *testing.T) {
defer session.Close() defer session.Close()
// MSSQL deny insert identity column excep declare as below // MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Begin() err = session.Begin()
assert.NoError(t, err) assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON") _, err = session.Exec("SET IDENTITY_INSERT message ON")
@ -279,7 +279,7 @@ func TestExtends2(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Commit() err = session.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }
@ -339,7 +339,7 @@ func TestExtends3(t *testing.T) {
defer session.Close() defer session.Close()
// MSSQL deny insert identity column excep declare as below // MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Begin() err = session.Begin()
assert.NoError(t, err) assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON") _, err = session.Exec("SET IDENTITY_INSERT message ON")
@ -348,7 +348,7 @@ func TestExtends3(t *testing.T) {
_, err = session.Insert(&msg) _, err = session.Insert(&msg)
assert.NoError(t, err) assert.NoError(t, err)
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Commit() err = session.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }
@ -433,7 +433,7 @@ func TestExtends4(t *testing.T) {
defer session.Close() defer session.Close()
// MSSQL deny insert identity column excep declare as below // MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Begin() err = session.Begin()
assert.NoError(t, err) assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON") _, err = session.Exec("SET IDENTITY_INSERT message ON")
@ -442,7 +442,7 @@ func TestExtends4(t *testing.T) {
_, err = session.Insert(&msg) _, err = session.Insert(&msg)
assert.NoError(t, err) assert.NoError(t, err)
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Commit() err = session.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@ -8,7 +8,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core" "xorm.io/xorm/names"
) )
type IDGonicMapper struct { type IDGonicMapper struct {
@ -20,7 +20,7 @@ func TestGonicMapperID(t *testing.T) {
oldMapper := testEngine.GetColumnMapper() oldMapper := testEngine.GetColumnMapper()
testEngine.UnMapType(rValue(new(IDGonicMapper)).Type()) testEngine.UnMapType(rValue(new(IDGonicMapper)).Type())
testEngine.SetMapper(core.LintGonicMapper) testEngine.SetMapper(names.LintGonicMapper)
defer func() { defer func() {
testEngine.UnMapType(rValue(new(IDGonicMapper)).Type()) testEngine.UnMapType(rValue(new(IDGonicMapper)).Type())
testEngine.SetMapper(oldMapper) testEngine.SetMapper(oldMapper)
@ -57,7 +57,7 @@ func TestSameMapperID(t *testing.T) {
oldMapper := testEngine.GetColumnMapper() oldMapper := testEngine.GetColumnMapper()
testEngine.UnMapType(rValue(new(IDSameMapper)).Type()) testEngine.UnMapType(rValue(new(IDSameMapper)).Type())
testEngine.SetMapper(core.SameMapper{}) testEngine.SetMapper(names.SameMapper{})
defer func() { defer func() {
testEngine.UnMapType(rValue(new(IDSameMapper)).Type()) testEngine.UnMapType(rValue(new(IDSameMapper)).Type())
testEngine.SetMapper(oldMapper) testEngine.SetMapper(oldMapper)

View File

@ -11,7 +11,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
type UserCU struct { type UserCU struct {
@ -203,7 +203,7 @@ func TestAutoIncrTag(t *testing.T) {
func TestTagComment(t *testing.T) { func TestTagComment(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
// FIXME: only support mysql // FIXME: only support mysql
if testEngine.Dialect().DriverName() != core.MYSQL { if testEngine.Dialect().DriverName() != schemas.MYSQL {
return return
} }

View File

@ -7,10 +7,10 @@ package xorm
import ( import (
"reflect" "reflect"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
var ( var (
ptrPkType = reflect.TypeOf(&core.PK{}) ptrPkType = reflect.TypeOf(&schemas.PK{})
pkType = reflect.TypeOf(core.PK{}) pkType = reflect.TypeOf(schemas.PK{})
) )

View File

@ -10,7 +10,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core" "xorm.io/xorm/schemas"
) )
func TestArrayField(t *testing.T) { func TestArrayField(t *testing.T) {
@ -137,8 +137,8 @@ type ConvStruct struct {
Conv ConvString Conv ConvString
Conv2 *ConvString Conv2 *ConvString
Cfg1 ConvConfig Cfg1 ConvConfig
Cfg2 *ConvConfig `xorm:"TEXT"` Cfg2 *ConvConfig `xorm:"TEXT"`
Cfg3 core.Conversion `xorm:"BLOB"` Cfg3 Conversion `xorm:"BLOB"`
Slice SliceType Slice SliceType
} }
@ -267,7 +267,7 @@ type Status struct {
} }
var ( var (
_ core.Conversion = &Status{} _ Conversion = &Status{}
Registered Status = Status{"Registered", "white"} Registered Status = Status{"Registered", "white"}
Approved Status = Status{"Approved", "green"} Approved Status = Status{"Approved", "green"}
Removed Status = Status{"Removed", "red"} Removed Status = Status{"Removed", "red"}
@ -311,7 +311,7 @@ func TestCustomType2(t *testing.T) {
session := testEngine.NewSession() session := testEngine.NewSession()
defer session.Close() defer session.Close()
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Begin() err = session.Begin()
assert.NoError(t, err) assert.NoError(t, err)
_, err = session.Exec("set IDENTITY_INSERT " + tableName + " on") _, err = session.Exec("set IDENTITY_INSERT " + tableName + " on")
@ -322,7 +322,7 @@ func TestCustomType2(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == core.MSSQL { if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Commit() err = session.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }

57
xorm.go
View File

@ -15,7 +15,12 @@ import (
"sync" "sync"
"time" "time"
"xorm.io/core" "xorm.io/xorm/caches"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
) )
const ( const (
@ -23,44 +28,14 @@ const (
Version string = "0.8.0.1015" Version string = "0.8.0.1015"
) )
func regDrvsNDialects() bool {
providedDrvsNDialects := map[string]struct {
dbType core.DbType
getDriver func() core.Driver
getDialect func() core.Dialect
}{
"mssql": {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }},
"odbc": {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access
"mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }},
"mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }},
"postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }},
"pgx": {"postgres", func() core.Driver { return &pqDriverPgx{} }, func() core.Dialect { return &postgres{} }},
"sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }},
"oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }},
"goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }},
}
for driverName, v := range providedDrvsNDialects {
if driver := core.QueryDriver(driverName); driver == nil {
core.RegisterDriver(driverName, v.getDriver())
core.RegisterDialect(v.dbType, v.getDialect)
}
}
return true
}
func close(engine *Engine) { func close(engine *Engine) {
engine.Close() engine.Close()
} }
func init() {
regDrvsNDialects()
}
// NewEngine new a db manager according to the parameter. Currently support four // NewEngine new a db manager according to the parameter. Currently support four
// drivers // drivers
func NewEngine(driverName string, dataSourceName string) (*Engine, error) { func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
driver := core.QueryDriver(driverName) driver := dialects.QueryDriver(driverName)
if driver == nil { if driver == nil {
return nil, fmt.Errorf("Unsupported driver name: %v", driverName) return nil, fmt.Errorf("Unsupported driver name: %v", driverName)
} }
@ -70,9 +45,9 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
return nil, err return nil, err
} }
dialect := core.QueryDialect(uri.DbType) dialect := dialects.QueryDialect(uri.DBType)
if dialect == nil { if dialect == nil {
return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DbType) return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType)
} }
db, err := core.Open(driverName, dataSourceName) db, err := core.Open(driverName, dataSourceName)
@ -88,32 +63,32 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine := &Engine{ engine := &Engine{
db: db, db: db,
dialect: dialect, dialect: dialect,
Tables: make(map[reflect.Type]*core.Table), Tables: make(map[reflect.Type]*schemas.Table),
mutex: &sync.RWMutex{}, mutex: &sync.RWMutex{},
TagIdentifier: "xorm", TagIdentifier: "xorm",
TZLocation: time.Local, TZLocation: time.Local,
tagHandlers: defaultTagHandlers, tagHandlers: defaultTagHandlers,
cachers: make(map[string]core.Cacher), cachers: make(map[string]caches.Cacher),
defaultContext: context.Background(), defaultContext: context.Background(),
} }
if uri.DbType == core.SQLITE { if uri.DBType == schemas.SQLITE {
engine.DatabaseTZ = time.UTC engine.DatabaseTZ = time.UTC
} else { } else {
engine.DatabaseTZ = time.Local engine.DatabaseTZ = time.Local
} }
logger := NewSimpleLogger(os.Stdout) logger := log.NewSimpleLogger(os.Stdout)
logger.SetLevel(core.LOG_INFO) logger.SetLevel(log.LOG_INFO)
engine.SetLogger(logger) engine.SetLogger(logger)
engine.SetMapper(core.NewCacheMapper(new(core.SnakeMapper))) engine.SetMapper(names.NewCacheMapper(new(names.SnakeMapper)))
runtime.SetFinalizer(engine, close) runtime.SetFinalizer(engine, close)
return engine, nil return engine, nil
} }
// NewEngineWithParams new a db manager with params. The params will be passed to dialect. // NewEngineWithParams new a db manager with params. The params will be passed to dialects.
func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) { func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) {
engine, err := NewEngine(driverName, dataSourceName) engine, err := NewEngine(driverName, dataSourceName)
engine.dialect.SetParams(params) engine.dialect.SetParams(params)

View File

@ -8,7 +8,6 @@ import (
"database/sql" "database/sql"
"flag" "flag"
"fmt" "fmt"
"log"
"os" "os"
"strings" "strings"
"testing" "testing"
@ -18,7 +17,10 @@ import (
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv" _ "github.com/ziutek/mymysql/godrv"
"xorm.io/core" "xorm.io/xorm/caches"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
) )
var ( var (
@ -30,14 +32,14 @@ var (
showSQL = flag.Bool("show_sql", true, "show generated SQLs") showSQL = flag.Bool("show_sql", true, "show generated SQLs")
ptrConnStr = flag.String("conn_str", "./test.db?cache=shared&mode=rwc", "test database connection string") ptrConnStr = flag.String("conn_str", "./test.db?cache=shared&mode=rwc", "test database connection string")
mapType = flag.String("map_type", "snake", "indicate the name mapping") mapType = flag.String("map_type", "snake", "indicate the name mapping")
cache = flag.Bool("cache", false, "if enable cache") cacheFlag = flag.Bool("cache", false, "if enable cache")
cluster = flag.Bool("cluster", false, "if this is a cluster") cluster = flag.Bool("cluster", false, "if this is a cluster")
splitter = flag.String("splitter", ";", "the splitter on connstr for cluster") splitter = flag.String("splitter", ";", "the splitter on connstr for cluster")
schema = flag.String("schema", "", "specify the schema") schema = flag.String("schema", "", "specify the schema")
ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb") ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb")
tableMapper core.IMapper tableMapper names.Mapper
colMapper core.IMapper colMapper names.Mapper
) )
func createEngine(dbType, connStr string) error { func createEngine(dbType, connStr string) error {
@ -46,7 +48,7 @@ func createEngine(dbType, connStr string) error {
if !*cluster { if !*cluster {
switch strings.ToLower(dbType) { switch strings.ToLower(dbType) {
case core.MSSQL: case schemas.MSSQL:
db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1)) db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1))
if err != nil { if err != nil {
return err return err
@ -56,7 +58,7 @@ func createEngine(dbType, connStr string) error {
} }
db.Close() db.Close()
*ignoreSelectUpdate = true *ignoreSelectUpdate = true
case core.POSTGRES: case schemas.POSTGRES:
db, err := sql.Open(dbType, connStr) db, err := sql.Open(dbType, connStr)
if err != nil { if err != nil {
return err return err
@ -79,7 +81,7 @@ func createEngine(dbType, connStr string) error {
} }
db.Close() db.Close()
*ignoreSelectUpdate = true *ignoreSelectUpdate = true
case core.MYSQL: case schemas.MYSQL:
db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "mysql", -1)) db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "mysql", -1))
if err != nil { if err != nil {
return err return err
@ -107,20 +109,20 @@ func createEngine(dbType, connStr string) error {
testEngine.SetSchema(*schema) testEngine.SetSchema(*schema)
} }
testEngine.ShowSQL(*showSQL) testEngine.ShowSQL(*showSQL)
testEngine.SetLogLevel(core.LOG_DEBUG) testEngine.SetLogLevel(log.LOG_DEBUG)
if *cache { if *cacheFlag {
cacher := NewLRUCacher(NewMemoryStore(), 100000) cacher := caches.NewLRUCacher(caches.NewMemoryStore(), 100000)
testEngine.SetDefaultCacher(cacher) testEngine.SetDefaultCacher(cacher)
} }
if len(*mapType) > 0 { if len(*mapType) > 0 {
switch *mapType { switch *mapType {
case "snake": case "snake":
testEngine.SetMapper(core.SnakeMapper{}) testEngine.SetMapper(names.SnakeMapper{})
case "same": case "same":
testEngine.SetMapper(core.SameMapper{}) testEngine.SetMapper(names.SameMapper{})
case "gonic": case "gonic":
testEngine.SetMapper(core.LintGonicMapper) testEngine.SetMapper(names.LintGonicMapper)
} }
} }
} }
@ -158,7 +160,7 @@ func TestMain(m *testing.M) {
} }
} else { } else {
if ptrConnStr == nil { if ptrConnStr == nil {
log.Fatal("you should indicate conn string") fmt.Println("you should indicate conn string")
return return
} }
connString = *ptrConnStr connString = *ptrConnStr
@ -175,7 +177,7 @@ func TestMain(m *testing.M) {
fmt.Println("testing", dbType, connString) fmt.Println("testing", dbType, connString)
if err := prepareEngine(); err != nil { if err := prepareEngine(); err != nil {
log.Fatal(err) fmt.Println(err)
return return
} }