383 lines
9.5 KiB
Go
383 lines
9.5 KiB
Go
// Copyright 2019 The Xorm Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package dialects
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"xorm.io/xorm/core"
|
|
"xorm.io/xorm/log"
|
|
"xorm.io/xorm/schemas"
|
|
)
|
|
|
|
type DBType string
|
|
|
|
type URI struct {
|
|
DBType DBType
|
|
Proto string
|
|
Host string
|
|
Port string
|
|
DBName string
|
|
User string
|
|
Passwd string
|
|
Charset string
|
|
Laddr string
|
|
Raddr string
|
|
Timeout time.Duration
|
|
Schema string
|
|
}
|
|
|
|
// a dialect is a driver's wrapper
|
|
type Dialect interface {
|
|
SetLogger(logger log.Logger)
|
|
Init(*core.DB, *URI, string, string) error
|
|
URI() *URI
|
|
DB() *core.DB
|
|
DBType() DBType
|
|
SQLType(*schemas.Column) string
|
|
FormatBytes(b []byte) string
|
|
|
|
DriverName() string
|
|
DataSourceName() string
|
|
|
|
IsReserved(string) bool
|
|
Quoter() schemas.Quoter
|
|
|
|
RollBackStr() string
|
|
AutoIncrStr() string
|
|
|
|
SupportInsertMany() bool
|
|
SupportEngine() bool
|
|
SupportCharset() bool
|
|
SupportDropIfExists() bool
|
|
IndexOnTable() bool
|
|
ShowCreateNull() bool
|
|
|
|
IndexCheckSQL(tableName, idxName string) (string, []interface{})
|
|
TableCheckSQL(tableName string) (string, []interface{})
|
|
|
|
IsColumnExist(ctx context.Context, tableName string, colName string) (bool, error)
|
|
|
|
CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string
|
|
DropTableSQL(tableName string) string
|
|
CreateIndexSQL(tableName string, index *schemas.Index) string
|
|
DropIndexSQL(tableName string, index *schemas.Index) string
|
|
|
|
AddColumnSQL(tableName string, col *schemas.Column) string
|
|
ModifyColumnSQL(tableName string, col *schemas.Column) string
|
|
|
|
ForUpdateSQL(query string) string
|
|
|
|
GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
|
|
GetTables(ctx context.Context) ([]*schemas.Table, error)
|
|
GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error)
|
|
|
|
Filters() []Filter
|
|
SetParams(params map[string]string)
|
|
}
|
|
|
|
func OpenDialect(dialect Dialect) (*core.DB, error) {
|
|
return core.Open(dialect.DriverName(), dialect.DataSourceName())
|
|
}
|
|
|
|
// Base represents a basic dialect and all real dialects could embed this struct
|
|
type Base struct {
|
|
db *core.DB
|
|
dialect Dialect
|
|
driverName string
|
|
dataSourceName string
|
|
logger log.Logger
|
|
uri *URI
|
|
}
|
|
|
|
func (b *Base) DB() *core.DB {
|
|
return b.db
|
|
}
|
|
|
|
func (b *Base) SetLogger(logger log.Logger) {
|
|
b.logger = logger
|
|
}
|
|
|
|
func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI, drivername, dataSourceName string) error {
|
|
b.db, b.dialect, b.uri = db, dialect, uri
|
|
b.driverName, b.dataSourceName = drivername, dataSourceName
|
|
return nil
|
|
}
|
|
|
|
func (b *Base) URI() *URI {
|
|
return b.uri
|
|
}
|
|
|
|
func (b *Base) DBType() DBType {
|
|
return b.uri.DBType
|
|
}
|
|
|
|
// String generate column description string according dialect
|
|
func (b *Base) String(col *schemas.Column) string {
|
|
sql := b.dialect.Quoter().Quote(col.Name) + " "
|
|
|
|
sql += b.dialect.SQLType(col) + " "
|
|
|
|
if col.IsPrimaryKey {
|
|
sql += "PRIMARY KEY "
|
|
if col.IsAutoIncrement {
|
|
sql += b.dialect.AutoIncrStr() + " "
|
|
}
|
|
}
|
|
|
|
if col.Default != "" {
|
|
sql += "DEFAULT " + col.Default + " "
|
|
}
|
|
|
|
if b.dialect.ShowCreateNull() {
|
|
if col.Nullable {
|
|
sql += "NULL "
|
|
} else {
|
|
sql += "NOT NULL "
|
|
}
|
|
}
|
|
|
|
return sql
|
|
}
|
|
|
|
// StringNoPk generate column description string according dialect without primary keys
|
|
func (b *Base) StringNoPk(col *schemas.Column) string {
|
|
sql := b.dialect.Quoter().Quote(col.Name) + " "
|
|
|
|
sql += b.dialect.SQLType(col) + " "
|
|
|
|
if col.Default != "" {
|
|
sql += "DEFAULT " + col.Default + " "
|
|
}
|
|
|
|
if b.dialect.ShowCreateNull() {
|
|
if col.Nullable {
|
|
sql += "NULL "
|
|
} else {
|
|
sql += "NOT NULL "
|
|
}
|
|
}
|
|
|
|
return sql
|
|
}
|
|
|
|
func (b *Base) FormatBytes(bs []byte) string {
|
|
return fmt.Sprintf("0x%x", bs)
|
|
}
|
|
|
|
func (b *Base) DriverName() string {
|
|
return b.driverName
|
|
}
|
|
|
|
func (b *Base) ShowCreateNull() bool {
|
|
return true
|
|
}
|
|
|
|
func (b *Base) DataSourceName() string {
|
|
return b.dataSourceName
|
|
}
|
|
|
|
func (db *Base) RollBackStr() string {
|
|
return "ROLL BACK"
|
|
}
|
|
|
|
func (db *Base) SupportDropIfExists() bool {
|
|
return true
|
|
}
|
|
|
|
func (db *Base) DropTableSQL(tableName string) string {
|
|
quote := db.dialect.Quoter().Quote
|
|
return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName))
|
|
}
|
|
|
|
func (db *Base) HasRecords(ctx context.Context, query string, args ...interface{}) (bool, error) {
|
|
db.LogSQL(query, args)
|
|
rows, err := db.DB().QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
if rows.Next() {
|
|
return true, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (db *Base) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) {
|
|
quote := db.dialect.Quoter().Quote
|
|
query := fmt.Sprintf(
|
|
"SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?",
|
|
quote("COLUMN_NAME"),
|
|
quote("INFORMATION_SCHEMA"),
|
|
quote("COLUMNS"),
|
|
quote("TABLE_SCHEMA"),
|
|
quote("TABLE_NAME"),
|
|
quote("COLUMN_NAME"),
|
|
)
|
|
return db.HasRecords(ctx, query, db.uri.DBName, tableName, colName)
|
|
}
|
|
|
|
func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
|
|
quoter := db.dialect.Quoter()
|
|
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName),
|
|
db.String(col))
|
|
if db.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
|
|
sql += " COMMENT '" + col.Comment + "'"
|
|
}
|
|
return sql
|
|
}
|
|
|
|
func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
|
|
quoter := db.dialect.Quoter()
|
|
var unique string
|
|
var idxName string
|
|
if index.Type == schemas.UniqueType {
|
|
unique = " UNIQUE"
|
|
}
|
|
idxName = index.XName(tableName)
|
|
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
|
|
quoter.Quote(idxName), quoter.Quote(tableName),
|
|
quoter.Quote(strings.Join(index.Cols, quoter.ReverseQuote(","))))
|
|
}
|
|
|
|
func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
|
|
quote := db.dialect.Quoter().Quote
|
|
var name string
|
|
if index.IsRegular {
|
|
name = index.XName(tableName)
|
|
} else {
|
|
name = index.Name
|
|
}
|
|
return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName))
|
|
}
|
|
|
|
func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
|
|
return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, db.StringNoPk(col))
|
|
}
|
|
|
|
func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
|
|
var sql string
|
|
sql = "CREATE TABLE IF NOT EXISTS "
|
|
if tableName == "" {
|
|
tableName = table.Name
|
|
}
|
|
|
|
quoter := b.dialect.Quoter()
|
|
sql += quoter.Quote(tableName)
|
|
sql += " ("
|
|
|
|
if len(table.ColumnsSeq()) > 0 {
|
|
pkList := table.PrimaryKeys
|
|
|
|
for _, colName := range table.ColumnsSeq() {
|
|
col := table.GetColumn(colName)
|
|
if col.IsPrimaryKey && len(pkList) == 1 {
|
|
sql += b.String(col)
|
|
} else {
|
|
sql += b.StringNoPk(col)
|
|
}
|
|
sql = strings.TrimSpace(sql)
|
|
if b.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
|
|
sql += " COMMENT '" + col.Comment + "'"
|
|
}
|
|
sql += ", "
|
|
}
|
|
|
|
if len(pkList) > 1 {
|
|
sql += "PRIMARY KEY ( "
|
|
sql += quoter.Quote(strings.Join(pkList, quoter.ReverseQuote(",")))
|
|
sql += " ), "
|
|
}
|
|
|
|
sql = sql[:len(sql)-2]
|
|
}
|
|
sql += ")"
|
|
|
|
if b.dialect.SupportEngine() && storeEngine != "" {
|
|
sql += " ENGINE=" + storeEngine
|
|
}
|
|
if b.dialect.SupportCharset() {
|
|
if len(charset) == 0 {
|
|
charset = b.dialect.URI().Charset
|
|
}
|
|
if len(charset) > 0 {
|
|
sql += " DEFAULT CHARSET " + charset
|
|
}
|
|
}
|
|
|
|
return sql
|
|
}
|
|
|
|
func (b *Base) ForUpdateSQL(query string) string {
|
|
return query + " FOR UPDATE"
|
|
}
|
|
|
|
func (b *Base) LogSQL(sql string, args []interface{}) {
|
|
if b.logger != nil && b.logger.IsShowSQL() {
|
|
if len(args) > 0 {
|
|
b.logger.Infof("[SQL] %v %v", sql, args)
|
|
} else {
|
|
b.logger.Infof("[SQL] %v", sql)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (b *Base) SetParams(params map[string]string) {
|
|
}
|
|
|
|
var (
|
|
dialects = map[string]func() Dialect{}
|
|
)
|
|
|
|
// RegisterDialect register database dialect
|
|
func RegisterDialect(dbName DBType, dialectFunc func() Dialect) {
|
|
if dialectFunc == nil {
|
|
panic("core: Register dialect is nil")
|
|
}
|
|
dialects[strings.ToLower(string(dbName))] = dialectFunc // !nashtsai! allow override dialect
|
|
}
|
|
|
|
// QueryDialect query if registered database dialect
|
|
func QueryDialect(dbName DBType) Dialect {
|
|
if d, ok := dialects[strings.ToLower(string(dbName))]; ok {
|
|
return d()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func regDrvsNDialects() bool {
|
|
providedDrvsNDialects := map[string]struct {
|
|
dbType DBType
|
|
getDriver func() Driver
|
|
getDialect func() Dialect
|
|
}{
|
|
"mssql": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }},
|
|
"odbc": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access
|
|
"mysql": {"mysql", func() Driver { return &mysqlDriver{} }, func() Dialect { return &mysql{} }},
|
|
"mymysql": {"mysql", func() Driver { return &mymysqlDriver{} }, func() Dialect { return &mysql{} }},
|
|
"postgres": {"postgres", func() Driver { return &pqDriver{} }, func() Dialect { return &postgres{} }},
|
|
"pgx": {"postgres", func() Driver { return &pqDriverPgx{} }, func() Dialect { return &postgres{} }},
|
|
"sqlite3": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }},
|
|
"oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }},
|
|
"goracle": {"oracle", func() Driver { return &goracleDriver{} }, func() Dialect { return &oracle{} }},
|
|
}
|
|
|
|
for driverName, v := range providedDrvsNDialects {
|
|
if driver := QueryDriver(driverName); driver == nil {
|
|
RegisterDriver(driverName, v.getDriver())
|
|
RegisterDialect(v.dbType, v.getDialect)
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func init() {
|
|
regDrvsNDialects()
|
|
}
|