Browse Source

Move column string to standalone method (#1633)

Move column string to standalone method

Reviewed-on: #1633
pull/1632/head
Lunny Xiao 1 week ago
parent
commit
6254e7899f
6 changed files with 76 additions and 71 deletions
  1. +64
    -48
      dialects/dialect.go
  2. +2
    -5
      dialects/mssql.go
  3. +4
    -7
      dialects/mysql.go
  4. +2
    -1
      dialects/oracle.go
  5. +2
    -5
      dialects/postgres.go
  6. +2
    -5
      dialects/sqlite3.go

+ 64
- 48
dialects/dialect.go View File

@@ -96,51 +96,6 @@ func (b *Base) DBType() schemas.DBType {
return b.uri.DBType
}

// String generate column description string according dialect
func (b *Base) String(col *schemas.Column) string {
sql := b.dialect.Quoter().Quote(col.Name) + " "

sql += b.dialect.SQLType(col) + " "

if col.IsPrimaryKey {
sql += "PRIMARY KEY "
if col.IsAutoIncrement {
sql += b.dialect.AutoIncrStr() + " "
}
}

if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}

if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}

return sql
}

// StringNoPk generate column description string according dialect without primary keys
func (b *Base) StringNoPk(col *schemas.Column) string {
sql := b.dialect.Quoter().Quote(col.Name) + " "

sql += b.dialect.SQLType(col) + " "

if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}

if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}

return sql
}

func (b *Base) FormatBytes(bs []byte) string {
return fmt.Sprintf("0x%x", bs)
}
@@ -178,8 +133,8 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa
}

func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName),
db.String(col))
s, _ := ColumnString(db.dialect, col, true)
return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName), s)
}

func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string {
@@ -207,7 +162,8 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
}

func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, db.StringNoPk(col))
s, _ := ColumnString(db.dialect, col, false)
return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, s)
}

func (b *Base) ForUpdateSQL(query string) string {
@@ -266,3 +222,63 @@ func regDrvsNDialects() bool {
func init() {
regDrvsNDialects()
}

// ColumnString generate column description string according dialect
func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) (string, error) {
bd := strings.Builder{}

if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil {
return "", err
}

if err := bd.WriteByte(' '); err != nil {
return "", err
}

if _, err := bd.WriteString(dialect.SQLType(col)); err != nil {
return "", err
}

if err := bd.WriteByte(' '); err != nil {
return "", err
}

if includePrimaryKey && col.IsPrimaryKey {
if _, err := bd.WriteString("PRIMARY KEY "); err != nil {
return "", err
}

if col.IsAutoIncrement {
if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil {
return "", err
}
if err := bd.WriteByte(' '); err != nil {
return "", err
}
}
}

if col.Default != "" {
if _, err := bd.WriteString("DEFAULT "); err != nil {
return "", err
}
if _, err := bd.WriteString(col.Default); err != nil {
return "", err
}
if err := bd.WriteByte(' '); err != nil {
return "", err
}
}

if col.Nullable {
if _, err := bd.WriteString("NULL "); err != nil {
return "", err
}
} else {
if _, err := bd.WriteString("NOT NULL "); err != nil {
return "", err
}
}

return bd.String(), nil
}

+ 2
- 5
dialects/mssql.go View File

@@ -501,11 +501,8 @@ func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) ([]strin

for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += db.String(col)
} else {
sql += db.StringNoPk(col)
}
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1)
sql += s
sql = strings.TrimSpace(sql)
sql += ", "
}


+ 4
- 7
dialects/mysql.go View File

@@ -293,8 +293,8 @@ func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableNa

func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
quoter := db.dialect.Quoter()
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName),
db.String(col))
s, _ := ColumnString(db, col, true)
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), s)
if len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"
}
@@ -525,11 +525,8 @@ func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]strin

for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += db.String(col)
} else {
sql += db.StringNoPk(col)
}
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1)
sql += s
sql = strings.TrimSpace(sql)
if len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'"


+ 2
- 1
dialects/oracle.go View File

@@ -572,7 +572,8 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]stri
/*if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect)
} else {*/
sql += db.StringNoPk(col)
s, _ := ColumnString(db, col, false)
sql += s
// }
sql = strings.TrimSpace(sql)
sql += ", "


+ 2
- 5
dialects/postgres.go View File

@@ -908,11 +908,8 @@ func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]st

for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += db.String(col)
} else {
sql += db.StringNoPk(col)
}
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1)
sql += s
sql = strings.TrimSpace(sql)
sql += ", "
}


+ 2
- 5
dialects/sqlite3.go View File

@@ -260,11 +260,8 @@ func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) ([]str

for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += db.String(col)
} else {
sql += db.StringNoPk(col)
}
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1)
sql += s
sql = strings.TrimSpace(sql)
sql += ", "
}


Loading…
Cancel
Save