Fix mssql quote #1535

Merged
lunny merged 2 commits from lunny/fix_mssql_quote into master 2020-02-24 10:03:52 +00:00
7 changed files with 31 additions and 28 deletions

View File

@ -263,6 +263,7 @@ func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, cha
}*/
func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
quotes := db.dialect.Quote("")
quote := db.dialect.Quote
var unique string
var idxName string
@ -272,7 +273,7 @@ func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
idxName = index.XName(tableName)
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
quote(idxName), quote(tableName),
quote(strings.Join(index.Cols, quote(","))))
quote(strings.Join(index.Cols, fmt.Sprintf("%c,%c", quotes[1], quotes[0]))))
}
func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
@ -300,6 +301,8 @@ func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, char
sql += b.dialect.Quote(tableName)
sql += " ("
quotes := b.dialect.Quote("")
if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys
@ -319,7 +322,7 @@ func (b *Base) CreateTableSQL(table *schemas.Table, tableName, storeEngine, char
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += b.dialect.Quote(strings.Join(pkList, b.dialect.Quote(",")))
sql += b.dialect.Quote(strings.Join(pkList, fmt.Sprintf("%c,%c", quotes[1], quotes[0])))
sql += " ), "
}

View File

@ -287,7 +287,7 @@ func (db *mssql) IsReserved(name string) bool {
}
func (db *mssql) Quote(name string) string {
return "\"" + name + "\""
return "[" + name + "]"
}
func (db *mssql) SupportEngine() bool {

View File

@ -507,12 +507,13 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*schemas.Index, error)
}
func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
var sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}
quotes := db.Quote("")
sql += db.Quote(tableName)
sql += " ("
@ -535,7 +536,7 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName, storeEngine, ch
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += db.Quote(strings.Join(pkList, db.Quote(",")))
sql += db.Quote(strings.Join(pkList, fmt.Sprintf("%c,%c", quotes[1], quotes[0])))
sql += " ), "
}

View File

@ -577,8 +577,7 @@ func (db *oracle) DropTableSQL(tableName string) string {
}
func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE "
var sql = "CREATE TABLE "
if tableName == "" {
tableName = table.Name
}
@ -598,9 +597,11 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName, storeEngine, c
sql += ", "
}
quotes := db.Quote("")
if len(pkList) > 0 {
sql += "PRIMARY KEY ( "
sql += db.Quote(strings.Join(pkList, db.Quote(",")))
sql += db.Quote(strings.Join(pkList, fmt.Sprintf("%c,%c", quotes[1], quotes[0])))
sql += " ), "
}

View File

@ -23,6 +23,11 @@ type Index struct {
Cols []string
}
// NewIndex new an index object
func NewIndex(name string, indexType int) *Index {
return &Index{true, name, indexType, make([]string, 0)}
}
func (index *Index) XName(tableName string) string {
if !strings.HasPrefix(index.Name, "UQE_") &&
!strings.HasPrefix(index.Name, "IDX_") {
@ -65,8 +70,3 @@ func (index *Index) Equal(dst *Index) bool {
}
return true
}
// NewIndex new an index object
func NewIndex(name string, indexType int) *Index {
return &Index{true, name, indexType, make([]string, 0)}
}

View File

@ -30,14 +30,6 @@ type Table struct {
Comment string
}
func (table *Table) Columns() []*Column {
return table.columns
}
func (table *Table) ColumnsSeq() []string {
return table.columnsSeq
}
func NewEmptyTable() *Table {
return NewTable("", nil)
}
@ -54,9 +46,16 @@ func NewTable(name string, t reflect.Type) *Table {
}
}
func (table *Table) Columns() []*Column {
return table.columns
}
func (table *Table) ColumnsSeq() []string {
return table.columnsSeq
}
func (table *Table) columnsByName(name string) []*Column {
n := len(name)
for k := range table.columnsMap {
if len(k) != n {
continue
@ -69,9 +68,7 @@ func (table *Table) columnsByName(name string) []*Column {
}
func (table *Table) GetColumn(name string) *Column {
cols := table.columnsByName(name)
if cols != nil {
return cols[0]
}
@ -81,7 +78,6 @@ func (table *Table) GetColumn(name string) *Column {
func (table *Table) GetColumnIdx(name string, idx int) *Column {
cols := table.columnsByName(name)
if cols != nil && idx < len(cols) {
return cols[idx]
}

View File

@ -64,8 +64,10 @@ func TestJoinLimit(t *testing.T) {
func assertSync(t *testing.T, beans ...interface{}) {
for _, bean := range beans {
assert.NoError(t, testEngine.DropTables(bean))
assert.NoError(t, testEngine.Sync2(bean))
t.Run(testEngine.TableName(bean, true), func(t *testing.T) {
assert.NoError(t, testEngine.DropTables(bean))
assert.NoError(t, testEngine.Sync2(bean))
})
}
}