WIP: Add dropTableCols from gitea (dropTableColumns at migrations.go) #1523
18
engine.go
18
engine.go
|
@ -1010,6 +1010,24 @@ func (engine *Engine) DropTables(beans ...interface{}) error {
|
|||
return session.Commit()
|
||||
}
|
||||
|
||||
// DropTableCols drop specify columns of a table
|
||||
func (engine *Engine) DropTableCols(bean interface{}, cols ...string) error {
|
||||
session := engine.NewSession()
|
||||
defer session.Close()
|
||||
|
||||
err := session.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = session.dropTableCols(bean, cols)
|
||||
if err != nil {
|
||||
session.Rollback()
|
||||
return err
|
||||
}
|
||||
return session.Commit()
|
||||
}
|
||||
|
||||
// DropIndexes drop indexes of a table
|
||||
func (engine *Engine) DropIndexes(bean interface{}) error {
|
||||
session := engine.NewSession()
|
||||
|
|
|
@ -330,3 +330,24 @@ func TestSync2_Default(t *testing.T) {
|
|||
assertSync(t, new(TestSync2Default))
|
||||
assert.NoError(t, testEngine.Sync2(new(TestSync2Default)))
|
||||
}
|
||||
|
||||
func TestDropTableCols(t *testing.T) {
|
||||
type TestDropTableCols struct {
|
||||
Id int64
|
||||
UserId int64 `xorm:"default(1)"`
|
||||
ToDrop bool `xorm:"default(true)"`
|
||||
Name string `xorm:"default('my_name')"`
|
||||
}
|
||||
|
||||
assert.NoError(t, PrepareEngine())
|
||||
assertSync(t, new(TestDropTableCols))
|
||||
|
||||
schema, err := testEngine.TableInfo(new(TestDropTableCols))
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, schema.GetColumn("to_drop"))
|
||||
|
||||
assert.NoError(t, testEngine.DropTableCols(new(TestDropTableCols), "name", "to_drop"))
|
||||
schema, err = testEngine.TableInfo(new(TestDropTableCols))
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, schema.GetColumn("to_drop"))
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ type Interface interface {
|
|||
Delete(interface{}) (int64, error)
|
||||
Distinct(columns ...string) *Session
|
||||
DropIndexes(bean interface{}) error
|
||||
DropTableCols(bean interface{}, cols ...string) error
|
||||
Exec(sqlOrArgs ...interface{}) (sql.Result, error)
|
||||
Exist(bean ...interface{}) (bool, error)
|
||||
Find(interface{}, ...interface{}) error
|
||||
|
|
|
@ -27,4 +27,3 @@ func SplitNNoCase(s, sep string, n int) []string {
|
|||
}
|
||||
return strings.SplitN(s, s[idx:idx+len(sep)], n)
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"xorm.io/xorm/internal/utils"
|
||||
|
@ -121,6 +122,178 @@ func (session *Session) dropIndexes(bean interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// DropTableCols drop specify columns of a table
|
||||
func (session *Session) DropTableCols(beanOrTableName interface{}, cols ...string) error {
|
||||
if session.isAutoClose {
|
||||
defer session.Close()
|
||||
}
|
||||
|
||||
return session.dropTableCols(beanOrTableName, cols)
|
||||
}
|
||||
|
||||
func (session *Session) dropTableCols(beanOrTableName interface{}, cols []string) error {
|
||||
tableName := session.engine.TableName(beanOrTableName)
|
||||
|
||||
if tableName == "" || len(cols) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: This will not work if there are foreign keys
|
||||
|
||||
switch session.engine.dialect.URI().DBType {
|
||||
case schemas.SQLITE:
|
||||
// First drop the indexes on the columns
|
||||
res, errIndex := session.Query(fmt.Sprintf("PRAGMA index_list(`%s`)", tableName))
|
||||
if errIndex != nil {
|
||||
return errIndex
|
||||
}
|
||||
for _, row := range res {
|
||||
indexName := row["name"]
|
||||
indexRes, err := session.Query(fmt.Sprintf("PRAGMA index_info(`%s`)", indexName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(indexRes) != 1 {
|
||||
continue
|
||||
}
|
||||
indexColumn := string(indexRes[0]["name"])
|
||||
for _, name := range cols {
|
||||
if name == indexColumn {
|
||||
_, err := session.Exec(fmt.Sprintf("DROP INDEX `%s`", indexName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Here we need to get the columns from the original table
|
||||
sql := fmt.Sprintf("SELECT sql FROM sqlite_master WHERE tbl_name='%s' and type='table'", tableName)
|
||||
res, err := session.Query(sql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tableSQL := string(res[0]["sql"])
|
||||
|
||||
// Separate out the column definitions
|
||||
tableSQL = tableSQL[strings.Index(tableSQL, "("):]
|
||||
|
||||
// Remove the required cols
|
||||
for _, name := range cols {
|
||||
tableSQL = regexp.MustCompile(regexp.QuoteMeta("`"+name+"`")+"[^`,)]*?[,)]").ReplaceAllString(tableSQL, "")
|
||||
}
|
||||
|
||||
// Ensure the query is ended properly
|
||||
tableSQL = strings.TrimSpace(tableSQL)
|
||||
if tableSQL[len(tableSQL)-1] != ')' {
|
||||
if tableSQL[len(tableSQL)-1] == ',' {
|
||||
tableSQL = tableSQL[:len(tableSQL)-1]
|
||||
}
|
||||
tableSQL += ")"
|
||||
}
|
||||
|
||||
// Find all the columns in the table
|
||||
columns := regexp.MustCompile("`([^`]*)`").FindAllString(tableSQL, -1)
|
||||
|
||||
tableSQL = fmt.Sprintf("CREATE TABLE `new_%s_new` ", tableName) + tableSQL
|
||||
if _, err := session.Exec(tableSQL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now restore the data
|
||||
columnsSeparated := strings.Join(columns, ",")
|
||||
insertSQL := fmt.Sprintf("INSERT INTO `new_%s_new` (%s) SELECT %s FROM %s", tableName, columnsSeparated, columnsSeparated, tableName)
|
||||
if _, err := session.Exec(insertSQL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now drop the old table
|
||||
if _, err := session.Exec(fmt.Sprintf("DROP TABLE `%s`", tableName)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rename the table
|
||||
if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `new_%s_new` RENAME TO `%s`", tableName, tableName)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case schemas.POSTGRES:
|
||||
columns := ""
|
||||
for _, col := range cols {
|
||||
if columns != "" {
|
||||
columns += ", "
|
||||
}
|
||||
columns += "DROP COLUMN `" + col + "` CASCADE"
|
||||
}
|
||||
if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `%s` %s", tableName, columns)); err != nil {
|
||||
return fmt.Errorf("Drop table `%s` columns %v: %v", tableName, cols, err)
|
||||
}
|
||||
|
||||
case schemas.MYSQL:
|
||||
// Drop indexes on columns first
|
||||
sql := fmt.Sprintf("SHOW INDEX FROM %s WHERE column_name IN ('%s')", tableName, strings.Join(cols, "','"))
|
||||
res, err := session.Query(sql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, index := range res {
|
||||
indexName := index["column_name"]
|
||||
if len(indexName) > 0 {
|
||||
_, err := session.Exec(fmt.Sprintf("DROP INDEX `%s` ON `%s`", indexName, tableName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now drop the columns
|
||||
columns := ""
|
||||
for _, col := range cols {
|
||||
if columns != "" {
|
||||
columns += ", "
|
||||
}
|
||||
columns += "DROP COLUMN `" + col + "`"
|
||||
}
|
||||
if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `%s` %s", tableName, columns)); err != nil {
|
||||
return fmt.Errorf("Drop table `%s` columns %v: %v", tableName, cols, err)
|
||||
}
|
||||
|
||||
case schemas.MSSQL:
|
||||
columns := ""
|
||||
for _, col := range cols {
|
||||
if columns != "" {
|
||||
columns += ", "
|
||||
}
|
||||
columns += "`" + strings.ToLower(col) + "`"
|
||||
}
|
||||
sql := fmt.Sprintf("SELECT Name FROM SYS.DEFAULT_CONSTRAINTS WHERE PARENT_OBJECT_ID = OBJECT_ID('%[1]s') AND PARENT_COLUMN_ID IN (SELECT column_id FROM sys.columns WHERE lower(NAME) IN (%[2]s) AND object_id = OBJECT_ID('%[1]s'))",
|
||||
tableName, strings.Replace(columns, "`", "'", -1))
|
||||
constraints := make([]string, 0)
|
||||
if err := session.SQL(sql).Find(&constraints); err != nil {
|
||||
session.Rollback()
|
||||
return fmt.Errorf("Find constraints: %v", err)
|
||||
}
|
||||
for _, constraint := range constraints {
|
||||
if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `%s` DROP CONSTRAINT `%s`", tableName, constraint)); err != nil {
|
||||
session.Rollback()
|
||||
return fmt.Errorf("Drop table `%s` constraint `%s`: %v", tableName, constraint, err)
|
||||
}
|
||||
}
|
||||
if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `%s` DROP COLUMN %s", tableName, columns)); err != nil {
|
||||
session.Rollback()
|
||||
return fmt.Errorf("Drop table `%s` columns %v: %v", tableName, cols, err)
|
||||
}
|
||||
|
||||
case schemas.ORACLE:
|
||||
return fmt.Errorf("not implemented for oracle")
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unrecognized DB")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropTable drop table will drop table if exist, if drop failed, it will return error
|
||||
func (session *Session) DropTable(beanOrTableName interface{}) error {
|
||||
if session.isAutoClose {
|
||||
|
|
Loading…
Reference in New Issue
Block a user