Merge core package back into the main repository and split into serval sub packages. #1543

Merged
lunny merged 4 commits from lunny/merge_core2 into master 2020-02-24 08:53:22 +00:00
87 changed files with 4457 additions and 3199 deletions

2
.gitignore vendored
View File

@ -32,3 +32,5 @@ xorm.test
test.db.sql
.idea/
*coverage.out

View File

@ -8,6 +8,8 @@ import (
"testing"
"time"
"xorm.io/xorm/caches"
"github.com/stretchr/testify/assert"
)
@ -21,7 +23,7 @@ func TestCacheFind(t *testing.T) {
}
oldCacher := testEngine.GetDefaultCacher()
cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000)
cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)
testEngine.SetDefaultCacher(cacher)
assert.NoError(t, testEngine.Sync2(new(MailBox)))
@ -96,7 +98,7 @@ func TestCacheFind2(t *testing.T) {
}
oldCacher := testEngine.GetDefaultCacher()
cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000)
cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)
testEngine.SetDefaultCacher(cacher)
assert.NoError(t, testEngine.Sync2(new(MailBox2)))
@ -147,7 +149,7 @@ func TestCacheGet(t *testing.T) {
}
oldCacher := testEngine.GetDefaultCacher()
cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000)
cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)
testEngine.SetDefaultCacher(cacher)
assert.NoError(t, testEngine.Sync2(new(MailBox3)))

99
caches/cache.go Normal file
View File

@ -0,0 +1,99 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package caches
import (
"bytes"
"encoding/gob"
"errors"
"fmt"
"strings"
"time"
"xorm.io/xorm/schemas"
)
const (
// CacheExpired is default cache expired time
CacheExpired = 60 * time.Minute
// CacheMaxMemory is not use now
CacheMaxMemory = 256
// CacheGcInterval represents interval time to clear all expired nodes
CacheGcInterval = 10 * time.Minute
// CacheGcMaxRemoved represents max nodes removed when gc
CacheGcMaxRemoved = 20
)
// list all the errors
var (
ErrCacheMiss = errors.New("xorm/cache: key not found")
ErrNotStored = errors.New("xorm/cache: not stored")
// ErrNotExist record does not exist error
ErrNotExist = errors.New("Record does not exist")
)
// CacheStore is a interface to store cache
type CacheStore interface {
// key is primary key or composite primary key
// value is struct's pointer
// key format : <tablename>-p-<pk1>-<pk2>...
Put(key string, value interface{}) error
Get(key string) (interface{}, error)
Del(key string) error
}
// Cacher is an interface to provide cache
// id format : u-<pk1>-<pk2>...
type Cacher interface {
GetIds(tableName, sql string) interface{}
GetBean(tableName string, id string) interface{}
PutIds(tableName, sql string, ids interface{})
PutBean(tableName string, id string, obj interface{})
DelIds(tableName, sql string)
DelBean(tableName string, id string)
ClearIds(tableName string)
ClearBeans(tableName string)
}
func encodeIds(ids []schemas.PK) (string, error) {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
err := enc.Encode(ids)
return buf.String(), err
}
func decodeIds(s string) ([]schemas.PK, error) {
pks := make([]schemas.PK, 0)
dec := gob.NewDecoder(strings.NewReader(s))
err := dec.Decode(&pks)
return pks, err
}
// GetCacheSql returns cacher PKs via SQL
func GetCacheSql(m Cacher, tableName, sql string, args interface{}) ([]schemas.PK, error) {
bytes := m.GetIds(tableName, GenSqlKey(sql, args))
if bytes == nil {
return nil, errors.New("Not Exist")
}
return decodeIds(bytes.(string))
}
// PutCacheSql puts cacher SQL and PKs
func PutCacheSql(m Cacher, ids []schemas.PK, tableName, sql string, args interface{}) error {
bytes, err := encodeIds(ids)
if err != nil {
return err
}
m.PutIds(tableName, GenSqlKey(sql, args), bytes)
return nil
}
// GenSqlKey generates cache key
func GenSqlKey(sql string, args interface{}) string {
return fmt.Sprintf("%v-%v", sql, args)
}

View File

@ -2,15 +2,13 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package caches
import (
"container/list"
"fmt"
"sync"
"time"
"xorm.io/core"
)
// LRUCacher implments cache object facilities
@ -19,7 +17,7 @@ type LRUCacher struct {
sqlList *list.List
idIndex map[string]map[string]*list.Element
sqlIndex map[string]map[string]*list.Element
store core.CacheStore
store CacheStore
mutex sync.Mutex
MaxElementSize int
Expired time.Duration
@ -27,15 +25,15 @@ type LRUCacher struct {
}
// NewLRUCacher creates a cacher
func NewLRUCacher(store core.CacheStore, maxElementSize int) *LRUCacher {
func NewLRUCacher(store CacheStore, maxElementSize int) *LRUCacher {
return NewLRUCacher2(store, 3600*time.Second, maxElementSize)
}
// NewLRUCacher2 creates a cache include different params
func NewLRUCacher2(store core.CacheStore, expired time.Duration, maxElementSize int) *LRUCacher {
func NewLRUCacher2(store CacheStore, expired time.Duration, maxElementSize int) *LRUCacher {
cacher := &LRUCacher{store: store, idList: list.New(),
sqlList: list.New(), Expired: expired,
GcInterval: core.CacheGcInterval, MaxElementSize: maxElementSize,
GcInterval: CacheGcInterval, MaxElementSize: maxElementSize,
sqlIndex: make(map[string]map[string]*list.Element),
idIndex: make(map[string]map[string]*list.Element),
}
@ -57,7 +55,7 @@ func (m *LRUCacher) GC() {
defer m.mutex.Unlock()
var removedNum int
for e := m.idList.Front(); e != nil; {
if removedNum <= core.CacheGcMaxRemoved &&
if removedNum <= CacheGcMaxRemoved &&
time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired {
removedNum++
next := e.Next()
@ -71,7 +69,7 @@ func (m *LRUCacher) GC() {
removedNum = 0
for e := m.sqlList.Front(); e != nil; {
if removedNum <= core.CacheGcMaxRemoved &&
if removedNum <= CacheGcMaxRemoved &&
time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired {
removedNum++
next := e.Next()
@ -268,11 +266,11 @@ type sqlNode struct {
}
func genSQLKey(sql string, args interface{}) string {
return fmt.Sprintf("%v-%v", sql, args)
return fmt.Sprintf("%s-%v", sql, args)
}
func genID(prefix string, id string) string {
return fmt.Sprintf("%v-%v", prefix, id)
return fmt.Sprintf("%s-%s", prefix, id)
}
func newIDNode(tbName string, id string) *idNode {

View File

@ -2,13 +2,13 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package caches
import (
"testing"
"github.com/stretchr/testify/assert"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
func TestLRUCache(t *testing.T) {
@ -20,7 +20,7 @@ func TestLRUCache(t *testing.T) {
cacher := NewLRUCacher(store, 10000)
tableName := "cache_object1"
pks := []core.PK{
pks := []schemas.PK{
{1},
{2},
}

View File

@ -2,15 +2,13 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package caches
import (
"sync"
"xorm.io/core"
)
var _ core.CacheStore = NewMemoryStore()
var _ CacheStore = NewMemoryStore()
// MemoryStore represents in-memory store
type MemoryStore struct {

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package caches
import (
"testing"

View File

@ -346,3 +346,10 @@ func asBool(bs []byte) (bool, error) {
}
return strconv.ParseBool(string(bs))
}
// Conversion is an interface. A type implements Conversion will according
// the custom method to fill into database and retrieve from database.
type Conversion interface {
FromDB([]byte) error
ToDB() ([]byte, error)
}

229
core/db.go Normal file
View File

@ -0,0 +1,229 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"sync"
"xorm.io/xorm/names"
)
var (
// DefaultCacheSize sets the default cache size
DefaultCacheSize = 200
)
func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
vv := reflect.ValueOf(mp)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return "", []interface{}{}, ErrNoMapPointer
}
args := make([]interface{}, 0, len(vv.Elem().MapKeys()))
var err error
query = re.ReplaceAllStringFunc(query, func(src string) string {
v := vv.Elem().MapIndex(reflect.ValueOf(src[1:]))
if !v.IsValid() {
err = fmt.Errorf("map key %s is missing", src[1:])
} else {
args = append(args, v.Interface())
}
return "?"
})
return query, args, err
}
func StructToSlice(query string, st interface{}) (string, []interface{}, error) {
vv := reflect.ValueOf(st)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return "", []interface{}{}, ErrNoStructPointer
}
args := make([]interface{}, 0)
var err error
query = re.ReplaceAllStringFunc(query, func(src string) string {
fv := vv.Elem().FieldByName(src[1:]).Interface()
if v, ok := fv.(driver.Valuer); ok {
var value driver.Value
value, err = v.Value()
if err != nil {
return "?"
}
args = append(args, value)
} else {
args = append(args, fv)
}
return "?"
})
if err != nil {
return "", []interface{}{}, err
}
return query, args, nil
}
type cacheStruct struct {
value reflect.Value
idx int
}
// DB is a wrap of sql.DB with extra contents
type DB struct {
*sql.DB
Mapper names.Mapper
reflectCache map[reflect.Type]*cacheStruct
reflectCacheMutex sync.RWMutex
}
// Open opens a database
func Open(driverName, dataSourceName string) (*DB, error) {
db, err := sql.Open(driverName, dataSourceName)
if err != nil {
return nil, err
}
return &DB{
DB: db,
Mapper: names.NewCacheMapper(&names.SnakeMapper{}),
reflectCache: make(map[reflect.Type]*cacheStruct),
}, nil
}
// FromDB creates a DB from a sql.DB
func FromDB(db *sql.DB) *DB {
return &DB{
DB: db,
Mapper: names.NewCacheMapper(&names.SnakeMapper{}),
reflectCache: make(map[reflect.Type]*cacheStruct),
}
}
func (db *DB) reflectNew(typ reflect.Type) reflect.Value {
db.reflectCacheMutex.Lock()
defer db.reflectCacheMutex.Unlock()
cs, ok := db.reflectCache[typ]
if !ok || cs.idx+1 > DefaultCacheSize-1 {
cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0}
db.reflectCache[typ] = cs
} else {
cs.idx = cs.idx + 1
}
return cs.value.Index(cs.idx).Addr()
}
// QueryContext overwrites sql.DB.QueryContext
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
rows, err := db.DB.QueryContext(ctx, query, args...)
if err != nil {
if rows != nil {
rows.Close()
}
return nil, err
}
return &Rows{rows, db}, nil
}
// Query overwrites sql.DB.Query
func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
return db.QueryContext(context.Background(), query, args...)
}
// QueryMapContext executes query with parameters via map and context
func (db *DB) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) {
query, args, err := MapToSlice(query, mp)
if err != nil {
return nil, err
}
return db.QueryContext(ctx, query, args...)
}
// QueryMap executes query with parameters via map
func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
return db.QueryMapContext(context.Background(), query, mp)
}
func (db *DB) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) {
query, args, err := StructToSlice(query, st)
if err != nil {
return nil, err
}
return db.QueryContext(ctx, query, args...)
}
func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
return db.QueryStructContext(context.Background(), query, st)
}
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return &Row{nil, err}
}
return &Row{rows, nil}
}
func (db *DB) QueryRow(query string, args ...interface{}) *Row {
return db.QueryRowContext(context.Background(), query, args...)
}
func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row {
query, args, err := MapToSlice(query, mp)
if err != nil {
return &Row{nil, err}
}
return db.QueryRowContext(ctx, query, args...)
}
func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
return db.QueryRowMapContext(context.Background(), query, mp)
}
func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row {
query, args, err := StructToSlice(query, st)
if err != nil {
return &Row{nil, err}
}
return db.QueryRowContext(ctx, query, args...)
}
func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
return db.QueryRowStructContext(context.Background(), query, st)
}
var (
re = regexp.MustCompile(`[?](\w+)`)
)
// ExecMapContext exec map with context.Context
// insert into (name) values (?)
// insert into (name) values (?name)
func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) {
query, args, err := MapToSlice(query, mp)
if err != nil {
return nil, err
}
return db.DB.ExecContext(ctx, query, args...)
}
func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) {
return db.ExecMapContext(context.Background(), query, mp)
}
func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) {
query, args, err := StructToSlice(query, st)
if err != nil {
return nil, err
}
return db.DB.ExecContext(ctx, query, args...)
}
func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
return db.ExecStructContext(context.Background(), query, st)
}

684
core/db_test.go Normal file
View File

@ -0,0 +1,684 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import (
"errors"
"flag"
"os"
"testing"
"time"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"xorm.io/xorm/names"
)
var (
dbtype = flag.String("dbtype", "mysql", "database type")
dbConn = flag.String("dbConn", "root:@/core_test?charset=utf8", "database connect string")
createTableSql string
)
func TestMain(m *testing.M) {
flag.Parse()
switch *dbtype {
case "sqlite3":
createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " +
"`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);"
case "mysql":
fallthrough
default:
createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " +
"`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);"
}
exitCode := m.Run()
os.Exit(exitCode)
}
func testOpen() (*DB, error) {
switch *dbtype {
case "sqlite3":
os.Remove("./test.db")
return Open("sqlite3", "./test.db")
case "mysql":
return Open("mysql", *dbConn)
default:
panic("no db type")
}
}
func BenchmarkOriQuery(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
for rows.Next() {
var Id int64
var Name, Title, Alias, NickName string
var Age float32
var Created NullTime
err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName, &Created)
if err != nil {
b.Error(err)
}
//fmt.Println(Id, Name, Title, Age, Alias, NickName)
}
rows.Close()
}
}
type User struct {
Id int64
Name string
Title string
Age float32
Alias string
NickName string
Created NullTime
}
func BenchmarkStructQuery(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
for rows.Next() {
var user User
err = rows.ScanStructByIndex(&user)
if err != nil {
b.Error(err)
}
if user.Name != "xlw" {
b.Log(user)
b.Error(errors.New("name should be xlw"))
}
}
rows.Close()
}
}
func BenchmarkStruct2Query(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
db.Mapper = names.NewCacheMapper(&names.SnakeMapper{})
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
for rows.Next() {
var user User
err = rows.ScanStructByName(&user)
if err != nil {
b.Error(err)
}
if user.Name != "xlw" {
b.Log(user)
b.Error(errors.New("name should be xlw"))
}
}
rows.Close()
}
}
func BenchmarkSliceInterfaceQuery(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
cols, err := rows.Columns()
if err != nil {
b.Error(err)
}
for rows.Next() {
slice := make([]interface{}, len(cols))
err = rows.ScanSlice(&slice)
if err != nil {
b.Error(err)
}
b.Log(slice)
switch slice[1].(type) {
case *string:
if *slice[1].(*string) != "xlw" {
b.Error(errors.New("name should be xlw"))
}
case []byte:
if string(slice[1].([]byte)) != "xlw" {
b.Error(errors.New("name should be xlw"))
}
}
}
rows.Close()
}
}
/*func BenchmarkSliceBytesQuery(b *testing.B) {
b.StopTimer()
os.Remove("./test.db")
db, err := Open("sqlite3", "./test.db")
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
cols, err := rows.Columns()
if err != nil {
b.Error(err)
}
for rows.Next() {
slice := make([][]byte, len(cols))
err = rows.ScanSlice(&slice)
if err != nil {
b.Error(err)
}
if string(slice[1]) != "xlw" {
fmt.Println(slice)
b.Error(errors.New("name should be xlw"))
}
}
rows.Close()
}
}
*/
func BenchmarkSliceStringQuery(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name, created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
cols, err := rows.Columns()
if err != nil {
b.Error(err)
}
for rows.Next() {
slice := make([]*string, len(cols))
err = rows.ScanSlice(&slice)
if err != nil {
b.Error(err)
}
if (*slice[1]) != "xlw" {
b.Log(slice)
b.Error(errors.New("name should be xlw"))
}
}
rows.Close()
}
}
func BenchmarkMapInterfaceQuery(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
for rows.Next() {
m := make(map[string]interface{})
err = rows.ScanMap(&m)
if err != nil {
b.Error(err)
}
switch m["name"].(type) {
case string:
if m["name"].(string) != "xlw" {
b.Log(m)
b.Error(errors.New("name should be xlw"))
}
case []byte:
if string(m["name"].([]byte)) != "xlw" {
b.Log(m)
b.Error(errors.New("name should be xlw"))
}
}
}
rows.Close()
}
}
/*func BenchmarkMapBytesQuery(b *testing.B) {
b.StopTimer()
os.Remove("./test.db")
db, err := Open("sqlite3", "./test.db")
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
for rows.Next() {
m := make(map[string][]byte)
err = rows.ScanMap(&m)
if err != nil {
b.Error(err)
}
if string(m["name"]) != "xlw" {
fmt.Println(m)
b.Error(errors.New("name should be xlw"))
}
}
rows.Close()
}
}
*/
/*
func BenchmarkMapStringQuery(b *testing.B) {
b.StopTimer()
os.Remove("./test.db")
db, err := Open("sqlite3", "./test.db")
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
for i := 0; i < 50; i++ {
_, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("select * from user")
if err != nil {
b.Error(err)
}
for rows.Next() {
m := make(map[string]string)
err = rows.ScanMap(&m)
if err != nil {
b.Error(err)
}
if m["name"] != "xlw" {
fmt.Println(m)
b.Error(errors.New("name should be xlw"))
}
}
rows.Close()
}
}*/
func BenchmarkExec(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
b.StartTimer()
for i := 0; i < b.N; i++ {
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
if err != nil {
b.Error(err)
}
}
}
func BenchmarkExecMap(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
b.StartTimer()
mp := map[string]interface{}{
"name": "xlw",
"title": "tester",
"age": 1.2,
"alias": "lunny",
"nick_name": "lunny xiao",
"created": time.Now(),
}
for i := 0; i < b.N; i++ {
_, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name, created) "+
"values (?name,?title,?age,?alias,?nick_name,?created)",
&mp)
if err != nil {
b.Error(err)
}
}
}
func TestExecMap(t *testing.T) {
db, err := testOpen()
if err != nil {
t.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
t.Error(err)
}
mp := map[string]interface{}{
"name": "xlw",
"title": "tester",
"age": 1.2,
"alias": "lunny",
"nick_name": "lunny xiao",
"created": time.Now(),
}
_, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name,created) "+
"values (?name,?title,?age,?alias,?nick_name,?created)",
&mp)
if err != nil {
t.Error(err)
}
rows, err := db.Query("select * from user")
if err != nil {
t.Error(err)
}
for rows.Next() {
var user User
err = rows.ScanStructByName(&user)
if err != nil {
t.Error(err)
}
t.Log("--", user)
}
}
func TestExecStruct(t *testing.T) {
db, err := testOpen()
if err != nil {
t.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
t.Error(err)
}
user := User{Name: "xlw",
Title: "tester",
Age: 1.2,
Alias: "lunny",
NickName: "lunny xiao",
Created: NullTime(time.Now()),
}
_, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+
"values (?Name,?Title,?Age,?Alias,?NickName,?Created)",
&user)
if err != nil {
t.Error(err)
}
rows, err := db.QueryStruct("select * from user where `name` = ?Name", &user)
if err != nil {
t.Error(err)
}
for rows.Next() {
var user User
err = rows.ScanStructByName(&user)
if err != nil {
t.Error(err)
}
t.Log("1--", user)
}
}
func BenchmarkExecStruct(b *testing.B) {
b.StopTimer()
db, err := testOpen()
if err != nil {
b.Error(err)
}
defer db.Close()
_, err = db.Exec(createTableSql)
if err != nil {
b.Error(err)
}
b.StartTimer()
user := User{Name: "xlw",
Title: "tester",
Age: 1.2,
Alias: "lunny",
NickName: "lunny xiao",
Created: NullTime(time.Now()),
}
for i := 0; i < b.N; i++ {
_, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+
"values (?Name,?Title,?Age,?Alias,?NickName,?Created)",
&user)
if err != nil {
b.Error(err)
}
}
}

14
core/error.go Normal file
View File

@ -0,0 +1,14 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import "errors"
var (
// ErrNoMapPointer represents error when no map pointer
ErrNoMapPointer = errors.New("mp should be a map's pointer")
// ErrNoStructPointer represents error when no struct pointer
ErrNoStructPointer = errors.New("mp should be a struct's pointer")
)

338
core/rows.go Normal file
View File

@ -0,0 +1,338 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import (
"database/sql"
"errors"
"reflect"
"sync"
)
type Rows struct {
*sql.Rows
db *DB
}
func (rs *Rows) ToMapString() ([]map[string]string, error) {
cols, err := rs.Columns()
if err != nil {
return nil, err
}
var results = make([]map[string]string, 0, 10)
for rs.Next() {
var record = make(map[string]string, len(cols))
err = rs.ScanMap(&record)
if err != nil {
return nil, err
}
results = append(results, record)
}
return results, nil
}
// scan data to a struct's pointer according field index
func (rs *Rows) ScanStructByIndex(dest ...interface{}) error {
if len(dest) == 0 {
return errors.New("at least one struct")
}
vvvs := make([]reflect.Value, len(dest))
for i, s := range dest {
vv := reflect.ValueOf(s)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return errors.New("dest should be a struct's pointer")
}
vvvs[i] = vv.Elem()
}
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
var i = 0
for _, vvv := range vvvs {
for j := 0; j < vvv.NumField(); j++ {
newDest[i] = vvv.Field(j).Addr().Interface()
i = i + 1
}
}
return rs.Rows.Scan(newDest...)
}
var (
fieldCache = make(map[reflect.Type]map[string]int)
fieldCacheMutex sync.RWMutex
)
func fieldByName(v reflect.Value, name string) reflect.Value {
t := v.Type()
fieldCacheMutex.RLock()
cache, ok := fieldCache[t]
fieldCacheMutex.RUnlock()
if !ok {
cache = make(map[string]int)
for i := 0; i < v.NumField(); i++ {
cache[t.Field(i).Name] = i
}
fieldCacheMutex.Lock()
fieldCache[t] = cache
fieldCacheMutex.Unlock()
}
if i, ok := cache[name]; ok {
return v.Field(i)
}
return reflect.Zero(t)
}
// scan data to a struct's pointer according field name
func (rs *Rows) ScanStructByName(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return errors.New("dest should be a struct's pointer")
}
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
var v EmptyScanner
for j, name := range cols {
f := fieldByName(vv.Elem(), rs.db.Mapper.Table2Obj(name))
if f.IsValid() {
newDest[j] = f.Addr().Interface()
} else {
newDest[j] = &v
}
}
return rs.Rows.Scan(newDest...)
}
// scan data to a slice's pointer, slice's length should equal to columns' number
func (rs *Rows) ScanSlice(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice {
return errors.New("dest should be a slice's pointer")
}
vvv := vv.Elem()
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
for j := 0; j < len(cols); j++ {
if j >= vvv.Len() {
newDest[j] = reflect.New(vvv.Type().Elem()).Interface()
} else {
newDest[j] = vvv.Index(j).Addr().Interface()
}
}
err = rs.Rows.Scan(newDest...)
if err != nil {
return err
}
srcLen := vvv.Len()
for i := srcLen; i < len(cols); i++ {
vvv = reflect.Append(vvv, reflect.ValueOf(newDest[i]).Elem())
}
return nil
}
// scan data to a map's pointer
func (rs *Rows) ScanMap(dest interface{}) error {
vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return errors.New("dest should be a map's pointer")
}
cols, err := rs.Columns()
if err != nil {
return err
}
newDest := make([]interface{}, len(cols))
vvv := vv.Elem()
for i := range cols {
newDest[i] = rs.db.reflectNew(vvv.Type().Elem()).Interface()
}
err = rs.Rows.Scan(newDest...)
if err != nil {
return err
}
for i, name := range cols {
vname := reflect.ValueOf(name)
vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
}
return nil
}
type Row struct {
rows *Rows
// One of these two will be non-nil:
err error // deferred error for easy chaining
}
// ErrorRow return an error row
func ErrorRow(err error) *Row {
return &Row{
err: err,
}
}
// NewRow from rows
func NewRow(rows *Rows, err error) *Row {
return &Row{rows, err}
}
func (row *Row) Columns() ([]string, error) {
if row.err != nil {
return nil, row.err
}
return row.rows.Columns()
}
func (row *Row) Scan(dest ...interface{}) error {
if row.err != nil {
return row.err
}
defer row.rows.Close()
for _, dp := range dest {
if _, ok := dp.(*sql.RawBytes); ok {
return errors.New("sql: RawBytes isn't allowed on Row.Scan")
}
}
if !row.rows.Next() {
if err := row.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
err := row.rows.Scan(dest...)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return row.rows.Close()
}
func (row *Row) ScanStructByName(dest interface{}) error {
if row.err != nil {
return row.err
}
defer row.rows.Close()
if !row.rows.Next() {
if err := row.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
err := row.rows.ScanStructByName(dest)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return row.rows.Close()
}
func (row *Row) ScanStructByIndex(dest interface{}) error {
if row.err != nil {
return row.err
}
defer row.rows.Close()
if !row.rows.Next() {
if err := row.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
err := row.rows.ScanStructByIndex(dest)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return row.rows.Close()
}
// scan data to a slice's pointer, slice's length should equal to columns' number
func (row *Row) ScanSlice(dest interface{}) error {
if row.err != nil {
return row.err
}
defer row.rows.Close()
if !row.rows.Next() {
if err := row.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
err := row.rows.ScanSlice(dest)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return row.rows.Close()
}
// scan data to a map's pointer
func (row *Row) ScanMap(dest interface{}) error {
if row.err != nil {
return row.err
}
defer row.rows.Close()
if !row.rows.Next() {
if err := row.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
err := row.rows.ScanMap(dest)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return row.rows.Close()
}
func (row *Row) ToMapString() (map[string]string, error) {
cols, err := row.Columns()
if err != nil {
return nil, err
}
var record = make(map[string]string, len(cols))
err = row.ScanMap(&record)
if err != nil {
return nil, err
}
return record, nil
}

66
core/scan.go Normal file
View File

@ -0,0 +1,66 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import (
"database/sql/driver"
"fmt"
"time"
)
type NullTime time.Time
var (
_ driver.Valuer = NullTime{}
)
func (ns *NullTime) Scan(value interface{}) error {
if value == nil {
return nil
}
return convertTime(ns, value)
}
// Value implements the driver Valuer interface.
func (ns NullTime) Value() (driver.Value, error) {
if (time.Time)(ns).IsZero() {
return nil, nil
}
return (time.Time)(ns).Format("2006-01-02 15:04:05"), nil
}
func convertTime(dest *NullTime, src interface{}) error {
// Common cases, without reflect.
switch s := src.(type) {
case string:
t, err := time.Parse("2006-01-02 15:04:05", s)
if err != nil {
return err
}
*dest = NullTime(t)
return nil
case []uint8:
t, err := time.Parse("2006-01-02 15:04:05", string(s))
if err != nil {
return err
}
*dest = NullTime(t)
return nil
case time.Time:
*dest = NullTime(s)
return nil
case nil:
default:
return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest)
}
return nil
}
type EmptyScanner struct {
}
func (EmptyScanner) Scan(src interface{}) error {
return nil
}

166
core/stmt.go Normal file
View File

@ -0,0 +1,166 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import (
"context"
"database/sql"
"errors"
"reflect"
)
// Stmt reprents a stmt objects
type Stmt struct {
*sql.Stmt
db *DB
names map[string]int
}
func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
names := make(map[string]int)
var i int
query = re.ReplaceAllStringFunc(query, func(src string) string {
names[src[1:]] = i
i += 1
return "?"
})
stmt, err := db.DB.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &Stmt{stmt, db, names}, nil
}
func (db *DB) Prepare(query string) (*Stmt, error) {
return db.PrepareContext(context.Background(), query)
}
func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, error) {
vv := reflect.ValueOf(mp)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return nil, errors.New("mp should be a map's pointer")
}
args := make([]interface{}, len(s.names))
for k, i := range s.names {
args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
}
return s.Stmt.ExecContext(ctx, args...)
}
func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
return s.ExecMapContext(context.Background(), mp)
}
func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Result, error) {
vv := reflect.ValueOf(st)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return nil, errors.New("mp should be a map's pointer")
}
args := make([]interface{}, len(s.names))
for k, i := range s.names {
args[i] = vv.Elem().FieldByName(k).Interface()
}
return s.Stmt.ExecContext(ctx, args...)
}
func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
return s.ExecStructContext(context.Background(), st)
}
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
rows, err := s.Stmt.QueryContext(ctx, args...)
if err != nil {
return nil, err
}
return &Rows{rows, s.db}, nil
}
func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return s.QueryContext(context.Background(), args...)
}
func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, error) {
vv := reflect.ValueOf(mp)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return nil, errors.New("mp should be a map's pointer")
}
args := make([]interface{}, len(s.names))
for k, i := range s.names {
args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
}
return s.QueryContext(ctx, args...)
}
func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
return s.QueryMapContext(context.Background(), mp)
}
func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, error) {
vv := reflect.ValueOf(st)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return nil, errors.New("mp should be a map's pointer")
}
args := make([]interface{}, len(s.names))
for k, i := range s.names {
args[i] = vv.Elem().FieldByName(k).Interface()
}
return s.Query(args...)
}
func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
return s.QueryStructContext(context.Background(), st)
}
func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
rows, err := s.QueryContext(ctx, args...)
return &Row{rows, err}
}
func (s *Stmt) QueryRow(args ...interface{}) *Row {
return s.QueryRowContext(context.Background(), args...)
}
func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row {
vv := reflect.ValueOf(mp)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
return &Row{nil, errors.New("mp should be a map's pointer")}
}
args := make([]interface{}, len(s.names))
for k, i := range s.names {
args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
}
return s.QueryRowContext(ctx, args...)
}
func (s *Stmt) QueryRowMap(mp interface{}) *Row {
return s.QueryRowMapContext(context.Background(), mp)
}
func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row {
vv := reflect.ValueOf(st)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
return &Row{nil, errors.New("st should be a struct's pointer")}
}
args := make([]interface{}, len(s.names))
for k, i := range s.names {
args[i] = vv.Elem().FieldByName(k).Interface()
}
return s.QueryRowContext(ctx, args...)
}
func (s *Stmt) QueryRowStruct(st interface{}) *Row {
return s.QueryRowStructContext(context.Background(), st)
}

153
core/tx.go Normal file
View File

@ -0,0 +1,153 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
import (
"context"
"database/sql"
)
type Tx struct {
*sql.Tx
db *DB
}
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
tx, err := db.DB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{tx, db}, nil
}
func (db *DB) Begin() (*Tx, error) {
tx, err := db.DB.Begin()
if err != nil {
return nil, err
}
return &Tx{tx, db}, nil
}
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
names := make(map[string]int)
var i int
query = re.ReplaceAllStringFunc(query, func(src string) string {
names[src[1:]] = i
i += 1
return "?"
})
stmt, err := tx.Tx.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &Stmt{stmt, tx.db, names}, nil
}
func (tx *Tx) Prepare(query string) (*Stmt, error) {
return tx.PrepareContext(context.Background(), query)
}
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
stmt.Stmt = tx.Tx.StmtContext(ctx, stmt.Stmt)
return stmt
}
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
return tx.StmtContext(context.Background(), stmt)
}
func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) {
query, args, err := MapToSlice(query, mp)
if err != nil {
return nil, err
}
return tx.Tx.ExecContext(ctx, query, args...)
}
func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) {
return tx.ExecMapContext(context.Background(), query, mp)
}
func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) {
query, args, err := StructToSlice(query, st)
if err != nil {
return nil, err
}
return tx.Tx.ExecContext(ctx, query, args...)
}
func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
return tx.ExecStructContext(context.Background(), query, st)
}
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
rows, err := tx.Tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &Rows{rows, tx.db}, nil
}
func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
return tx.QueryContext(context.Background(), query, args...)
}
func (tx *Tx) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) {
query, args, err := MapToSlice(query, mp)
if err != nil {
return nil, err
}
return tx.QueryContext(ctx, query, args...)
}
func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) {
return tx.QueryMapContext(context.Background(), query, mp)
}
func (tx *Tx) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) {
query, args, err := StructToSlice(query, st)
if err != nil {
return nil, err
}
return tx.QueryContext(ctx, query, args...)
}
func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) {
return tx.QueryStructContext(context.Background(), query, st)
}
func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := tx.QueryContext(ctx, query, args...)
return &Row{rows, err}
}
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
return tx.QueryRowContext(context.Background(), query, args...)
}
func (tx *Tx) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row {
query, args, err := MapToSlice(query, mp)
if err != nil {
return &Row{nil, err}
}
return tx.QueryRowContext(ctx, query, args...)
}
func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
return tx.QueryRowMapContext(context.Background(), query, mp)
}
func (tx *Tx) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row {
query, args, err := StructToSlice(query, st)
if err != nil {
return &Row{nil, err}
}
return tx.QueryRowContext(ctx, query, args...)
}
func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row {
return tx.QueryRowStructContext(context.Background(), query, st)
}

410
dialects/dialect.go Normal file
View File

@ -0,0 +1,410 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dialects
import (
"fmt"
"strings"
"time"
"xorm.io/xorm/core"
"xorm.io/xorm/log"
"xorm.io/xorm/schemas"
)
type DBType string
type URI struct {
DBType DBType
Proto string
Host string
Port string
DBName string
User string
Passwd string
Charset string
Laddr string
Raddr string
Timeout time.Duration
Schema string
}
// a dialect is a driver's wrapper
type Dialect interface {
SetLogger(logger log.Logger)
Init(*core.DB, *URI, string, string) error
URI() *URI
DB() *core.DB
DBType() DBType
SQLType(*schemas.Column) string
FormatBytes(b []byte) string
DriverName() string
DataSourceName() string
IsReserved(string) bool
Quote(string) string
AndStr() string
OrStr() string
EqStr() string
RollBackStr() string
AutoIncrStr() string
SupportInsertMany() bool
SupportEngine() bool
SupportCharset() bool
SupportDropIfExists() bool
IndexOnTable() bool
ShowCreateNull() bool
IndexCheckSQL(tableName, idxName string) (string, []interface{})
TableCheckSQL(tableName string) (string, []interface{})
IsColumnExist(tableName string, colName string) (bool, error)
CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string
DropTableSQL(tableName string) string
CreateIndexSQL(tableName string, index *schemas.Index) string
DropIndexSQL(tableName string, index *schemas.Index) string
ModifyColumnSQL(tableName string, col *schemas.Column) string
ForUpdateSQL(query string) string
// CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error
// MustDropTable(tableName string) error
GetColumns(tableName string) ([]string, map[string]*schemas.Column, error)
GetTables() ([]*schemas.Table, error)
GetIndexes(tableName string) (map[string]*schemas.Index, error)
Filters() []Filter
SetParams(params map[string]string)
}
func OpenDialect(dialect Dialect) (*core.DB, error) {
return core.Open(dialect.DriverName(), dialect.DataSourceName())
}
// Base represents a basic dialect and all real dialects could embed this struct
type Base struct {
db *core.DB
dialect Dialect
driverName string
dataSourceName string
logger log.Logger
uri *URI
}
// String generate column description string according dialect
func String(d Dialect, col *schemas.Column) string {
sql := d.Quote(col.Name) + " "
sql += d.SQLType(col) + " "
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
if col.IsAutoIncrement {
sql += d.AutoIncrStr() + " "
}
}
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
if d.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
}
return sql
}
// StringNoPk generate column description string according dialect without primary keys
func StringNoPk(d Dialect, col *schemas.Column) string {
sql := d.Quote(col.Name) + " "
sql += d.SQLType(col) + " "
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
if d.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
}
return sql
}
func (b *Base) DB() *core.DB {
return b.db
}
func (b *Base) SetLogger(logger log.Logger) {
b.logger = logger
}
func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI, drivername, dataSourceName string) error {
b.db, b.dialect, b.uri = db, dialect, uri
b.driverName, b.dataSourceName = drivername, dataSourceName
return nil
}
func (b *Base) URI() *URI {
return b.uri
}
func (b *Base) DBType() DBType {
return b.uri.DBType
}
func (b *Base) FormatBytes(bs []byte) string {
return fmt.Sprintf("0x%x", bs)
}
func (b *Base) DriverName() string {
return b.driverName
}
func (b *Base) ShowCreateNull() bool {
return true
}
func (b *Base) DataSourceName() string {
return b.dataSourceName
}
func (b *Base) AndStr() string {
return "AND"
}
func (b *Base) OrStr() string {
return "OR"
}
func (b *Base) EqStr() string {
return "="
}
func (db *Base) RollBackStr() string {
return "ROLL BACK"
}
func (db *Base) SupportDropIfExists() bool {
return true
}
func (db *Base) DropTableSQL(tableName string) string {
quote := db.dialect.Quote
return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName))
}
func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) {
db.LogSQL(query, args)
rows, err := db.DB().Query(query, args...)
if err != nil {
return false, err
}
defer rows.Close()
if rows.Next() {
return true, nil
}
return false, nil
}
func (db *Base) IsColumnExist(tableName, colName string) (bool, error) {
query := fmt.Sprintf(
"SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?",
db.dialect.Quote("COLUMN_NAME"),
db.dialect.Quote("INFORMATION_SCHEMA"),
db.dialect.Quote("COLUMNS"),
db.dialect.Quote("TABLE_SCHEMA"),
db.dialect.Quote("TABLE_NAME"),
db.dialect.Quote("COLUMN_NAME"),
)
return db.HasRecords(query, db.uri.DBName, tableName, colName)
}
/*
func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error {
sql, args := db.dialect.TableCheckSQL(tableName)
rows, err := db.DB().Query(sql, args...)
if db.Logger != nil {
db.Logger.Info("[sql]", sql, args)
}
if err != nil {
return err
}
defer rows.Close()
if rows.Next() {
return nil
}
sql = db.dialect.CreateTableSQL(table, tableName, storeEngine, charset)
_, err = db.DB().Exec(sql)
if db.Logger != nil {
db.Logger.Info("[sql]", sql)
}
return err
}*/
func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
quote := db.dialect.Quote
var unique string
var idxName string
if index.Type == schemas.UniqueType {
unique = " UNIQUE"
}
idxName = index.XName(tableName)
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
quote(idxName), quote(tableName),
quote(strings.Join(index.Cols, quote(","))))
}
func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
quote := db.dialect.Quote
var name string
if index.IsRegular {
name = index.XName(tableName)
} else {
name = index.Name
}
return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName))
}
func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, StringNoPk(db.dialect, col))
}
func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}
sql += b.dialect.Quote(tableName)
sql += " ("
if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += String(b.dialect, col)
} else {
sql += StringNoPk(b.dialect, col)
}
sql = strings.TrimSpace(sql)
if b.DriverName() == schemas.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ", "
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += b.dialect.Quote(strings.Join(pkList, b.dialect.Quote(",")))
sql += " ), "
}
sql = sql[:len(sql)-2]
}
sql += ")"
if b.dialect.SupportEngine() && storeEngine != "" {
sql += " ENGINE=" + storeEngine
}
if b.dialect.SupportCharset() {
if len(charset) == 0 {
charset = b.dialect.URI().Charset
}
if len(charset) > 0 {
sql += " DEFAULT CHARSET " + charset
}
}
return sql
}
func (b *Base) ForUpdateSQL(query string) string {
return query + " FOR UPDATE"
}
func (b *Base) LogSQL(sql string, args []interface{}) {
if b.logger != nil && b.logger.IsShowSQL() {
if len(args) > 0 {
b.logger.Infof("[SQL] %v %v", sql, args)
} else {
b.logger.Infof("[SQL] %v", sql)
}
}
}
func (b *Base) SetParams(params map[string]string) {
}
var (
dialects = map[string]func() Dialect{}
)
// RegisterDialect register database dialect
func RegisterDialect(dbName DBType, dialectFunc func() Dialect) {
if dialectFunc == nil {
panic("core: Register dialect is nil")
}
dialects[strings.ToLower(string(dbName))] = dialectFunc // !nashtsai! allow override dialect
}
// QueryDialect query if registered database dialect
func QueryDialect(dbName DBType) Dialect {
if d, ok := dialects[strings.ToLower(string(dbName))]; ok {
return d()
}
return nil
}
func regDrvsNDialects() bool {
providedDrvsNDialects := map[string]struct {
dbType DBType
getDriver func() Driver
getDialect func() Dialect
}{
"mssql": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }},
"odbc": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access
"mysql": {"mysql", func() Driver { return &mysqlDriver{} }, func() Dialect { return &mysql{} }},
"mymysql": {"mysql", func() Driver { return &mymysqlDriver{} }, func() Dialect { return &mysql{} }},
"postgres": {"postgres", func() Driver { return &pqDriver{} }, func() Dialect { return &postgres{} }},
"pgx": {"postgres", func() Driver { return &pqDriverPgx{} }, func() Dialect { return &postgres{} }},
"sqlite3": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }},
"oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }},
"goracle": {"oracle", func() Driver { return &goracleDriver{} }, func() Dialect { return &oracle{} }},
}
for driverName, v := range providedDrvsNDialects {
if driver := QueryDriver(driverName); driver == nil {
RegisterDriver(driverName, v.getDriver())
RegisterDialect(v.dbType, v.getDialect)
}
}
return true
}
func init() {
regDrvsNDialects()
}

31
dialects/driver.go Normal file
View File

@ -0,0 +1,31 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dialects
type Driver interface {
Parse(string, string) (*URI, error)
}
var (
drivers = map[string]Driver{}
)
func RegisterDriver(driverName string, driver Driver) {
if driver == nil {
panic("core: Register driver is nil")
}
if _, dup := drivers[driverName]; dup {
panic("core: Register called twice for driver " + driverName)
}
drivers[driverName] = driver
}
func QueryDriver(driverName string) Driver {
return drivers[driverName]
}
func RegisteredDriverSize() int {
return len(drivers)
}

95
dialects/filter.go Normal file
View File

@ -0,0 +1,95 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dialects
import (
"fmt"
"strings"
"xorm.io/xorm/schemas"
)
// Filter is an interface to filter SQL
type Filter interface {
Do(sql string, dialect Dialect, table *schemas.Table) string
}
// QuoteFilter filter SQL replace ` to database's own quote character
type QuoteFilter struct {
}
func (s *QuoteFilter) Do(sql string, dialect Dialect, table *schemas.Table) string {
dummy := dialect.Quote("")
if len(dummy) != 2 {
return sql
}
prefix, suffix := dummy[0], dummy[1]
raw := []byte(sql)
for i, cnt := 0, 0; i < len(raw); i = i + 1 {
if raw[i] == '`' {
if cnt%2 == 0 {
raw[i] = prefix
} else {
raw[i] = suffix
}
cnt++
}
}
return string(raw)
}
// IdFilter filter SQL replace (id) to primary key column name
type IdFilter struct {
}
type Quoter struct {
dialect Dialect
}
func NewQuoter(dialect Dialect) *Quoter {
return &Quoter{dialect}
}
func (q *Quoter) Quote(content string) string {
return q.dialect.Quote(content)
}
func (i *IdFilter) Do(sql string, dialect Dialect, table *schemas.Table) string {
quoter := NewQuoter(dialect)
if table != nil && len(table.PrimaryKeys) == 1 {
sql = strings.Replace(sql, " `(id)` ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1)
sql = strings.Replace(sql, " "+quoter.Quote("(id)")+" ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1)
return strings.Replace(sql, " (id) ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1)
}
return sql
}
// SeqFilter filter SQL replace ?, ? ... to $1, $2 ...
type SeqFilter struct {
Prefix string
Start int
}
func convertQuestionMark(sql, prefix string, start int) string {
var buf strings.Builder
var beginSingleQuote bool
var index = start
for _, c := range sql {
if !beginSingleQuote && c == '?' {
buf.WriteString(fmt.Sprintf("%s%v", prefix, index))
index++
} else {
if c == '\'' {
beginSingleQuote = !beginSingleQuote
}
buf.WriteRune(c)
}
}
return buf.String()
}
func (s *SeqFilter) Do(sql string, dialect Dialect, table *schemas.Table) string {
return convertQuestionMark(sql, s.Prefix, s.Start)
}

39
dialects/filter_test.go Normal file
View File

@ -0,0 +1,39 @@
package dialects
import (
"testing"
"github.com/stretchr/testify/assert"
)
type quoterOnly struct {
Dialect
}
func (q *quoterOnly) Quote(item string) string {
return "[" + item + "]"
}
func TestQuoteFilter_Do(t *testing.T) {
f := QuoteFilter{}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
res := f.Do(sql, new(quoterOnly), nil)
assert.EqualValues(t,
"SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?",
res,
)
}
func TestSeqFilter(t *testing.T) {
var kases = map[string]string{
"SELECT * FROM TABLE1 WHERE a=? AND b=?": "SELECT * FROM TABLE1 WHERE a=$1 AND b=$2",
"SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=? AND b=?": "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=$1 AND b=$2",
"select '1''?' from issue": "select '1''?' from issue",
"select '1\\??' from issue": "select '1\\??' from issue",
"select '1\\\\',? from issue": "select '1\\\\',$1 from issue",
"select '1\\''?',? from issue": "select '1\\''?',$1 from issue",
}
for sql, result := range kases {
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
}
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package dialects
import (
"errors"
@ -11,7 +11,8 @@ import (
"strconv"
"strings"
"xorm.io/core"
"xorm.io/xorm/core"
"xorm.io/xorm/schemas"
)
var (
@ -205,64 +206,64 @@ var (
)
type mssql struct {
core.Base
Base
}
func (db *mssql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
func (db *mssql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName)
}
func (db *mssql) SqlType(c *core.Column) string {
func (db *mssql) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {
case core.Bool:
res = core.Bit
case schemas.Bool:
res = schemas.Bit
if strings.EqualFold(c.Default, "true") {
c.Default = "1"
} else if strings.EqualFold(c.Default, "false") {
c.Default = "0"
}
case core.Serial:
case schemas.Serial:
c.IsAutoIncrement = true
c.IsPrimaryKey = true
c.Nullable = false
res = core.Int
case core.BigSerial:
res = schemas.Int
case schemas.BigSerial:
c.IsAutoIncrement = true
c.IsPrimaryKey = true
c.Nullable = false
res = core.BigInt
case core.Bytea, core.Blob, core.Binary, core.TinyBlob, core.MediumBlob, core.LongBlob:
res = core.VarBinary
res = schemas.BigInt
case schemas.Bytea, schemas.Blob, schemas.Binary, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob:
res = schemas.VarBinary
if c.Length == 0 {
c.Length = 50
}
case core.TimeStamp:
res = core.DateTime
case core.TimeStampz:
case schemas.TimeStamp:
res = schemas.DateTime
case schemas.TimeStampz:
res = "DATETIMEOFFSET"
c.Length = 7
case core.MediumInt:
res = core.Int
case core.Text, core.MediumText, core.TinyText, core.LongText, core.Json:
res = core.Varchar + "(MAX)"
case core.Double:
res = core.Real
case core.Uuid:
res = core.Varchar
case schemas.MediumInt:
res = schemas.Int
case schemas.Text, schemas.MediumText, schemas.TinyText, schemas.LongText, schemas.Json:
res = schemas.Varchar + "(MAX)"
case schemas.Double:
res = schemas.Real
case schemas.Uuid:
res = schemas.Varchar
c.Length = 40
case core.TinyInt:
res = core.TinyInt
case schemas.TinyInt:
res = schemas.TinyInt
c.Length = 0
case core.BigInt:
res = core.BigInt
case schemas.BigInt:
res = schemas.BigInt
c.Length = 0
default:
res = t
}
if res == core.Int {
return core.Int
if res == schemas.Int {
return schemas.Int
}
hasLen1 := (c.Length > 0)
@ -297,7 +298,7 @@ func (db *mssql) AutoIncrStr() string {
return "IDENTITY"
}
func (db *mssql) DropTableSql(tableName string) string {
func (db *mssql) DropTableSQL(tableName string) string {
return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+
"object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+
"DROP TABLE \"%s\"", tableName, tableName)
@ -311,7 +312,7 @@ func (db *mssql) IndexOnTable() bool {
return true
}
func (db *mssql) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{idxName}
sql := "select name from sysindexes where id=object_id('" + tableName + "') and name=?"
return sql, args
@ -329,13 +330,13 @@ func (db *mssql) IsColumnExist(tableName, colName string) (bool, error) {
return db.HasRecords(query, tableName, colName)
}
func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) {
func (db *mssql) TableCheckSQL(tableName string) (string, []interface{}) {
args := []interface{}{}
sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1"
return sql, args
}
func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
func (db *mssql) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{}
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable,
"default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END),
@ -357,7 +358,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
}
defer rows.Close()
cols := make(map[string]*core.Column)
cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0)
for rows.Next() {
var name, ctype, vdefault string
@ -368,7 +369,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
return nil, nil, err
}
col := new(core.Column)
col := new(schemas.Column)
col.Indexes = make(map[string]int)
col.Name = strings.Trim(name, "` ")
col.Nullable = nullable
@ -387,14 +388,14 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
}
switch ct {
case "DATETIMEOFFSET":
col.SQLType = core.SQLType{Name: core.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
case "NVARCHAR":
col.SQLType = core.SQLType{Name: core.NVarchar, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.NVarchar, DefaultLength: 0, DefaultLength2: 0}
case "IMAGE":
col.SQLType = core.SQLType{Name: core.VarBinary, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.VarBinary, DefaultLength: 0, DefaultLength2: 0}
default:
if _, ok := core.SqlTypes[ct]; ok {
col.SQLType = core.SQLType{Name: ct, DefaultLength: 0, DefaultLength2: 0}
if _, ok := schemas.SqlTypes[ct]; ok {
col.SQLType = schemas.SQLType{Name: ct, DefaultLength: 0, DefaultLength2: 0}
} else {
return nil, nil, fmt.Errorf("Unknown colType %v for %v - %v", ct, tableName, col.Name)
}
@ -406,7 +407,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
return colSeq, cols, nil
}
func (db *mssql) GetTables() ([]*core.Table, error) {
func (db *mssql) GetTables() ([]*schemas.Table, error) {
args := []interface{}{}
s := `select name from sysobjects where xtype ='U'`
db.LogSQL(s, args)
@ -417,9 +418,9 @@ func (db *mssql) GetTables() ([]*core.Table, error) {
}
defer rows.Close()
tables := make([]*core.Table, 0)
tables := make([]*schemas.Table, 0)
for rows.Next() {
table := core.NewEmptyTable()
table := schemas.NewEmptyTable()
var name string
err = rows.Scan(&name)
if err != nil {
@ -431,7 +432,7 @@ func (db *mssql) GetTables() ([]*core.Table, error) {
return tables, nil
}
func (db *mssql) GetIndexes(tableName string) (map[string]*core.Index, error) {
func (db *mssql) GetIndexes(tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName}
s := `SELECT
IXS.NAME AS [INDEX_NAME],
@ -452,7 +453,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
}
defer rows.Close()
indexes := make(map[string]*core.Index, 0)
indexes := make(map[string]*schemas.Index, 0)
for rows.Next() {
var indexType int
var indexName, colName, isUnique string
@ -468,9 +469,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
}
if i {
indexType = core.UniqueType
indexType = schemas.UniqueType
} else {
indexType = core.IndexType
indexType = schemas.IndexType
}
colName = strings.Trim(colName, "` ")
@ -480,10 +481,10 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
isRegular = true
}
var index *core.Index
var index *schemas.Index
var ok bool
if index, ok = indexes[indexName]; !ok {
index = new(core.Index)
index = new(schemas.Index)
index.Type = indexType
index.Name = indexName
index.IsRegular = isRegular
@ -494,7 +495,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
return indexes, nil
}
func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string {
func (db *mssql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string
if tableName == "" {
tableName = table.Name
@ -509,9 +510,9 @@ func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, chars
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(db)
sql += String(db, col)
} else {
sql += col.StringNoPk(db)
sql += StringNoPk(db, col)
}
sql = strings.TrimSpace(sql)
sql += ", "
@ -528,18 +529,18 @@ func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, chars
return sql
}
func (db *mssql) ForUpdateSql(query string) string {
func (db *mssql) ForUpdateSQL(query string) string {
return query
}
func (db *mssql) Filters() []core.Filter {
return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}}
func (db *mssql) Filters() []Filter {
return []Filter{&IdFilter{}, &QuoteFilter{}}
}
type odbcDriver struct {
}
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) {
var dbName string
if strings.HasPrefix(dataSourceName, "sqlserver://") {
@ -563,5 +564,5 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error)
if dbName == "" {
return nil, errors.New("no db name provided")
}
return &core.Uri{DbName: dbName, DbType: core.MSSQL}, nil
return &URI{DBName: dbName, DBType: schemas.MSSQL}, nil
}

View File

@ -2,13 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package dialects
import (
"reflect"
"testing"
"xorm.io/core"
)
func TestParseMSSQL(t *testing.T) {
@ -21,15 +19,15 @@ func TestParseMSSQL(t *testing.T) {
{"server=localhost;user id=sa;password=yourStrong(!)Password;database=db", "db", true},
}
driver := core.QueryDriver("mssql")
driver := QueryDriver("mssql")
for _, test := range tests {
uri, err := driver.Parse("mssql", test.in)
if err != nil && test.valid {
t.Errorf("%q got unexpected error: %s", test.in, err)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected)
}
}
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package dialects
import (
"crypto/tls"
@ -13,7 +13,8 @@ import (
"strings"
"time"
"xorm.io/core"
"xorm.io/xorm/core"
"xorm.io/xorm/schemas"
)
var (
@ -162,7 +163,7 @@ var (
)
type mysql struct {
core.Base
Base
net string
addr string
params map[string]string
@ -175,7 +176,7 @@ type mysql struct {
rowFormat string
}
func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
func (db *mysql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName)
}
@ -199,29 +200,29 @@ func (db *mysql) SetParams(params map[string]string) {
}
}
func (db *mysql) SqlType(c *core.Column) string {
func (db *mysql) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {
case core.Bool:
res = core.TinyInt
case schemas.Bool:
res = schemas.TinyInt
c.Length = 1
case core.Serial:
case schemas.Serial:
c.IsAutoIncrement = true
c.IsPrimaryKey = true
c.Nullable = false
res = core.Int
case core.BigSerial:
res = schemas.Int
case schemas.BigSerial:
c.IsAutoIncrement = true
c.IsPrimaryKey = true
c.Nullable = false
res = core.BigInt
case core.Bytea:
res = core.Blob
case core.TimeStampz:
res = core.Char
res = schemas.BigInt
case schemas.Bytea:
res = schemas.Blob
case schemas.TimeStampz:
res = schemas.Char
c.Length = 64
case core.Enum: // mysql enum
res = core.Enum
case schemas.Enum: // mysql enum
res = schemas.Enum
res += "("
opts := ""
for v := range c.EnumOptions {
@ -229,8 +230,8 @@ func (db *mysql) SqlType(c *core.Column) string {
}
res += strings.TrimLeft(opts, ",")
res += ")"
case core.Set: // mysql set
res = core.Set
case schemas.Set: // mysql set
res = schemas.Set
res += "("
opts := ""
for v := range c.SetOptions {
@ -238,13 +239,13 @@ func (db *mysql) SqlType(c *core.Column) string {
}
res += strings.TrimLeft(opts, ",")
res += ")"
case core.NVarchar:
res = core.Varchar
case core.Uuid:
res = core.Varchar
case schemas.NVarchar:
res = schemas.Varchar
case schemas.Uuid:
res = schemas.Varchar
c.Length = 40
case core.Json:
res = core.Text
case schemas.Json:
res = schemas.Text
default:
res = t
}
@ -252,7 +253,7 @@ func (db *mysql) SqlType(c *core.Column) string {
hasLen1 := (c.Length > 0)
hasLen2 := (c.Length2 > 0)
if res == core.BigInt && !hasLen1 && !hasLen2 {
if res == schemas.BigInt && !hasLen1 && !hasLen2 {
c.Length = 20
hasLen1 = true
}
@ -294,8 +295,8 @@ func (db *mysql) IndexOnTable() bool {
return true
}
func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{db.DbName, tableName, idxName}
func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{db.uri.DBName, tableName, idxName}
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?"
return sql, args
@ -307,14 +308,14 @@ func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}
return sql, args
}*/
func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{db.DbName, tableName}
func (db *mysql) TableCheckSQL(tableName string) (string, []interface{}) {
args := []interface{}{db.uri.DBName, tableName}
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
return sql, args
}
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{db.DbName, tableName}
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{db.uri.DBName, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
db.LogSQL(s, args)
@ -325,10 +326,10 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
}
defer rows.Close()
cols := make(map[string]*core.Column)
cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0)
for rows.Next() {
col := new(core.Column)
col := new(schemas.Column)
col.Indexes = make(map[string]int)
var columnName, isNullable, colType, colKey, extra, comment string
@ -356,7 +357,7 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
var len1, len2 int
if len(cts) == 2 {
idx := strings.Index(cts[1], ")")
if colType == core.Enum && cts[1][0] == '\'' { // enum
if colType == schemas.Enum && cts[1][0] == '\'' { // enum
options := strings.Split(cts[1][0:idx], ",")
col.EnumOptions = make(map[string]int)
for k, v := range options {
@ -364,7 +365,7 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
v = strings.Trim(v, "'")
col.EnumOptions[v] = k
}
} else if colType == core.Set && cts[1][0] == '\'' {
} else if colType == schemas.Set && cts[1][0] == '\'' {
options := strings.Split(cts[1][0:idx], ",")
col.SetOptions = make(map[string]int)
for k, v := range options {
@ -394,8 +395,8 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
}
col.Length = len1
col.Length2 = len2
if _, ok := core.SqlTypes[colType]; ok {
col.SQLType = core.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2}
if _, ok := schemas.SqlTypes[colType]; ok {
col.SQLType = schemas.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2}
} else {
return nil, nil, fmt.Errorf("Unknown colType %v", colType)
}
@ -424,8 +425,8 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
return colSeq, cols, nil
}
func (db *mysql) GetTables() ([]*core.Table, error) {
args := []interface{}{db.DbName}
func (db *mysql) GetTables() ([]*schemas.Table, error) {
args := []interface{}{db.uri.DBName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " +
"`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')"
db.LogSQL(s, args)
@ -436,9 +437,9 @@ func (db *mysql) GetTables() ([]*core.Table, error) {
}
defer rows.Close()
tables := make([]*core.Table, 0)
tables := make([]*schemas.Table, 0)
for rows.Next() {
table := core.NewEmptyTable()
table := schemas.NewEmptyTable()
var name, engine, tableRows, comment string
var autoIncr *string
err = rows.Scan(&name, &engine, &tableRows, &autoIncr, &comment)
@ -454,8 +455,8 @@ func (db *mysql) GetTables() ([]*core.Table, error) {
return tables, nil
}
func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{db.DbName, tableName}
func (db *mysql) GetIndexes(tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{db.uri.DBName, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
db.LogSQL(s, args)
@ -465,7 +466,7 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
}
defer rows.Close()
indexes := make(map[string]*core.Index, 0)
indexes := make(map[string]*schemas.Index, 0)
for rows.Next() {
var indexType int
var indexName, colName, nonUnique string
@ -479,9 +480,9 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
}
if "YES" == nonUnique || nonUnique == "1" {
indexType = core.IndexType
indexType = schemas.IndexType
} else {
indexType = core.UniqueType
indexType = schemas.UniqueType
}
colName = strings.Trim(colName, "` ")
@ -491,10 +492,10 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
isRegular = true
}
var index *core.Index
var index *schemas.Index
var ok bool
if index, ok = indexes[indexName]; !ok {
index = new(core.Index)
index = new(schemas.Index)
index.IsRegular = isRegular
index.Type = indexType
index.Name = indexName
@ -505,7 +506,7 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
return indexes, nil
}
func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string {
func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
@ -521,9 +522,9 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(db)
sql += String(db, col)
} else {
sql += col.StringNoPk(db)
sql += StringNoPk(db, col)
}
sql = strings.TrimSpace(sql)
if len(col.Comment) > 0 {
@ -559,15 +560,15 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars
return sql
}
func (db *mysql) Filters() []core.Filter {
return []core.Filter{&core.IdFilter{}}
func (db *mysql) Filters() []Filter {
return []Filter{&IdFilter{}}
}
type mymysqlDriver struct {
}
func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.MYSQL}
func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
uri := &URI{DBType: schemas.MYSQL}
pd := strings.SplitN(dataSourceName, "*", 2)
if len(pd) == 2 {
@ -576,9 +577,9 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, err
if len(p) != 2 {
return nil, errors.New("Wrong protocol part of URI")
}
db.Proto = p[0]
uri.Proto = p[0]
options := strings.Split(p[1], ",")
db.Raddr = options[0]
uri.Raddr = options[0]
for _, o := range options[1:] {
kv := strings.SplitN(o, "=", 2)
var k, v string
@ -589,13 +590,13 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, err
}
switch k {
case "laddr":
db.Laddr = v
uri.Laddr = v
case "timeout":
to, err := time.ParseDuration(v)
if err != nil {
return nil, err
}
db.Timeout = to
uri.Timeout = to
default:
return nil, errors.New("Unknown option: " + k)
}
@ -608,17 +609,17 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, err
if len(dup) != 3 {
return nil, errors.New("Wrong database part of URI")
}
db.DbName = dup[0]
db.User = dup[1]
db.Passwd = dup[2]
uri.DBName = dup[0]
uri.User = dup[1]
uri.Passwd = dup[2]
return db, nil
return uri, nil
}
type mysqlDriver struct {
}
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) {
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
@ -628,12 +629,12 @@ func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error
// tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames()
uri := &core.Uri{DbType: core.MYSQL}
uri := &URI{DBType: schemas.MYSQL}
for i, match := range matches {
switch names[i] {
case "dbname":
uri.DbName = match
uri.DBName = match
case "params":
if len(match) > 0 {
kvs := strings.Split(match, "&")

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package dialects
import (
"errors"
@ -11,7 +11,8 @@ import (
"strconv"
"strings"
"xorm.io/core"
"xorm.io/xorm/core"
"xorm.io/xorm/schemas"
)
var (
@ -499,29 +500,29 @@ var (
)
type oracle struct {
core.Base
Base
}
func (db *oracle) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
func (db *oracle) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName)
}
func (db *oracle) SqlType(c *core.Column) string {
func (db *oracle) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {
case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool, core.Serial, core.BigSerial:
case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt, schemas.Bool, schemas.Serial, schemas.BigSerial:
res = "NUMBER"
case core.Binary, core.VarBinary, core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob, core.Bytea:
return core.Blob
case core.Time, core.DateTime, core.TimeStamp:
res = core.TimeStamp
case core.TimeStampz:
case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
return schemas.Blob
case schemas.Time, schemas.DateTime, schemas.TimeStamp:
res = schemas.TimeStamp
case schemas.TimeStampz:
res = "TIMESTAMP WITH TIME ZONE"
case core.Float, core.Double, core.Numeric, core.Decimal:
case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal:
res = "NUMBER"
case core.Text, core.MediumText, core.LongText, core.Json:
case schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json:
res = "CLOB"
case core.Char, core.Varchar, core.TinyText:
case schemas.Char, schemas.Varchar, schemas.TinyText:
res = "VARCHAR2"
default:
res = t
@ -571,11 +572,11 @@ func (db *oracle) IndexOnTable() bool {
return false
}
func (db *oracle) DropTableSql(tableName string) string {
func (db *oracle) DropTableSQL(tableName string) string {
return fmt.Sprintf("DROP TABLE `%s`", tableName)
}
func (db *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string {
func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE "
if tableName == "" {
@ -591,7 +592,7 @@ func (db *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, char
/*if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect)
} else {*/
sql += col.StringNoPk(db)
sql += StringNoPk(db, col)
// }
sql = strings.TrimSpace(sql)
sql += ", "
@ -618,19 +619,19 @@ func (db *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, char
return sql
}
func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{tableName, idxName}
return `SELECT INDEX_NAME FROM USER_INDEXES ` +
`WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args
}
func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) {
func (db *oracle) TableCheckSQL(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return `SELECT table_name FROM user_tables WHERE table_name = :1`, args
}
func (db *oracle) MustDropTable(tableName string) error {
sql, args := db.TableCheckSql(tableName)
sql, args := db.TableCheckSQL(tableName)
db.LogSQL(sql, args)
rows, err := db.DB().Query(sql, args...)
@ -674,7 +675,7 @@ func (db *oracle) IsColumnExist(tableName, colName string) (bool, error) {
return false, nil
}
func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
func (db *oracle) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName}
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
@ -686,10 +687,10 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum
}
defer rows.Close()
cols := make(map[string]*core.Column)
cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0)
for rows.Next() {
col := new(core.Column)
col := new(schemas.Column)
col.Indexes = make(map[string]int)
var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string
@ -731,30 +732,30 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum
switch dt {
case "VARCHAR2":
col.SQLType = core.SQLType{Name: core.Varchar, DefaultLength: len1, DefaultLength2: len2}
col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: len1, DefaultLength2: len2}
case "NVARCHAR2":
col.SQLType = core.SQLType{Name: core.NVarchar, DefaultLength: len1, DefaultLength2: len2}
col.SQLType = schemas.SQLType{Name: schemas.NVarchar, DefaultLength: len1, DefaultLength2: len2}
case "TIMESTAMP WITH TIME ZONE":
col.SQLType = core.SQLType{Name: core.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
case "NUMBER":
col.SQLType = core.SQLType{Name: core.Double, DefaultLength: len1, DefaultLength2: len2}
col.SQLType = schemas.SQLType{Name: schemas.Double, DefaultLength: len1, DefaultLength2: len2}
case "LONG", "LONG RAW":
col.SQLType = core.SQLType{Name: core.Text, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.Text, DefaultLength: 0, DefaultLength2: 0}
case "RAW":
col.SQLType = core.SQLType{Name: core.Binary, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.Binary, DefaultLength: 0, DefaultLength2: 0}
case "ROWID":
col.SQLType = core.SQLType{Name: core.Varchar, DefaultLength: 18, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: 18, DefaultLength2: 0}
case "AQ$_SUBSCRIBERS":
ignore = true
default:
col.SQLType = core.SQLType{Name: strings.ToUpper(dt), DefaultLength: len1, DefaultLength2: len2}
col.SQLType = schemas.SQLType{Name: strings.ToUpper(dt), DefaultLength: len1, DefaultLength2: len2}
}
if ignore {
continue
}
if _, ok := core.SqlTypes[col.SQLType.Name]; !ok {
if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, fmt.Errorf("Unknown colType %v %v", *dataType, col.SQLType)
}
@ -772,7 +773,7 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum
return colSeq, cols, nil
}
func (db *oracle) GetTables() ([]*core.Table, error) {
func (db *oracle) GetTables() ([]*schemas.Table, error) {
args := []interface{}{}
s := "SELECT table_name FROM user_tables"
db.LogSQL(s, args)
@ -783,9 +784,9 @@ func (db *oracle) GetTables() ([]*core.Table, error) {
}
defer rows.Close()
tables := make([]*core.Table, 0)
tables := make([]*schemas.Table, 0)
for rows.Next() {
table := core.NewEmptyTable()
table := schemas.NewEmptyTable()
err = rows.Scan(&table.Name)
if err != nil {
return nil, err
@ -796,7 +797,7 @@ func (db *oracle) GetTables() ([]*core.Table, error) {
return tables, nil
}
func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) {
func (db *oracle) GetIndexes(tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName}
s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " +
"WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1"
@ -808,7 +809,7 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) {
}
defer rows.Close()
indexes := make(map[string]*core.Index, 0)
indexes := make(map[string]*schemas.Index, 0)
for rows.Next() {
var indexType int
var indexName, colName, uniqueness string
@ -827,15 +828,15 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) {
}
if uniqueness == "UNIQUE" {
indexType = core.UniqueType
indexType = schemas.UniqueType
} else {
indexType = core.IndexType
indexType = schemas.IndexType
}
var index *core.Index
var index *schemas.Index
var ok bool
if index, ok = indexes[indexName]; !ok {
index = new(core.Index)
index = new(schemas.Index)
index.Type = indexType
index.Name = indexName
index.IsRegular = isRegular
@ -846,15 +847,15 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) {
return indexes, nil
}
func (db *oracle) Filters() []core.Filter {
return []core.Filter{&core.QuoteFilter{}, &core.SeqFilter{Prefix: ":", Start: 1}, &core.IdFilter{}}
func (db *oracle) Filters() []Filter {
return []Filter{&QuoteFilter{}, &SeqFilter{Prefix: ":", Start: 1}, &IdFilter{}}
}
type goracleDriver struct {
}
func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.ORACLE}
func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &URI{DBType: schemas.ORACLE}
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
@ -867,10 +868,10 @@ func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*core.Uri, e
for i, match := range matches {
switch names[i] {
case "dbname":
db.DbName = match
db.DBName = match
}
}
if db.DbName == "" {
if db.DBName == "" {
return nil, errors.New("dbname is empty")
}
return db, nil
@ -881,8 +882,8 @@ type oci8Driver struct {
// dataSourceName=user/password@ipv4:port/dbname
// dataSourceName=user/password@[ipv6]:port/dbname
func (p *oci8Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.ORACLE}
func (p *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &URI{DBType: schemas.ORACLE}
dsnPattern := regexp.MustCompile(
`^(?P<user>.*)\/(?P<password>.*)@` + // user:password@
`(?P<net>.*)` + // ip:port
@ -892,10 +893,10 @@ func (p *oci8Driver) Parse(driverName, dataSourceName string) (*core.Uri, error)
for i, match := range matches {
switch names[i] {
case "dbname":
db.DbName = match
db.DBName = match
}
}
if db.DbName == "" {
if db.DBName == "" {
return nil, errors.New("dbname is empty")
}
return db, nil

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package dialects
import (
"errors"
@ -11,7 +11,8 @@ import (
"strconv"
"strings"
"xorm.io/core"
"xorm.io/xorm/core"
"xorm.io/xorm/schemas"
)
// from http://www.postgresql.org/docs/current/static/sql-keywords-appendix.html
@ -769,67 +770,67 @@ var (
DefaultPostgresSchema = "public"
)
const postgresPublicSchema = "public"
const PostgresPublicSchema = "public"
type postgres struct {
core.Base
Base
}
func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
func (db *postgres) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
err := db.Base.Init(d, db, uri, drivername, dataSourceName)
if err != nil {
return err
}
if db.Schema == "" {
db.Schema = DefaultPostgresSchema
if db.uri.Schema == "" {
db.uri.Schema = DefaultPostgresSchema
}
return nil
}
func (db *postgres) SqlType(c *core.Column) string {
func (db *postgres) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {
case core.TinyInt:
res = core.SmallInt
case schemas.TinyInt:
res = schemas.SmallInt
return res
case core.Bit:
res = core.Boolean
case schemas.Bit:
res = schemas.Boolean
return res
case core.MediumInt, core.Int, core.Integer:
case schemas.MediumInt, schemas.Int, schemas.Integer:
if c.IsAutoIncrement {
return core.Serial
return schemas.Serial
}
return core.Integer
case core.BigInt:
return schemas.Integer
case schemas.BigInt:
if c.IsAutoIncrement {
return core.BigSerial
return schemas.BigSerial
}
return core.BigInt
case core.Serial, core.BigSerial:
return schemas.BigInt
case schemas.Serial, schemas.BigSerial:
c.IsAutoIncrement = true
c.Nullable = false
res = t
case core.Binary, core.VarBinary:
return core.Bytea
case core.DateTime:
res = core.TimeStamp
case core.TimeStampz:
case schemas.Binary, schemas.VarBinary:
return schemas.Bytea
case schemas.DateTime:
res = schemas.TimeStamp
case schemas.TimeStampz:
return "timestamp with time zone"
case core.Float:
res = core.Real
case core.TinyText, core.MediumText, core.LongText:
res = core.Text
case core.NVarchar:
res = core.Varchar
case core.Uuid:
return core.Uuid
case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob:
return core.Bytea
case core.Double:
case schemas.Float:
res = schemas.Real
case schemas.TinyText, schemas.MediumText, schemas.LongText:
res = schemas.Text
case schemas.NVarchar:
res = schemas.Varchar
case schemas.Uuid:
return schemas.Uuid
case schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob:
return schemas.Bytea
case schemas.Double:
return "DOUBLE PRECISION"
default:
if c.IsAutoIncrement {
return core.Serial
return schemas.Serial
}
res = t
}
@ -879,37 +880,37 @@ func (db *postgres) IndexOnTable() bool {
return false
}
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
if len(db.Schema) == 0 {
func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
if len(db.uri.Schema) == 0 {
args := []interface{}{tableName, idxName}
return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args
}
args := []interface{}{db.Schema, tableName, idxName}
args := []interface{}{db.uri.Schema, tableName, idxName}
return `SELECT indexname FROM pg_indexes ` +
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
}
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
if len(db.Schema) == 0 {
func (db *postgres) TableCheckSQL(tableName string) (string, []interface{}) {
if len(db.uri.Schema) == 0 {
args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
}
args := []interface{}{db.Schema, tableName}
args := []interface{}{db.uri.Schema, tableName}
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args
}
func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string {
if len(db.Schema) == 0 || strings.Contains(tableName, ".") {
func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string {
if len(db.uri.Schema) == 0 || strings.Contains(tableName, ".") {
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
tableName, col.Name, db.SqlType(col))
tableName, col.Name, db.SQLType(col))
}
return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s",
db.Schema, tableName, col.Name, db.SqlType(col))
db.uri.Schema, tableName, col.Name, db.SQLType(col))
}
func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string {
quote := db.Quote
idxName := index.Name
@ -918,23 +919,23 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
if !strings.HasPrefix(idxName, "UQE_") &&
!strings.HasPrefix(idxName, "IDX_") {
if index.Type == core.UniqueType {
if index.Type == schemas.UniqueType {
idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
} else {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
}
}
if db.Uri.Schema != "" {
idxName = db.Uri.Schema + "." + idxName
if db.uri.Schema != "" {
idxName = db.uri.Schema + "." + idxName
}
return fmt.Sprintf("DROP INDEX %v", quote(idxName))
}
func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
args := []interface{}{db.Schema, tableName, colName}
args := []interface{}{db.uri.Schema, tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
" AND column_name = $3"
if len(db.Schema) == 0 {
if len(db.uri.Schema) == 0 {
args = []interface{}{tableName, colName}
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2"
@ -950,7 +951,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
return rows.Next(), nil
}
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
@ -965,8 +966,8 @@ FROM pg_attribute f
WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
var f string
if len(db.Schema) != 0 {
args = append(args, db.Schema)
if len(db.uri.Schema) != 0 {
args = append(args, db.uri.Schema)
f = " AND s.table_schema = $2"
}
s = fmt.Sprintf(s, f)
@ -979,11 +980,11 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
}
defer rows.Close()
cols := make(map[string]*core.Column)
cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0)
for rows.Next() {
col := new(core.Column)
col := new(schemas.Column)
col.Indexes = make(map[string]int)
var colName, isNullable, dataType string
@ -1023,23 +1024,23 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
switch dataType {
case "character varying", "character":
col.SQLType = core.SQLType{Name: core.Varchar, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: 0, DefaultLength2: 0}
case "timestamp without time zone":
col.SQLType = core.SQLType{Name: core.DateTime, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.DateTime, DefaultLength: 0, DefaultLength2: 0}
case "timestamp with time zone":
col.SQLType = core.SQLType{Name: core.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
case "double precision":
col.SQLType = core.SQLType{Name: core.Double, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.Double, DefaultLength: 0, DefaultLength2: 0}
case "boolean":
col.SQLType = core.SQLType{Name: core.Bool, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.Bool, DefaultLength: 0, DefaultLength2: 0}
case "time without time zone":
col.SQLType = core.SQLType{Name: core.Time, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.Time, DefaultLength: 0, DefaultLength2: 0}
case "oid":
col.SQLType = core.SQLType{Name: core.BigInt, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: schemas.BigInt, DefaultLength: 0, DefaultLength2: 0}
default:
col.SQLType = core.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0}
}
if _, ok := core.SqlTypes[col.SQLType.Name]; !ok {
if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, fmt.Errorf("Unknown colType: %v", dataType)
}
@ -1065,11 +1066,11 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
return colSeq, cols, nil
}
func (db *postgres) GetTables() ([]*core.Table, error) {
func (db *postgres) GetTables() ([]*schemas.Table, error) {
args := []interface{}{}
s := "SELECT tablename FROM pg_tables"
if len(db.Schema) != 0 {
args = append(args, db.Schema)
if len(db.uri.Schema) != 0 {
args = append(args, db.uri.Schema)
s = s + " WHERE schemaname = $1"
}
@ -1081,9 +1082,9 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
}
defer rows.Close()
tables := make([]*core.Table, 0)
tables := make([]*schemas.Table, 0)
for rows.Next() {
table := core.NewEmptyTable()
table := schemas.NewEmptyTable()
var name string
err = rows.Scan(&name)
if err != nil {
@ -1106,11 +1107,11 @@ func getIndexColName(indexdef string) []string {
return colNames
}
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
func (db *postgres) GetIndexes(tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
if len(db.Schema) != 0 {
args = append(args, db.Schema)
if len(db.uri.Schema) != 0 {
args = append(args, db.uri.Schema)
s = s + " AND schemaname=$2"
}
db.LogSQL(s, args)
@ -1121,7 +1122,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error)
}
defer rows.Close()
indexes := make(map[string]*core.Index, 0)
indexes := make(map[string]*schemas.Index, 0)
for rows.Next() {
var indexType int
var indexName, indexdef string
@ -1135,9 +1136,9 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error)
continue
}
if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") {
indexType = core.UniqueType
indexType = schemas.UniqueType
} else {
indexType = core.IndexType
indexType = schemas.IndexType
}
colNames = getIndexColName(indexdef)
var isRegular bool
@ -1149,7 +1150,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error)
}
}
index := &core.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
index := &schemas.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
for _, colName := range colNames {
index.Cols = append(index.Cols, strings.Trim(colName, `" `))
}
@ -1159,8 +1160,8 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error)
return indexes, nil
}
func (db *postgres) Filters() []core.Filter {
return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}, &core.SeqFilter{Prefix: "$", Start: 1}}
func (db *postgres) Filters() []Filter {
return []Filter{&IdFilter{}, &QuoteFilter{}, &SeqFilter{Prefix: "$", Start: 1}}
}
type pqDriver struct {
@ -1214,12 +1215,12 @@ func parseOpts(name string, o values) error {
return nil
}
func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.POSTGRES}
func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &URI{DBType: schemas.POSTGRES}
var err error
if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
db.DbName, err = parseURL(dataSourceName)
db.DBName, err = parseURL(dataSourceName)
if err != nil {
return nil, err
}
@ -1230,10 +1231,10 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
return nil, err
}
db.DbName = o.Get("dbname")
db.DBName = o.Get("dbname")
}
if db.DbName == "" {
if db.DBName == "" {
return nil, errors.New("dbname is empty")
}
@ -1244,7 +1245,7 @@ type pqDriverPgx struct {
pqDriver
}
func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*core.Uri, error) {
func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*URI, error) {
// Remove the leading characters for driver to work
if len(dataSourceName) >= 9 && dataSourceName[0] == 0 {
dataSourceName = dataSourceName[9:]

View File

@ -1,11 +1,10 @@
package xorm
package dialects
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"xorm.io/core"
)
func TestParsePostgres(t *testing.T) {
@ -27,15 +26,15 @@ func TestParsePostgres(t *testing.T) {
{"dbname=db =disable", "db", false},
}
driver := core.QueryDriver("postgres")
driver := QueryDriver("postgres")
for _, test := range tests {
uri, err := driver.Parse("postgres", test.in)
if err != nil && test.valid {
t.Errorf("%q got unexpected error: %s", test.in, err)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected)
}
}
}
@ -59,23 +58,23 @@ func TestParsePgx(t *testing.T) {
{"dbname=db =disable", "db", false},
}
driver := core.QueryDriver("pgx")
driver := QueryDriver("pgx")
for _, test := range tests {
uri, err := driver.Parse("pgx", test.in)
if err != nil && test.valid {
t.Errorf("%q got unexpected error: %s", test.in, err)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected)
}
// Register DriverConfig
uri, err = driver.Parse("pgx", test.in)
if err != nil && test.valid {
t.Errorf("%q got unexpected error: %s", test.in, err)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected)
}
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package dialects
import (
"database/sql"
@ -11,7 +11,8 @@ import (
"regexp"
"strings"
"xorm.io/core"
"xorm.io/xorm/core"
"xorm.io/xorm/schemas"
)
var (
@ -144,42 +145,42 @@ var (
)
type sqlite3 struct {
core.Base
Base
}
func (db *sqlite3) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
func (db *sqlite3) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName)
}
func (db *sqlite3) SqlType(c *core.Column) string {
func (db *sqlite3) SQLType(c *schemas.Column) string {
switch t := c.SQLType.Name; t {
case core.Bool:
case schemas.Bool:
if c.Default == "true" {
c.Default = "1"
} else if c.Default == "false" {
c.Default = "0"
}
return core.Integer
case core.Date, core.DateTime, core.TimeStamp, core.Time:
return core.DateTime
case core.TimeStampz:
return core.Text
case core.Char, core.Varchar, core.NVarchar, core.TinyText,
core.Text, core.MediumText, core.LongText, core.Json:
return core.Text
case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt:
return core.Integer
case core.Float, core.Double, core.Real:
return core.Real
case core.Decimal, core.Numeric:
return core.Numeric
case core.TinyBlob, core.Blob, core.MediumBlob, core.LongBlob, core.Bytea, core.Binary, core.VarBinary:
return core.Blob
case core.Serial, core.BigSerial:
return schemas.Integer
case schemas.Date, schemas.DateTime, schemas.TimeStamp, schemas.Time:
return schemas.DateTime
case schemas.TimeStampz:
return schemas.Text
case schemas.Char, schemas.Varchar, schemas.NVarchar, schemas.TinyText,
schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json:
return schemas.Text
case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt:
return schemas.Integer
case schemas.Float, schemas.Double, schemas.Real:
return schemas.Real
case schemas.Decimal, schemas.Numeric:
return schemas.Numeric
case schemas.TinyBlob, schemas.Blob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea, schemas.Binary, schemas.VarBinary:
return schemas.Blob
case schemas.Serial, schemas.BigSerial:
c.IsPrimaryKey = true
c.IsAutoIncrement = true
c.Nullable = false
return core.Integer
return schemas.Integer
default:
return t
}
@ -218,24 +219,24 @@ func (db *sqlite3) IndexOnTable() bool {
return false
}
func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{idxName}
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
}
func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) {
func (db *sqlite3) TableCheckSQL(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
}
func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string {
func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string {
// var unique string
quote := db.Quote
idxName := index.Name
if !strings.HasPrefix(idxName, "UQE_") &&
!strings.HasPrefix(idxName, "IDX_") {
if index.Type == core.UniqueType {
if index.Type == schemas.UniqueType {
idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
} else {
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
@ -244,7 +245,7 @@ func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string {
return fmt.Sprintf("DROP INDEX %v", quote(idxName))
}
func (db *sqlite3) ForUpdateSql(query string) string {
func (db *sqlite3) ForUpdateSQL(query string) string {
return query
}
@ -298,9 +299,9 @@ func splitColStr(colStr string) []string {
return results
}
func parseString(colStr string) (*core.Column, error) {
func parseString(colStr string) (*schemas.Column, error) {
fields := splitColStr(colStr)
col := new(core.Column)
col := new(schemas.Column)
col.Indexes = make(map[string]int)
col.Nullable = true
col.DefaultIsEmpty = true
@ -310,7 +311,7 @@ func parseString(colStr string) (*core.Column, error) {
col.Name = strings.Trim(strings.Trim(field, "`[] "), `"`)
continue
} else if idx == 1 {
col.SQLType = core.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0}
col.SQLType = schemas.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0}
continue
}
switch field {
@ -332,7 +333,7 @@ func parseString(colStr string) (*core.Column, error) {
return col, nil
}
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
db.LogSQL(s, args)
@ -359,7 +360,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu
nEnd := strings.LastIndex(name, ")")
reg := regexp.MustCompile(`[^\(,\)]*(\([^\(]*\))?`)
colCreates := reg.FindAllString(name[nStart+1:nEnd], -1)
cols := make(map[string]*core.Column)
cols := make(map[string]*schemas.Column)
colSeq := make([]string, 0)
for _, colStr := range colCreates {
@ -389,7 +390,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu
return colSeq, cols, nil
}
func (db *sqlite3) GetTables() ([]*core.Table, error) {
func (db *sqlite3) GetTables() ([]*schemas.Table, error) {
args := []interface{}{}
s := "SELECT name FROM sqlite_master WHERE type='table'"
db.LogSQL(s, args)
@ -400,9 +401,9 @@ func (db *sqlite3) GetTables() ([]*core.Table, error) {
}
defer rows.Close()
tables := make([]*core.Table, 0)
tables := make([]*schemas.Table, 0)
for rows.Next() {
table := core.NewEmptyTable()
table := schemas.NewEmptyTable()
err = rows.Scan(&table.Name)
if err != nil {
return nil, err
@ -415,7 +416,7 @@ func (db *sqlite3) GetTables() ([]*core.Table, error) {
return tables, nil
}
func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) {
func (db *sqlite3) GetIndexes(tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
db.LogSQL(s, args)
@ -426,7 +427,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
}
defer rows.Close()
indexes := make(map[string]*core.Index, 0)
indexes := make(map[string]*schemas.Index, 0)
for rows.Next() {
var tmpSQL sql.NullString
err = rows.Scan(&tmpSQL)
@ -439,7 +440,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
}
sql := tmpSQL.String
index := new(core.Index)
index := new(schemas.Index)
nNStart := strings.Index(sql, "INDEX")
nNEnd := strings.Index(sql, "ON")
if nNStart == -1 || nNEnd == -1 {
@ -456,9 +457,9 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
}
if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") {
index.Type = core.UniqueType
index.Type = schemas.UniqueType
} else {
index.Type = core.IndexType
index.Type = schemas.IndexType
}
nStart := strings.Index(sql, "(")
@ -476,17 +477,17 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error)
return indexes, nil
}
func (db *sqlite3) Filters() []core.Filter {
return []core.Filter{&core.IdFilter{}}
func (db *sqlite3) Filters() []Filter {
return []Filter{&IdFilter{}}
}
type sqlite3Driver struct {
}
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) {
if strings.Contains(dataSourceName, "?") {
dataSourceName = dataSourceName[:strings.Index(dataSourceName, "?")]
}
return &core.Uri{DbType: core.SQLITE, DbName: dataSourceName}, nil
return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package dialects
import (
"testing"

2
doc.go
View File

@ -126,7 +126,7 @@ Attention: the above 8 methods should be the last chainable method.
engine.ID(1).Get(&user) // for single primary key
// SELECT * FROM user WHERE id = 1
engine.ID(core.PK{1, 2}).Get(&user) // for composite primary keys
engine.ID(schemas.PK{1, 2}).Get(&user) // for composite primary keys
// SELECT * FROM user WHERE id1 = 1 AND id2 = 2
engine.In("id", 1, 2, 3).Find(&users)
// SELECT * FROM user WHERE id IN (1, 2, 3)

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 238 KiB

175
engine.go
View File

@ -21,27 +21,32 @@ import (
"time"
"xorm.io/builder"
"xorm.io/core"
"xorm.io/xorm/caches"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
)
// Engine is the major struct of xorm, it means a database manager.
// Commonly, an application only need one engine
type Engine struct {
db *core.DB
dialect core.Dialect
dialect dialects.Dialect
ColumnMapper core.IMapper
TableMapper core.IMapper
ColumnMapper names.Mapper
TableMapper names.Mapper
TagIdentifier string
Tables map[reflect.Type]*core.Table
Tables map[reflect.Type]*schemas.Table
mutex *sync.RWMutex
Cacher core.Cacher
Cacher caches.Cacher
showSQL bool
showExecTime bool
logger core.ILogger
logger log.Logger
TZLocation *time.Location // The timezone of the application
DatabaseTZ *time.Location // The timezone of the database
@ -51,24 +56,24 @@ type Engine struct {
engineGroup *EngineGroup
cachers map[string]core.Cacher
cachers map[string]caches.Cacher
cacherLock sync.RWMutex
defaultContext context.Context
}
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
func (engine *Engine) setCacher(tableName string, cacher caches.Cacher) {
engine.cacherLock.Lock()
engine.cachers[tableName] = cacher
engine.cacherLock.Unlock()
}
func (engine *Engine) SetCacher(tableName string, cacher core.Cacher) {
func (engine *Engine) SetCacher(tableName string, cacher caches.Cacher) {
engine.setCacher(tableName, cacher)
}
func (engine *Engine) getCacher(tableName string) core.Cacher {
var cacher core.Cacher
func (engine *Engine) getCacher(tableName string) caches.Cacher {
var cacher caches.Cacher
var ok bool
engine.cacherLock.RLock()
cacher, ok = engine.cachers[tableName]
@ -79,7 +84,7 @@ func (engine *Engine) getCacher(tableName string) core.Cacher {
return cacher
}
func (engine *Engine) GetCacher(tableName string) core.Cacher {
func (engine *Engine) GetCacher(tableName string) caches.Cacher {
return engine.getCacher(tableName)
}
@ -91,13 +96,13 @@ func (engine *Engine) BufferSize(size int) *Session {
}
// CondDeleted returns the conditions whether a record is soft deleted.
func (engine *Engine) CondDeleted(col *core.Column) builder.Cond {
func (engine *Engine) CondDeleted(col *schemas.Column) builder.Cond {
var cond = builder.NewCond()
if col.SQLType.IsNumeric() {
cond = builder.Eq{col.Name: 0}
} else {
// FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value.
if engine.dialect.DBType() != core.MSSQL {
if engine.dialect.DBType() != schemas.MSSQL {
cond = builder.Eq{col.Name: zeroTime1}
}
}
@ -129,19 +134,19 @@ func (engine *Engine) ShowExecTime(show ...bool) {
}
// Logger return the logger interface
func (engine *Engine) Logger() core.ILogger {
func (engine *Engine) Logger() log.Logger {
return engine.logger
}
// SetLogger set the new logger
func (engine *Engine) SetLogger(logger core.ILogger) {
func (engine *Engine) SetLogger(logger log.Logger) {
engine.logger = logger
engine.showSQL = logger.IsShowSQL()
engine.dialect.SetLogger(logger)
}
// SetLogLevel sets the logger level
func (engine *Engine) SetLogLevel(level core.LogLevel) {
func (engine *Engine) SetLogLevel(level log.LogLevel) {
engine.logger.SetLevel(level)
}
@ -163,18 +168,18 @@ func (engine *Engine) DataSourceName() string {
}
// SetMapper set the name mapping rules
func (engine *Engine) SetMapper(mapper core.IMapper) {
func (engine *Engine) SetMapper(mapper names.Mapper) {
engine.SetTableMapper(mapper)
engine.SetColumnMapper(mapper)
}
// SetTableMapper set the table name mapping rule
func (engine *Engine) SetTableMapper(mapper core.IMapper) {
func (engine *Engine) SetTableMapper(mapper names.Mapper) {
engine.TableMapper = mapper
}
// SetColumnMapper set the column name mapping rule
func (engine *Engine) SetColumnMapper(mapper core.IMapper) {
func (engine *Engine) SetColumnMapper(mapper names.Mapper) {
engine.ColumnMapper = mapper
}
@ -268,13 +273,13 @@ func (engine *Engine) quote(sql string) string {
// SqlType will be deprecated, please use SQLType instead
//
// Deprecated: use SQLType instead
func (engine *Engine) SqlType(c *core.Column) string {
func (engine *Engine) SqlType(c *schemas.Column) string {
return engine.SQLType(c)
}
// SQLType A simple wrapper to dialect's core.SqlType method
func (engine *Engine) SQLType(c *core.Column) string {
return engine.dialect.SqlType(c)
func (engine *Engine) SQLType(c *schemas.Column) string {
return engine.dialect.SQLType(c)
}
// AutoIncrStr Database's autoincrement statement
@ -298,12 +303,12 @@ func (engine *Engine) SetMaxIdleConns(conns int) {
}
// SetDefaultCacher set the default cacher. Xorm's default not enable cacher.
func (engine *Engine) SetDefaultCacher(cacher core.Cacher) {
func (engine *Engine) SetDefaultCacher(cacher caches.Cacher) {
engine.Cacher = cacher
}
// GetDefaultCacher returns the default cacher
func (engine *Engine) GetDefaultCacher() core.Cacher {
func (engine *Engine) GetDefaultCacher() caches.Cacher {
return engine.Cacher
}
@ -323,14 +328,14 @@ func (engine *Engine) NoCascade() *Session {
}
// MapCacher Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error {
func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error {
engine.setCacher(engine.TableName(bean, true), cacher)
return nil
}
// NewDB provides an interface to operate database directly
func (engine *Engine) NewDB() (*core.DB, error) {
return core.OpenDialect(engine.dialect)
return dialects.OpenDialect(engine.dialect)
}
// DB return the wrapper of sql.DB
@ -339,7 +344,7 @@ func (engine *Engine) DB() *core.DB {
}
// Dialect return database dialect
func (engine *Engine) Dialect() core.Dialect {
func (engine *Engine) Dialect() dialects.Dialect {
return engine.dialect
}
@ -409,7 +414,7 @@ func (engine *Engine) NoAutoCondition(no ...bool) *Session {
return session.NoAutoCondition(no...)
}
func (engine *Engine) loadTableInfo(table *core.Table) error {
func (engine *Engine) loadTableInfo(table *schemas.Table) error {
colSeq, cols, err := engine.dialect.GetColumns(table.Name)
if err != nil {
return err
@ -436,7 +441,7 @@ func (engine *Engine) loadTableInfo(table *core.Table) error {
}
// DBMetas Retrieve all tables, columns, indexes' informations from database.
func (engine *Engine) DBMetas() ([]*core.Table, error) {
func (engine *Engine) DBMetas() ([]*schemas.Table, error) {
tables, err := engine.dialect.GetTables()
if err != nil {
return nil, err
@ -451,7 +456,7 @@ func (engine *Engine) DBMetas() ([]*core.Table, error) {
}
// DumpAllToFile dump database all table structs and data to a file
func (engine *Engine) DumpAllToFile(fp string, tp ...core.DbType) error {
func (engine *Engine) DumpAllToFile(fp string, tp ...dialects.DBType) error {
f, err := os.Create(fp)
if err != nil {
return err
@ -461,7 +466,7 @@ func (engine *Engine) DumpAllToFile(fp string, tp ...core.DbType) error {
}
// DumpAll dump database all table structs and data to w
func (engine *Engine) DumpAll(w io.Writer, tp ...core.DbType) error {
func (engine *Engine) DumpAll(w io.Writer, tp ...dialects.DBType) error {
tables, err := engine.DBMetas()
if err != nil {
return err
@ -470,7 +475,7 @@ func (engine *Engine) DumpAll(w io.Writer, tp ...core.DbType) error {
}
// DumpTablesToFile dump specified tables to SQL file.
func (engine *Engine) DumpTablesToFile(tables []*core.Table, fp string, tp ...core.DbType) error {
func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ...dialects.DBType) error {
f, err := os.Create(fp)
if err != nil {
return err
@ -480,19 +485,19 @@ func (engine *Engine) DumpTablesToFile(tables []*core.Table, fp string, tp ...co
}
// DumpTables dump specify tables to io.Writer
func (engine *Engine) DumpTables(tables []*core.Table, w io.Writer, tp ...core.DbType) error {
func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...dialects.DBType) error {
return engine.dumpTables(tables, w, tp...)
}
// dumpTables dump database all table structs and data to w with specify db type
func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.DbType) error {
var dialect core.Dialect
func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dialects.DBType) error {
var dialect dialects.Dialect
var distDBName string
if len(tp) == 0 {
dialect = engine.dialect
distDBName = string(engine.dialect.DBType())
} else {
dialect = core.QueryDialect(tp[0])
dialect = dialects.QueryDialect(tp[0])
if dialect == nil {
return errors.New("Unsupported database type")
}
@ -513,12 +518,12 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
return err
}
}
_, err = io.WriteString(w, dialect.CreateTableSql(table, "", table.StoreEngine, "")+";\n")
_, err = io.WriteString(w, dialect.CreateTableSQL(table, "", table.StoreEngine, "")+";\n")
if err != nil {
return err
}
for _, index := range table.Indexes {
_, err = io.WriteString(w, dialect.CreateIndexSql(table.Name, index)+";\n")
_, err = io.WriteString(w, dialect.CreateIndexSQL(table.Name, index)+";\n")
if err != nil {
return err
}
@ -571,19 +576,19 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
} else if col.SQLType.IsNumeric() {
switch reflect.TypeOf(d).Kind() {
case reflect.Slice:
if col.SQLType.Name == core.Bool {
if col.SQLType.Name == schemas.Bool {
temp += fmt.Sprintf(", %v", strconv.FormatBool(d.([]byte)[0] != byte('0')))
} else {
temp += fmt.Sprintf(", %s", string(d.([]byte)))
}
case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int:
if col.SQLType.Name == core.Bool {
if col.SQLType.Name == schemas.Bool {
temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Int() > 0))
} else {
temp += fmt.Sprintf(", %v", d)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if col.SQLType.Name == core.Bool {
if col.SQLType.Name == schemas.Bool {
temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Uint() > 0))
} else {
temp += fmt.Sprintf(", %v", d)
@ -611,7 +616,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
}
// FIXME: Hack for postgres
if string(dialect.DBType()) == core.POSTGRES && table.AutoIncrColumn() != nil {
if string(dialect.DBType()) == schemas.POSTGRES && table.AutoIncrColumn() != nil {
_, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quote(table.Name)+"), 1), false);\n")
if err != nil {
return err
@ -856,7 +861,7 @@ func (engine *Engine) UnMapType(t reflect.Type) {
delete(engine.Tables, t)
}
func (engine *Engine) autoMapType(v reflect.Value) (*core.Table, error) {
func (engine *Engine) autoMapType(v reflect.Value) (*schemas.Table, error) {
t := v.Type()
engine.mutex.Lock()
defer engine.mutex.Unlock()
@ -888,7 +893,7 @@ func (engine *Engine) GobRegister(v interface{}) *Engine {
// Table table struct
type Table struct {
*core.Table
*schemas.Table
Name string
}
@ -907,12 +912,12 @@ func (engine *Engine) TableInfo(bean interface{}) *Table {
return &Table{tb, engine.TableName(bean)}
}
func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) {
func addIndex(indexName string, table *schemas.Table, col *schemas.Column, indexType int) {
if index, ok := table.Indexes[indexName]; ok {
index.AddColumn(col.Name)
col.Indexes[index.Name] = indexType
} else {
index := core.NewIndex(indexName, indexType)
index := schemas.NewIndex(indexName, indexType)
index.AddColumn(col.Name)
table.AddIndex(index)
col.Indexes[index.Name] = indexType
@ -928,11 +933,11 @@ var (
tpTableName = reflect.TypeOf((*TableName)(nil)).Elem()
)
func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
func (engine *Engine) mapType(v reflect.Value) (*schemas.Table, error) {
t := v.Type()
table := core.NewEmptyTable()
table := schemas.NewEmptyTable()
table.Type = t
table.Name = getTableName(engine.TableMapper, v)
table.Name = names.GetTableName(engine.TableMapper, v)
var idFieldColName string
var hasCacheTag, hasNoCacheTag bool
@ -941,17 +946,17 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
tag := t.Field(i).Tag
ormTagStr := tag.Get(engine.TagIdentifier)
var col *core.Column
var col *schemas.Column
fieldValue := v.Field(i)
fieldType := fieldValue.Type()
if ormTagStr != "" {
col = &core.Column{
col = &schemas.Column{
FieldName: t.Field(i).Name,
Nullable: true,
IsPrimaryKey: false,
IsAutoIncrement: false,
MapType: core.TWOSIDES,
MapType: schemas.TWOSIDES,
Indexes: make(map[string]int),
DefaultIsEmpty: true,
}
@ -1039,9 +1044,9 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
}
if col.SQLType.Name == "" {
col.SQLType = core.Type2SQLType(fieldType)
col.SQLType = schemas.Type2SQLType(fieldType)
}
engine.dialect.SqlType(col)
engine.dialect.SQLType(col)
if col.Length == 0 {
col.Length = col.SQLType.DefaultLength
}
@ -1053,9 +1058,9 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
}
if ctx.isUnique {
ctx.indexNames[col.Name] = core.UniqueType
ctx.indexNames[col.Name] = schemas.UniqueType
} else if ctx.isIndex {
ctx.indexNames[col.Name] = core.IndexType
ctx.indexNames[col.Name] = schemas.IndexType
}
for indexName, indexType := range ctx.indexNames {
@ -1063,18 +1068,18 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
}
}
} else {
var sqlType core.SQLType
var sqlType schemas.SQLType
if fieldValue.CanAddr() {
if _, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
sqlType = core.SQLType{Name: core.Text}
if _, ok := fieldValue.Addr().Interface().(Conversion); ok {
sqlType = schemas.SQLType{Name: schemas.Text}
}
}
if _, ok := fieldValue.Interface().(core.Conversion); ok {
sqlType = core.SQLType{Name: core.Text}
if _, ok := fieldValue.Interface().(Conversion); ok {
sqlType = schemas.SQLType{Name: schemas.Text}
} else {
sqlType = core.Type2SQLType(fieldType)
sqlType = schemas.Type2SQLType(fieldType)
}
col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name),
col = schemas.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name),
t.Field(i).Name, sqlType, sqlType.DefaultLength,
sqlType.DefaultLength2, true)
@ -1105,7 +1110,7 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
engine.setCacher(table.Name, engine.Cacher)
} else {
engine.logger.Info("enable LRU cache on table:", table.Name)
engine.setCacher(table.Name, NewLRUCacher2(NewMemoryStore(), time.Hour, 10000))
engine.setCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000))
}
}
if hasNoCacheTag {
@ -1133,24 +1138,24 @@ func (engine *Engine) IsTableExist(beanOrTableName interface{}) (bool, error) {
// IdOf get id from one struct
//
// Deprecated: use IDOf instead.
func (engine *Engine) IdOf(bean interface{}) core.PK {
func (engine *Engine) IdOf(bean interface{}) schemas.PK {
return engine.IDOf(bean)
}
// IDOf get id from one struct
func (engine *Engine) IDOf(bean interface{}) core.PK {
func (engine *Engine) IDOf(bean interface{}) schemas.PK {
return engine.IdOfV(reflect.ValueOf(bean))
}
// IdOfV get id from one value of struct
//
// Deprecated: use IDOfV instead.
func (engine *Engine) IdOfV(rv reflect.Value) core.PK {
func (engine *Engine) IdOfV(rv reflect.Value) schemas.PK {
return engine.IDOfV(rv)
}
// IDOfV get id from one value of struct
func (engine *Engine) IDOfV(rv reflect.Value) core.PK {
func (engine *Engine) IDOfV(rv reflect.Value) schemas.PK {
pk, err := engine.idOfV(rv)
if err != nil {
engine.logger.Error(err)
@ -1159,7 +1164,7 @@ func (engine *Engine) IDOfV(rv reflect.Value) core.PK {
return pk
}
func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) {
func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) {
v := reflect.Indirect(rv)
table, err := engine.autoMapType(v)
if err != nil {
@ -1202,10 +1207,10 @@ func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) {
return nil, err
}
}
return core.PK(pk), nil
return schemas.PK(pk), nil
}
func (engine *Engine) idTypeAssertion(col *core.Column, sid string) (interface{}, error) {
func (engine *Engine) idTypeAssertion(col *schemas.Column, sid string) (interface{}, error) {
if col.SQLType.IsNumeric() {
n, err := strconv.ParseInt(sid, 10, 64)
if err != nil {
@ -1317,7 +1322,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
if err := session.statement.setRefBean(bean); err != nil {
return err
}
if index.Type == core.UniqueType {
if index.Type == schemas.UniqueType {
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true)
if err != nil {
return err
@ -1332,7 +1337,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
}
} else if index.Type == core.IndexType {
} else if index.Type == schemas.IndexType {
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false)
if err != nil {
return err
@ -1601,7 +1606,7 @@ func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) {
}
// nowTime return current time
func (engine *Engine) nowTime(col *core.Column) (interface{}, time.Time) {
func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time) {
t := time.Now()
var tz = engine.DatabaseTZ
if !col.DisableTimeZone && col.TimeZone != nil {
@ -1610,7 +1615,7 @@ func (engine *Engine) nowTime(col *core.Column) (interface{}, time.Time) {
return engine.formatTime(col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation)
}
func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{}) {
func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interface{}) {
if t.IsZero() {
if col.Nullable {
return nil
@ -1627,20 +1632,20 @@ func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{
// formatTime format time as column type
func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) {
switch sqlTypeName {
case core.Time:
case schemas.Time:
s := t.Format("2006-01-02 15:04:05") // time.RFC3339
v = s[11:19]
case core.Date:
case schemas.Date:
v = t.Format("2006-01-02")
case core.DateTime, core.TimeStamp, core.Varchar: // !DarthPestilane! format time when sqlTypeName is core.Varchar.
case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar.
v = t.Format("2006-01-02 15:04:05")
case core.TimeStampz:
if engine.dialect.DBType() == core.MSSQL {
case schemas.TimeStampz:
if engine.dialect.DBType() == schemas.MSSQL {
v = t.Format("2006-01-02T15:04:05.9999999Z07:00")
} else {
v = t.Format(time.RFC3339Nano)
}
case core.BigInt, core.Int:
case schemas.BigInt, schemas.Int:
v = t.Unix()
default:
v = t
@ -1649,12 +1654,12 @@ func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}
}
// GetColumnMapper returns the column name mapper
func (engine *Engine) GetColumnMapper() core.IMapper {
func (engine *Engine) GetColumnMapper() names.Mapper {
return engine.ColumnMapper
}
// GetTableMapper returns the table name mapper
func (engine *Engine) GetTableMapper() core.IMapper {
func (engine *Engine) GetTableMapper() names.Mapper {
return engine.TableMapper
}

View File

@ -12,10 +12,10 @@ import (
"time"
"xorm.io/builder"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
func (engine *Engine) buildConds(table *core.Table, bean interface{},
func (engine *Engine) buildConds(table *schemas.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) {
@ -31,7 +31,7 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{},
continue
}
if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) {
if engine.dialect.DBType() == schemas.MSSQL && (col.SQLType.Name == schemas.Text || col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) {
continue
}
if col.SQLType.IsJson() {
@ -130,13 +130,13 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{},
t := int64(fieldValue.Uint())
val = reflect.ValueOf(&t).Interface()
case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) {
t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue
}
val = engine.formatColTime(col, t)
} else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok {
} else if _, ok := reflect.New(fieldType).Interface().(Conversion); ok {
continue
} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = valNul.Value()

View File

@ -8,7 +8,9 @@ import (
"context"
"time"
"xorm.io/core"
"xorm.io/xorm/caches"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
)
// EngineGroup defines an engine group
@ -109,7 +111,7 @@ func (eg *EngineGroup) Ping() error {
}
// SetColumnMapper set the column name mapping rule
func (eg *EngineGroup) SetColumnMapper(mapper core.IMapper) {
func (eg *EngineGroup) SetColumnMapper(mapper names.Mapper) {
eg.Engine.ColumnMapper = mapper
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].ColumnMapper = mapper
@ -125,7 +127,7 @@ func (eg *EngineGroup) SetConnMaxLifetime(d time.Duration) {
}
// SetDefaultCacher set the default cacher
func (eg *EngineGroup) SetDefaultCacher(cacher core.Cacher) {
func (eg *EngineGroup) SetDefaultCacher(cacher caches.Cacher) {
eg.Engine.SetDefaultCacher(cacher)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetDefaultCacher(cacher)
@ -133,7 +135,7 @@ func (eg *EngineGroup) SetDefaultCacher(cacher core.Cacher) {
}
// SetLogger set the new logger
func (eg *EngineGroup) SetLogger(logger core.ILogger) {
func (eg *EngineGroup) SetLogger(logger log.Logger) {
eg.Engine.SetLogger(logger)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetLogger(logger)
@ -141,7 +143,7 @@ func (eg *EngineGroup) SetLogger(logger core.ILogger) {
}
// SetLogLevel sets the logger level
func (eg *EngineGroup) SetLogLevel(level core.LogLevel) {
func (eg *EngineGroup) SetLogLevel(level log.LogLevel) {
eg.Engine.SetLogLevel(level)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetLogLevel(level)
@ -149,7 +151,7 @@ func (eg *EngineGroup) SetLogLevel(level core.LogLevel) {
}
// SetMapper set the name mapping rules
func (eg *EngineGroup) SetMapper(mapper core.IMapper) {
func (eg *EngineGroup) SetMapper(mapper names.Mapper) {
eg.Engine.SetMapper(mapper)
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].SetMapper(mapper)
@ -179,7 +181,7 @@ func (eg *EngineGroup) SetPolicy(policy GroupPolicy) *EngineGroup {
}
// SetTableMapper set the table name mapping rule
func (eg *EngineGroup) SetTableMapper(mapper core.IMapper) {
func (eg *EngineGroup) SetTableMapper(mapper names.Mapper) {
eg.Engine.TableMapper = mapper
for i := 0; i < len(eg.slaves); i++ {
eg.slaves[i].TableMapper = mapper

View File

@ -9,16 +9,18 @@ import (
"reflect"
"strings"
"xorm.io/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
)
// tbNameWithSchema will automatically add schema prefix on table name
func (engine *Engine) tbNameWithSchema(v string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if engine.dialect.DBType() == core.POSTGRES &&
if engine.dialect.DBType() == schemas.POSTGRES &&
engine.dialect.URI().Schema != "" &&
engine.dialect.URI().Schema != postgresPublicSchema &&
engine.dialect.URI().Schema != dialects.PostgresPublicSchema &&
strings.Index(v, ".") == -1 {
return engine.dialect.URI().Schema + "." + v
}
@ -44,7 +46,7 @@ func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string
}
// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string {
func (session *Session) tbNameNoSchema(table *schemas.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}
@ -76,7 +78,7 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
table = getTableName(engine.TableMapper, v)
table = names.GetTableName(engine.TableMapper, v)
} else {
table = engine.Quote(fmt.Sprintf("%v", f))
}
@ -94,12 +96,12 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
return tablename.(string)
case reflect.Value:
v := tablename.(reflect.Value)
return getTableName(engine.TableMapper, v)
return names.GetTableName(engine.TableMapper, v)
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
return getTableName(engine.TableMapper, v)
return names.GetTableName(engine.TableMapper, v)
}
return engine.Quote(fmt.Sprintf("%v", tablename))
}

View File

@ -6,6 +6,7 @@ import (
_ "github.com/mattn/go-sqlite3"
"xorm.io/xorm"
"xorm.io/xorm/caches"
)
// User describes a user
@ -15,7 +16,7 @@ type User struct {
}
func main() {
f := "cache.db"
f := "caches.db"
os.Remove(f)
Orm, err := xorm.NewEngine("sqlite3", f)
@ -24,7 +25,7 @@ func main() {
return
}
Orm.ShowSQL(true)
cacher := xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)
cacher := caches.NewLRUCacher(caches.NewMemoryStore(), 1000)
Orm.SetDefaultCacher(cacher)
err = Orm.CreateTables(&User{})

View File

@ -8,6 +8,7 @@ import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"xorm.io/xorm"
"xorm.io/xorm/caches"
)
// User describes a user
@ -87,7 +88,7 @@ func main() {
return
}
engine.ShowSQL(true)
cacher := xorm.NewLRUCacher2(xorm.NewMemoryStore(), time.Hour, 1000)
cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 1000)
engine.SetDefaultCacher(cacher)
fmt.Println(engine)
test(engine)
@ -97,7 +98,7 @@ func main() {
fmt.Println("-----start mysql go routines-----")
engine, err = mysqlEngine()
engine.ShowSQL(true)
cacher = xorm.NewLRUCacher2(xorm.NewMemoryStore(), time.Hour, 1000)
cacher = caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 1000)
engine.SetDefaultCacher(cacher)
if err != nil {
fmt.Println(err)

2
go.mod
View File

@ -10,6 +10,6 @@ require (
github.com/mattn/go-sqlite3 v1.10.0
github.com/stretchr/testify v1.4.0
github.com/ziutek/mymysql v1.5.4
google.golang.org/appengine v1.6.0 // indirect
xorm.io/builder v0.3.6
xorm.io/core v0.7.2
)

2
go.sum
View File

@ -145,5 +145,3 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
xorm.io/builder v0.3.6 h1:ha28mQ2M+TFx96Hxo+iq6tQgnkC9IZkM6D8w9sKHHF8=
xorm.io/builder v0.3.6/go.mod h1:LEFAPISnRzG+zxaxj2vPicRwz67BdhFreKg8yv8/TgU=
xorm.io/core v0.7.2 h1:mEO22A2Z7a3fPaZMk6gKL/jMD80iiyNwRrX5HOv3XLw=
xorm.io/core v0.7.2/go.mod h1:jJfd0UAEzZ4t87nbQYtVjmqpIODugN6PD2D9E+dJvdM=

View File

@ -12,7 +12,7 @@ import (
"strconv"
"strings"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
// str2PK convert string value to primary key value according to tp
@ -95,26 +95,6 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) {
return v.Interface(), nil
}
func splitTag(tag string) (tags []string) {
tag = strings.TrimSpace(tag)
var hasQuote = false
var lastIdx = 0
for i, t := range tag {
if t == '\'' {
hasQuote = !hasQuote
} else if t == ' ' {
if lastIdx < i && !hasQuote {
tags = append(tags, strings.TrimSpace(tag[lastIdx:i]))
lastIdx = i + 1
}
}
}
if lastIdx < len(tag) {
tags = append(tags, strings.TrimSpace(tag[lastIdx:]))
}
return
}
type zeroable interface {
IsZero() bool
}
@ -249,7 +229,7 @@ func int64ToInt(id int64, tp reflect.Type) interface{} {
return int64ToIntValue(id, tp).Interface()
}
func isPKZero(pk core.PK) bool {
func isPKZero(pk schemas.PK) bool {
for _, k := range pk {
if isZero(k) {
return true

View File

@ -10,7 +10,11 @@ import (
"reflect"
"time"
"xorm.io/core"
"xorm.io/xorm/caches"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
)
// Interface defines the interface which Engine, EngineGroup and Session will implementate.
@ -76,31 +80,31 @@ type EngineInterface interface {
ClearCache(...interface{}) error
Context(context.Context) *Session
CreateTables(...interface{}) error
DBMetas() ([]*core.Table, error)
Dialect() core.Dialect
DBMetas() ([]*schemas.Table, error)
Dialect() dialects.Dialect
DropTables(...interface{}) error
DumpAllToFile(fp string, tp ...core.DbType) error
GetCacher(string) core.Cacher
GetColumnMapper() core.IMapper
GetDefaultCacher() core.Cacher
GetTableMapper() core.IMapper
DumpAllToFile(fp string, tp ...dialects.DBType) error
GetCacher(string) caches.Cacher
GetColumnMapper() names.Mapper
GetDefaultCacher() caches.Cacher
GetTableMapper() names.Mapper
GetTZDatabase() *time.Location
GetTZLocation() *time.Location
MapCacher(interface{}, core.Cacher) error
MapCacher(interface{}, caches.Cacher) error
NewSession() *Session
NoAutoTime() *Session
Quote(string) string
SetCacher(string, core.Cacher)
SetCacher(string, caches.Cacher)
SetConnMaxLifetime(time.Duration)
SetColumnMapper(core.IMapper)
SetDefaultCacher(core.Cacher)
SetLogger(logger core.ILogger)
SetLogLevel(core.LogLevel)
SetMapper(core.IMapper)
SetColumnMapper(names.Mapper)
SetDefaultCacher(caches.Cacher)
SetLogger(logger log.Logger)
SetLogLevel(log.LogLevel)
SetMapper(names.Mapper)
SetMaxOpenConns(int)
SetMaxIdleConns(int)
SetSchema(string)
SetTableMapper(core.IMapper)
SetTableMapper(names.Mapper)
SetTZDatabase(tz *time.Location)
SetTZLocation(tz *time.Location)
ShowExecTime(...bool)

View File

@ -2,26 +2,56 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package log
import (
"fmt"
"io"
"log"
)
"xorm.io/core"
// LogLevel defines a log level
type LogLevel int
// enumerate all LogLevels
const (
// !nashtsai! following level also match syslog.Priority value
LOG_DEBUG LogLevel = iota
LOG_INFO
LOG_WARNING
LOG_ERR
LOG_OFF
LOG_UNKNOWN
)
// default log options
const (
DEFAULT_LOG_PREFIX = "[xorm]"
DEFAULT_LOG_FLAG = log.Ldate | log.Lmicroseconds
DEFAULT_LOG_LEVEL = core.LOG_DEBUG
DEFAULT_LOG_LEVEL = LOG_DEBUG
)
var _ core.ILogger = DiscardLogger{}
// Logger is a logger interface
type Logger interface {
Debug(v ...interface{})
Debugf(format string, v ...interface{})
Error(v ...interface{})
Errorf(format string, v ...interface{})
Info(v ...interface{})
Infof(format string, v ...interface{})
Warn(v ...interface{})
Warnf(format string, v ...interface{})
// DiscardLogger don't log implementation for core.ILogger
Level() LogLevel
SetLevel(l LogLevel)
ShowSQL(show ...bool)
IsShowSQL() bool
}
var _ Logger = DiscardLogger{}
// DiscardLogger don't log implementation for ILogger
type DiscardLogger struct{}
// Debug empty implementation
@ -49,12 +79,12 @@ func (DiscardLogger) Warn(v ...interface{}) {}
func (DiscardLogger) Warnf(format string, v ...interface{}) {}
// Level empty implementation
func (DiscardLogger) Level() core.LogLevel {
return core.LOG_UNKNOWN
func (DiscardLogger) Level() LogLevel {
return LOG_UNKNOWN
}
// SetLevel empty implementation
func (DiscardLogger) SetLevel(l core.LogLevel) {}
func (DiscardLogger) SetLevel(l LogLevel) {}
// ShowSQL empty implementation
func (DiscardLogger) ShowSQL(show ...bool) {}
@ -64,17 +94,17 @@ func (DiscardLogger) IsShowSQL() bool {
return false
}
// SimpleLogger is the default implment of core.ILogger
// SimpleLogger is the default implment of ILogger
type SimpleLogger struct {
DEBUG *log.Logger
ERR *log.Logger
INFO *log.Logger
WARN *log.Logger
level core.LogLevel
level LogLevel
showSQL bool
}
var _ core.ILogger = &SimpleLogger{}
var _ Logger = &SimpleLogger{}
// NewSimpleLogger use a special io.Writer as logger output
func NewSimpleLogger(out io.Writer) *SimpleLogger {
@ -87,7 +117,7 @@ func NewSimpleLogger2(out io.Writer, prefix string, flag int) *SimpleLogger {
}
// NewSimpleLogger3 let you customrize your logger prefix and flag and logLevel
func NewSimpleLogger3(out io.Writer, prefix string, flag int, l core.LogLevel) *SimpleLogger {
func NewSimpleLogger3(out io.Writer, prefix string, flag int, l LogLevel) *SimpleLogger {
return &SimpleLogger{
DEBUG: log.New(out, fmt.Sprintf("%s [debug] ", prefix), flag),
ERR: log.New(out, fmt.Sprintf("%s [error] ", prefix), flag),
@ -97,82 +127,82 @@ func NewSimpleLogger3(out io.Writer, prefix string, flag int, l core.LogLevel) *
}
}
// Error implement core.ILogger
// Error implement ILogger
func (s *SimpleLogger) Error(v ...interface{}) {
if s.level <= core.LOG_ERR {
if s.level <= LOG_ERR {
s.ERR.Output(2, fmt.Sprint(v...))
}
return
}
// Errorf implement core.ILogger
// Errorf implement ILogger
func (s *SimpleLogger) Errorf(format string, v ...interface{}) {
if s.level <= core.LOG_ERR {
if s.level <= LOG_ERR {
s.ERR.Output(2, fmt.Sprintf(format, v...))
}
return
}
// Debug implement core.ILogger
// Debug implement ILogger
func (s *SimpleLogger) Debug(v ...interface{}) {
if s.level <= core.LOG_DEBUG {
if s.level <= LOG_DEBUG {
s.DEBUG.Output(2, fmt.Sprint(v...))
}
return
}
// Debugf implement core.ILogger
// Debugf implement ILogger
func (s *SimpleLogger) Debugf(format string, v ...interface{}) {
if s.level <= core.LOG_DEBUG {
if s.level <= LOG_DEBUG {
s.DEBUG.Output(2, fmt.Sprintf(format, v...))
}
return
}
// Info implement core.ILogger
// Info implement ILogger
func (s *SimpleLogger) Info(v ...interface{}) {
if s.level <= core.LOG_INFO {
if s.level <= LOG_INFO {
s.INFO.Output(2, fmt.Sprint(v...))
}
return
}
// Infof implement core.ILogger
// Infof implement ILogger
func (s *SimpleLogger) Infof(format string, v ...interface{}) {
if s.level <= core.LOG_INFO {
if s.level <= LOG_INFO {
s.INFO.Output(2, fmt.Sprintf(format, v...))
}
return
}
// Warn implement core.ILogger
// Warn implement ILogger
func (s *SimpleLogger) Warn(v ...interface{}) {
if s.level <= core.LOG_WARNING {
if s.level <= LOG_WARNING {
s.WARN.Output(2, fmt.Sprint(v...))
}
return
}
// Warnf implement core.ILogger
// Warnf implement ILogger
func (s *SimpleLogger) Warnf(format string, v ...interface{}) {
if s.level <= core.LOG_WARNING {
if s.level <= LOG_WARNING {
s.WARN.Output(2, fmt.Sprintf(format, v...))
}
return
}
// Level implement core.ILogger
func (s *SimpleLogger) Level() core.LogLevel {
// Level implement ILogger
func (s *SimpleLogger) Level() LogLevel {
return s.level
}
// SetLevel implement core.ILogger
func (s *SimpleLogger) SetLevel(l core.LogLevel) {
// SetLevel implement ILogger
func (s *SimpleLogger) SetLevel(l LogLevel) {
s.level = l
return
}
// ShowSQL implement core.ILogger
// ShowSQL implement ILogger
func (s *SimpleLogger) ShowSQL(show ...bool) {
if len(show) == 0 {
s.showSQL = true
@ -181,7 +211,7 @@ func (s *SimpleLogger) ShowSQL(show ...bool) {
s.showSQL = show[0]
}
// IsShowSQL implement core.ILogger
// IsShowSQL implement ILogger
func (s *SimpleLogger) IsShowSQL() bool {
return s.showSQL
}

View File

@ -4,16 +4,14 @@
// +build !windows,!nacl,!plan9
package xorm
package log
import (
"fmt"
"log/syslog"
"xorm.io/core"
)
var _ core.ILogger = &SyslogLogger{}
var _ Logger = &SyslogLogger{}
// SyslogLogger will be depricated
type SyslogLogger struct {
@ -21,7 +19,7 @@ type SyslogLogger struct {
showSQL bool
}
// NewSyslogLogger implements core.ILogger
// NewSyslogLogger implements Logger
func NewSyslogLogger(w *syslog.Writer) *SyslogLogger {
return &SyslogLogger{w: w}
}
@ -67,12 +65,12 @@ func (s *SyslogLogger) Warnf(format string, v ...interface{}) {
}
// Level shows log level
func (s *SyslogLogger) Level() core.LogLevel {
return core.LOG_UNKNOWN
func (s *SyslogLogger) Level() LogLevel {
return LOG_UNKNOWN
}
// SetLevel always return error, as current log/syslog package doesn't allow to set priority level after syslog.Writer created
func (s *SyslogLogger) SetLevel(l core.LogLevel) {}
func (s *SyslogLogger) SetLevel(l LogLevel) {}
// ShowSQL set if logging SQL
func (s *SyslogLogger) ShowSQL(show ...bool) {

View File

@ -13,7 +13,7 @@ type MigrateFunc func(*xorm.Engine) error
// RollbackFunc is the func signature for rollbacking.
type RollbackFunc func(*xorm.Engine) error
// InitSchemaFunc is the func signature for initializing the schema.
// InitSchemaFunc is the func signature for initializing the schemas.
type InitSchemaFunc func(*xorm.Engine) error
// Options define options for all migrations.
@ -34,7 +34,7 @@ type Migration struct {
Rollback RollbackFunc
}
// Migrate represents a collection of all migrations of a database schema.
// Migrate represents a collection of all migrations of a database schemas.
type Migrate struct {
db *xorm.Engine
options *Options

258
names/mapper.go Normal file
View File

@ -0,0 +1,258 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package names
import (
"strings"
"sync"
)
// Mapper represents a name convertation between struct's fields name and table's column name
type Mapper interface {
Obj2Table(string) string
Table2Obj(string) string
}
type CacheMapper struct {
oriMapper Mapper
obj2tableCache map[string]string
obj2tableMutex sync.RWMutex
table2objCache map[string]string
table2objMutex sync.RWMutex
}
func NewCacheMapper(mapper Mapper) *CacheMapper {
return &CacheMapper{oriMapper: mapper, obj2tableCache: make(map[string]string),
table2objCache: make(map[string]string),
}
}
func (m *CacheMapper) Obj2Table(o string) string {
m.obj2tableMutex.RLock()
t, ok := m.obj2tableCache[o]
m.obj2tableMutex.RUnlock()
if ok {
return t
}
t = m.oriMapper.Obj2Table(o)
m.obj2tableMutex.Lock()
m.obj2tableCache[o] = t
m.obj2tableMutex.Unlock()
return t
}
func (m *CacheMapper) Table2Obj(t string) string {
m.table2objMutex.RLock()
o, ok := m.table2objCache[t]
m.table2objMutex.RUnlock()
if ok {
return o
}
o = m.oriMapper.Table2Obj(t)
m.table2objMutex.Lock()
m.table2objCache[t] = o
m.table2objMutex.Unlock()
return o
}
// SameMapper implements IMapper and provides same name between struct and
// database table
type SameMapper struct {
}
func (m SameMapper) Obj2Table(o string) string {
return o
}
func (m SameMapper) Table2Obj(t string) string {
return t
}
// SnakeMapper implements IMapper and provides name transaltion between
// struct and database table
type SnakeMapper struct {
}
func snakeCasedName(name string) string {
newstr := make([]rune, 0)
for idx, chr := range name {
if isUpper := 'A' <= chr && chr <= 'Z'; isUpper {
if idx > 0 {
newstr = append(newstr, '_')
}
chr -= ('A' - 'a')
}
newstr = append(newstr, chr)
}
return string(newstr)
}
func (mapper SnakeMapper) Obj2Table(name string) string {
return snakeCasedName(name)
}
func titleCasedName(name string) string {
newstr := make([]rune, 0)
upNextChar := true
name = strings.ToLower(name)
for _, chr := range name {
switch {
case upNextChar:
upNextChar = false
if 'a' <= chr && chr <= 'z' {
chr -= ('a' - 'A')
}
case chr == '_':
upNextChar = true
continue
}
newstr = append(newstr, chr)
}
return string(newstr)
}
func (mapper SnakeMapper) Table2Obj(name string) string {
return titleCasedName(name)
}
// GonicMapper implements IMapper. It will consider initialisms when mapping names.
// E.g. id -> ID, user -> User and to table names: UserID -> user_id, MyUID -> my_uid
type GonicMapper map[string]bool
func isASCIIUpper(r rune) bool {
return 'A' <= r && r <= 'Z'
}
func toASCIIUpper(r rune) rune {
if 'a' <= r && r <= 'z' {
r -= ('a' - 'A')
}
return r
}
func gonicCasedName(name string) string {
newstr := make([]rune, 0, len(name)+3)
for idx, chr := range name {
if isASCIIUpper(chr) && idx > 0 {
if !isASCIIUpper(newstr[len(newstr)-1]) {
newstr = append(newstr, '_')
}
}
if !isASCIIUpper(chr) && idx > 1 {
l := len(newstr)
if isASCIIUpper(newstr[l-1]) && isASCIIUpper(newstr[l-2]) {
newstr = append(newstr, newstr[l-1])
newstr[l-1] = '_'
}
}
newstr = append(newstr, chr)
}
return strings.ToLower(string(newstr))
}
func (mapper GonicMapper) Obj2Table(name string) string {
return gonicCasedName(name)
}
func (mapper GonicMapper) Table2Obj(name string) string {
newstr := make([]rune, 0)
name = strings.ToLower(name)
parts := strings.Split(name, "_")
for _, p := range parts {
_, isInitialism := mapper[strings.ToUpper(p)]
for i, r := range p {
if i == 0 || isInitialism {
r = toASCIIUpper(r)
}
newstr = append(newstr, r)
}
}
return string(newstr)
}
// LintGonicMapper is A GonicMapper that contains a list of common initialisms taken from golang/lint
var LintGonicMapper = GonicMapper{
"API": true,
"ASCII": true,
"CPU": true,
"CSS": true,
"DNS": true,
"EOF": true,
"GUID": true,
"HTML": true,
"HTTP": true,
"HTTPS": true,
"ID": true,
"IP": true,
"JSON": true,
"LHS": true,
"QPS": true,
"RAM": true,
"RHS": true,
"RPC": true,
"SLA": true,
"SMTP": true,
"SSH": true,
"TLS": true,
"TTL": true,
"UI": true,
"UID": true,
"UUID": true,
"URI": true,
"URL": true,
"UTF8": true,
"VM": true,
"XML": true,
"XSRF": true,
"XSS": true,
}
// PrefixMapper provides prefix table name support
type PrefixMapper struct {
Mapper Mapper
Prefix string
}
func (mapper PrefixMapper) Obj2Table(name string) string {
return mapper.Prefix + mapper.Mapper.Obj2Table(name)
}
func (mapper PrefixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):])
}
func NewPrefixMapper(mapper Mapper, prefix string) PrefixMapper {
return PrefixMapper{mapper, prefix}
}
// SuffixMapper provides suffix table name support
type SuffixMapper struct {
Mapper Mapper
Suffix string
}
func (mapper SuffixMapper) Obj2Table(name string) string {
return mapper.Mapper.Obj2Table(name) + mapper.Suffix
}
func (mapper SuffixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)])
}
func NewSuffixMapper(mapper Mapper, suffix string) SuffixMapper {
return SuffixMapper{mapper, suffix}
}

49
names/mapper_test.go Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package names
import (
"testing"
)
func TestGonicMapperFromObj(t *testing.T) {
testCases := map[string]string{
"HTTPLib": "http_lib",
"id": "id",
"ID": "id",
"IDa": "i_da",
"iDa": "i_da",
"IDAa": "id_aa",
"aID": "a_id",
"aaID": "aa_id",
"aaaID": "aaa_id",
"MyREalFunkYLONgNAME": "my_r_eal_funk_ylo_ng_name",
}
for in, expected := range testCases {
out := gonicCasedName(in)
if out != expected {
t.Errorf("Given %s, expected %s but got %s", in, expected, out)
}
}
}
func TestGonicMapperToObj(t *testing.T) {
testCases := map[string]string{
"http_lib": "HTTPLib",
"id": "ID",
"ida": "Ida",
"id_aa": "IDAa",
"aa_id": "AaID",
"my_r_eal_funk_ylo_ng_name": "MyREalFunkYloNgName",
}
for in, expected := range testCases {
out := LintGonicMapper.Table2Obj(in)
if out != expected {
t.Errorf("Given %s, expected %s but got %s", in, expected, out)
}
}
}

View File

@ -2,15 +2,22 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package names
import (
"reflect"
"xorm.io/core"
)
func getTableName(mapper core.IMapper, v reflect.Value) string {
// TableName table name interface to define customerize table name
type TableName interface {
TableName() string
}
var (
tpTableName = reflect.TypeOf((*TableName)(nil)).Elem()
)
func GetTableName(mapper Mapper, v reflect.Value) string {
if t, ok := v.Interface().(TableName); ok {
return t.TableName()
}

View File

@ -2,17 +2,45 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
package names
import (
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
"xorm.io/core"
)
type Userinfo struct {
Uid int64 `xorm:"id pk not null autoincr"`
Username string `xorm:"unique"`
Departname string
Alias string `xorm:"-"`
Created time.Time
Detail Userdetail `xorm:"detail_id int(11)"`
Height float64
Avatar []byte
IsMan bool
}
type Userdetail struct {
Id int64
Intro string `xorm:"text"`
Profile string `xorm:"varchar(2000)"`
}
type MyGetCustomTableImpletation struct {
Id int64 `json:"id"`
Name string `json:"name"`
}
const getCustomTableName = "GetCustomTableInterface"
func (MyGetCustomTableImpletation) TableName() string {
return getCustomTableName
}
type TestTableNameStruct struct{}
func (t *TestTableNameStruct) TableName() string {
@ -21,53 +49,53 @@ func (t *TestTableNameStruct) TableName() string {
func TestGetTableName(t *testing.T) {
var kases = []struct {
mapper core.IMapper
mapper Mapper
v reflect.Value
expectedTableName string
}{
{
core.SnakeMapper{},
SnakeMapper{},
reflect.ValueOf(new(Userinfo)),
"userinfo",
},
{
core.SnakeMapper{},
SnakeMapper{},
reflect.ValueOf(Userinfo{}),
"userinfo",
},
{
core.SameMapper{},
SameMapper{},
reflect.ValueOf(new(Userinfo)),
"Userinfo",
},
{
core.SameMapper{},
SameMapper{},
reflect.ValueOf(Userinfo{}),
"Userinfo",
},
{
core.SnakeMapper{},
SnakeMapper{},
reflect.ValueOf(new(MyGetCustomTableImpletation)),
getCustomTableName,
},
{
core.SnakeMapper{},
SnakeMapper{},
reflect.ValueOf(MyGetCustomTableImpletation{}),
getCustomTableName,
},
{
core.SnakeMapper{},
SnakeMapper{},
reflect.ValueOf(MyGetCustomTableImpletation{}),
getCustomTableName,
},
{
core.SnakeMapper{},
SnakeMapper{},
reflect.ValueOf(new(TestTableNameStruct)),
new(TestTableNameStruct).TableName(),
},
}
for _, kase := range kases {
assert.EqualValues(t, kase.expectedTableName, getTableName(kase.mapper, kase.v))
assert.EqualValues(t, kase.expectedTableName, GetTableName(kase.mapper, kase.v))
}
}

View File

@ -9,7 +9,7 @@ import (
"fmt"
"reflect"
"xorm.io/core"
"xorm.io/xorm/core"
)
// Rows rows wrapper a rows to

117
schemas/column.go Normal file
View File

@ -0,0 +1,117 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"fmt"
"reflect"
"strings"
"time"
)
const (
TWOSIDES = iota + 1
ONLYTODB
ONLYFROMDB
)
// Column defines database column
type Column struct {
Name string
TableName string
FieldName string
SQLType SQLType
IsJSON bool
Length int
Length2 int
Nullable bool
Default string
Indexes map[string]int
IsPrimaryKey bool
IsAutoIncrement bool
MapType int
IsCreated bool
IsUpdated bool
IsDeleted bool
IsCascade bool
IsVersion bool
DefaultIsEmpty bool // false means column has no default set, but not default value is empty
EnumOptions map[string]int
SetOptions map[string]int
DisableTimeZone bool
TimeZone *time.Location // column specified time zone
Comment string
}
// NewColumn creates a new column
func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable bool) *Column {
return &Column{
Name: name,
TableName: "",
FieldName: fieldName,
SQLType: sqlType,
Length: len1,
Length2: len2,
Nullable: nullable,
Default: "",
Indexes: make(map[string]int),
IsPrimaryKey: false,
IsAutoIncrement: false,
MapType: TWOSIDES,
IsCreated: false,
IsUpdated: false,
IsDeleted: false,
IsCascade: false,
IsVersion: false,
DefaultIsEmpty: true, // default should be no default
EnumOptions: make(map[string]int),
Comment: "",
}
}
// ValueOf returns column's filed of struct's value
func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) {
dataStruct := reflect.Indirect(reflect.ValueOf(bean))
return col.ValueOfV(&dataStruct)
}
// ValueOfV returns column's filed of struct's value accept reflevt value
func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
var fieldValue reflect.Value
fieldPath := strings.Split(col.FieldName, ".")
if dataStruct.Type().Kind() == reflect.Map {
keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1])
fieldValue = dataStruct.MapIndex(keyValue)
return &fieldValue, nil
} else if dataStruct.Type().Kind() == reflect.Interface {
structValue := reflect.ValueOf(dataStruct.Interface())
dataStruct = &structValue
}
level := len(fieldPath)
fieldValue = dataStruct.FieldByName(fieldPath[0])
for i := 0; i < level-1; i++ {
if !fieldValue.IsValid() {
break
}
if fieldValue.Kind() == reflect.Struct {
fieldValue = fieldValue.FieldByName(fieldPath[i+1])
} else if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
}
fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1])
} else {
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
}
}
if !fieldValue.IsValid() {
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
}
return &fieldValue, nil
}

72
schemas/index.go Normal file
View File

@ -0,0 +1,72 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"fmt"
"strings"
)
// enumerate all index types
const (
IndexType = iota + 1
UniqueType
)
// Index represents a database index
type Index struct {
IsRegular bool
Name string
Type int
Cols []string
}
func (index *Index) XName(tableName string) string {
if !strings.HasPrefix(index.Name, "UQE_") &&
!strings.HasPrefix(index.Name, "IDX_") {
tableParts := strings.Split(strings.Replace(tableName, `"`, "", -1), ".")
tableName = tableParts[len(tableParts)-1]
if index.Type == UniqueType {
return fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
}
return fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
}
return index.Name
}
// AddColumn add columns which will be composite index
func (index *Index) AddColumn(cols ...string) {
for _, col := range cols {
index.Cols = append(index.Cols, col)
}
}
func (index *Index) Equal(dst *Index) bool {
if index.Type != dst.Type {
return false
}
if len(index.Cols) != len(dst.Cols) {
return false
}
for i := 0; i < len(index.Cols); i++ {
var found bool
for j := 0; j < len(dst.Cols); j++ {
if index.Cols[i] == dst.Cols[j] {
found = true
break
}
}
if !found {
return false
}
}
return true
}
// NewIndex new an index object
func NewIndex(name string, indexType int) *Index {
return &Index{true, name, indexType, make([]string, 0)}
}

30
schemas/pk.go Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"bytes"
"encoding/gob"
)
type PK []interface{}
func NewPK(pks ...interface{}) *PK {
p := PK(pks)
return &p
}
func (p *PK) ToString() (string, error) {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
err := enc.Encode(*p)
return buf.String(), err
}
func (p *PK) FromString(content string) error {
dec := gob.NewDecoder(bytes.NewBufferString(content))
err := dec.Decode(p)
return err
}

36
schemas/pk_test.go Normal file
View File

@ -0,0 +1,36 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"reflect"
"testing"
)
func TestPK(t *testing.T) {
p := NewPK(1, 3, "string")
str, err := p.ToString()
if err != nil {
t.Error(err)
}
t.Log(str)
s := &PK{}
err = s.FromString(str)
if err != nil {
t.Error(err)
}
t.Log(s)
if len(*p) != len(*s) {
t.Fatal("p", *p, "should be equal", *s)
}
for i, ori := range *p {
if ori != (*s)[i] {
t.Fatal("ori", ori, reflect.ValueOf(ori), "should be equal", (*s)[i], reflect.ValueOf((*s)[i]))
}
}
}

156
schemas/table.go Normal file
View File

@ -0,0 +1,156 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"reflect"
"strings"
//"xorm.io/xorm/cache"
)
// Table represents a database table
type Table struct {
Name string
Type reflect.Type
columnsSeq []string
columnsMap map[string][]*Column
columns []*Column
Indexes map[string]*Index
PrimaryKeys []string
AutoIncrement string
Created map[string]bool
Updated string
Deleted string
Version string
//Cacher caches.Cacher
StoreEngine string
Charset string
Comment string
}
func (table *Table) Columns() []*Column {
return table.columns
}
func (table *Table) ColumnsSeq() []string {
return table.columnsSeq
}
func NewEmptyTable() *Table {
return NewTable("", nil)
}
// NewTable creates a new Table object
func NewTable(name string, t reflect.Type) *Table {
return &Table{Name: name, Type: t,
columnsSeq: make([]string, 0),
columns: make([]*Column, 0),
columnsMap: make(map[string][]*Column),
Indexes: make(map[string]*Index),
Created: make(map[string]bool),
PrimaryKeys: make([]string, 0),
}
}
func (table *Table) columnsByName(name string) []*Column {
n := len(name)
for k := range table.columnsMap {
if len(k) != n {
continue
}
if strings.EqualFold(k, name) {
return table.columnsMap[k]
}
}
return nil
}
func (table *Table) GetColumn(name string) *Column {
cols := table.columnsByName(name)
if cols != nil {
return cols[0]
}
return nil
}
func (table *Table) GetColumnIdx(name string, idx int) *Column {
cols := table.columnsByName(name)
if cols != nil && idx < len(cols) {
return cols[idx]
}
return nil
}
// PKColumns reprents all primary key columns
func (table *Table) PKColumns() []*Column {
columns := make([]*Column, len(table.PrimaryKeys))
for i, name := range table.PrimaryKeys {
columns[i] = table.GetColumn(name)
}
return columns
}
func (table *Table) ColumnType(name string) reflect.Type {
t, _ := table.Type.FieldByName(name)
return t.Type
}
func (table *Table) AutoIncrColumn() *Column {
return table.GetColumn(table.AutoIncrement)
}
func (table *Table) VersionColumn() *Column {
return table.GetColumn(table.Version)
}
func (table *Table) UpdatedColumn() *Column {
return table.GetColumn(table.Updated)
}
func (table *Table) DeletedColumn() *Column {
return table.GetColumn(table.Deleted)
}
// AddColumn adds a column to table
func (table *Table) AddColumn(col *Column) {
table.columnsSeq = append(table.columnsSeq, col.Name)
table.columns = append(table.columns, col)
colName := strings.ToLower(col.Name)
if c, ok := table.columnsMap[colName]; ok {
table.columnsMap[colName] = append(c, col)
} else {
table.columnsMap[colName] = []*Column{col}
}
if col.IsPrimaryKey {
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
}
if col.IsAutoIncrement {
table.AutoIncrement = col.Name
}
if col.IsCreated {
table.Created[col.Name] = true
}
if col.IsUpdated {
table.Updated = col.Name
}
if col.IsDeleted {
table.Deleted = col.Name
}
if col.IsVersion {
table.Version = col.Name
}
}
// AddIndex adds an index or an unique to table
func (table *Table) AddIndex(index *Index) {
table.Indexes[index.Name] = index
}

111
schemas/table_test.go Normal file
View File

@ -0,0 +1,111 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"strings"
"testing"
)
var testsGetColumn = []struct {
name string
idx int
}{
{"Id", 0},
{"Deleted", 0},
{"Caption", 0},
{"Code_1", 0},
{"Code_2", 0},
{"Code_3", 0},
{"Parent_Id", 0},
{"Latitude", 0},
{"Longitude", 0},
}
var table *Table
func init() {
table = NewEmptyTable()
var name string
for i := 0; i < len(testsGetColumn); i++ {
// as in Table.AddColumn func
name = strings.ToLower(testsGetColumn[i].name)
table.columnsMap[name] = append(table.columnsMap[name], &Column{})
}
}
func TestGetColumn(t *testing.T) {
for _, test := range testsGetColumn {
if table.GetColumn(test.name) == nil {
t.Error("Column not found!")
}
}
}
func TestGetColumnIdx(t *testing.T) {
for _, test := range testsGetColumn {
if table.GetColumnIdx(test.name, test.idx) == nil {
t.Errorf("Column %s with idx %d not found!", test.name, test.idx)
}
}
}
func BenchmarkGetColumnWithToLower(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, test := range testsGetColumn {
if _, ok := table.columnsMap[strings.ToLower(test.name)]; !ok {
b.Errorf("Column not found:%s", test.name)
}
}
}
}
func BenchmarkGetColumnIdxWithToLower(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, test := range testsGetColumn {
if c, ok := table.columnsMap[strings.ToLower(test.name)]; ok {
if test.idx < len(c) {
continue
} else {
b.Errorf("Bad idx in: %s, %d", test.name, test.idx)
}
} else {
b.Errorf("Column not found: %s, %d", test.name, test.idx)
}
}
}
}
func BenchmarkGetColumn(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, test := range testsGetColumn {
if table.GetColumn(test.name) == nil {
b.Errorf("Column not found:%s", test.name)
}
}
}
}
func BenchmarkGetColumnIdx(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, test := range testsGetColumn {
if table.GetColumnIdx(test.name, test.idx) == nil {
b.Errorf("Column not found:%s, %d", test.name, test.idx)
}
}
}
}

325
schemas/type.go Normal file
View File

@ -0,0 +1,325 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
import (
"reflect"
"sort"
"strings"
"time"
)
const (
POSTGRES = "postgres"
SQLITE = "sqlite3"
MYSQL = "mysql"
MSSQL = "mssql"
ORACLE = "oracle"
)
// SQLType represents SQL types
type SQLType struct {
Name string
DefaultLength int
DefaultLength2 int
}
const (
UNKNOW_TYPE = iota
TEXT_TYPE
BLOB_TYPE
TIME_TYPE
NUMERIC_TYPE
)
func (s *SQLType) IsType(st int) bool {
if t, ok := SqlTypes[s.Name]; ok && t == st {
return true
}
return false
}
func (s *SQLType) IsText() bool {
return s.IsType(TEXT_TYPE)
}
func (s *SQLType) IsBlob() bool {
return s.IsType(BLOB_TYPE)
}
func (s *SQLType) IsTime() bool {
return s.IsType(TIME_TYPE)
}
func (s *SQLType) IsNumeric() bool {
return s.IsType(NUMERIC_TYPE)
}
func (s *SQLType) IsJson() bool {
return s.Name == Json || s.Name == Jsonb
}
var (
Bit = "BIT"
TinyInt = "TINYINT"
SmallInt = "SMALLINT"
MediumInt = "MEDIUMINT"
Int = "INT"
Integer = "INTEGER"
BigInt = "BIGINT"
Enum = "ENUM"
Set = "SET"
Char = "CHAR"
Varchar = "VARCHAR"
NChar = "NCHAR"
NVarchar = "NVARCHAR"
TinyText = "TINYTEXT"
Text = "TEXT"
NText = "NTEXT"
Clob = "CLOB"
MediumText = "MEDIUMTEXT"
LongText = "LONGTEXT"
Uuid = "UUID"
UniqueIdentifier = "UNIQUEIDENTIFIER"
SysName = "SYSNAME"
Date = "DATE"
DateTime = "DATETIME"
SmallDateTime = "SMALLDATETIME"
Time = "TIME"
TimeStamp = "TIMESTAMP"
TimeStampz = "TIMESTAMPZ"
Year = "YEAR"
Decimal = "DECIMAL"
Numeric = "NUMERIC"
Money = "MONEY"
SmallMoney = "SMALLMONEY"
Real = "REAL"
Float = "FLOAT"
Double = "DOUBLE"
Binary = "BINARY"
VarBinary = "VARBINARY"
TinyBlob = "TINYBLOB"
Blob = "BLOB"
MediumBlob = "MEDIUMBLOB"
LongBlob = "LONGBLOB"
Bytea = "BYTEA"
Bool = "BOOL"
Boolean = "BOOLEAN"
Serial = "SERIAL"
BigSerial = "BIGSERIAL"
Json = "JSON"
Jsonb = "JSONB"
SqlTypes = map[string]int{
Bit: NUMERIC_TYPE,
TinyInt: NUMERIC_TYPE,
SmallInt: NUMERIC_TYPE,
MediumInt: NUMERIC_TYPE,
Int: NUMERIC_TYPE,
Integer: NUMERIC_TYPE,
BigInt: NUMERIC_TYPE,
Enum: TEXT_TYPE,
Set: TEXT_TYPE,
Json: TEXT_TYPE,
Jsonb: TEXT_TYPE,
Char: TEXT_TYPE,
NChar: TEXT_TYPE,
Varchar: TEXT_TYPE,
NVarchar: TEXT_TYPE,
TinyText: TEXT_TYPE,
Text: TEXT_TYPE,
NText: TEXT_TYPE,
MediumText: TEXT_TYPE,
LongText: TEXT_TYPE,
Uuid: TEXT_TYPE,
Clob: TEXT_TYPE,
SysName: TEXT_TYPE,
Date: TIME_TYPE,
DateTime: TIME_TYPE,
Time: TIME_TYPE,
TimeStamp: TIME_TYPE,
TimeStampz: TIME_TYPE,
SmallDateTime: TIME_TYPE,
Year: TIME_TYPE,
Decimal: NUMERIC_TYPE,
Numeric: NUMERIC_TYPE,
Real: NUMERIC_TYPE,
Float: NUMERIC_TYPE,
Double: NUMERIC_TYPE,
Money: NUMERIC_TYPE,
SmallMoney: NUMERIC_TYPE,
Binary: BLOB_TYPE,
VarBinary: BLOB_TYPE,
TinyBlob: BLOB_TYPE,
Blob: BLOB_TYPE,
MediumBlob: BLOB_TYPE,
LongBlob: BLOB_TYPE,
Bytea: BLOB_TYPE,
UniqueIdentifier: BLOB_TYPE,
Bool: NUMERIC_TYPE,
Serial: NUMERIC_TYPE,
BigSerial: NUMERIC_TYPE,
}
intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"}
uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"}
)
// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison
var (
c_EMPTY_STRING string
c_BOOL_DEFAULT bool
c_BYTE_DEFAULT byte
c_COMPLEX64_DEFAULT complex64
c_COMPLEX128_DEFAULT complex128
c_FLOAT32_DEFAULT float32
c_FLOAT64_DEFAULT float64
c_INT64_DEFAULT int64
c_UINT64_DEFAULT uint64
c_INT32_DEFAULT int32
c_UINT32_DEFAULT uint32
c_INT16_DEFAULT int16
c_UINT16_DEFAULT uint16
c_INT8_DEFAULT int8
c_UINT8_DEFAULT uint8
c_INT_DEFAULT int
c_UINT_DEFAULT uint
c_TIME_DEFAULT time.Time
)
var (
IntType = reflect.TypeOf(c_INT_DEFAULT)
Int8Type = reflect.TypeOf(c_INT8_DEFAULT)
Int16Type = reflect.TypeOf(c_INT16_DEFAULT)
Int32Type = reflect.TypeOf(c_INT32_DEFAULT)
Int64Type = reflect.TypeOf(c_INT64_DEFAULT)
UintType = reflect.TypeOf(c_UINT_DEFAULT)
Uint8Type = reflect.TypeOf(c_UINT8_DEFAULT)
Uint16Type = reflect.TypeOf(c_UINT16_DEFAULT)
Uint32Type = reflect.TypeOf(c_UINT32_DEFAULT)
Uint64Type = reflect.TypeOf(c_UINT64_DEFAULT)
Float32Type = reflect.TypeOf(c_FLOAT32_DEFAULT)
Float64Type = reflect.TypeOf(c_FLOAT64_DEFAULT)
Complex64Type = reflect.TypeOf(c_COMPLEX64_DEFAULT)
Complex128Type = reflect.TypeOf(c_COMPLEX128_DEFAULT)
StringType = reflect.TypeOf(c_EMPTY_STRING)
BoolType = reflect.TypeOf(c_BOOL_DEFAULT)
ByteType = reflect.TypeOf(c_BYTE_DEFAULT)
BytesType = reflect.SliceOf(ByteType)
TimeType = reflect.TypeOf(c_TIME_DEFAULT)
)
var (
PtrIntType = reflect.PtrTo(IntType)
PtrInt8Type = reflect.PtrTo(Int8Type)
PtrInt16Type = reflect.PtrTo(Int16Type)
PtrInt32Type = reflect.PtrTo(Int32Type)
PtrInt64Type = reflect.PtrTo(Int64Type)
PtrUintType = reflect.PtrTo(UintType)
PtrUint8Type = reflect.PtrTo(Uint8Type)
PtrUint16Type = reflect.PtrTo(Uint16Type)
PtrUint32Type = reflect.PtrTo(Uint32Type)
PtrUint64Type = reflect.PtrTo(Uint64Type)
PtrFloat32Type = reflect.PtrTo(Float32Type)
PtrFloat64Type = reflect.PtrTo(Float64Type)
PtrComplex64Type = reflect.PtrTo(Complex64Type)
PtrComplex128Type = reflect.PtrTo(Complex128Type)
PtrStringType = reflect.PtrTo(StringType)
PtrBoolType = reflect.PtrTo(BoolType)
PtrByteType = reflect.PtrTo(ByteType)
PtrTimeType = reflect.PtrTo(TimeType)
)
// Type2SQLType generate SQLType acorrding Go's type
func Type2SQLType(t reflect.Type) (st SQLType) {
switch k := t.Kind(); k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
st = SQLType{Int, 0, 0}
case reflect.Int64, reflect.Uint64:
st = SQLType{BigInt, 0, 0}
case reflect.Float32:
st = SQLType{Float, 0, 0}
case reflect.Float64:
st = SQLType{Double, 0, 0}
case reflect.Complex64, reflect.Complex128:
st = SQLType{Varchar, 64, 0}
case reflect.Array, reflect.Slice, reflect.Map:
if t.Elem() == reflect.TypeOf(c_BYTE_DEFAULT) {
st = SQLType{Blob, 0, 0}
} else {
st = SQLType{Text, 0, 0}
}
case reflect.Bool:
st = SQLType{Bool, 0, 0}
case reflect.String:
st = SQLType{Varchar, 255, 0}
case reflect.Struct:
if t.ConvertibleTo(TimeType) {
st = SQLType{DateTime, 0, 0}
} else {
// TODO need to handle association struct
st = SQLType{Text, 0, 0}
}
case reflect.Ptr:
st = Type2SQLType(t.Elem())
default:
st = SQLType{Text, 0, 0}
}
return
}
// default sql type change to go types
func SQLType2Type(st SQLType) reflect.Type {
name := strings.ToUpper(st.Name)
switch name {
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial:
return reflect.TypeOf(1)
case BigInt, BigSerial:
return reflect.TypeOf(int64(1))
case Float, Real:
return reflect.TypeOf(float32(1))
case Double:
return reflect.TypeOf(float64(1))
case Char, NChar, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName:
return reflect.TypeOf("")
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier:
return reflect.TypeOf([]byte{})
case Bool:
return reflect.TypeOf(true)
case DateTime, Date, Time, TimeStamp, TimeStampz, SmallDateTime, Year:
return reflect.TypeOf(c_TIME_DEFAULT)
case Decimal, Numeric, Money, SmallMoney:
return reflect.TypeOf("")
default:
return reflect.TypeOf("")
}
}

View File

@ -14,7 +14,8 @@ import (
"strings"
"time"
"xorm.io/core"
"xorm.io/xorm/core"
"xorm.io/xorm/schemas"
)
type sessionType int
@ -306,8 +307,8 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt,
return
}
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) (*reflect.Value, error) {
var col *core.Column
func (session *Session) getField(dataStruct *reflect.Value, key string, table *schemas.Table, idx int) (*reflect.Value, error) {
var col *schemas.Column
if col = table.GetColumnIdx(key, idx); col == nil {
return nil, ErrFieldIsNotExist{key, table.Name}
}
@ -328,8 +329,8 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *c
type Cell *interface{}
func (session *Session) rows2Beans(rows *core.Rows, fields []string,
table *core.Table, newElemFunc func([]string) reflect.Value,
sliceValueSetFunc func(*reflect.Value, core.PK) error) error {
table *schemas.Table, newElemFunc func([]string) reflect.Value,
sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error {
for rows.Next() {
var newValue = newElemFunc(fields)
bean := newValue.Interface()
@ -377,7 +378,7 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa
return scanResults, nil
}
func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) {
func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) {
defer func() {
if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet {
for ii, key := range fields {
@ -421,7 +422,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
}
var tempMap = make(map[string]int)
var pk core.PK
var pk schemas.PK
for ii, key := range fields {
var idx int
var ok bool
@ -451,7 +452,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
}
if fieldValue.CanAddr() {
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
if data, err := value2Bytes(&rawValue); err == nil {
if err := structConvert.FromDB(data); err != nil {
return nil, err
@ -463,12 +464,12 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
}
}
if _, ok := fieldValue.Interface().(core.Conversion); ok {
if _, ok := fieldValue.Interface().(Conversion); ok {
if data, err := value2Bytes(&rawValue); err == nil {
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
}
fieldValue.Interface().(core.Conversion).FromDB(data)
fieldValue.Interface().(Conversion).FromDB(data)
} else {
return nil, err
}
@ -488,7 +489,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
var bs []byte
if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(core.BytesType) {
} else if rawValueType.ConvertibleTo(schemas.BytesType) {
bs = vv.Bytes()
} else {
return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind())
@ -525,7 +526,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
var bs []byte
if rawValueType.Kind() == reflect.String {
bs = []byte(vv.String())
} else if rawValueType.ConvertibleTo(core.BytesType) {
} else if rawValueType.ConvertibleTo(schemas.BytesType) {
bs = vv.Bytes()
}
@ -607,16 +608,16 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
fieldValue.SetUint(uint64(vv.Int()))
}
case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) {
if fieldType.ConvertibleTo(schemas.TimeType) {
dbTZ := session.engine.DatabaseTZ
if col.TimeZone != nil {
dbTZ = col.TimeZone
}
if rawValueType == core.TimeType {
if rawValueType == schemas.TimeType {
hasAssigned = true
t := vv.Convert(core.TimeType).Interface().(time.Time)
t := vv.Convert(schemas.TimeType).Interface().(time.Time)
z, _ := t.Zone()
// set new location if database don't save timezone or give an incorrect timezone
@ -628,8 +629,8 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
t = t.In(session.engine.TZLocation)
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType))
} else if rawValueType == core.IntType || rawValueType == core.Int64Type ||
rawValueType == core.Int32Type {
} else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type ||
rawValueType == schemas.Int32Type {
hasAssigned = true
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation)
@ -696,7 +697,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
if len(table.PrimaryKeys) != 1 {
return nil, errors.New("unsupported non or composited primary key cascade")
}
var pk = make(core.PK, len(table.PrimaryKeys))
var pk = make(schemas.PK, len(table.PrimaryKeys))
pk[0], err = asKind(vv, rawValueType)
if err != nil {
return nil, err
@ -722,97 +723,97 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
// !nashtsai! TODO merge duplicated codes above
switch fieldType {
// following types case matching ptr's native type, therefore assign ptr directly
case core.PtrStringType:
case schemas.PtrStringType:
if rawValueType.Kind() == reflect.String {
x := vv.String()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrBoolType:
case schemas.PtrBoolType:
if rawValueType.Kind() == reflect.Bool {
x := vv.Bool()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrTimeType:
if rawValueType == core.PtrTimeType {
case schemas.PtrTimeType:
if rawValueType == schemas.PtrTimeType {
hasAssigned = true
var x = rawValue.Interface().(time.Time)
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrFloat64Type:
case schemas.PtrFloat64Type:
if rawValueType.Kind() == reflect.Float64 {
x := vv.Float()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint64Type:
case schemas.PtrUint64Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint64(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt64Type:
case schemas.PtrInt64Type:
if rawValueType.Kind() == reflect.Int64 {
x := vv.Int()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrFloat32Type:
case schemas.PtrFloat32Type:
if rawValueType.Kind() == reflect.Float64 {
var x = float32(vv.Float())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrIntType:
case schemas.PtrIntType:
if rawValueType.Kind() == reflect.Int64 {
var x = int(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt32Type:
case schemas.PtrInt32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt8Type:
case schemas.PtrInt8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrInt16Type:
case schemas.PtrInt16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = int16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUintType:
case schemas.PtrUintType:
if rawValueType.Kind() == reflect.Int64 {
var x = uint(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.PtrUint32Type:
case schemas.PtrUint32Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint32(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Uint8Type:
case schemas.Uint8Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint8(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Uint16Type:
case schemas.Uint16Type:
if rawValueType.Kind() == reflect.Int64 {
var x = uint16(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case core.Complex64Type:
case schemas.Complex64Type:
var x complex64
if len([]byte(vv.String())) > 0 {
err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x)
@ -822,7 +823,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
fieldValue.Set(reflect.ValueOf(&x))
}
hasAssigned = true
case core.Complex128Type:
case schemas.Complex128Type:
var x complex128
if len([]byte(vv.String())) > 0 {
err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x)

View File

@ -9,10 +9,10 @@ import (
"strings"
"time"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
func setColumnInt(bean interface{}, col *core.Column, t int64) {
func setColumnInt(bean interface{}, col *schemas.Column, t int64) {
v, err := col.ValueOf(bean)
if err != nil {
return
@ -27,7 +27,7 @@ func setColumnInt(bean interface{}, col *core.Column, t int64) {
}
}
func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
func setColumnTime(bean interface{}, col *schemas.Column, t time.Time) {
v, err := col.ValueOf(bean)
if err != nil {
return
@ -44,7 +44,7 @@ func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
}
}
func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) {
func getFlagForColumn(m map[string]bool, col *schemas.Column) (val bool, has bool) {
if len(m) == 0 {
return false, false
}

View File

@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"xorm.io/builder"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
func TestSetExpr(t *testing.T) {
@ -45,7 +45,7 @@ func TestSetExpr(t *testing.T) {
assert.EqualValues(t, 1, cnt)
var not = "NOT"
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
not = "~"
}
cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr))

View File

@ -14,10 +14,10 @@ import (
"strings"
"time"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
func (session *Session) str2Time(col *core.Column, data string) (outTime time.Time, outErr error) {
func (session *Session) str2Time(col *schemas.Column, data string) (outTime time.Time, outErr error) {
sdata := strings.TrimSpace(data)
var x time.Time
var err error
@ -54,14 +54,14 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc)
//session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
} else if col.SQLType.Name == core.Time {
} else if col.SQLType.Name == schemas.Time {
if strings.Contains(sdata, " ") {
ssd := strings.Split(sdata, " ")
sdata = ssd[1]
}
sdata = strings.TrimSpace(sdata)
if session.engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 {
if session.engine.dialect.DBType() == schemas.MYSQL && len(sdata) > 8 {
sdata = sdata[len(sdata)-8:]
}
@ -80,7 +80,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti
return
}
func (session *Session) byte2Time(col *core.Column, data []byte) (outTime time.Time, outErr error) {
func (session *Session) byte2Time(col *schemas.Column, data []byte) (outTime time.Time, outErr error) {
return session.str2Time(col, string(data))
}
@ -89,12 +89,12 @@ var (
)
// convert a db data([]byte) to a field value
func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, data []byte) error {
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Value, data []byte) error {
if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
return structConvert.FromDB(data)
}
if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
if structConvert, ok := fieldValue.Interface().(Conversion); ok {
return structConvert.FromDB(data)
}
@ -157,8 +157,8 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
var x int64
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
session.engine.dialect.DBType() == core.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API
if col.SQLType.Name == schemas.Bit &&
session.engine.dialect.DBType() == schemas.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API
if len(data) == 1 {
x = int64(data[0])
} else {
@ -199,7 +199,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
return fmt.Errorf("sql.Scan(%v) failed: %s ", data, err.Error())
}
} else {
if fieldType.ConvertibleTo(core.TimeType) {
if fieldType.ConvertibleTo(schemas.TimeType) {
x, err := session.byte2Time(col, data)
if err != nil {
return err
@ -217,7 +217,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
return errors.New("unsupported composited primary key cascade")
}
var pk = make(core.PK, len(table.PrimaryKeys))
var pk = make(schemas.PK, len(table.PrimaryKeys))
rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
pk[0], err = str2PK(string(data), rawValueType)
if err != nil {
@ -247,11 +247,11 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
//typeStr := fieldType.String()
switch fieldType.Elem().Kind() {
// case "*string":
case core.StringType.Kind():
case schemas.StringType.Kind():
x := string(data)
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*bool":
case core.BoolType.Kind():
case schemas.BoolType.Kind():
d := string(data)
v, err := strconv.ParseBool(d)
if err != nil {
@ -259,7 +259,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
}
fieldValue.Set(reflect.ValueOf(&v).Convert(fieldType))
// case "*complex64":
case core.Complex64Type.Kind():
case schemas.Complex64Type.Kind():
var x complex64
if len(data) > 0 {
err := DefaultJSONHandler.Unmarshal(data, &x)
@ -270,7 +270,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
}
// case "*complex128":
case core.Complex128Type.Kind():
case schemas.Complex128Type.Kind():
var x complex128
if len(data) > 0 {
err := DefaultJSONHandler.Unmarshal(data, &x)
@ -281,14 +281,14 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
}
// case "*float64":
case core.Float64Type.Kind():
case schemas.Float64Type.Kind():
x, err := strconv.ParseFloat(string(data), 64)
if err != nil {
return fmt.Errorf("arg %v as float64: %s", key, err.Error())
}
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*float32":
case core.Float32Type.Kind():
case schemas.Float32Type.Kind():
var x float32
x1, err := strconv.ParseFloat(string(data), 32)
if err != nil {
@ -297,7 +297,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
x = float32(x1)
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*uint64":
case core.Uint64Type.Kind():
case schemas.Uint64Type.Kind():
var x uint64
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
@ -305,7 +305,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
}
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*uint":
case core.UintType.Kind():
case schemas.UintType.Kind():
var x uint
x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
@ -314,7 +314,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
x = uint(x1)
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*uint32":
case core.Uint32Type.Kind():
case schemas.Uint32Type.Kind():
var x uint32
x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
@ -323,7 +323,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
x = uint32(x1)
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*uint8":
case core.Uint8Type.Kind():
case schemas.Uint8Type.Kind():
var x uint8
x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
@ -332,7 +332,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
x = uint8(x1)
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*uint16":
case core.Uint16Type.Kind():
case schemas.Uint16Type.Kind():
var x uint16
x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
@ -341,12 +341,12 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
x = uint16(x1)
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*int64":
case core.Int64Type.Kind():
case schemas.Int64Type.Kind():
sdata := string(data)
var x int64
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
if col.SQLType.Name == schemas.Bit &&
strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 {
x = int64(data[0])
@ -365,13 +365,13 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
}
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*int":
case core.IntType.Kind():
case schemas.IntType.Kind():
sdata := string(data)
var x int
var x1 int64
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
if col.SQLType.Name == schemas.Bit &&
strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 {
x = int(data[0])
@ -393,14 +393,14 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
}
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*int32":
case core.Int32Type.Kind():
case schemas.Int32Type.Kind():
sdata := string(data)
var x int32
var x1 int64
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
session.engine.dialect.DBType() == core.MYSQL {
if col.SQLType.Name == schemas.Bit &&
session.engine.dialect.DBType() == schemas.MYSQL {
if len(data) == 1 {
x = int32(data[0])
} else {
@ -421,13 +421,13 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
}
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*int8":
case core.Int8Type.Kind():
case schemas.Int8Type.Kind():
sdata := string(data)
var x int8
var x1 int64
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
if col.SQLType.Name == schemas.Bit &&
strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 {
x = int8(data[0])
@ -449,13 +449,13 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
}
fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType))
// case "*int16":
case core.Int16Type.Kind():
case schemas.Int16Type.Kind():
sdata := string(data)
var x int16
var x1 int64
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == core.Bit &&
if col.SQLType.Name == schemas.Bit &&
strings.Contains(session.engine.DriverName(), "mysql") {
if len(data) == 1 {
x = int16(data[0])
@ -480,7 +480,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
case reflect.Struct:
switch fieldType {
// case "*.time.Time":
case core.PtrTimeType:
case schemas.PtrTimeType:
x, err := session.byte2Time(col, data)
if err != nil {
return err
@ -499,7 +499,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
return errors.New("unsupported composited primary key cascade")
}
var pk = make(core.PK, len(table.PrimaryKeys))
var pk = make(schemas.PK, len(table.PrimaryKeys))
rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
pk[0], err = str2PK(string(data), rawValueType)
if err != nil {
@ -536,9 +536,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
}
// convert a field value of a struct to interface for put into db
func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Value) (interface{}, error) {
func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.Value) (interface{}, error) {
if fieldValue.CanAddr() {
if fieldConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
if fieldConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
data, err := fieldConvert.ToDB()
if err != nil {
return 0, err
@ -550,7 +550,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
}
}
if fieldConvert, ok := fieldValue.Interface().(core.Conversion); ok {
if fieldConvert, ok := fieldValue.Interface().(Conversion); ok {
data, err := fieldConvert.ToDB()
if err != nil {
return 0, err
@ -583,8 +583,8 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
case reflect.String:
return fieldValue.String(), nil
case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) {
t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
tf := session.engine.formatColTime(col, t)
return tf, nil
} else if fieldType.ConvertibleTo(nullFloatType) {

View File

@ -9,10 +9,11 @@ import (
"fmt"
"strconv"
"xorm.io/core"
"xorm.io/xorm/caches"
"xorm.io/xorm/schemas"
)
func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, args ...interface{}) error {
func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error {
if table == nil ||
session.tx != nil {
return ErrCacheFailed
@ -29,17 +30,17 @@ func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string,
cacher := session.engine.getCacher(tableName)
pkColumns := table.PKColumns()
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
ids, err := caches.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
resultsSlice, err := session.queryBytes(newsql, args...)
if err != nil {
return err
}
ids = make([]core.PK, 0)
ids = make([]schemas.PK, 0)
if len(resultsSlice) > 0 {
for _, data := range resultsSlice {
var id int64
var pk core.PK = make([]interface{}, 0)
var pk schemas.PK = make([]interface{}, 0)
for _, col := range pkColumns {
if v, ok := data[col.Name]; !ok {
return errors.New("no id")
@ -127,14 +128,14 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
if len(orderSQL) > 0 {
switch session.engine.dialect.DBType() {
case core.POSTGRES:
case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
deleteSQL += " AND " + inSQL
} else {
deleteSQL += " WHERE " + inSQL
}
case core.SQLITE:
case schemas.SQLITE:
inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
deleteSQL += " AND " + inSQL
@ -142,7 +143,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
deleteSQL += " WHERE " + inSQL
}
// TODO: how to handle delete limit on mssql?
case core.MSSQL:
case schemas.MSSQL:
return 0, ErrNotImplemented
default:
deleteSQL += orderSQL
@ -156,7 +157,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
copy(argsForCache, condArgs)
argsForCache = append(condArgs, argsForCache...)
} else {
// !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for cache.
// !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches.
copy(argsForCache, condArgs)
argsForCache = append(condArgs, argsForCache...)
@ -168,14 +169,14 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
if len(orderSQL) > 0 {
switch session.engine.dialect.DBType() {
case core.POSTGRES:
case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
realSQL += " AND " + inSQL
} else {
realSQL += " WHERE " + inSQL
}
case core.SQLITE:
case schemas.SQLITE:
inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
realSQL += " AND " + inSQL
@ -183,7 +184,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
realSQL += " WHERE " + inSQL
}
// TODO: how to handle delete limit on mssql?
case core.MSSQL:
case schemas.MSSQL:
return 0, ErrNotImplemented
default:
realSQL += orderSQL

View File

@ -9,7 +9,8 @@ import (
"time"
"github.com/stretchr/testify/assert"
"xorm.io/core"
"xorm.io/xorm/caches"
"xorm.io/xorm/schemas"
)
func TestDelete(t *testing.T) {
@ -26,7 +27,7 @@ func TestDelete(t *testing.T) {
defer session.Close()
var err error
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT userinfo_delete ON")
@ -38,7 +39,7 @@ func TestDelete(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}
@ -159,7 +160,7 @@ func TestCacheDelete(t *testing.T) {
assert.NoError(t, prepareEngine())
oldCacher := testEngine.GetDefaultCacher()
cacher := NewLRUCacher(NewMemoryStore(), 1000)
cacher := caches.NewLRUCacher(caches.NewMemoryStore(), 1000)
testEngine.SetDefaultCacher(cacher)
type CacheDeleteStruct struct {

View File

@ -10,7 +10,7 @@ import (
"reflect"
"xorm.io/builder"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
// Exist returns true if the record exist otherwise return false
@ -45,18 +45,18 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
return false, err
}
if session.engine.dialect.DBType() == core.MSSQL {
if session.engine.dialect.DBType() == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL)
} else if session.engine.dialect.DBType() == core.ORACLE {
} else if session.engine.dialect.DBType() == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL)
}
args = condArgs
} else {
if session.engine.dialect.DBType() == core.MSSQL {
if session.engine.dialect.DBType() == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr)
} else if session.engine.dialect.DBType() == core.ORACLE {
} else if session.engine.dialect.DBType() == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr)

View File

@ -11,7 +11,8 @@ import (
"strings"
"xorm.io/builder"
"xorm.io/core"
"xorm.io/xorm/caches"
"xorm.io/xorm/schemas"
)
const (
@ -197,7 +198,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
return session.noCacheFind(table, sliceValue, sqlStr, args...)
}
func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error {
func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error {
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return err
@ -236,10 +237,10 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
return reflect.New(elemType)
}
var containerValueSetFunc func(*reflect.Value, core.PK) error
var containerValueSetFunc func(*reflect.Value, schemas.PK) error
if containerValue.Kind() == reflect.Slice {
containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error {
containerValueSetFunc = func(newValue *reflect.Value, pk schemas.PK) error {
if isPointer {
containerValue.Set(reflect.Append(containerValue, newValue.Elem().Addr()))
} else {
@ -256,7 +257,7 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
return errors.New("don't support multiple primary key's map has non-slice key type")
}
containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error {
containerValueSetFunc = func(newValue *reflect.Value, pk schemas.PK) error {
keyValue := reflect.New(keyType)
err := convertPKToValue(table, keyValue.Interface(), pk)
if err != nil {
@ -310,7 +311,7 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
return nil
}
func convertPKToValue(table *core.Table, dst interface{}, pk core.PK) error {
func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error {
cols := table.PKColumns()
if len(cols) == 1 {
return convertAssign(dst, pk[0])
@ -343,7 +344,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
}
table := session.statement.RefTable
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
ids, err := caches.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
rows, err := session.queryRows(newsql, args...)
if err != nil {
@ -352,7 +353,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
defer rows.Close()
var i int
ids = make([]core.PK, 0)
ids = make([]schemas.PK, 0)
for rows.Next() {
i++
if i > 500 {
@ -364,7 +365,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
if err != nil {
return err
}
var pk core.PK = make([]interface{}, len(table.PrimaryKeys))
var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
pk[i], err = session.engine.idTypeAssertion(col, res[i])
if err != nil {
@ -376,7 +377,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
}
session.engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, sqlStr, newsql, args)
err = core.PutCacheSql(cacher, ids, tableName, newsql, args)
err = caches.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil {
return err
}
@ -387,7 +388,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
ididxes := make(map[string]int)
var ides []core.PK
var ides []schemas.PK
var temps = make([]interface{}, len(ids))
for idx, id := range ids {
@ -502,7 +503,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
}
} else {
if keyType.Kind() != reflect.Slice {
return errors.New("table have multiple primary keys, key is not core.PK or slice")
return errors.New("table have multiple primary keys, key is not schemas.PK or slice")
}
ikey = key
}

View File

@ -11,7 +11,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"xorm.io/core"
"xorm.io/xorm/names"
)
func TestJoinLimit(t *testing.T) {
@ -300,7 +300,7 @@ func TestOrderSameMapper(t *testing.T) {
testEngine.UnMapType(rValue(new(Userinfo)).Type())
mapper := testEngine.GetTableMapper()
testEngine.SetMapper(core.SameMapper{})
testEngine.SetMapper(names.SameMapper{})
defer func() {
testEngine.UnMapType(rValue(new(Userinfo)).Type())
@ -325,7 +325,7 @@ func TestHavingSameMapper(t *testing.T) {
testEngine.UnMapType(rValue(new(Userinfo)).Type())
mapper := testEngine.GetTableMapper()
testEngine.SetMapper(core.SameMapper{})
testEngine.SetMapper(names.SameMapper{})
defer func() {
testEngine.UnMapType(rValue(new(Userinfo)).Type())
testEngine.SetMapper(mapper)

View File

@ -11,7 +11,8 @@ import (
"reflect"
"strconv"
"xorm.io/core"
"xorm.io/xorm/caches"
"xorm.io/xorm/schemas"
)
// Get retrieve one record from database, bean's non-empty fields
@ -99,7 +100,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
return true, nil
}
func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) {
rows, err := session.queryRows(sqlStr, args...)
if err != nil {
return false, err
@ -283,7 +284,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
session.engine.logger.Debug("[cacheGet] find sql:", newsql, args)
table := session.statement.RefTable
ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
ids, err := caches.GetCacheSql(cacher, tableName, newsql, args)
if err != nil {
var res = make([]string, len(table.PrimaryKeys))
rows, err := session.NoCache().queryRows(newsql, args...)
@ -301,7 +302,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return false, ErrCacheFailed
}
var pk core.PK = make([]interface{}, len(table.PrimaryKeys))
var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
if col.SQLType.IsText() {
pk[i] = res[i]
@ -316,9 +317,9 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
}
}
ids = []core.PK{pk}
ids = []schemas.PK{pk}
session.engine.logger.Debug("[cacheGet] cache ids:", newsql, ids)
err = core.PutCacheSql(cacher, ids, tableName, newsql, args)
err = caches.PutCacheSql(cacher, ids, tableName, newsql, args)
if err != nil {
return false, err
}

View File

@ -11,7 +11,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
func TestGetVar(t *testing.T) {
@ -153,7 +153,7 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))
var money2 float64
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
has, err = testEngine.SQL("SELECT TOP 1 money FROM " + testEngine.TableName("get_var", true)).Get(&money2)
} else {
has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2)
@ -233,7 +233,7 @@ func TestGetStruct(t *testing.T) {
defer session.Close()
var err error
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT userinfo_get ON")
@ -242,7 +242,7 @@ func TestGetStruct(t *testing.T) {
cnt, err := session.Insert(&UserinfoGet{Uid: 2})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}

View File

@ -13,7 +13,7 @@ import (
"strings"
"xorm.io/builder"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
// ErrNoElementsOnSlice represents an error there is no element when insert
@ -127,7 +127,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
var colNames []string
var colMultiPlaces []string
var args []interface{}
var cols []*core.Column
var cols []*schemas.Column
for i := 0; i < size; i++ {
v := sliceValue.Index(i)
@ -156,7 +156,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
continue
}
if col.MapType == core.ONLYFROMDB {
if col.MapType == schemas.ONLYFROMDB {
continue
}
if col.IsDeleted {
@ -207,7 +207,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
continue
}
if col.MapType == core.ONLYFROMDB {
if col.MapType == schemas.ONLYFROMDB {
continue
}
if col.IsDeleted {
@ -251,7 +251,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
cleanupProcessorsClosures(&session.beforeClosures)
var sql string
if session.engine.dialect.DBType() == core.ORACLE {
if session.engine.dialect.DBType() == schemas.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
session.engine.Quote(tableName),
quoteColumns(colNames, session.engine.Quote, ","))
@ -358,7 +358,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
var tableName = session.statement.TableName()
var output string
if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 {
if session.engine.dialect.DBType() == schemas.MSSQL && len(table.AutoIncrement) > 0 {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
}
@ -368,7 +368,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
}
if len(colPlaces) <= 0 {
if session.engine.dialect.DBType() == core.MYSQL {
if session.engine.dialect.DBType() == schemas.MYSQL {
if _, err := buf.WriteString(" VALUES ()"); err != nil {
return 0, err
}
@ -430,7 +430,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
}
}
if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES {
if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == schemas.POSTGRES {
if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil {
return 0, err
}
@ -469,7 +469,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself.
if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
if session.engine.dialect.DBType() == schemas.ORACLE && len(table.AutoIncrement) > 0 {
res, err := session.queryBytes("select seq_atable.currval from dual", args...)
if err != nil {
return 0, err
@ -510,7 +510,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
return 1, nil
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.DBType() == core.POSTGRES || session.engine.dialect.DBType() == core.MSSQL) {
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.DBType() == schemas.POSTGRES || session.engine.dialect.DBType() == schemas.MSSQL) {
res, err := session.queryBytes(sqlStr, args...)
if err != nil {
@ -626,7 +626,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() {
if col.MapType == core.ONLYFROMDB {
if col.MapType == schemas.ONLYFROMDB {
continue
}

View File

@ -11,7 +11,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
type IntId struct {
@ -726,7 +726,7 @@ func TestCompositeKey(t *testing.T) {
}
var compositeKeyVal CompositeKey
has, err := testEngine.ID(core.PK{11, 22}).Get(&compositeKeyVal)
has, err := testEngine.ID(schemas.PK{11, 22}).Get(&compositeKeyVal)
if err != nil {
t.Error(err)
} else if !has {
@ -735,7 +735,7 @@ func TestCompositeKey(t *testing.T) {
var compositeKeyVal2 CompositeKey
// test passing PK ptr, this test seem failed withCache
has, err = testEngine.ID(&core.PK{11, 22}).Get(&compositeKeyVal2)
has, err = testEngine.ID(&schemas.PK{11, 22}).Get(&compositeKeyVal2)
if err != nil {
t.Error(err)
} else if !has {
@ -772,14 +772,14 @@ func TestCompositeKey(t *testing.T) {
assert.EqualValues(t, compositeKeyVal, cps[0], "should be equeal")
compositeKeyVal = CompositeKey{UpdateStr: "test1"}
cnt, err = testEngine.ID(core.PK{11, 22}).Update(&compositeKeyVal)
cnt, err = testEngine.ID(schemas.PK{11, 22}).Update(&compositeKeyVal)
if err != nil {
t.Error(err)
} else if cnt != 1 {
t.Error(errors.New("can't update CompositeKey{11, 22}"))
}
cnt, err = testEngine.ID(core.PK{11, 22}).Delete(&CompositeKey{})
cnt, err = testEngine.ID(schemas.PK{11, 22}).Delete(&CompositeKey{})
if err != nil {
t.Error(err)
} else if cnt != 1 {
@ -823,7 +823,7 @@ func TestCompositeKey2(t *testing.T) {
}
var user User
has, err := testEngine.ID(core.PK{"11", 22}).Get(&user)
has, err := testEngine.ID(schemas.PK{"11", 22}).Get(&user)
if err != nil {
t.Error(err)
} else if !has {
@ -831,7 +831,7 @@ func TestCompositeKey2(t *testing.T) {
}
// test passing PK ptr, this test seem failed withCache
has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user)
has, err = testEngine.ID(&schemas.PK{"11", 22}).Get(&user)
if err != nil {
t.Error(err)
} else if !has {
@ -839,14 +839,14 @@ func TestCompositeKey2(t *testing.T) {
}
user = User{NickName: "test1"}
cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user)
cnt, err = testEngine.ID(schemas.PK{"11", 22}).Update(&user)
if err != nil {
t.Error(err)
} else if cnt != 1 {
t.Error(errors.New("can't update User{11, 22}"))
}
cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&User{})
cnt, err = testEngine.ID(schemas.PK{"11", 22}).Delete(&User{})
if err != nil {
t.Error(err)
} else if cnt != 1 {
@ -891,7 +891,7 @@ func TestCompositeKey3(t *testing.T) {
}
var user UserPK2
has, err := testEngine.ID(core.PK{"11", 22}).Get(&user)
has, err := testEngine.ID(schemas.PK{"11", 22}).Get(&user)
if err != nil {
t.Error(err)
} else if !has {
@ -899,7 +899,7 @@ func TestCompositeKey3(t *testing.T) {
}
// test passing PK ptr, this test seem failed withCache
has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user)
has, err = testEngine.ID(&schemas.PK{"11", 22}).Get(&user)
if err != nil {
t.Error(err)
} else if !has {
@ -907,14 +907,14 @@ func TestCompositeKey3(t *testing.T) {
}
user = UserPK2{NickName: "test1"}
cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user)
cnt, err = testEngine.ID(schemas.PK{"11", 22}).Update(&user)
if err != nil {
t.Error(err)
} else if cnt != 1 {
t.Error(errors.New("can't update User{11, 22}"))
}
cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&UserPK2{})
cnt, err = testEngine.ID(schemas.PK{"11", 22}).Delete(&UserPK2{})
if err != nil {
t.Error(err)
} else if cnt != 1 {
@ -1130,7 +1130,7 @@ func TestCompositePK(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1+len(tables1), len(tables2))
var table *core.Table
var table *schemas.Table
for _, t := range tables2 {
if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") {
table = t

View File

@ -12,7 +12,8 @@ import (
"time"
"xorm.io/builder"
"xorm.io/core"
"xorm.io/xorm/core"
"xorm.io/xorm/schemas"
)
func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) {
@ -116,8 +117,8 @@ func value2String(rawValue *reflect.Value) (str string, err error) {
}
// time type
case reflect.Struct:
if aa.ConvertibleTo(core.TimeType) {
str = vv.Convert(core.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
if aa.ConvertibleTo(schemas.TimeType) {
str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
} else {
err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
}

View File

@ -11,7 +11,7 @@ import (
"time"
"xorm.io/builder"
"xorm.io/core"
"xorm.io/xorm/schemas"
"github.com/stretchr/testify/assert"
)
@ -207,7 +207,7 @@ func TestQueryStringNoParam(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0]["id"])
if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL {
assert.EqualValues(t, "false", records[0]["msg"])
} else {
assert.EqualValues(t, "0", records[0]["msg"])
@ -217,7 +217,7 @@ func TestQueryStringNoParam(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0]["id"])
if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL {
assert.EqualValues(t, "false", records[0]["msg"])
} else {
assert.EqualValues(t, "0", records[0]["msg"])
@ -244,7 +244,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0][0])
if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL {
assert.EqualValues(t, "false", records[0][1])
} else {
assert.EqualValues(t, "0", records[0][1])
@ -254,7 +254,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0][0])
if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL {
assert.EqualValues(t, "false", records[0][1])
} else {
assert.EqualValues(t, "0", records[0][1])

View File

@ -10,7 +10,7 @@ import (
"time"
"xorm.io/builder"
"xorm.io/core"
"xorm.io/xorm/core"
)
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {

View File

@ -9,7 +9,7 @@ import (
"fmt"
"strings"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
// Ping test if database is ok
@ -125,7 +125,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName := session.engine.TableName(beanOrTableName)
var needDrop = true
if !session.engine.dialect.SupportDropIfExists() {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
sqlStr, args := session.engine.dialect.TableCheckSQL(tableName)
results, err := session.queryBytes(sqlStr, args...)
if err != nil {
return err
@ -134,7 +134,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
}
if needDrop {
sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true))
sqlStr := session.engine.Dialect().DropTableSQL(session.engine.TableName(tableName, true))
_, err := session.exec(sqlStr)
return err
}
@ -153,7 +153,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
}
func (session *Session) isTableExist(tableName string) (bool, error) {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
sqlStr, args := session.engine.dialect.TableCheckSQL(tableName)
results, err := session.queryBytes(sqlStr, args...)
return len(results) > 0, err
}
@ -190,9 +190,9 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo
for _, index := range indexes {
if sliceEq(index.Cols, cols) {
if unique {
return index.Type == core.UniqueType, nil
return index.Type == schemas.UniqueType, nil
}
return index.Type == core.IndexType, nil
return index.Type == schemas.IndexType, nil
}
}
return false, nil
@ -207,14 +207,14 @@ func (session *Session) addColumn(colName string) error {
func (session *Session) addIndex(tableName, idxName string) error {
index := session.statement.RefTable.Indexes[idxName]
sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index)
_, err := session.exec(sqlStr)
return err
}
func (session *Session) addUnique(tableName, uqeName string) error {
index := session.statement.RefTable.Indexes[uqeName]
sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index)
_, err := session.exec(sqlStr)
return err
}
@ -253,7 +253,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
tbNameWithSchema := engine.tbNameWithSchema(tbName)
var oriTable *core.Table
var oriTable *schemas.Table
for _, tb := range tables {
if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
oriTable = tb
@ -287,7 +287,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
// check columns
for _, col := range table.Columns() {
var oriCol *core.Column
var oriCol *schemas.Column
for _, col2 := range oriTable.Columns() {
if strings.EqualFold(col.Name, col2.Name) {
oriCol = col2
@ -306,27 +306,27 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
err = nil
expectedType := engine.dialect.SqlType(col)
curType := engine.dialect.SqlType(oriCol)
expectedType := engine.dialect.SQLType(col)
curType := engine.dialect.SQLType(oriCol)
if expectedType != curType {
if expectedType == core.Text &&
strings.HasPrefix(curType, core.Varchar) {
if expectedType == schemas.Text &&
strings.HasPrefix(curType, schemas.Varchar) {
// currently only support mysql & postgres
if engine.dialect.DBType() == core.MYSQL ||
engine.dialect.DBType() == core.POSTGRES {
if engine.dialect.DBType() == schemas.MYSQL ||
engine.dialect.DBType() == schemas.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
}
} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
if engine.dialect.DBType() == core.MYSQL {
} else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) {
if engine.dialect.DBType() == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
} else {
@ -335,12 +335,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
tbNameWithSchema, col.Name, curType, expectedType)
}
}
} else if expectedType == core.Varchar {
if engine.dialect.DBType() == core.MYSQL {
} else if expectedType == schemas.Varchar {
if engine.dialect.DBType() == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
}
@ -348,7 +348,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if col.Default != oriCol.Default {
switch {
case col.IsAutoIncrement: // For autoincrement column, don't check default
case (col.SQLType.Name == core.Bool || col.SQLType.Name == core.Boolean) &&
case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) &&
((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") ||
(strings.EqualFold(col.Default, "false") && oriCol.Default == "0")):
default:
@ -367,10 +367,10 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
var foundIndexNames = make(map[string]bool)
var addedNames = make(map[string]*core.Index)
var addedNames = make(map[string]*schemas.Index)
for name, index := range table.Indexes {
var oriIndex *core.Index
var oriIndex *schemas.Index
for name2, index2 := range oriTable.Indexes {
if index.Equal(index2) {
oriIndex = index2
@ -381,7 +381,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriIndex != nil {
if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex)
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex)
_, err = session.exec(sql)
if err != nil {
return err
@ -397,7 +397,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2)
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2)
_, err = session.exec(sql)
if err != nil {
return err
@ -406,11 +406,11 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
for name, index := range addedNames {
if index.Type == core.UniqueType {
if index.Type == schemas.UniqueType {
session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema
err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == core.IndexType {
} else if index.Type == schemas.IndexType {
session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema
err = session.addIndex(tbNameWithSchema, name)

View File

@ -213,7 +213,7 @@ func TestCustomTableName(t *testing.T) {
func TestDump(t *testing.T) {
assert.NoError(t, prepareEngine())
fp := testEngine.Dialect().URI().DbName + ".sql"
fp := testEngine.Dialect().URI().DBName + ".sql"
os.Remove(fp)
assert.NoError(t, testEngine.DumpAllToFile(fp))
}

View File

@ -10,7 +10,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"xorm.io/core"
"xorm.io/xorm/names"
)
func TestTransaction(t *testing.T) {
@ -89,7 +89,7 @@ func TestCombineTransactionSameMapper(t *testing.T) {
oldMapper := testEngine.GetColumnMapper()
testEngine.UnMapType(rValue(new(Userinfo)).Type())
testEngine.SetMapper(core.SameMapper{})
testEngine.SetMapper(names.SameMapper{})
defer func() {
testEngine.UnMapType(rValue(new(Userinfo)).Type())
testEngine.SetMapper(oldMapper)

View File

@ -12,10 +12,11 @@ import (
"strings"
"xorm.io/builder"
"xorm.io/core"
"xorm.io/xorm/caches"
"xorm.io/xorm/schemas"
)
func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, args ...interface{}) error {
func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error {
if table == nil ||
session.tx != nil {
return ErrCacheFailed
@ -42,7 +43,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
cacher := session.engine.getCacher(tableName)
session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
ids, err := caches.GetCacheSql(cacher, tableName, newsql, args[nStart:])
if err != nil {
rows, err := session.NoCache().queryRows(newsql, args[nStart:]...)
if err != nil {
@ -50,14 +51,14 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
}
defer rows.Close()
ids = make([]core.PK, 0)
ids = make([]schemas.PK, 0)
for rows.Next() {
var res = make([]string, len(table.PrimaryKeys))
err = rows.ScanSlice(&res)
if err != nil {
return err
}
var pk core.PK = make([]interface{}, len(table.PrimaryKeys))
var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
if col.SQLType.IsNumeric() {
n, err := strconv.ParseInt(res[i], 10, 64)
@ -339,9 +340,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var top string
if st.LimitN != nil {
limitValue := *st.LimitN
if st.Engine.dialect.DBType() == core.MYSQL {
if st.Engine.dialect.DBType() == schemas.MYSQL {
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
} else if st.Engine.dialect.DBType() == core.SQLITE {
} else if st.Engine.dialect.DBType() == schemas.SQLITE {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
@ -352,7 +353,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
} else if st.Engine.dialect.DBType() == core.POSTGRES {
} else if st.Engine.dialect.DBType() == schemas.POSTGRES {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
@ -364,8 +365,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
} else if st.Engine.dialect.DBType() == core.MSSQL {
if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL &&
} else if st.Engine.dialect.DBType() == schemas.MSSQL {
if st.OrderStr != "" && st.Engine.dialect.DBType() == schemas.MSSQL &&
table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
@ -392,7 +393,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var fromSQL string
if session.statement.TableAlias != "" {
switch session.engine.dialect.DBType() {
case core.MSSQL:
case schemas.MSSQL:
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias)
tableAlias = session.statement.TableAlias
default:
@ -467,7 +468,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
continue
}
}
if col.MapType == core.ONLYFROMDB {
if col.MapType == schemas.ONLYFROMDB {
continue
}

View File

@ -12,7 +12,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"xorm.io/core"
"xorm.io/xorm/names"
)
func TestUpdateMap(t *testing.T) {
@ -691,7 +691,7 @@ func TestUpdateSameMapper(t *testing.T) {
testEngine.UnMapType(rValue(new(UpdateAllCols)).Type())
testEngine.UnMapType(rValue(new(UpdateMustCols)).Type())
testEngine.UnMapType(rValue(new(UpdateIncr)).Type())
testEngine.SetMapper(core.SameMapper{})
testEngine.SetMapper(names.SameMapper{})
defer func() {
testEngine.UnMapType(rValue(new(Userinfo)).Type())
testEngine.UnMapType(rValue(new(Condi)).Type())
@ -1419,7 +1419,7 @@ func TestUpdateMap3(t *testing.T) {
testEngine.SetColumnMapper(oldMapper)
}()
mapper := core.NewPrefixMapper(core.SnakeMapper{}, "F")
mapper := names.NewPrefixMapper(names.SnakeMapper{}, "F")
testEngine.SetColumnMapper(mapper)
assertSync(t, new(UpdateMapUser))

View File

@ -12,16 +12,17 @@ import (
"time"
"xorm.io/builder"
"xorm.io/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/schemas"
)
// Statement save all the sql info for executing SQL
type Statement struct {
RefTable *core.Table
RefTable *schemas.Table
Engine *Engine
Start int
LimitN *int
idParam *core.PK
idParam *schemas.PK
OrderStr string
JoinStr string
joinArgs []interface{}
@ -266,7 +267,7 @@ func (statement *Statement) buildUpdates(bean interface{},
continue
}
if col.MapType == core.ONLYFROMDB {
if col.MapType == schemas.ONLYFROMDB {
continue
}
@ -314,7 +315,7 @@ func (statement *Statement) buildUpdates(bean interface{},
var val interface{}
if fieldValue.CanAddr() {
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok {
data, err := structConvert.ToDB()
if err != nil {
engine.logger.Error(err)
@ -325,7 +326,7 @@ func (statement *Statement) buildUpdates(bean interface{},
}
}
if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
if structConvert, ok := fieldValue.Interface().(Conversion); ok {
data, err := structConvert.ToDB()
if err != nil {
engine.logger.Error(err)
@ -388,8 +389,8 @@ func (statement *Statement) buildUpdates(bean interface{},
t := int64(fieldValue.Uint())
val = reflect.ValueOf(&t).Interface()
case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) {
t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue
}
@ -496,7 +497,7 @@ func (statement *Statement) needTableName() bool {
return len(statement.JoinStr) > 0
}
func (statement *Statement) colName(col *core.Column, tableName string) string {
func (statement *Statement) colName(col *schemas.Column, tableName string) string {
if statement.needTableName() {
var nm = tableName
if len(statement.TableAlias) > 0 {
@ -523,12 +524,12 @@ func (statement *Statement) ID(id interface{}) *Statement {
switch idType {
case ptrPkType:
if pkPtr, ok := (id).(*core.PK); ok {
if pkPtr, ok := (id).(*schemas.PK); ok {
statement.idParam = pkPtr
return statement
}
case pkType:
if pk, ok := (id).(core.PK); ok {
if pk, ok := (id).(schemas.PK); ok {
statement.idParam = &pk
return statement
}
@ -536,11 +537,11 @@ func (statement *Statement) ID(id interface{}) *Statement {
switch idType.Kind() {
case reflect.String:
statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
statement.idParam = &schemas.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
return statement
}
statement.idParam = &core.PK{id}
statement.idParam = &schemas.PK{id}
return statement
}
@ -807,7 +808,7 @@ func (statement *Statement) genColumnStr() string {
continue
}
if col.MapType == core.ONLYTODB {
if col.MapType == schemas.ONLYTODB {
continue
}
@ -832,7 +833,7 @@ func (statement *Statement) genColumnStr() string {
}
func (statement *Statement) genCreateTableSQL() string {
return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
return statement.Engine.dialect.CreateTableSQL(statement.RefTable, statement.TableName(),
statement.StoreEngine, statement.Charset)
}
@ -840,8 +841,8 @@ func (statement *Statement) genIndexSQL() []string {
var sqls []string
tbName := statement.TableName()
for _, index := range statement.RefTable.Indexes {
if index.Type == core.IndexType {
sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
if index.Type == schemas.IndexType {
sql := statement.Engine.dialect.CreateIndexSQL(tbName, index)
/*idxTBName := strings.Replace(tbName, ".", "_", -1)
idxTBName = strings.Replace(idxTBName, `"`, "", -1)
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
@ -860,8 +861,8 @@ func (statement *Statement) genUniqueSQL() []string {
var sqls []string
tbName := statement.TableName()
for _, index := range statement.RefTable.Indexes {
if index.Type == core.UniqueType {
sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
if index.Type == schemas.UniqueType {
sql := statement.Engine.dialect.CreateIndexSQL(tbName, index)
sqls = append(sqls, sql)
}
}
@ -871,13 +872,17 @@ func (statement *Statement) genUniqueSQL() []string {
func (statement *Statement) genDelIndexSQL() []string {
var sqls []string
tbName := statement.TableName()
idx := strings.Index(tbName, ".")
if idx > -1 {
tbName = tbName[idx+1:]
}
idxPrefixName := strings.Replace(tbName, `"`, "", -1)
idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
for idxName, index := range statement.RefTable.Indexes {
var rIdxName string
if index.Type == core.UniqueType {
if index.Type == schemas.UniqueType {
rIdxName = uniqueName(idxPrefixName, idxName)
} else if index.Type == core.IndexType {
} else if index.Type == schemas.IndexType {
rIdxName = indexName(idxPrefixName, idxName)
}
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
@ -889,18 +894,18 @@ func (statement *Statement) genDelIndexSQL() []string {
return sqls
}
func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
func (statement *Statement) genAddColumnStr(col *schemas.Column) (string, []interface{}) {
quote := statement.Engine.Quote
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
col.String(statement.Engine.dialect))
if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
dialects.String(statement.Engine.dialect, col))
if statement.Engine.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ";"
return sql, []interface{}{}
}
func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
func (statement *Statement) buildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
}
@ -1054,14 +1059,14 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
whereStr = " WHERE " + condSQL
}
if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
if dialect.DBType() == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
fromStr += statement.TableName()
} else {
fromStr += quote(statement.TableName())
}
if statement.TableAlias != "" {
if dialect.DBType() == core.ORACLE {
if dialect.DBType() == schemas.ORACLE {
fromStr += " " + quote(statement.TableAlias)
} else {
fromStr += " AS " + quote(statement.TableAlias)
@ -1072,7 +1077,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
}
pLimitN := statement.LimitN
if dialect.DBType() == core.MSSQL {
if dialect.DBType() == schemas.MSSQL {
if pLimitN != nil {
LimitNValue := *pLimitN
top = fmt.Sprintf("TOP %d ", LimitNValue)
@ -1134,7 +1139,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
}
if needLimit {
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
if dialect.DBType() != schemas.MSSQL && dialect.DBType() != schemas.ORACLE {
if statement.Start > 0 {
if pLimitN != nil {
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
@ -1144,7 +1149,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
} else if pLimitN != nil {
fmt.Fprint(&buf, " LIMIT ", *pLimitN)
}
} else if dialect.DBType() == core.ORACLE {
} else if dialect.DBType() == schemas.ORACLE {
if statement.Start != 0 || pLimitN != nil {
oldString := buf.String()
buf.Reset()
@ -1158,7 +1163,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
}
}
if statement.IsForUpdate {
return dialect.ForUpdateSql(buf.String()), nil
return dialect.ForUpdateSQL(buf.String()), nil
}
return buf.String(), nil
@ -1183,7 +1188,7 @@ func (statement *Statement) processIDParam() error {
return nil
}
func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName bool) string {
var colnames = make([]string, len(cols))
for i, col := range cols {
if includeTableName {
@ -1211,7 +1216,7 @@ func (statement *Statement) convertIDSQL(sqlStr string) string {
var top string
pLimitN := statement.LimitN
if pLimitN != nil && statement.Engine.dialect.DBType() == core.MSSQL {
if pLimitN != nil && statement.Engine.dialect.DBType() == schemas.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN)
}
@ -1240,9 +1245,9 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
// TODO: for postgres only, if any other database?
var paraStr string
if statement.Engine.dialect.DBType() == core.POSTGRES {
if statement.Engine.dialect.DBType() == schemas.POSTGRES {
paraStr = "$"
} else if statement.Engine.dialect.DBType() == core.MSSQL {
} else if statement.Engine.dialect.DBType() == schemas.MSSQL {
paraStr = ":"
}

View File

@ -11,7 +11,7 @@ import (
"time"
"xorm.io/builder"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
func quoteNeeded(a interface{}) bool {
@ -80,7 +80,7 @@ const insertSelectPlaceHolder = true
func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error {
switch argv := arg.(type) {
case bool:
if statement.Engine.dialect.DBType() == core.MSSQL {
if statement.Engine.dialect.DBType() == schemas.MSSQL {
if argv {
if _, err := w.WriteString("1"); err != nil {
return err
@ -119,7 +119,7 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er
w.Append(arg)
} else {
var convertFunc = convertStringSingleQuote
if statement.Engine.dialect.DBType() == core.MYSQL {
if statement.Engine.dialect.DBType() == schemas.MYSQL {
convertFunc = convertString
}
if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil {

View File

@ -10,7 +10,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
var colStrTests = []struct {
@ -42,7 +42,7 @@ func TestColumnsStringGeneration(t *testing.T) {
columns := statement.RefTable.Columns()
if testCase.onlyToDBColumnNdx >= 0 {
columns[testCase.onlyToDBColumnNdx].MapType = core.ONLYTODB
columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB
}
actual := statement.genColumnStr()
@ -51,7 +51,7 @@ func TestColumnsStringGeneration(t *testing.T) {
t.Errorf("[test #%d] Unexpected columns string:\nwant:\t%s\nhave:\t%s", ndx, testCase.expected, actual)
}
if testCase.onlyToDBColumnNdx >= 0 {
columns[testCase.onlyToDBColumnNdx].MapType = core.TWOSIDES
columns[testCase.onlyToDBColumnNdx].MapType = schemas.TWOSIDES
}
}
}
@ -69,7 +69,7 @@ func BenchmarkColumnsStringGeneration(b *testing.B) {
if testCase.onlyToDBColumnNdx >= 0 {
columns := statement.RefTable.Columns()
columns[testCase.onlyToDBColumnNdx].MapType = core.ONLYTODB // !nemec784! Column must be skipped
columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped
}
b.StartTimer()
@ -88,7 +88,7 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
b.StopTimer()
mapCols := make(map[string]bool)
cols := []*core.Column{
cols := []*schemas.Column{
{Name: `ID`},
{Name: `IsDeleted`},
{Name: `Caption`},
@ -122,7 +122,7 @@ func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) {
b.StopTimer()
mapCols := make(map[string]bool)
cols := []*core.Column{
cols := []*schemas.Column{
{Name: `ID`},
{Name: `IsDeleted`},
{Name: `Caption`},

43
tag.go
View File

@ -11,15 +11,36 @@ import (
"strings"
"time"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
func splitTag(tag string) (tags []string) {
tag = strings.TrimSpace(tag)
var hasQuote = false
var lastIdx = 0
for i, t := range tag {
if t == '\'' {
hasQuote = !hasQuote
} else if t == ' ' {
if lastIdx < i && !hasQuote {
tags = append(tags, strings.TrimSpace(tag[lastIdx:i]))
lastIdx = i + 1
}
}
}
if lastIdx < len(tag) {
tags = append(tags, strings.TrimSpace(tag[lastIdx:]))
}
return
}
// tagContext represents a context for xorm tag parse.
type tagContext struct {
tagName string
params []string
preTag, nextTag string
table *core.Table
col *core.Column
table *schemas.Table
col *schemas.Column
fieldValue reflect.Value
isIndex bool
isUnique bool
@ -59,7 +80,7 @@ var (
)
func init() {
for k := range core.SqlTypes {
for k := range schemas.SqlTypes {
defaultTagHandlers[k] = SQLTypeTagHandler
}
}
@ -71,13 +92,13 @@ func IgnoreTagHandler(ctx *tagContext) error {
// OnlyFromDBTagHandler describes mapping direction tag handler
func OnlyFromDBTagHandler(ctx *tagContext) error {
ctx.col.MapType = core.ONLYFROMDB
ctx.col.MapType = schemas.ONLYFROMDB
return nil
}
// OnlyToDBTagHandler describes mapping direction tag handler
func OnlyToDBTagHandler(ctx *tagContext) error {
ctx.col.MapType = core.ONLYTODB
ctx.col.MapType = schemas.ONLYTODB
return nil
}
@ -177,7 +198,7 @@ func DeletedTagHandler(ctx *tagContext) error {
// IndexTagHandler describes index tag handler
func IndexTagHandler(ctx *tagContext) error {
if len(ctx.params) > 0 {
ctx.indexNames[ctx.params[0]] = core.IndexType
ctx.indexNames[ctx.params[0]] = schemas.IndexType
} else {
ctx.isIndex = true
}
@ -187,7 +208,7 @@ func IndexTagHandler(ctx *tagContext) error {
// UniqueTagHandler describes unique tag handler
func UniqueTagHandler(ctx *tagContext) error {
if len(ctx.params) > 0 {
ctx.indexNames[ctx.params[0]] = core.UniqueType
ctx.indexNames[ctx.params[0]] = schemas.UniqueType
} else {
ctx.isUnique = true
}
@ -204,16 +225,16 @@ func CommentTagHandler(ctx *tagContext) error {
// SQLTypeTagHandler describes SQL Type tag handler
func SQLTypeTagHandler(ctx *tagContext) error {
ctx.col.SQLType = core.SQLType{Name: ctx.tagName}
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName}
if len(ctx.params) > 0 {
if ctx.tagName == core.Enum {
if ctx.tagName == schemas.Enum {
ctx.col.EnumOptions = make(map[string]int)
for k, v := range ctx.params {
v = strings.TrimSpace(v)
v = strings.Trim(v, "'")
ctx.col.EnumOptions[v] = k
}
} else if ctx.tagName == core.Set {
} else if ctx.tagName == schemas.Set {
ctx.col.SetOptions = make(map[string]int)
for k, v := range ctx.params {
v = strings.TrimSpace(v)

View File

@ -11,7 +11,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
type tempUser struct {
@ -269,7 +269,7 @@ func TestExtends2(t *testing.T) {
defer session.Close()
// MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON")
@ -279,7 +279,7 @@ func TestExtends2(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}
@ -339,7 +339,7 @@ func TestExtends3(t *testing.T) {
defer session.Close()
// MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON")
@ -348,7 +348,7 @@ func TestExtends3(t *testing.T) {
_, err = session.Insert(&msg)
assert.NoError(t, err)
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}
@ -433,7 +433,7 @@ func TestExtends4(t *testing.T) {
defer session.Close()
// MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON")
@ -442,7 +442,7 @@ func TestExtends4(t *testing.T) {
_, err = session.Insert(&msg)
assert.NoError(t, err)
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}

View File

@ -8,7 +8,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"xorm.io/core"
"xorm.io/xorm/names"
)
type IDGonicMapper struct {
@ -20,7 +20,7 @@ func TestGonicMapperID(t *testing.T) {
oldMapper := testEngine.GetColumnMapper()
testEngine.UnMapType(rValue(new(IDGonicMapper)).Type())
testEngine.SetMapper(core.LintGonicMapper)
testEngine.SetMapper(names.LintGonicMapper)
defer func() {
testEngine.UnMapType(rValue(new(IDGonicMapper)).Type())
testEngine.SetMapper(oldMapper)
@ -57,7 +57,7 @@ func TestSameMapperID(t *testing.T) {
oldMapper := testEngine.GetColumnMapper()
testEngine.UnMapType(rValue(new(IDSameMapper)).Type())
testEngine.SetMapper(core.SameMapper{})
testEngine.SetMapper(names.SameMapper{})
defer func() {
testEngine.UnMapType(rValue(new(IDSameMapper)).Type())
testEngine.SetMapper(oldMapper)

View File

@ -11,7 +11,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
type UserCU struct {
@ -203,7 +203,7 @@ func TestAutoIncrTag(t *testing.T) {
func TestTagComment(t *testing.T) {
assert.NoError(t, prepareEngine())
// FIXME: only support mysql
if testEngine.Dialect().DriverName() != core.MYSQL {
if testEngine.Dialect().DriverName() != schemas.MYSQL {
return
}

View File

@ -7,10 +7,10 @@ package xorm
import (
"reflect"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
var (
ptrPkType = reflect.TypeOf(&core.PK{})
pkType = reflect.TypeOf(core.PK{})
ptrPkType = reflect.TypeOf(&schemas.PK{})
pkType = reflect.TypeOf(schemas.PK{})
)

View File

@ -10,7 +10,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"xorm.io/core"
"xorm.io/xorm/schemas"
)
func TestArrayField(t *testing.T) {
@ -137,8 +137,8 @@ type ConvStruct struct {
Conv ConvString
Conv2 *ConvString
Cfg1 ConvConfig
Cfg2 *ConvConfig `xorm:"TEXT"`
Cfg3 core.Conversion `xorm:"BLOB"`
Cfg2 *ConvConfig `xorm:"TEXT"`
Cfg3 Conversion `xorm:"BLOB"`
Slice SliceType
}
@ -267,7 +267,7 @@ type Status struct {
}
var (
_ core.Conversion = &Status{}
_ Conversion = &Status{}
Registered Status = Status{"Registered", "white"}
Approved Status = Status{"Approved", "green"}
Removed Status = Status{"Removed", "red"}
@ -311,7 +311,7 @@ func TestCustomType2(t *testing.T) {
session := testEngine.NewSession()
defer session.Close()
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("set IDENTITY_INSERT " + tableName + " on")
@ -322,7 +322,7 @@ func TestCustomType2(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == core.MSSQL {
if testEngine.Dialect().DBType() == schemas.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}

57
xorm.go
View File

@ -15,7 +15,12 @@ import (
"sync"
"time"
"xorm.io/core"
"xorm.io/xorm/caches"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
)
const (
@ -23,44 +28,14 @@ const (
Version string = "0.8.0.1015"
)
func regDrvsNDialects() bool {
providedDrvsNDialects := map[string]struct {
dbType core.DbType
getDriver func() core.Driver
getDialect func() core.Dialect
}{
"mssql": {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }},
"odbc": {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access
"mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }},
"mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }},
"postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }},
"pgx": {"postgres", func() core.Driver { return &pqDriverPgx{} }, func() core.Dialect { return &postgres{} }},
"sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }},
"oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }},
"goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }},
}
for driverName, v := range providedDrvsNDialects {
if driver := core.QueryDriver(driverName); driver == nil {
core.RegisterDriver(driverName, v.getDriver())
core.RegisterDialect(v.dbType, v.getDialect)
}
}
return true
}
func close(engine *Engine) {
engine.Close()
}
func init() {
regDrvsNDialects()
}
// NewEngine new a db manager according to the parameter. Currently support four
// drivers
func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
driver := core.QueryDriver(driverName)
driver := dialects.QueryDriver(driverName)
if driver == nil {
return nil, fmt.Errorf("Unsupported driver name: %v", driverName)
}
@ -70,9 +45,9 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
return nil, err
}
dialect := core.QueryDialect(uri.DbType)
dialect := dialects.QueryDialect(uri.DBType)
if dialect == nil {
return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DbType)
return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType)
}
db, err := core.Open(driverName, dataSourceName)
@ -88,32 +63,32 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine := &Engine{
db: db,
dialect: dialect,
Tables: make(map[reflect.Type]*core.Table),
Tables: make(map[reflect.Type]*schemas.Table),
mutex: &sync.RWMutex{},
TagIdentifier: "xorm",
TZLocation: time.Local,
tagHandlers: defaultTagHandlers,
cachers: make(map[string]core.Cacher),
cachers: make(map[string]caches.Cacher),
defaultContext: context.Background(),
}
if uri.DbType == core.SQLITE {
if uri.DBType == schemas.SQLITE {
engine.DatabaseTZ = time.UTC
} else {
engine.DatabaseTZ = time.Local
}
logger := NewSimpleLogger(os.Stdout)
logger.SetLevel(core.LOG_INFO)
logger := log.NewSimpleLogger(os.Stdout)
logger.SetLevel(log.LOG_INFO)
engine.SetLogger(logger)
engine.SetMapper(core.NewCacheMapper(new(core.SnakeMapper)))
engine.SetMapper(names.NewCacheMapper(new(names.SnakeMapper)))
runtime.SetFinalizer(engine, close)
return engine, nil
}
// NewEngineWithParams new a db manager with params. The params will be passed to dialect.
// NewEngineWithParams new a db manager with params. The params will be passed to dialects.
func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) {
engine, err := NewEngine(driverName, dataSourceName)
engine.dialect.SetParams(params)

View File

@ -8,7 +8,6 @@ import (
"database/sql"
"flag"
"fmt"
"log"
"os"
"strings"
"testing"
@ -18,7 +17,10 @@ import (
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv"
"xorm.io/core"
"xorm.io/xorm/caches"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
)
var (
@ -30,14 +32,14 @@ var (
showSQL = flag.Bool("show_sql", true, "show generated SQLs")
ptrConnStr = flag.String("conn_str", "./test.db?cache=shared&mode=rwc", "test database connection string")
mapType = flag.String("map_type", "snake", "indicate the name mapping")
cache = flag.Bool("cache", false, "if enable cache")
cacheFlag = flag.Bool("cache", false, "if enable cache")
cluster = flag.Bool("cluster", false, "if this is a cluster")
splitter = flag.String("splitter", ";", "the splitter on connstr for cluster")
schema = flag.String("schema", "", "specify the schema")
ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb")
tableMapper core.IMapper
colMapper core.IMapper
tableMapper names.Mapper
colMapper names.Mapper
)
func createEngine(dbType, connStr string) error {
@ -46,7 +48,7 @@ func createEngine(dbType, connStr string) error {
if !*cluster {
switch strings.ToLower(dbType) {
case core.MSSQL:
case schemas.MSSQL:
db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1))
if err != nil {
return err
@ -56,7 +58,7 @@ func createEngine(dbType, connStr string) error {
}
db.Close()
*ignoreSelectUpdate = true
case core.POSTGRES:
case schemas.POSTGRES:
db, err := sql.Open(dbType, connStr)
if err != nil {
return err
@ -79,7 +81,7 @@ func createEngine(dbType, connStr string) error {
}
db.Close()
*ignoreSelectUpdate = true
case core.MYSQL:
case schemas.MYSQL:
db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "mysql", -1))
if err != nil {
return err
@ -107,20 +109,20 @@ func createEngine(dbType, connStr string) error {
testEngine.SetSchema(*schema)
}
testEngine.ShowSQL(*showSQL)
testEngine.SetLogLevel(core.LOG_DEBUG)
if *cache {
cacher := NewLRUCacher(NewMemoryStore(), 100000)
testEngine.SetLogLevel(log.LOG_DEBUG)
if *cacheFlag {
cacher := caches.NewLRUCacher(caches.NewMemoryStore(), 100000)
testEngine.SetDefaultCacher(cacher)
}
if len(*mapType) > 0 {
switch *mapType {
case "snake":
testEngine.SetMapper(core.SnakeMapper{})
testEngine.SetMapper(names.SnakeMapper{})
case "same":
testEngine.SetMapper(core.SameMapper{})
testEngine.SetMapper(names.SameMapper{})
case "gonic":
testEngine.SetMapper(core.LintGonicMapper)
testEngine.SetMapper(names.LintGonicMapper)
}
}
}
@ -158,7 +160,7 @@ func TestMain(m *testing.M) {
}
} else {
if ptrConnStr == nil {
log.Fatal("you should indicate conn string")
fmt.Println("you should indicate conn string")
return
}
connString = *ptrConnStr
@ -175,7 +177,7 @@ func TestMain(m *testing.M) {
fmt.Println("testing", dbType, connString)
if err := prepareEngine(); err != nil {
log.Fatal(err)
fmt.Println(err)
return
}