Browse Source

Improve dialect interface (#1578)

Improve dialect interface

Reviewed-on: #1578
tags/v1.0.0
Lunny Xiao 3 weeks ago
parent
commit
7f22948be9
23 changed files with 162 additions and 240 deletions
  1. +16
    -85
      dialects/dialect.go
  2. +1
    -19
      dialects/mssql.go
  3. +4
    -15
      dialects/mysql.go
  4. +1
    -24
      dialects/oracle.go
  5. +34
    -8
      dialects/postgres.go
  6. +38
    -12
      dialects/sqlite3.go
  7. +1
    -1
      dialects/time.go
  8. +4
    -4
      engine.go
  9. +3
    -3
      internal/statements/cache.go
  10. +9
    -9
      internal/statements/query.go
  11. +8
    -18
      internal/statements/statement.go
  12. +2
    -2
      internal/statements/statement_args.go
  13. +1
    -1
      session_cols_test.go
  14. +3
    -3
      session_convert.go
  15. +2
    -2
      session_delete.go
  16. +2
    -2
      session_delete_test.go
  17. +3
    -3
      session_get_test.go
  18. +7
    -6
      session_insert.go
  19. +4
    -4
      session_query_test.go
  20. +4
    -4
      session_schema.go
  21. +6
    -6
      session_update.go
  22. +7
    -7
      tags_test.go
  23. +2
    -2
      types_test.go

+ 16
- 85
dialects/dialect.go View File

@@ -34,7 +34,6 @@ type Dialect interface {
Init(*core.DB, *URI) error
URI() *URI
DB() *core.DB
DBType() schemas.DBType
SQLType(*schemas.Column) string
FormatBytes(b []byte) string
DefaultSchema() string
@@ -44,33 +43,26 @@ type Dialect interface {
SetQuotePolicy(quotePolicy QuotePolicy)

AutoIncrStr() string

SupportInsertMany() bool
SupportEngine() bool
SupportCharset() bool
SupportDropIfExists() bool
IndexOnTable() bool
ShowCreateNull() bool

GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error)
IndexCheckSQL(tableName, idxName string) (string, []interface{})
TableCheckSQL(tableName string) (string, []interface{})

IsColumnExist(ctx context.Context, tableName string, colName string) (bool, error)

CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string
DropTableSQL(tableName string) string
CreateIndexSQL(tableName string, index *schemas.Index) string
DropIndexSQL(tableName string, index *schemas.Index) string

GetTables(ctx context.Context) ([]*schemas.Table, error)
TableCheckSQL(tableName string) (string, []interface{})
CreateTableSQL(table *schemas.Table, tableName string) string
DropTableSQL(tableName string) string

GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
IsColumnExist(ctx context.Context, tableName string, colName string) (bool, error)
AddColumnSQL(tableName string, col *schemas.Column) string
ModifyColumnSQL(tableName string, col *schemas.Column) string

ForUpdateSQL(query string) string

GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
GetTables(ctx context.Context) ([]*schemas.Table, error)
GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error)

Filters() []Filter
SetParams(params map[string]string)
}
@@ -125,12 +117,10 @@ func (b *Base) String(col *schemas.Column) string {
sql += "DEFAULT " + col.Default + " "
}

if b.dialect.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}

return sql
@@ -146,12 +136,10 @@ func (b *Base) StringNoPk(col *schemas.Column) string {
sql += "DEFAULT " + col.Default + " "
}

if b.dialect.ShowCreateNull() {
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}

return sql
@@ -161,10 +149,6 @@ func (b *Base) FormatBytes(bs []byte) string {
return fmt.Sprintf("0x%x", bs)
}

func (b *Base) ShowCreateNull() bool {
return true
}

func (db *Base) SupportDropIfExists() bool {
return true
}
@@ -234,59 +218,6 @@ func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, db.StringNoPk(col))
}

func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}

quoter := b.dialect.Quoter()
sql += quoter.Quote(tableName)
sql += " ("

if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys

for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += b.String(col)
} else {
sql += b.StringNoPk(col)
}
sql = strings.TrimSpace(sql)
if b.DBType() == schemas.MYSQL && len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
sql += ", "
}

if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += quoter.Join(pkList, ",")
sql += " ), "
}

sql = sql[:len(sql)-2]
}
sql += ")"

if b.dialect.SupportEngine() && storeEngine != "" {
sql += " ENGINE=" + storeEngine
}
if b.dialect.SupportCharset() {
if len(charset) == 0 {
charset = b.dialect.URI().Charset
}
if len(charset) > 0 {
sql += " DEFAULT CHARSET " + charset
}
}

return sql
}

func (b *Base) ForUpdateSQL(query string) string {
return query + " FOR UPDATE"
}


+ 1
- 19
dialects/mssql.go View File

@@ -307,10 +307,6 @@ func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) {
}
}

func (db *mssql) SupportEngine() bool {
return false
}

func (db *mssql) AutoIncrStr() string {
return "IDENTITY"
}
@@ -321,26 +317,12 @@ func (db *mssql) DropTableSQL(tableName string) string {
"DROP TABLE \"%s\"", tableName, tableName)
}

func (db *mssql) SupportCharset() bool {
return false
}

func (db *mssql) IndexOnTable() bool {
return true
}

func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{idxName}
sql := "select name from sysindexes where id=object_id('" + tableName + "') and name=?"
return sql, args
}

/*func (db *mssql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName, colName}
sql := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?`
return sql, args
}*/

func (db *mssql) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) {
query := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?`

@@ -509,7 +491,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
return indexes, nil
}

func (db *mssql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) string {
var sql string
if tableName == "" {
tableName = table.Name


+ 4
- 15
dialects/mysql.go View File

@@ -279,22 +279,10 @@ func (db *mysql) IsReserved(name string) bool {
return ok
}

func (db *mysql) SupportEngine() bool {
return true
}

func (db *mysql) AutoIncrStr() string {
return "AUTO_INCREMENT"
}

func (db *mysql) SupportCharset() bool {
return true
}

func (db *mysql) IndexOnTable() bool {
return true
}

func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{db.uri.DBName, tableName, idxName}
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
@@ -524,7 +512,7 @@ func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]*
return indexes, nil
}

func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) string {
var sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
@@ -562,10 +550,11 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch
}
sql += ")"

if storeEngine != "" {
sql += " ENGINE=" + storeEngine
if table.StoreEngine != "" {
sql += " ENGINE=" + table.StoreEngine
}

var charset = table.Charset
if len(charset) == 0 {
charset = db.URI().Charset
}


+ 1
- 24
dialects/oracle.go View File

@@ -556,27 +556,15 @@ func (db *oracle) IsReserved(name string) bool {
return ok
}

func (db *oracle) SupportEngine() bool {
return false
}

func (db *oracle) SupportCharset() bool {
return false
}

func (db *oracle) SupportDropIfExists() bool {
return false
}

func (db *oracle) IndexOnTable() bool {
return false
}

func (db *oracle) DropTableSQL(tableName string) string {
return fmt.Sprintf("DROP TABLE `%s`", tableName)
}

func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) string {
var sql = "CREATE TABLE "
if tableName == "" {
tableName = table.Name
@@ -605,17 +593,6 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c
}

sql = sql[:len(sql)-2] + ")"
if db.SupportEngine() && storeEngine != "" {
sql += " ENGINE=" + storeEngine
}
if db.SupportCharset() {
if len(charset) == 0 {
charset = db.URI().Charset
}
if len(charset) > 0 {
sql += " DEFAULT CHARSET " + charset
}
}
return sql
}



+ 34
- 8
dialects/postgres.go View File

@@ -897,16 +897,42 @@ func (db *postgres) AutoIncrStr() string {
return ""
}

func (db *postgres) SupportEngine() bool {
return false
}
func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}

func (db *postgres) SupportCharset() bool {
return false
}
quoter := db.Quoter()
sql += quoter.Quote(tableName)
sql += " ("

func (db *postgres) IndexOnTable() bool {
return false
if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys

for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += db.String(col)
} else {
sql += db.StringNoPk(col)
}
sql = strings.TrimSpace(sql)
sql += ", "
}

if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += quoter.Join(pkList, ",")
sql += " ), "
}

sql = sql[:len(sql)-2]
}
sql += ")"

return sql
}

func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {


+ 38
- 12
dialects/sqlite3.go View File

@@ -224,18 +224,6 @@ func (db *sqlite3) AutoIncrStr() string {
return "AUTOINCREMENT"
}

func (db *sqlite3) SupportEngine() bool {
return false
}

func (db *sqlite3) SupportCharset() bool {
return false
}

func (db *sqlite3) IndexOnTable() bool {
return false
}

func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
args := []interface{}{idxName}
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
@@ -261,6 +249,44 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string {
return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName))
}

func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}

quoter := db.Quoter()
sql += quoter.Quote(tableName)
sql += " ("

if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys

for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += db.String(col)
} else {
sql += db.StringNoPk(col)
}
sql = strings.TrimSpace(sql)
sql += ", "
}

if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += quoter.Join(pkList, ",")
sql += " ), "
}

sql = sql[:len(sql)-2]
}
sql += ")"

return sql
}

func (db *sqlite3) ForUpdateSQL(query string) string {
return query
}


+ 1
- 1
dialects/time.go View File

@@ -21,7 +21,7 @@ func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{}
case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar.
v = t.Format("2006-01-02 15:04:05")
case schemas.TimeStampz:
if dialect.DBType() == schemas.MSSQL {
if dialect.URI().DBType == schemas.MSSQL {
v = t.Format("2006-01-02T15:04:05.9999999Z07:00")
} else {
v = t.Format(time.RFC3339Nano)


+ 4
- 4
engine.go View File

@@ -365,7 +365,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
var distDBName string
if len(tp) == 0 {
dialect = engine.dialect
distDBName = string(engine.dialect.DBType())
distDBName = string(engine.dialect.URI().DBType)
} else {
dialect = dialects.QueryDialect(tp[0])
if dialect == nil {
@@ -376,7 +376,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
}

_, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm v%s %s, from %s to %s*/\n\n",
Version, time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.DBType(), strings.ToUpper(distDBName)))
Version, time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, strings.ToUpper(distDBName)))
if err != nil {
return err
}
@@ -388,7 +388,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return err
}
}
_, err = io.WriteString(w, dialect.CreateTableSQL(table, "", table.StoreEngine, "")+";\n")
_, err = io.WriteString(w, dialect.CreateTableSQL(table, "")+";\n")
if err != nil {
return err
}
@@ -486,7 +486,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
}

// FIXME: Hack for postgres
if dialect.DBType() == schemas.POSTGRES && table.AutoIncrColumn() != nil {
if dialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil {
_, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quoter().Quote(table.Name)+"), 1), false);\n")
if err != nil {
return err


+ 3
- 3
internal/statements/cache.go View File

@@ -27,7 +27,7 @@ func (statement *Statement) ConvertIDSQL(sqlStr string) string {

var top string
pLimitN := statement.LimitN
if pLimitN != nil && statement.dialect.DBType() == schemas.MSSQL {
if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN)
}

@@ -56,9 +56,9 @@ func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) {

// TODO: for postgres only, if any other database?
var paraStr string
if statement.dialect.DBType() == schemas.POSTGRES {
if statement.dialect.URI().DBType == schemas.POSTGRES {
paraStr = "$"
} else if statement.dialect.DBType() == schemas.MSSQL {
} else if statement.dialect.URI().DBType == schemas.MSSQL {
paraStr = ":"
}



+ 9
- 9
internal/statements/query.go View File

@@ -201,14 +201,14 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
whereStr = " WHERE " + condSQL
}

if dialect.DBType() == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
if dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
fromStr += statement.TableName()
} else {
fromStr += quote(statement.TableName())
}

if statement.TableAlias != "" {
if dialect.DBType() == schemas.ORACLE {
if dialect.URI().DBType == schemas.ORACLE {
fromStr += " " + quote(statement.TableAlias)
} else {
fromStr += " AS " + quote(statement.TableAlias)
@@ -219,7 +219,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
}

pLimitN := statement.LimitN
if dialect.DBType() == schemas.MSSQL {
if dialect.URI().DBType == schemas.MSSQL {
if pLimitN != nil {
LimitNValue := *pLimitN
top = fmt.Sprintf("TOP %d ", LimitNValue)
@@ -281,7 +281,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
}
if needLimit {
if dialect.DBType() != schemas.MSSQL && dialect.DBType() != schemas.ORACLE {
if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE {
if statement.Start > 0 {
if pLimitN != nil {
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
@@ -291,7 +291,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} else if pLimitN != nil {
fmt.Fprint(&buf, " LIMIT ", *pLimitN)
}
} else if dialect.DBType() == schemas.ORACLE {
} else if dialect.URI().DBType == schemas.ORACLE {
if statement.Start != 0 || pLimitN != nil {
oldString := buf.String()
buf.Reset()
@@ -337,18 +337,18 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
return "", nil, err
}

if statement.dialect.DBType() == schemas.MSSQL {
if statement.dialect.URI().DBType == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL)
} else if statement.dialect.DBType() == schemas.ORACLE {
} else if statement.dialect.URI().DBType == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL)
}
args = condArgs
} else {
if statement.dialect.DBType() == schemas.MSSQL {
if statement.dialect.URI().DBType == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr)
} else if statement.dialect.DBType() == schemas.ORACLE {
} else if statement.dialect.URI().DBType == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr)


+ 8
- 18
internal/statements/statement.go View File

@@ -641,8 +641,9 @@ func (statement *Statement) genColumnStr() string {
}

func (statement *Statement) GenCreateTableSQL() string {
return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName(),
statement.StoreEngine, statement.Charset)
statement.RefTable.StoreEngine = statement.StoreEngine
statement.RefTable.Charset = statement.Charset
return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName())
}

func (statement *Statement) GenIndexSQL() []string {
@@ -680,20 +681,8 @@ func (statement *Statement) GenDelIndexSQL() []string {
if idx > -1 {
tbName = tbName[idx+1:]
}
idxPrefixName := strings.Replace(tbName, `"`, "", -1)
idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
for idxName, index := range statement.RefTable.Indexes {
var rIdxName string
if index.Type == schemas.UniqueType {
rIdxName = uniqueName(idxPrefixName, idxName)
} else if index.Type == schemas.IndexType {
rIdxName = utils.IndexName(idxPrefixName, idxName)
}
sql := fmt.Sprintf("DROP INDEX %v", statement.quote(dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), rIdxName, true)))
if statement.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.quote(tbName))
}
sqls = append(sqls, sql)
for _, index := range statement.RefTable.Indexes {
sqls = append(sqls, statement.dialect.DropIndexSQL(tbName, index))
}
return sqls
}
@@ -714,7 +703,8 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
continue
}

if statement.dialect.DBType() == schemas.MSSQL && (col.SQLType.Name == schemas.Text || col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) {
if statement.dialect.URI().DBType == schemas.MSSQL && (col.SQLType.Name == schemas.Text ||
col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) {
continue
}
if col.SQLType.IsJson() {
@@ -1002,7 +992,7 @@ func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond {
cond = builder.Eq{colName: 0}
} else {
// FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value.
if statement.dialect.DBType() != schemas.MSSQL {
if statement.dialect.URI().DBType != schemas.MSSQL {
cond = builder.Eq{colName: utils.ZeroTime1}
}
}


+ 2
- 2
internal/statements/statement_args.go View File

@@ -80,7 +80,7 @@ const insertSelectPlaceHolder = true
func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error {
switch argv := arg.(type) {
case bool:
if statement.dialect.DBType() == schemas.MSSQL {
if statement.dialect.URI().DBType == schemas.MSSQL {
if argv {
if _, err := w.WriteString("1"); err != nil {
return err
@@ -119,7 +119,7 @@ func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) er
w.Append(arg)
} else {
var convertFunc = convertStringSingleQuote
if statement.dialect.DBType() == schemas.MYSQL {
if statement.dialect.URI().DBType == schemas.MYSQL {
convertFunc = convertString
}
if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil {


+ 1
- 1
session_cols_test.go View File

@@ -45,7 +45,7 @@ func TestSetExpr(t *testing.T) {
assert.EqualValues(t, 1, cnt)

var not = "NOT"
if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
not = "~"
}
cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr))


+ 3
- 3
session_convert.go View File

@@ -65,7 +65,7 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time
}

sdata = strings.TrimSpace(sdata)
if session.engine.dialect.DBType() == schemas.MYSQL && len(sdata) > 8 {
if session.engine.dialect.URI().DBType == schemas.MYSQL && len(sdata) > 8 {
sdata = sdata[len(sdata)-8:]
}

@@ -159,7 +159,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == schemas.Bit &&
session.engine.dialect.DBType() == schemas.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API
session.engine.dialect.URI().DBType == schemas.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API
if len(data) == 1 {
x = int64(data[0])
} else {
@@ -399,7 +399,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
var err error
// for mysql, when use bit, it returned \x01
if col.SQLType.Name == schemas.Bit &&
session.engine.dialect.DBType() == schemas.MYSQL {
session.engine.dialect.URI().DBType == schemas.MYSQL {
if len(data) == 1 {
x = int32(data[0])
} else {


+ 2
- 2
session_delete.go View File

@@ -135,7 +135,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
}

if len(orderSQL) > 0 {
switch session.engine.dialect.DBType() {
switch session.engine.dialect.URI().DBType {
case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {
@@ -176,7 +176,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
condSQL)

if len(orderSQL) > 0 {
switch session.engine.dialect.DBType() {
switch session.engine.dialect.URI().DBType {
case schemas.POSTGRES:
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
if len(condSQL) > 0 {


+ 2
- 2
session_delete_test.go View File

@@ -28,7 +28,7 @@ func TestDelete(t *testing.T) {
defer session.Close()

var err error
if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT userinfo_delete ON")
@@ -40,7 +40,7 @@ func TestDelete(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)

if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}


+ 3
- 3
session_get_test.go View File

@@ -154,7 +154,7 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))

var money2 float64
if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
has, err = testEngine.SQL("SELECT TOP 1 money FROM " + testEngine.TableName("get_var", true)).Get(&money2)
} else {
has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2)
@@ -234,7 +234,7 @@ func TestGetStruct(t *testing.T) {
defer session.Close()

var err error
if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT userinfo_get ON")
@@ -243,7 +243,7 @@ func TestGetStruct(t *testing.T) {
cnt, err := session.Insert(&UserinfoGet{Uid: 2})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}


+ 7
- 6
session_insert.go View File

@@ -254,7 +254,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
quoter := session.engine.dialect.Quoter()
var sql string
colStr := quoter.Join(colNames, ",")
if session.engine.dialect.DBType() == schemas.ORACLE {
if session.engine.dialect.URI().DBType == schemas.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
quoter.Quote(tableName),
colStr)
@@ -361,7 +361,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {

var tableName = session.statement.TableName()
var output string
if session.engine.dialect.DBType() == schemas.MSSQL && len(table.AutoIncrement) > 0 {
if session.engine.dialect.URI().DBType == schemas.MSSQL && len(table.AutoIncrement) > 0 {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
}

@@ -371,7 +371,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
}

if len(colPlaces) <= 0 {
if session.engine.dialect.DBType() == schemas.MYSQL {
if session.engine.dialect.URI().DBType == schemas.MYSQL {
if _, err := buf.WriteString(" VALUES ()"); err != nil {
return 0, err
}
@@ -433,7 +433,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
}
}

if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == schemas.POSTGRES {
if len(table.AutoIncrement) > 0 && session.engine.dialect.URI().DBType == schemas.POSTGRES {
if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil {
return 0, err
}
@@ -472,7 +472,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {

// for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself.
if session.engine.dialect.DBType() == schemas.ORACLE && len(table.AutoIncrement) > 0 {
if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 {
res, err := session.queryBytes("select seq_atable.currval from dual", args...)
if err != nil {
return 0, err
@@ -513,7 +513,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue.Set(int64ToIntValue(id, aiValue.Type()))

return 1, nil
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.DBType() == schemas.POSTGRES || session.engine.dialect.DBType() == schemas.MSSQL) {
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
session.engine.dialect.URI().DBType == schemas.MSSQL) {
res, err := session.queryBytes(sqlStr, args...)

if err != nil {


+ 4
- 4
session_query_test.go View File

@@ -207,7 +207,7 @@ func TestQueryStringNoParam(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0]["id"])
if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL {
assert.EqualValues(t, "false", records[0]["msg"])
} else {
assert.EqualValues(t, "0", records[0]["msg"])
@@ -217,7 +217,7 @@ func TestQueryStringNoParam(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0]["id"])
if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL {
assert.EqualValues(t, "false", records[0]["msg"])
} else {
assert.EqualValues(t, "0", records[0]["msg"])
@@ -244,7 +244,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0][0])
if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL {
assert.EqualValues(t, "false", records[0][1])
} else {
assert.EqualValues(t, "0", records[0][1])
@@ -254,7 +254,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, len(records))
assert.EqualValues(t, "1", records[0][0])
if testEngine.Dialect().DBType() == schemas.POSTGRES || testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL {
assert.EqualValues(t, "false", records[0][1])
} else {
assert.EqualValues(t, "0", records[0][1])


+ 4
- 4
session_schema.go View File

@@ -313,8 +313,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
if expectedType == schemas.Text &&
strings.HasPrefix(curType, schemas.Varchar) {
// currently only support mysql & postgres
if engine.dialect.DBType() == schemas.MYSQL ||
engine.dialect.DBType() == schemas.POSTGRES {
if engine.dialect.URI().DBType == schemas.MYSQL ||
engine.dialect.URI().DBType == schemas.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
@@ -323,7 +323,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
tbNameWithSchema, col.Name, curType, expectedType)
}
} else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) {
if engine.dialect.DBType() == schemas.MYSQL {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
@@ -337,7 +337,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
}
} else if expectedType == schemas.Varchar {
if engine.dialect.DBType() == schemas.MYSQL {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)


+ 6
- 6
session_update.go View File

@@ -335,9 +335,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var top string
if st.LimitN != nil {
limitValue := *st.LimitN
if session.engine.dialect.DBType() == schemas.MYSQL {
if session.engine.dialect.URI().DBType == schemas.MYSQL {
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
} else if session.engine.dialect.DBType() == schemas.SQLITE {
} else if session.engine.dialect.URI().DBType == schemas.SQLITE {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
@@ -348,7 +348,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
} else if session.engine.dialect.DBType() == schemas.POSTGRES {
} else if session.engine.dialect.URI().DBType == schemas.POSTGRES {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
@@ -360,8 +360,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
} else if session.engine.dialect.DBType() == schemas.MSSQL {
if st.OrderStr != "" && session.engine.dialect.DBType() == schemas.MSSQL &&
} else if session.engine.dialect.URI().DBType == schemas.MSSQL {
if st.OrderStr != "" && session.engine.dialect.URI().DBType == schemas.MSSQL &&
table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
@@ -387,7 +387,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var tableAlias = session.engine.Quote(tableName)
var fromSQL string
if session.statement.TableAlias != "" {
switch session.engine.dialect.DBType() {
switch session.engine.dialect.URI().DBType {
case schemas.MSSQL:
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias)
tableAlias = session.statement.TableAlias


+ 7
- 7
tags_test.go View File

@@ -238,7 +238,7 @@ func TestExtends2(t *testing.T) {
defer session.Close()

// MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON")
@@ -248,7 +248,7 @@ func TestExtends2(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)

if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}
@@ -299,7 +299,7 @@ func TestExtends3(t *testing.T) {
defer session.Close()

// MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON")
@@ -308,7 +308,7 @@ func TestExtends3(t *testing.T) {
_, err = session.Insert(&msg)
assert.NoError(t, err)

if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}
@@ -362,7 +362,7 @@ func TestExtends4(t *testing.T) {
defer session.Close()

// MSSQL deny insert identity column excep declare as below
if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("SET IDENTITY_INSERT message ON")
@@ -371,7 +371,7 @@ func TestExtends4(t *testing.T) {
_, err = session.Insert(&msg)
assert.NoError(t, err)

if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}
@@ -800,7 +800,7 @@ func TestAutoIncrTag(t *testing.T) {
func TestTagComment(t *testing.T) {
assert.NoError(t, prepareEngine())
// FIXME: only support mysql
if testEngine.Dialect().DBType() != schemas.MYSQL {
if testEngine.Dialect().URI().DBType != schemas.MYSQL {
return
}



+ 2
- 2
types_test.go View File

@@ -314,7 +314,7 @@ func TestCustomType2(t *testing.T) {
session := testEngine.NewSession()
defer session.Close()

if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Begin()
assert.NoError(t, err)
_, err = session.Exec("set IDENTITY_INSERT " + tableName + " on")
@@ -325,7 +325,7 @@ func TestCustomType2(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)

if testEngine.Dialect().DBType() == schemas.MSSQL {
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
err = session.Commit()
assert.NoError(t, err)
}


Loading…
Cancel
Save