Fix mssql quote #1535

Merged
lunny merged 2 commits from lunny/fix_mssql_quote into master 2020-02-24 10:03:52 +00:00
7 changed files with 31 additions and 28 deletions

View File

@ -263,6 +263,7 @@ func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, cha
}*/ }*/
func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
quotes := db.dialect.Quote("")
quote := db.dialect.Quote quote := db.dialect.Quote
var unique string var unique string
var idxName string var idxName string
@ -272,7 +273,7 @@ func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
idxName = index.XName(tableName) idxName = index.XName(tableName)
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique, return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
quote(idxName), quote(tableName), quote(idxName), quote(tableName),
quote(strings.Join(index.Cols, quote(",")))) quote(strings.Join(index.Cols, fmt.Sprintf("%c,%c", quotes[1], quotes[0]))))
} }
func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
@ -300,6 +301,8 @@ func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, char
sql += b.dialect.Quote(tableName) sql += b.dialect.Quote(tableName)
sql += " (" sql += " ("
quotes := b.dialect.Quote("")
if len(table.ColumnsSeq()) > 0 { if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys pkList := table.PrimaryKeys
@ -319,7 +322,7 @@ func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, char
if len(pkList) > 1 { if len(pkList) > 1 {
sql += "PRIMARY KEY ( " sql += "PRIMARY KEY ( "
sql += b.dialect.Quote(strings.Join(pkList, b.dialect.Quote(","))) sql += b.dialect.Quote(strings.Join(pkList, fmt.Sprintf("%c,%c", quotes[1], quotes[0])))
sql += " ), " sql += " ), "
} }

View File

@ -287,7 +287,7 @@ func (db *mssql) IsReserved(name string) bool {
} }
func (db *mssql) Quote(name string) string { func (db *mssql) Quote(name string) string {
return "\"" + name + "\"" return "[" + name + "]"
} }
func (db *mssql) SupportEngine() bool { func (db *mssql) SupportEngine() bool {

View File

@ -507,12 +507,13 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*schemas.Index, error)
} }
func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string { func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string var sql = "CREATE TABLE IF NOT EXISTS "
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
} }
quotes := db.Quote("")
sql += db.Quote(tableName) sql += db.Quote(tableName)
sql += " (" sql += " ("
@ -535,7 +536,7 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch
if len(pkList) > 1 { if len(pkList) > 1 {
sql += "PRIMARY KEY ( " sql += "PRIMARY KEY ( "
sql += db.Quote(strings.Join(pkList, db.Quote(","))) sql += db.Quote(strings.Join(pkList, fmt.Sprintf("%c,%c", quotes[1], quotes[0])))
sql += " ), " sql += " ), "
} }

View File

@ -577,8 +577,7 @@ func (db *oracle) DropTableSQL(tableName string) string {
} }
func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string { func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string var sql = "CREATE TABLE "
sql = "CREATE TABLE "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
} }
@ -598,9 +597,11 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c
sql += ", " sql += ", "
} }
quotes := db.Quote("")
if len(pkList) > 0 { if len(pkList) > 0 {
sql += "PRIMARY KEY ( " sql += "PRIMARY KEY ( "
sql += db.Quote(strings.Join(pkList, db.Quote(","))) sql += db.Quote(strings.Join(pkList, fmt.Sprintf("%c,%c", quotes[1], quotes[0])))
sql += " ), " sql += " ), "
} }

View File

@ -23,6 +23,11 @@ type Index struct {
Cols []string Cols []string
} }
// NewIndex new an index object
func NewIndex(name string, indexType int) *Index {
return &Index{true, name, indexType, make([]string, 0)}
}
func (index *Index) XName(tableName string) string { func (index *Index) XName(tableName string) string {
if !strings.HasPrefix(index.Name, "UQE_") && if !strings.HasPrefix(index.Name, "UQE_") &&
!strings.HasPrefix(index.Name, "IDX_") { !strings.HasPrefix(index.Name, "IDX_") {
@ -65,8 +70,3 @@ func (index *Index) Equal(dst *Index) bool {
} }
return true return true
} }
// NewIndex new an index object
func NewIndex(name string, indexType int) *Index {
return &Index{true, name, indexType, make([]string, 0)}
}

View File

@ -30,14 +30,6 @@ type Table struct {
Comment string Comment string
} }
func (table *Table) Columns() []*Column {
return table.columns
}
func (table *Table) ColumnsSeq() []string {
return table.columnsSeq
}
func NewEmptyTable() *Table { func NewEmptyTable() *Table {
return NewTable("", nil) return NewTable("", nil)
} }
@ -54,9 +46,16 @@ func NewTable(name string, t reflect.Type) *Table {
} }
} }
func (table *Table) Columns() []*Column {
return table.columns
}
func (table *Table) ColumnsSeq() []string {
return table.columnsSeq
}
func (table *Table) columnsByName(name string) []*Column { func (table *Table) columnsByName(name string) []*Column {
n := len(name) n := len(name)
for k := range table.columnsMap { for k := range table.columnsMap {
if len(k) != n { if len(k) != n {
continue continue
@ -69,9 +68,7 @@ func (table *Table) columnsByName(name string) []*Column {
} }
func (table *Table) GetColumn(name string) *Column { func (table *Table) GetColumn(name string) *Column {
cols := table.columnsByName(name) cols := table.columnsByName(name)
if cols != nil { if cols != nil {
return cols[0] return cols[0]
} }
@ -81,7 +78,6 @@ func (table *Table) GetColumn(name string) *Column {
func (table *Table) GetColumnIdx(name string, idx int) *Column { func (table *Table) GetColumnIdx(name string, idx int) *Column {
cols := table.columnsByName(name) cols := table.columnsByName(name)
if cols != nil && idx < len(cols) { if cols != nil && idx < len(cols) {
return cols[idx] return cols[idx]
} }

View File

@ -64,8 +64,10 @@ func TestJoinLimit(t *testing.T) {
func assertSync(t *testing.T, beans ...interface{}) { func assertSync(t *testing.T, beans ...interface{}) {
for _, bean := range beans { for _, bean := range beans {
assert.NoError(t, testEngine.DropTables(bean)) t.Run(testEngine.TableName(bean, true), func(t *testing.T) {
assert.NoError(t, testEngine.Sync2(bean)) assert.NoError(t, testEngine.DropTables(bean))
assert.NoError(t, testEngine.Sync2(bean))
})
} }
} }