WIP: Add dropTableCols from gitea (dropTableColumns at migrations.go) #1523

Closed
6543 wants to merge 7 commits from DropTableCols_fromGitea into master
5 changed files with 213 additions and 1 deletions

View File

@ -1010,6 +1010,24 @@ func (engine *Engine) DropTables(beans ...interface{}) error {
return session.Commit()
}
// DropTableCols drop specify columns of a table
func (engine *Engine) DropTableCols(bean interface{}, cols ...string) error {
session := engine.NewSession()
defer session.Close()
err := session.Begin()
if err != nil {
return err
}
err = session.dropTableCols(bean, cols)
if err != nil {
session.Rollback()
return err
}
return session.Commit()
}
// DropIndexes drop indexes of a table
func (engine *Engine) DropIndexes(bean interface{}) error {
session := engine.NewSession()

View File

@ -330,3 +330,24 @@ func TestSync2_Default(t *testing.T) {
assertSync(t, new(TestSync2Default))
assert.NoError(t, testEngine.Sync2(new(TestSync2Default)))
}
func TestDropTableCols(t *testing.T) {
type TestDropTableCols struct {
Id int64
UserId int64 `xorm:"default(1)"`
ToDrop bool `xorm:"default(true)"`
Name string `xorm:"default('my_name')"`
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(TestDropTableCols))
schema, err := testEngine.TableInfo(new(TestDropTableCols))
assert.NoError(t, err)
assert.NotNil(t, schema.GetColumn("to_drop"))
assert.NoError(t, testEngine.DropTableCols(new(TestDropTableCols), "name", "to_drop"))
schema, err = testEngine.TableInfo(new(TestDropTableCols))
assert.NoError(t, err)
assert.Nil(t, schema.GetColumn("to_drop"))
}

View File

@ -33,6 +33,7 @@ type Interface interface {
Delete(interface{}) (int64, error)
Distinct(columns ...string) *Session
DropIndexes(bean interface{}) error
DropTableCols(bean interface{}, cols ...string) error
Exec(sqlOrArgs ...interface{}) (sql.Result, error)
Exist(bean ...interface{}) (bool, error)
Find(interface{}, ...interface{}) error

View File

@ -27,4 +27,3 @@ func SplitNNoCase(s, sep string, n int) []string {
}
return strings.SplitN(s, s[idx:idx+len(sep)], n)
}

View File

@ -10,6 +10,7 @@ import (
"fmt"
"io"
"os"
"regexp"
"strings"
"xorm.io/xorm/internal/utils"
@ -121,6 +122,178 @@ func (session *Session) dropIndexes(bean interface{}) error {
return nil
}
// DropTableCols drop specify columns of a table
func (session *Session) DropTableCols(beanOrTableName interface{}, cols ...string) error {
if session.isAutoClose {
defer session.Close()
}
return session.dropTableCols(beanOrTableName, cols)
}
func (session *Session) dropTableCols(beanOrTableName interface{}, cols []string) error {
tableName := session.engine.TableName(beanOrTableName)
if tableName == "" || len(cols) == 0 {
return nil
}
// TODO: This will not work if there are foreign keys
switch session.engine.dialect.URI().DBType {
case schemas.SQLITE:
// First drop the indexes on the columns
res, errIndex := session.Query(fmt.Sprintf("PRAGMA index_list(`%s`)", tableName))
if errIndex != nil {
return errIndex
}
for _, row := range res {
indexName := row["name"]
indexRes, err := session.Query(fmt.Sprintf("PRAGMA index_info(`%s`)", indexName))
if err != nil {
return err
}
if len(indexRes) != 1 {
continue
}
indexColumn := string(indexRes[0]["name"])
for _, name := range cols {
if name == indexColumn {
_, err := session.Exec(fmt.Sprintf("DROP INDEX `%s`", indexName))
if err != nil {
return err
}
}
}
}
// Here we need to get the columns from the original table
sql := fmt.Sprintf("SELECT sql FROM sqlite_master WHERE tbl_name='%s' and type='table'", tableName)
res, err := session.Query(sql)
if err != nil {
return err
}
tableSQL := string(res[0]["sql"])
// Separate out the column definitions
tableSQL = tableSQL[strings.Index(tableSQL, "("):]
// Remove the required cols
for _, name := range cols {
tableSQL = regexp.MustCompile(regexp.QuoteMeta("`"+name+"`")+"[^`,)]*?[,)]").ReplaceAllString(tableSQL, "")
}
// Ensure the query is ended properly
tableSQL = strings.TrimSpace(tableSQL)
if tableSQL[len(tableSQL)-1] != ')' {
if tableSQL[len(tableSQL)-1] == ',' {
tableSQL = tableSQL[:len(tableSQL)-1]
}
tableSQL += ")"
}
// Find all the columns in the table
columns := regexp.MustCompile("`([^`]*)`").FindAllString(tableSQL, -1)
tableSQL = fmt.Sprintf("CREATE TABLE `new_%s_new` ", tableName) + tableSQL
if _, err := session.Exec(tableSQL); err != nil {
return err
}
// Now restore the data
columnsSeparated := strings.Join(columns, ",")
insertSQL := fmt.Sprintf("INSERT INTO `new_%s_new` (%s) SELECT %s FROM %s", tableName, columnsSeparated, columnsSeparated, tableName)
if _, err := session.Exec(insertSQL); err != nil {
return err
}
// Now drop the old table
if _, err := session.Exec(fmt.Sprintf("DROP TABLE `%s`", tableName)); err != nil {
return err
}
// Rename the table
if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `new_%s_new` RENAME TO `%s`", tableName, tableName)); err != nil {
return err
}
case schemas.POSTGRES:
columns := ""
for _, col := range cols {
if columns != "" {
columns += ", "
}
columns += "DROP COLUMN `" + col + "` CASCADE"
}
if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `%s` %s", tableName, columns)); err != nil {
return fmt.Errorf("Drop table `%s` columns %v: %v", tableName, cols, err)
}
case schemas.MYSQL:
// Drop indexes on columns first
sql := fmt.Sprintf("SHOW INDEX FROM %s WHERE column_name IN ('%s')", tableName, strings.Join(cols, "','"))
res, err := session.Query(sql)
if err != nil {
return err
}
for _, index := range res {
indexName := index["column_name"]
if len(indexName) > 0 {
_, err := session.Exec(fmt.Sprintf("DROP INDEX `%s` ON `%s`", indexName, tableName))
if err != nil {
return err
}
}
}
// Now drop the columns
columns := ""
for _, col := range cols {
if columns != "" {
columns += ", "
}
columns += "DROP COLUMN `" + col + "`"
}
if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `%s` %s", tableName, columns)); err != nil {
return fmt.Errorf("Drop table `%s` columns %v: %v", tableName, cols, err)
}
case schemas.MSSQL:
columns := ""
for _, col := range cols {
if columns != "" {
columns += ", "
}
columns += "`" + strings.ToLower(col) + "`"
}
sql := fmt.Sprintf("SELECT Name FROM SYS.DEFAULT_CONSTRAINTS WHERE PARENT_OBJECT_ID = OBJECT_ID('%[1]s') AND PARENT_COLUMN_ID IN (SELECT column_id FROM sys.columns WHERE lower(NAME) IN (%[2]s) AND object_id = OBJECT_ID('%[1]s'))",
tableName, strings.Replace(columns, "`", "'", -1))
constraints := make([]string, 0)
if err := session.SQL(sql).Find(&constraints); err != nil {
session.Rollback()
return fmt.Errorf("Find constraints: %v", err)
}
for _, constraint := range constraints {
if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `%s` DROP CONSTRAINT `%s`", tableName, constraint)); err != nil {
session.Rollback()
return fmt.Errorf("Drop table `%s` constraint `%s`: %v", tableName, constraint, err)
}
}
if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `%s` DROP COLUMN %s", tableName, columns)); err != nil {
session.Rollback()
return fmt.Errorf("Drop table `%s` columns %v: %v", tableName, cols, err)
}
case schemas.ORACLE:
return fmt.Errorf("not implemented for oracle")
default:
return fmt.Errorf("unrecognized DB")
}
return nil
}
// DropTable drop table will drop table if exist, if drop failed, it will return error
func (session *Session) DropTable(beanOrTableName interface{}) error {
if session.isAutoClose {