diff --git a/engine.go b/engine.go index d5e599d6..415191d7 100644 --- a/engine.go +++ b/engine.go @@ -1010,6 +1010,24 @@ func (engine *Engine) DropTables(beans ...interface{}) error { return session.Commit() } +// DropTableCols drop specify columns of a table +func (engine *Engine) DropTableCols(bean interface{}, cols ...string) error { + session := engine.NewSession() + defer session.Close() + + err := session.Begin() + if err != nil { + return err + } + + err = session.dropTableCols(bean, cols) + if err != nil { + session.Rollback() + return err + } + return session.Commit() +} + // DropIndexes drop indexes of a table func (engine *Engine) DropIndexes(bean interface{}) error { session := engine.NewSession() diff --git a/integrations/session_schema_test.go b/integrations/session_schema_test.go index 005b6619..74e067c1 100644 --- a/integrations/session_schema_test.go +++ b/integrations/session_schema_test.go @@ -330,3 +330,24 @@ func TestSync2_Default(t *testing.T) { assertSync(t, new(TestSync2Default)) assert.NoError(t, testEngine.Sync2(new(TestSync2Default))) } + +func TestDropTableCols(t *testing.T) { + type TestDropTableCols struct { + Id int64 + UserId int64 `xorm:"default(1)"` + ToDrop bool `xorm:"default(true)"` + Name string `xorm:"default('my_name')"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestDropTableCols)) + + schema, err := testEngine.TableInfo(new(TestDropTableCols)) + assert.NoError(t, err) + assert.NotNil(t, schema.GetColumn("to_drop")) + + assert.NoError(t, testEngine.DropTableCols(new(TestDropTableCols), "name", "to_drop")) + schema, err = testEngine.TableInfo(new(TestDropTableCols)) + assert.NoError(t, err) + assert.Nil(t, schema.GetColumn("to_drop")) +} diff --git a/interface.go b/interface.go index 6aac4ae8..ab150fe6 100644 --- a/interface.go +++ b/interface.go @@ -33,6 +33,7 @@ type Interface interface { Delete(interface{}) (int64, error) Distinct(columns ...string) *Session DropIndexes(bean interface{}) error + DropTableCols(bean interface{}, cols ...string) error Exec(sqlOrArgs ...interface{}) (sql.Result, error) Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error diff --git a/internal/utils/strings.go b/internal/utils/strings.go index b5dc37b7..72466705 100644 --- a/internal/utils/strings.go +++ b/internal/utils/strings.go @@ -27,4 +27,3 @@ func SplitNNoCase(s, sep string, n int) []string { } return strings.SplitN(s, s[idx:idx+len(sep)], n) } - diff --git a/session_schema.go b/session_schema.go index 9ccf8abe..93a7416b 100644 --- a/session_schema.go +++ b/session_schema.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "os" + "regexp" "strings" "xorm.io/xorm/internal/utils" @@ -121,6 +122,178 @@ func (session *Session) dropIndexes(bean interface{}) error { return nil } +// DropTableCols drop specify columns of a table +func (session *Session) DropTableCols(beanOrTableName interface{}, cols ...string) error { + if session.isAutoClose { + defer session.Close() + } + + return session.dropTableCols(beanOrTableName, cols) +} + +func (session *Session) dropTableCols(beanOrTableName interface{}, cols []string) error { + tableName := session.engine.TableName(beanOrTableName) + + if tableName == "" || len(cols) == 0 { + return nil + } + + // TODO: This will not work if there are foreign keys + + switch session.engine.dialect.URI().DBType { + case schemas.SQLITE: + // First drop the indexes on the columns + res, errIndex := session.Query(fmt.Sprintf("PRAGMA index_list(`%s`)", tableName)) + if errIndex != nil { + return errIndex + } + for _, row := range res { + indexName := row["name"] + indexRes, err := session.Query(fmt.Sprintf("PRAGMA index_info(`%s`)", indexName)) + if err != nil { + return err + } + if len(indexRes) != 1 { + continue + } + indexColumn := string(indexRes[0]["name"]) + for _, name := range cols { + if name == indexColumn { + _, err := session.Exec(fmt.Sprintf("DROP INDEX `%s`", indexName)) + if err != nil { + return err + } + } + } + } + + // Here we need to get the columns from the original table + sql := fmt.Sprintf("SELECT sql FROM sqlite_master WHERE tbl_name='%s' and type='table'", tableName) + res, err := session.Query(sql) + if err != nil { + return err + } + tableSQL := string(res[0]["sql"]) + + // Separate out the column definitions + tableSQL = tableSQL[strings.Index(tableSQL, "("):] + + // Remove the required cols + for _, name := range cols { + tableSQL = regexp.MustCompile(regexp.QuoteMeta("`"+name+"`")+"[^`,)]*?[,)]").ReplaceAllString(tableSQL, "") + } + + // Ensure the query is ended properly + tableSQL = strings.TrimSpace(tableSQL) + if tableSQL[len(tableSQL)-1] != ')' { + if tableSQL[len(tableSQL)-1] == ',' { + tableSQL = tableSQL[:len(tableSQL)-1] + } + tableSQL += ")" + } + + // Find all the columns in the table + columns := regexp.MustCompile("`([^`]*)`").FindAllString(tableSQL, -1) + + tableSQL = fmt.Sprintf("CREATE TABLE `new_%s_new` ", tableName) + tableSQL + if _, err := session.Exec(tableSQL); err != nil { + return err + } + + // Now restore the data + columnsSeparated := strings.Join(columns, ",") + insertSQL := fmt.Sprintf("INSERT INTO `new_%s_new` (%s) SELECT %s FROM %s", tableName, columnsSeparated, columnsSeparated, tableName) + if _, err := session.Exec(insertSQL); err != nil { + return err + } + + // Now drop the old table + if _, err := session.Exec(fmt.Sprintf("DROP TABLE `%s`", tableName)); err != nil { + return err + } + + // Rename the table + if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `new_%s_new` RENAME TO `%s`", tableName, tableName)); err != nil { + return err + } + + case schemas.POSTGRES: + columns := "" + for _, col := range cols { + if columns != "" { + columns += ", " + } + columns += "DROP COLUMN `" + col + "` CASCADE" + } + if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `%s` %s", tableName, columns)); err != nil { + return fmt.Errorf("Drop table `%s` columns %v: %v", tableName, cols, err) + } + + case schemas.MYSQL: + // Drop indexes on columns first + sql := fmt.Sprintf("SHOW INDEX FROM %s WHERE column_name IN ('%s')", tableName, strings.Join(cols, "','")) + res, err := session.Query(sql) + if err != nil { + return err + } + for _, index := range res { + indexName := index["column_name"] + if len(indexName) > 0 { + _, err := session.Exec(fmt.Sprintf("DROP INDEX `%s` ON `%s`", indexName, tableName)) + if err != nil { + return err + } + } + } + + // Now drop the columns + columns := "" + for _, col := range cols { + if columns != "" { + columns += ", " + } + columns += "DROP COLUMN `" + col + "`" + } + if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `%s` %s", tableName, columns)); err != nil { + return fmt.Errorf("Drop table `%s` columns %v: %v", tableName, cols, err) + } + + case schemas.MSSQL: + columns := "" + for _, col := range cols { + if columns != "" { + columns += ", " + } + columns += "`" + strings.ToLower(col) + "`" + } + sql := fmt.Sprintf("SELECT Name FROM SYS.DEFAULT_CONSTRAINTS WHERE PARENT_OBJECT_ID = OBJECT_ID('%[1]s') AND PARENT_COLUMN_ID IN (SELECT column_id FROM sys.columns WHERE lower(NAME) IN (%[2]s) AND object_id = OBJECT_ID('%[1]s'))", + tableName, strings.Replace(columns, "`", "'", -1)) + constraints := make([]string, 0) + if err := session.SQL(sql).Find(&constraints); err != nil { + session.Rollback() + return fmt.Errorf("Find constraints: %v", err) + } + for _, constraint := range constraints { + if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `%s` DROP CONSTRAINT `%s`", tableName, constraint)); err != nil { + session.Rollback() + return fmt.Errorf("Drop table `%s` constraint `%s`: %v", tableName, constraint, err) + } + } + if _, err := session.Exec(fmt.Sprintf("ALTER TABLE `%s` DROP COLUMN %s", tableName, columns)); err != nil { + session.Rollback() + return fmt.Errorf("Drop table `%s` columns %v: %v", tableName, cols, err) + } + + case schemas.ORACLE: + return fmt.Errorf("not implemented for oracle") + + default: + return fmt.Errorf("unrecognized DB") + } + + return nil +} + // DropTable drop table will drop table if exist, if drop failed, it will return error func (session *Session) DropTable(beanOrTableName interface{}) error { if session.isAutoClose {