Move some codes to statement sub package #1574
|
@ -22,8 +22,10 @@ steps:
|
||||||
commands:
|
commands:
|
||||||
- make test-sqlite
|
- make test-sqlite
|
||||||
- TEST_CACHE_ENABLE=true make test-sqlite
|
- TEST_CACHE_ENABLE=true make test-sqlite
|
||||||
- go test ./caches/... ./convert/... ./core/... ./dialects/... \
|
- go test ./caches/... ./contexts/... ./convert/... ./core/... ./dialects/... \
|
||||||
|
./internal/json/... ./internal/statements/... ./internal/utils/... \
|
||||||
./log/... ./migrate/... ./names/... ./schemas/... ./tags/...
|
./log/... ./migrate/... ./names/... ./schemas/... ./tags/...
|
||||||
|
|
||||||
when:
|
when:
|
||||||
event:
|
event:
|
||||||
- push
|
- push
|
||||||
|
|
|
@ -31,7 +31,7 @@ type URI struct {
|
||||||
|
|
||||||
// Dialect represents a kind of database
|
// Dialect represents a kind of database
|
||||||
type Dialect interface {
|
type Dialect interface {
|
||||||
Init(*core.DB, *URI, string, string) error
|
Init(*core.DB, *URI) error
|
||||||
URI() *URI
|
URI() *URI
|
||||||
DB() *core.DB
|
DB() *core.DB
|
||||||
DBType() schemas.DBType
|
DBType() schemas.DBType
|
||||||
|
@ -39,9 +39,6 @@ type Dialect interface {
|
||||||
FormatBytes(b []byte) string
|
FormatBytes(b []byte) string
|
||||||
DefaultSchema() string
|
DefaultSchema() string
|
||||||
|
|
||||||
DriverName() string
|
|
||||||
DataSourceName() string
|
|
||||||
|
|
||||||
IsReserved(string) bool
|
IsReserved(string) bool
|
||||||
Quoter() schemas.Quoter
|
Quoter() schemas.Quoter
|
||||||
|
|
||||||
|
@ -77,17 +74,11 @@ type Dialect interface {
|
||||||
SetParams(params map[string]string)
|
SetParams(params map[string]string)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenDialect(dialect Dialect) (*core.DB, error) {
|
|
||||||
return core.Open(dialect.DriverName(), dialect.DataSourceName())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Base represents a basic dialect and all real dialects could embed this struct
|
// Base represents a basic dialect and all real dialects could embed this struct
|
||||||
type Base struct {
|
type Base struct {
|
||||||
db *core.DB
|
db *core.DB
|
||||||
dialect Dialect
|
dialect Dialect
|
||||||
driverName string
|
uri *URI
|
||||||
dataSourceName string
|
|
||||||
uri *URI
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Base) DB() *core.DB {
|
func (b *Base) DB() *core.DB {
|
||||||
|
@ -98,9 +89,8 @@ func (b *Base) DefaultSchema() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI, drivername, dataSourceName string) error {
|
func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI) error {
|
||||||
b.db, b.dialect, b.uri = db, dialect, uri
|
b.db, b.dialect, b.uri = db, dialect, uri
|
||||||
b.driverName, b.dataSourceName = drivername, dataSourceName
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,18 +155,10 @@ func (b *Base) FormatBytes(bs []byte) string {
|
||||||
return fmt.Sprintf("0x%x", bs)
|
return fmt.Sprintf("0x%x", bs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Base) DriverName() string {
|
|
||||||
return b.driverName
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Base) ShowCreateNull() bool {
|
func (b *Base) ShowCreateNull() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Base) DataSourceName() string {
|
|
||||||
return b.dataSourceName
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *Base) SupportDropIfExists() bool {
|
func (db *Base) SupportDropIfExists() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,12 @@
|
||||||
|
|
||||||
package dialects
|
package dialects
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"xorm.io/xorm/core"
|
||||||
|
)
|
||||||
|
|
||||||
type Driver interface {
|
type Driver interface {
|
||||||
Parse(string, string) (*URI, error)
|
Parse(string, string) (*URI, error)
|
||||||
}
|
}
|
||||||
|
@ -29,3 +35,29 @@ func QueryDriver(driverName string) Driver {
|
||||||
func RegisteredDriverSize() int {
|
func RegisteredDriverSize() int {
|
||||||
return len(drivers)
|
return len(drivers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenDialect opens a dialect via driver name and connection string
|
||||||
|
func OpenDialect(driverName, connstr string) (Dialect, error) {
|
||||||
|
driver := QueryDriver(driverName)
|
||||||
|
if driver == nil {
|
||||||
|
return nil, fmt.Errorf("Unsupported driver name: %v", driverName)
|
||||||
|
}
|
||||||
|
|
||||||
|
uri, err := driver.Parse(driverName, connstr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dialect := QueryDialect(uri.DBType)
|
||||||
|
if dialect == nil {
|
||||||
|
return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType)
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := core.Open(driverName, connstr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dialect.Init(db, uri)
|
||||||
|
|
||||||
|
return dialect, nil
|
||||||
|
}
|
||||||
|
|
|
@ -210,8 +210,8 @@ type mssql struct {
|
||||||
Base
|
Base
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mssql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
|
func (db *mssql) Init(d *core.DB, uri *URI) error {
|
||||||
return db.Base.Init(d, db, uri, drivername, dataSourceName)
|
return db.Base.Init(d, db, uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mssql) SQLType(c *schemas.Column) string {
|
func (db *mssql) SQLType(c *schemas.Column) string {
|
||||||
|
|
|
@ -177,8 +177,8 @@ type mysql struct {
|
||||||
rowFormat string
|
rowFormat string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
|
func (db *mysql) Init(d *core.DB, uri *URI) error {
|
||||||
return db.Base.Init(d, db, uri, drivername, dataSourceName)
|
return db.Base.Init(d, db, uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) SetParams(params map[string]string) {
|
func (db *mysql) SetParams(params map[string]string) {
|
||||||
|
|
|
@ -504,8 +504,8 @@ type oracle struct {
|
||||||
Base
|
Base
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *oracle) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
|
func (db *oracle) Init(d *core.DB, uri *URI) error {
|
||||||
return db.Base.Init(d, db, uri, drivername, dataSourceName)
|
return db.Base.Init(d, db, uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *oracle) SQLType(c *schemas.Column) string {
|
func (db *oracle) SQLType(c *schemas.Column) string {
|
||||||
|
|
|
@ -766,30 +766,27 @@ var (
|
||||||
"YES": true,
|
"YES": true,
|
||||||
"ZONE": true,
|
"ZONE": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultPostgresSchema default postgres schema
|
|
||||||
DefaultPostgresSchema = "public"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const PostgresPublicSchema = "public"
|
const postgresPublicSchema = "public"
|
||||||
|
|
||||||
type postgres struct {
|
type postgres struct {
|
||||||
Base
|
Base
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
|
func (db *postgres) Init(d *core.DB, uri *URI) error {
|
||||||
err := db.Base.Init(d, db, uri, drivername, dataSourceName)
|
err := db.Base.Init(d, db, uri)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if db.uri.Schema == "" {
|
if db.uri.Schema == "" {
|
||||||
db.uri.Schema = DefaultPostgresSchema
|
db.uri.Schema = postgresPublicSchema
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) DefaultSchema() string {
|
func (db *postgres) DefaultSchema() string {
|
||||||
return PostgresPublicSchema
|
return postgresPublicSchema
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) SQLType(c *schemas.Column) string {
|
func (db *postgres) SQLType(c *schemas.Column) string {
|
||||||
|
|
|
@ -149,8 +149,8 @@ type sqlite3 struct {
|
||||||
Base
|
Base
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error {
|
func (db *sqlite3) Init(d *core.DB, uri *URI) error {
|
||||||
return db.Base.Init(d, db, uri, drivername, dataSourceName)
|
return db.Base.Init(d, db, uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) SQLType(c *schemas.Column) string {
|
func (db *sqlite3) SQLType(c *schemas.Column) string {
|
||||||
|
|
29
engine.go
29
engine.go
|
@ -39,6 +39,9 @@ type Engine struct {
|
||||||
logger log.ContextLogger
|
logger log.ContextLogger
|
||||||
tagParser *tags.Parser
|
tagParser *tags.Parser
|
||||||
|
|
||||||
|
driverName string
|
||||||
|
dataSourceName string
|
||||||
|
|
||||||
TZLocation *time.Location // The timezone of the application
|
TZLocation *time.Location // The timezone of the application
|
||||||
DatabaseTZ *time.Location // The timezone of the database
|
DatabaseTZ *time.Location // The timezone of the database
|
||||||
}
|
}
|
||||||
|
@ -61,7 +64,7 @@ func (engine *Engine) BufferSize(size int) *Session {
|
||||||
// ShowSQL show SQL statement or not on logger if log level is great than INFO
|
// ShowSQL show SQL statement or not on logger if log level is great than INFO
|
||||||
func (engine *Engine) ShowSQL(show ...bool) {
|
func (engine *Engine) ShowSQL(show ...bool) {
|
||||||
engine.logger.ShowSQL(show...)
|
engine.logger.ShowSQL(show...)
|
||||||
engine.db.Logger = engine.logger
|
engine.DB().Logger = engine.logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logger return the logger interface
|
// Logger return the logger interface
|
||||||
|
@ -79,7 +82,7 @@ func (engine *Engine) SetLogger(logger interface{}) {
|
||||||
realLogger = t
|
realLogger = t
|
||||||
}
|
}
|
||||||
engine.logger = realLogger
|
engine.logger = realLogger
|
||||||
engine.db.Logger = realLogger
|
engine.DB().Logger = realLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLogLevel sets the logger level
|
// SetLogLevel sets the logger level
|
||||||
|
@ -94,12 +97,12 @@ func (engine *Engine) SetDisableGlobalCache(disable bool) {
|
||||||
|
|
||||||
// DriverName return the current sql driver's name
|
// DriverName return the current sql driver's name
|
||||||
func (engine *Engine) DriverName() string {
|
func (engine *Engine) DriverName() string {
|
||||||
return engine.dialect.DriverName()
|
return engine.driverName
|
||||||
}
|
}
|
||||||
|
|
||||||
// DataSourceName return the current connection string
|
// DataSourceName return the current connection string
|
||||||
func (engine *Engine) DataSourceName() string {
|
func (engine *Engine) DataSourceName() string {
|
||||||
return engine.dialect.DataSourceName()
|
return engine.dataSourceName
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMapper set the name mapping rules
|
// SetMapper set the name mapping rules
|
||||||
|
@ -164,17 +167,17 @@ func (engine *Engine) AutoIncrStr() string {
|
||||||
|
|
||||||
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
|
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
|
||||||
func (engine *Engine) SetConnMaxLifetime(d time.Duration) {
|
func (engine *Engine) SetConnMaxLifetime(d time.Duration) {
|
||||||
engine.db.SetConnMaxLifetime(d)
|
engine.DB().SetConnMaxLifetime(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMaxOpenConns is only available for go 1.2+
|
// SetMaxOpenConns is only available for go 1.2+
|
||||||
func (engine *Engine) SetMaxOpenConns(conns int) {
|
func (engine *Engine) SetMaxOpenConns(conns int) {
|
||||||
engine.db.SetMaxOpenConns(conns)
|
engine.DB().SetMaxOpenConns(conns)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMaxIdleConns set the max idle connections on pool, default is 2
|
// SetMaxIdleConns set the max idle connections on pool, default is 2
|
||||||
func (engine *Engine) SetMaxIdleConns(conns int) {
|
func (engine *Engine) SetMaxIdleConns(conns int) {
|
||||||
engine.db.SetMaxIdleConns(conns)
|
engine.DB().SetMaxIdleConns(conns)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetDefaultCacher set the default cacher. Xorm's default not enable cacher.
|
// SetDefaultCacher set the default cacher. Xorm's default not enable cacher.
|
||||||
|
@ -210,12 +213,12 @@ func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error {
|
||||||
|
|
||||||
// NewDB provides an interface to operate database directly
|
// NewDB provides an interface to operate database directly
|
||||||
func (engine *Engine) NewDB() (*core.DB, error) {
|
func (engine *Engine) NewDB() (*core.DB, error) {
|
||||||
return dialects.OpenDialect(engine.dialect)
|
return core.Open(engine.driverName, engine.dataSourceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DB return the wrapper of sql.DB
|
// DB return the wrapper of sql.DB
|
||||||
func (engine *Engine) DB() *core.DB {
|
func (engine *Engine) DB() *core.DB {
|
||||||
return engine.db
|
return engine.dialect.DB()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dialect return database dialect
|
// Dialect return database dialect
|
||||||
|
@ -232,7 +235,7 @@ func (engine *Engine) NewSession() *Session {
|
||||||
|
|
||||||
// Close the engine
|
// Close the engine
|
||||||
func (engine *Engine) Close() error {
|
func (engine *Engine) Close() error {
|
||||||
return engine.db.Close()
|
return engine.DB().Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ping tests if database is alive
|
// Ping tests if database is alive
|
||||||
|
@ -364,7 +367,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
||||||
if dialect == nil {
|
if dialect == nil {
|
||||||
return errors.New("Unsupported database type")
|
return errors.New("Unsupported database type")
|
||||||
}
|
}
|
||||||
dialect.Init(nil, engine.dialect.URI(), "", "")
|
dialect.Init(nil, engine.dialect.URI())
|
||||||
distDBName = string(tp[0])
|
distDBName = string(tp[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1211,10 +1214,6 @@ func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time) {
|
||||||
return dialects.FormatTime(engine.dialect, col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation)
|
return dialects.FormatTime(engine.dialect, col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interface{}) {
|
|
||||||
return dialects.FormatColumnTime(engine.dialect, engine.DatabaseTZ, col, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetColumnMapper returns the column name mapper
|
// GetColumnMapper returns the column name mapper
|
||||||
func (engine *Engine) GetColumnMapper() names.Mapper {
|
func (engine *Engine) GetColumnMapper() names.Mapper {
|
||||||
return engine.tagParser.GetColumnMapper()
|
return engine.tagParser.GetColumnMapper()
|
||||||
|
|
4
error.go
4
error.go
|
@ -20,10 +20,6 @@ var (
|
||||||
ErrNotExist = errors.New("Record does not exist")
|
ErrNotExist = errors.New("Record does not exist")
|
||||||
// ErrCacheFailed cache failed error
|
// ErrCacheFailed cache failed error
|
||||||
ErrCacheFailed = errors.New("Cache failed")
|
ErrCacheFailed = errors.New("Cache failed")
|
||||||
// ErrNeedDeletedCond delete needs less one condition error
|
|
||||||
ErrNeedDeletedCond = errors.New("Delete action needs at least one condition")
|
|
||||||
// ErrNotImplemented not implemented
|
|
||||||
ErrNotImplemented = errors.New("Not implemented")
|
|
||||||
// ErrConditionType condition type unsupported
|
// ErrConditionType condition type unsupported
|
||||||
ErrConditionType = errors.New("Unsupported condition type")
|
ErrConditionType = errors.New("Unsupported condition type")
|
||||||
)
|
)
|
||||||
|
|
|
@ -82,6 +82,7 @@ type EngineInterface interface {
|
||||||
CreateTables(...interface{}) error
|
CreateTables(...interface{}) error
|
||||||
DBMetas() ([]*schemas.Table, error)
|
DBMetas() ([]*schemas.Table, error)
|
||||||
Dialect() dialects.Dialect
|
Dialect() dialects.Dialect
|
||||||
|
DriverName() string
|
||||||
DropTables(...interface{}) error
|
DropTables(...interface{}) error
|
||||||
DumpAllToFile(fp string, tp ...schemas.DBType) error
|
DumpAllToFile(fp string, tp ...schemas.DBType) error
|
||||||
GetCacher(string) caches.Cacher
|
GetCacher(string) caches.Cacher
|
||||||
|
|
|
@ -57,16 +57,12 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
condSQL, condArgs, err := builder.ToSQL(statement.cond)
|
sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
args := append(statement.joinArgs, condArgs...)
|
args := append(statement.joinArgs, condArgs...)
|
||||||
sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
// for mssql and use limit
|
// for mssql and use limit
|
||||||
qs := strings.Count(sqlStr, "?")
|
qs := strings.Count(sqlStr, "?")
|
||||||
if len(args)*2 == qs {
|
if len(args)*2 == qs {
|
||||||
|
@ -92,12 +88,11 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
|
||||||
}
|
}
|
||||||
sumSelect := strings.Join(sumStrs, ", ")
|
sumSelect := strings.Join(sumStrs, ", ")
|
||||||
|
|
||||||
condSQL, condArgs, err := statement.GenConds(bean)
|
if err := statement.mergeConds(bean); err != nil {
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlStr, err := statement.GenSelectSQL(sumSelect, condSQL, true, true)
|
sqlStr, condArgs, err := statement.genSelectSQL(sumSelect, true, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
@ -147,12 +142,8 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
condSQL, condArgs, err := builder.ToSQL(statement.cond)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true)
|
sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
@ -165,17 +156,13 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
|
||||||
return statement.RawSQL, statement.RawParams, nil
|
return statement.RawSQL, statement.RawParams, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var condSQL string
|
|
||||||
var condArgs []interface{}
|
var condArgs []interface{}
|
||||||
var err error
|
var err error
|
||||||
if len(beans) > 0 {
|
if len(beans) > 0 {
|
||||||
statement.SetRefBean(beans[0])
|
statement.SetRefBean(beans[0])
|
||||||
condSQL, condArgs, err = statement.GenConds(beans[0])
|
if err := statement.mergeConds(beans[0]); err != nil {
|
||||||
} else {
|
return "", nil, err
|
||||||
condSQL, condArgs, err = builder.ToSQL(statement.cond)
|
}
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var selectSQL = statement.SelectStr
|
var selectSQL = statement.SelectStr
|
||||||
|
@ -186,7 +173,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
|
||||||
selectSQL = "count(*)"
|
selectSQL = "count(*)"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sqlStr, err := statement.GenSelectSQL(selectSQL, condSQL, false, false)
|
sqlStr, condArgs, err := statement.genSelectSQL(selectSQL, false, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
@ -194,7 +181,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
|
||||||
return sqlStr, append(statement.joinArgs, condArgs...), nil
|
return sqlStr, append(statement.joinArgs, condArgs...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
|
func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) {
|
||||||
var (
|
var (
|
||||||
distinct string
|
distinct string
|
||||||
dialect = statement.dialect
|
dialect = statement.dialect
|
||||||
|
@ -205,6 +192,11 @@ func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, n
|
||||||
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
|
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
|
||||||
distinct = "DISTINCT "
|
distinct = "DISTINCT "
|
||||||
}
|
}
|
||||||
|
|
||||||
|
condSQL, condArgs, err := builder.ToSQL(statement.cond)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
if len(condSQL) > 0 {
|
if len(condSQL) > 0 {
|
||||||
whereStr = " WHERE " + condSQL
|
whereStr = " WHERE " + condSQL
|
||||||
}
|
}
|
||||||
|
@ -313,10 +305,10 @@ func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, n
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if statement.IsForUpdate {
|
if statement.IsForUpdate {
|
||||||
return dialect.ForUpdateSQL(buf.String()), nil
|
return dialect.ForUpdateSQL(buf.String()), condArgs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return buf.String(), nil
|
return buf.String(), condArgs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) {
|
func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) {
|
||||||
|
@ -428,16 +420,12 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
|
||||||
}
|
}
|
||||||
|
|
||||||
statement.cond = statement.cond.And(autoCond)
|
statement.cond = statement.cond.And(autoCond)
|
||||||
condSQL, condArgs, err := builder.ToSQL(statement.cond)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
args = append(statement.joinArgs, condArgs...)
|
sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true)
|
||||||
sqlStr, err = statement.GenSelectSQL(columnStr, condSQL, true, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
args = append(statement.joinArgs, condArgs...)
|
||||||
// for mssql and use limit
|
// for mssql and use limit
|
||||||
qs := strings.Count(sqlStr, "?")
|
qs := strings.Count(sqlStr, "?")
|
||||||
if len(args)*2 == qs {
|
if len(args)*2 == qs {
|
||||||
|
|
|
@ -8,10 +8,37 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"xorm.io/xorm/caches"
|
||||||
|
"xorm.io/xorm/dialects"
|
||||||
|
"xorm.io/xorm/names"
|
||||||
"xorm.io/xorm/schemas"
|
"xorm.io/xorm/schemas"
|
||||||
|
"xorm.io/xorm/tags"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
dialect dialects.Dialect
|
||||||
|
tagParser *tags.Parser
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
var err error
|
||||||
|
dialect, err = dialects.OpenDialect("sqlite3", "./test.db")
|
||||||
|
if err != nil {
|
||||||
|
panic("unknow dialect")
|
||||||
|
}
|
||||||
|
|
||||||
|
tagParser = tags.NewParser("xorm", dialect, names.SnakeMapper{}, names.SnakeMapper{}, caches.NewManager())
|
||||||
|
if tagParser == nil {
|
||||||
|
panic("tags parser is nil")
|
||||||
|
}
|
||||||
|
m.Run()
|
||||||
|
}
|
||||||
|
|
||||||
var colStrTests = []struct {
|
var colStrTests = []struct {
|
||||||
omitColumn string
|
omitColumn string
|
||||||
onlyToDBColumnNdx int
|
onlyToDBColumnNdx int
|
||||||
|
@ -26,14 +53,9 @@ var colStrTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestColumnsStringGeneration(t *testing.T) {
|
func TestColumnsStringGeneration(t *testing.T) {
|
||||||
if dbType == "postgres" || dbType == "mssql" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var statement *Statement
|
|
||||||
|
|
||||||
for ndx, testCase := range colStrTests {
|
for ndx, testCase := range colStrTests {
|
||||||
statement = createTestStatement()
|
statement, err := createTestStatement()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
if testCase.omitColumn != "" {
|
if testCase.omitColumn != "" {
|
||||||
statement.Omit(testCase.omitColumn)
|
statement.Omit(testCase.omitColumn)
|
||||||
|
@ -55,33 +77,6 @@ func TestColumnsStringGeneration(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkColumnsStringGeneration(b *testing.B) {
|
|
||||||
b.StopTimer()
|
|
||||||
|
|
||||||
statement := createTestStatement()
|
|
||||||
|
|
||||||
testCase := colStrTests[0]
|
|
||||||
|
|
||||||
if testCase.omitColumn != "" {
|
|
||||||
statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped
|
|
||||||
}
|
|
||||||
|
|
||||||
if testCase.onlyToDBColumnNdx >= 0 {
|
|
||||||
columns := statement.RefTable.Columns()
|
|
||||||
columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped
|
|
||||||
}
|
|
||||||
|
|
||||||
b.StartTimer()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
actual := statement.genColumnStr()
|
|
||||||
|
|
||||||
if actual != testCase.expected {
|
|
||||||
b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
|
func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
|
||||||
|
|
||||||
b.StopTimer()
|
b.StopTimer()
|
||||||
|
@ -162,23 +157,40 @@ func (TestType) TableName() string {
|
||||||
return "TestTable"
|
return "TestTable"
|
||||||
}
|
}
|
||||||
|
|
||||||
func createTestStatement() *Statement {
|
func createTestStatement() (*Statement, error) {
|
||||||
if engine, ok := testEngine.(*Engine); ok {
|
statement := NewStatement(dialect, tagParser, time.Local)
|
||||||
statement := &Statement{}
|
if err := statement.SetRefValue(reflect.ValueOf(TestType{})); err != nil {
|
||||||
statement.Reset()
|
return nil, err
|
||||||
statement.Engine = engine
|
}
|
||||||
statement.dialect = engine.dialect
|
return statement, nil
|
||||||
statement.SetRefValue(reflect.ValueOf(TestType{}))
|
}
|
||||||
|
|
||||||
return statement
|
func BenchmarkColumnsStringGeneration(b *testing.B) {
|
||||||
} else if eg, ok := testEngine.(*EngineGroup); ok {
|
b.StopTimer()
|
||||||
statement := &Statement{}
|
|
||||||
statement.Reset()
|
statement, err := createTestStatement()
|
||||||
statement.Engine = eg.Engine
|
if err != nil {
|
||||||
statement.dialect = eg.Engine.dialect
|
panic(err)
|
||||||
statement.SetRefValue(reflect.ValueOf(TestType{}))
|
}
|
||||||
|
|
||||||
return statement
|
testCase := colStrTests[0]
|
||||||
|
|
||||||
|
if testCase.omitColumn != "" {
|
||||||
|
statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped
|
||||||
|
}
|
||||||
|
|
||||||
|
if testCase.onlyToDBColumnNdx >= 0 {
|
||||||
|
columns := statement.RefTable.Columns()
|
||||||
|
columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped
|
||||||
|
}
|
||||||
|
|
||||||
|
b.StartTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
actual := statement.genColumnStr()
|
||||||
|
|
||||||
|
if actual != testCase.expected {
|
||||||
|
b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -284,7 +284,7 @@ func (session *Session) Having(conditions string) *Session {
|
||||||
// DB db return the wrapper of sql.DB
|
// DB db return the wrapper of sql.DB
|
||||||
func (session *Session) DB() *core.DB {
|
func (session *Session) DB() *core.DB {
|
||||||
if session.db == nil {
|
if session.db == nil {
|
||||||
session.db = session.engine.db
|
session.db = session.engine.DB()
|
||||||
session.stmtCache = make(map[uint32]*core.Stmt, 0)
|
session.stmtCache = make(map[uint32]*core.Stmt, 0)
|
||||||
}
|
}
|
||||||
return session.db
|
return session.db
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"xorm.io/xorm/convert"
|
"xorm.io/xorm/convert"
|
||||||
|
"xorm.io/xorm/dialects"
|
||||||
"xorm.io/xorm/internal/json"
|
"xorm.io/xorm/internal/json"
|
||||||
"xorm.io/xorm/internal/utils"
|
"xorm.io/xorm/internal/utils"
|
||||||
"xorm.io/xorm/schemas"
|
"xorm.io/xorm/schemas"
|
||||||
|
@ -583,7 +584,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if fieldType.ConvertibleTo(schemas.TimeType) {
|
if fieldType.ConvertibleTo(schemas.TimeType) {
|
||||||
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
|
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
|
||||||
tf := session.engine.formatColTime(col, t)
|
tf := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, t)
|
||||||
return tf, nil
|
return tf, nil
|
||||||
} else if fieldType.ConvertibleTo(nullFloatType) {
|
} else if fieldType.ConvertibleTo(nullFloatType) {
|
||||||
t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64)
|
t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64)
|
||||||
|
|
|
@ -13,6 +13,14 @@ import (
|
||||||
"xorm.io/xorm/schemas"
|
"xorm.io/xorm/schemas"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNeedDeletedCond delete needs less one condition error
|
||||||
|
ErrNeedDeletedCond = errors.New("Delete action needs at least one condition")
|
||||||
|
|
||||||
|
// ErrNotImplemented not implemented
|
||||||
|
ErrNotImplemented = errors.New("Not implemented")
|
||||||
|
)
|
||||||
|
|
||||||
func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error {
|
func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error {
|
||||||
if table == nil ||
|
if table == nil ||
|
||||||
session.tx != nil {
|
session.tx != nil {
|
||||||
|
|
|
@ -179,7 +179,7 @@ func TestGetVar(t *testing.T) {
|
||||||
assert.Equal(t, "1.5", valuesString["money"])
|
assert.Equal(t, "1.5", valuesString["money"])
|
||||||
|
|
||||||
// for mymysql driver, interface{} will be []byte, so ignore it currently
|
// for mymysql driver, interface{} will be []byte, so ignore it currently
|
||||||
if testEngine.Dialect().DriverName() != "mymysql" {
|
if testEngine.DriverName() != "mymysql" {
|
||||||
var valuesInter = make(map[string]interface{})
|
var valuesInter = make(map[string]interface{})
|
||||||
has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter)
|
has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
|
@ -34,7 +34,7 @@ func (session *Session) Rollback() error {
|
||||||
session.isAutoCommit = true
|
session.isAutoCommit = true
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
needSQL := session.engine.db.NeedLogSQL(session.ctx)
|
needSQL := session.DB().NeedLogSQL(session.ctx)
|
||||||
if needSQL {
|
if needSQL {
|
||||||
session.engine.logger.BeforeSQL(log.LogContext{
|
session.engine.logger.BeforeSQL(log.LogContext{
|
||||||
Ctx: session.ctx,
|
Ctx: session.ctx,
|
||||||
|
@ -63,7 +63,7 @@ func (session *Session) Commit() error {
|
||||||
session.isAutoCommit = true
|
session.isAutoCommit = true
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
needSQL := session.engine.db.NeedLogSQL(session.ctx)
|
needSQL := session.DB().NeedLogSQL(session.ctx)
|
||||||
if needSQL {
|
if needSQL {
|
||||||
session.engine.logger.BeforeSQL(log.LogContext{
|
session.engine.logger.BeforeSQL(log.LogContext{
|
||||||
Ctx: session.ctx,
|
Ctx: session.ctx,
|
||||||
|
|
29
xorm.go
29
xorm.go
|
@ -8,13 +8,11 @@ package xorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"xorm.io/xorm/caches"
|
"xorm.io/xorm/caches"
|
||||||
"xorm.io/xorm/core"
|
|
||||||
"xorm.io/xorm/dialects"
|
"xorm.io/xorm/dialects"
|
||||||
"xorm.io/xorm/log"
|
"xorm.io/xorm/log"
|
||||||
"xorm.io/xorm/names"
|
"xorm.io/xorm/names"
|
||||||
|
@ -34,27 +32,7 @@ func close(engine *Engine) {
|
||||||
// NewEngine new a db manager according to the parameter. Currently support four
|
// NewEngine new a db manager according to the parameter. Currently support four
|
||||||
// drivers
|
// drivers
|
||||||
func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
|
func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
|
||||||
driver := dialects.QueryDriver(driverName)
|
dialect, err := dialects.OpenDialect(driverName, dataSourceName)
|
||||||
if driver == nil {
|
|
||||||
return nil, fmt.Errorf("Unsupported driver name: %v", driverName)
|
|
||||||
}
|
|
||||||
|
|
||||||
uri, err := driver.Parse(driverName, dataSourceName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
dialect := dialects.QueryDialect(uri.DBType)
|
|
||||||
if dialect == nil {
|
|
||||||
return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType)
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := core.Open(driverName, dataSourceName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = dialect.Init(db, uri, driverName, dataSourceName)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -64,15 +42,16 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
|
||||||
tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr)
|
tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr)
|
||||||
|
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
db: db,
|
|
||||||
dialect: dialect,
|
dialect: dialect,
|
||||||
TZLocation: time.Local,
|
TZLocation: time.Local,
|
||||||
defaultContext: context.Background(),
|
defaultContext: context.Background(),
|
||||||
cacherMgr: cacherMgr,
|
cacherMgr: cacherMgr,
|
||||||
tagParser: tagParser,
|
tagParser: tagParser,
|
||||||
|
driverName: driverName,
|
||||||
|
dataSourceName: dataSourceName,
|
||||||
}
|
}
|
||||||
|
|
||||||
if uri.DBType == schemas.SQLITE {
|
if dialect.URI().DBType == schemas.SQLITE {
|
||||||
engine.DatabaseTZ = time.UTC
|
engine.DatabaseTZ = time.UTC
|
||||||
} else {
|
} else {
|
||||||
engine.DatabaseTZ = time.Local
|
engine.DatabaseTZ = time.Local
|
||||||
|
|
Loading…
Reference in New Issue
Block a user