From 9a7b4e7af526d67e8dbbdd8a8efb07f437d0aa9c Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 4 Mar 2020 14:59:03 +0800 Subject: [PATCH 1/4] Move some codes to statement sub package --- engine.go | 4 - error.go | 4 - internal/statements/delete.go | 139 ++++++++++++++++++++++++++++++++++ internal/statements/query.go | 50 +++++------- session_convert.go | 3 +- session_delete.go | 122 +++++------------------------ 6 files changed, 177 insertions(+), 145 deletions(-) create mode 100644 internal/statements/delete.go diff --git a/engine.go b/engine.go index 221b7488..8b4f3931 100644 --- a/engine.go +++ b/engine.go @@ -1211,10 +1211,6 @@ func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time) { return dialects.FormatTime(engine.dialect, col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation) } -func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interface{}) { - return dialects.FormatColumnTime(engine.dialect, engine.DatabaseTZ, col, t) -} - // GetColumnMapper returns the column name mapper func (engine *Engine) GetColumnMapper() names.Mapper { return engine.tagParser.GetColumnMapper() diff --git a/error.go b/error.go index a19860e3..21a83f47 100644 --- a/error.go +++ b/error.go @@ -20,10 +20,6 @@ var ( ErrNotExist = errors.New("Record does not exist") // ErrCacheFailed cache failed error ErrCacheFailed = errors.New("Cache failed") - // ErrNeedDeletedCond delete needs less one condition error - ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") - // ErrNotImplemented not implemented - ErrNotImplemented = errors.New("Not implemented") // ErrConditionType condition type unsupported ErrConditionType = errors.New("Unsupported condition type") ) diff --git a/internal/statements/delete.go b/internal/statements/delete.go new file mode 100644 index 00000000..de4f9f0f --- /dev/null +++ b/internal/statements/delete.go @@ -0,0 +1,139 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "errors" + "fmt" + "time" + + "xorm.io/xorm/dialects" + "xorm.io/xorm/schemas" +) + +var ( + // ErrNeedDeletedCond delete needs less one condition error + ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") + + // ErrNotImplemented not implemented + ErrNotImplemented = errors.New("Not implemented") +) + +// GenDeleteSQL generated delete SQL according conditions +func (statement *Statement) GenDeleteSQL(bean interface{}) (string, string, []interface{}, error) { + condSQL, condArgs, err := statement.GenConds(bean) + if err != nil { + return "", "", nil, err + } + pLimitN := statement.LimitN + if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { + return "", "", nil, ErrNeedDeletedCond + } + + var tableNameNoQuote = statement.TableName() + var tableName = statement.quote(tableNameNoQuote) + var table = statement.RefTable + var deleteSQL string + if len(condSQL) > 0 { + deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL) + } else { + deleteSQL = fmt.Sprintf("DELETE FROM %v", tableName) + } + + var orderSQL string + if len(statement.OrderStr) > 0 { + orderSQL += fmt.Sprintf(" ORDER BY %s", statement.OrderStr) + } + if pLimitN != nil && *pLimitN > 0 { + limitNValue := *pLimitN + orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue) + } + + if len(orderSQL) > 0 { + switch statement.dialect.DBType() { + case schemas.POSTGRES: + inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + deleteSQL += " AND " + inSQL + } else { + deleteSQL += " WHERE " + inSQL + } + case schemas.SQLITE: + inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + deleteSQL += " AND " + inSQL + } else { + deleteSQL += " WHERE " + inSQL + } + // TODO: how to handle delete limit on mssql? + case schemas.MSSQL: + return "", "", nil, ErrNotImplemented + default: + deleteSQL += orderSQL + } + } + + var realSQL string + argsForCache := make([]interface{}, 0, len(condArgs)*2) + if statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled + realSQL = deleteSQL + copy(argsForCache, condArgs) + argsForCache = append(condArgs, argsForCache...) + } else { + // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches. + copy(argsForCache, condArgs) + argsForCache = append(condArgs, argsForCache...) + + deletedColumn := table.DeletedColumn() + realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", + statement.quote(statement.TableName()), + statement.quote(deletedColumn.Name), + condSQL) + + if len(orderSQL) > 0 { + switch statement.dialect.DBType() { + case schemas.POSTGRES: + inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + realSQL += " AND " + inSQL + } else { + realSQL += " WHERE " + inSQL + } + case schemas.SQLITE: + inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + realSQL += " AND " + inSQL + } else { + realSQL += " WHERE " + inSQL + } + // TODO: how to handle delete limit on mssql? + case schemas.MSSQL: + return "", "", nil, ErrNotImplemented + default: + realSQL += orderSQL + } + } + + // !oinume! Insert nowTime to the head of statement.Params + condArgs = append(condArgs, "") + paramsLen := len(condArgs) + copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) + + now := ColumnNow(deletedColumn, statement.defaultTimeZone) + val := dialects.FormatTime(statement.dialect, deletedColumn.SQLType.Name, now) + condArgs[0] = val + } + return realSQL, deleteSQL, condArgs, nil +} + +// ColumnNow returns the current time for a column +func ColumnNow(col *schemas.Column, defaultTimeZone *time.Location) time.Time { + t := time.Now() + tz := defaultTimeZone + if !col.DisableTimeZone && col.TimeZone != nil { + tz = col.TimeZone + } + return t.In(tz) +} diff --git a/internal/statements/query.go b/internal/statements/query.go index 1519cb08..a058f752 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -57,16 +57,12 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int return "", nil, err } - condSQL, condArgs, err := builder.ToSQL(statement.cond) + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) if err != nil { return "", nil, err } - args := append(statement.joinArgs, condArgs...) - sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true) - if err != nil { - return "", nil, err - } + // for mssql and use limit qs := strings.Count(sqlStr, "?") if len(args)*2 == qs { @@ -92,12 +88,11 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri } sumSelect := strings.Join(sumStrs, ", ") - condSQL, condArgs, err := statement.GenConds(bean) - if err != nil { + if err := statement.mergeConds(bean); err != nil { return "", nil, err } - sqlStr, err := statement.GenSelectSQL(sumSelect, condSQL, true, true) + sqlStr, condArgs, err := statement.genSelectSQL(sumSelect, true, true) if err != nil { return "", nil, err } @@ -147,12 +142,8 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, return "", nil, err } } - condSQL, condArgs, err := builder.ToSQL(statement.cond) - if err != nil { - return "", nil, err - } - sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true) + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) if err != nil { return "", nil, err } @@ -165,17 +156,13 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa return statement.RawSQL, statement.RawParams, nil } - var condSQL string var condArgs []interface{} var err error if len(beans) > 0 { statement.SetRefBean(beans[0]) - condSQL, condArgs, err = statement.GenConds(beans[0]) - } else { - condSQL, condArgs, err = builder.ToSQL(statement.cond) - } - if err != nil { - return "", nil, err + if err := statement.mergeConds(beans[0]); err != nil { + return "", nil, err + } } var selectSQL = statement.SelectStr @@ -186,7 +173,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa selectSQL = "count(*)" } } - sqlStr, err := statement.GenSelectSQL(selectSQL, condSQL, false, false) + sqlStr, condArgs, err := statement.genSelectSQL(selectSQL, false, false) if err != nil { return "", nil, err } @@ -194,7 +181,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) { +func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) { var ( distinct string dialect = statement.dialect @@ -205,6 +192,11 @@ func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, n if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { distinct = "DISTINCT " } + + condSQL, condArgs, err := builder.ToSQL(statement.cond) + if err != nil { + return "", nil, err + } if len(condSQL) > 0 { whereStr = " WHERE " + condSQL } @@ -313,10 +305,10 @@ func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, n } } if statement.IsForUpdate { - return dialect.ForUpdateSQL(buf.String()), nil + return dialect.ForUpdateSQL(buf.String()), condArgs, nil } - return buf.String(), nil + return buf.String(), condArgs, nil } func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) { @@ -428,16 +420,12 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa } statement.cond = statement.cond.And(autoCond) - condSQL, condArgs, err := builder.ToSQL(statement.cond) - if err != nil { - return "", nil, err - } - args = append(statement.joinArgs, condArgs...) - sqlStr, err = statement.GenSelectSQL(columnStr, condSQL, true, true) + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) if err != nil { return "", nil, err } + args = append(statement.joinArgs, condArgs...) // for mssql and use limit qs := strings.Count(sqlStr, "?") if len(args)*2 == qs { diff --git a/session_convert.go b/session_convert.go index 1cd00627..0776bc45 100644 --- a/session_convert.go +++ b/session_convert.go @@ -15,6 +15,7 @@ import ( "time" "xorm.io/xorm/convert" + "xorm.io/xorm/dialects" "xorm.io/xorm/internal/json" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" @@ -583,7 +584,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. case reflect.Struct: if fieldType.ConvertibleTo(schemas.TimeType) { t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) - tf := session.engine.formatColTime(col, t) + tf := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, t) return tf, nil } else if fieldType.ConvertibleTo(nullFloatType) { t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64) diff --git a/session_delete.go b/session_delete.go index 04200035..3373d89e 100644 --- a/session_delete.go +++ b/session_delete.go @@ -6,8 +6,8 @@ package xorm import ( "errors" - "fmt" "strconv" + "time" "xorm.io/xorm/caches" "xorm.io/xorm/schemas" @@ -98,119 +98,31 @@ func (session *Session) Delete(bean interface{}) (int64, error) { processor.BeforeDelete() } - condSQL, condArgs, err := session.statement.GenConds(bean) + realSQL, deleteSQL, condArgs, err := session.statement.GenDeleteSQL(bean) if err != nil { return 0, err } - pLimitN := session.statement.LimitN - if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { - return 0, ErrNeedDeletedCond + + argsForCache := make([]interface{}, 0, len(condArgs)*2) + copy(argsForCache, condArgs) + argsForCache = append(condArgs, argsForCache...) + + if !session.statement.GetUnscoped() && session.statement.RefTable.DeletedColumn() != nil { + deletedColumn := session.statement.RefTable.DeletedColumn() + + session.afterClosures = append(session.afterClosures, func(col *schemas.Column, tz *time.Location) func(interface{}) { + return func(bean interface{}) { + t := time.Now().In(tz) + setColumnTime(bean, col, t) + } + }(deletedColumn, session.engine.TZLocation)) } var tableNameNoQuote = session.statement.TableName() - var tableName = session.engine.Quote(tableNameNoQuote) - var table = session.statement.RefTable - var deleteSQL string - if len(condSQL) > 0 { - deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL) - } else { - deleteSQL = fmt.Sprintf("DELETE FROM %v", tableName) - } - - var orderSQL string - if len(session.statement.OrderStr) > 0 { - orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr) - } - if pLimitN != nil && *pLimitN > 0 { - limitNValue := *pLimitN - orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue) - } - - if len(orderSQL) > 0 { - switch session.engine.dialect.DBType() { - case schemas.POSTGRES: - inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - deleteSQL += " AND " + inSQL - } else { - deleteSQL += " WHERE " + inSQL - } - case schemas.SQLITE: - inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - deleteSQL += " AND " + inSQL - } else { - deleteSQL += " WHERE " + inSQL - } - // TODO: how to handle delete limit on mssql? - case schemas.MSSQL: - return 0, ErrNotImplemented - default: - deleteSQL += orderSQL - } - } - - var realSQL string - argsForCache := make([]interface{}, 0, len(condArgs)*2) - if session.statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled - realSQL = deleteSQL - copy(argsForCache, condArgs) - argsForCache = append(condArgs, argsForCache...) - } else { - // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches. - copy(argsForCache, condArgs) - argsForCache = append(condArgs, argsForCache...) - - deletedColumn := table.DeletedColumn() - realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", - session.engine.Quote(session.statement.TableName()), - session.engine.Quote(deletedColumn.Name), - condSQL) - - if len(orderSQL) > 0 { - switch session.engine.dialect.DBType() { - case schemas.POSTGRES: - inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - realSQL += " AND " + inSQL - } else { - realSQL += " WHERE " + inSQL - } - case schemas.SQLITE: - inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - realSQL += " AND " + inSQL - } else { - realSQL += " WHERE " + inSQL - } - // TODO: how to handle delete limit on mssql? - case schemas.MSSQL: - return 0, ErrNotImplemented - default: - realSQL += orderSQL - } - } - - // !oinume! Insert nowTime to the head of session.statement.Params - condArgs = append(condArgs, "") - paramsLen := len(condArgs) - copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) - - val, t := session.engine.nowTime(deletedColumn) - condArgs[0] = val - - var colName = deletedColumn.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) - } - if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { - session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) + session.cacheDelete(session.statement.RefTable, tableNameNoQuote, deleteSQL, argsForCache...) } - session.statement.RefTable = table res, err := session.exec(realSQL, condArgs...) if err != nil { return 0, err -- 2.40.1 From 51d6afa3300f978c9071952227aa3134ac77ecde Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 4 Mar 2020 16:49:41 +0800 Subject: [PATCH 2/4] fix tests --- .drone.yml | 4 +- dialects/dialect.go | 28 ++----- dialects/driver.go | 32 +++++++ dialects/mssql.go | 4 +- dialects/mysql.go | 4 +- dialects/oracle.go | 4 +- dialects/postgres.go | 13 ++- dialects/sqlite3.go | 4 +- engine.go | 11 ++- interface.go | 1 + internal/statements/delete.go | 98 ++++++++++------------ internal/statements/statement_test.go | 116 ++++++++++++++------------ session_delete.go | 7 +- session_get_test.go | 2 +- xorm.go | 4 +- 15 files changed, 178 insertions(+), 154 deletions(-) diff --git a/.drone.yml b/.drone.yml index dac49cdf..9a62c6bd 100644 --- a/.drone.yml +++ b/.drone.yml @@ -22,8 +22,10 @@ steps: commands: - make test-sqlite - TEST_CACHE_ENABLE=true make test-sqlite - - go test ./caches/... ./convert/... ./core/... ./dialects/... \ + - go test ./caches/... ./contexts/... ./convert/... ./core/... ./dialects/... \ + ./internal/json/... ./internal/statements/... ./internal/utils/... \ ./log/... ./migrate/... ./names/... ./schemas/... ./tags/... + when: event: - push diff --git a/dialects/dialect.go b/dialects/dialect.go index a0139d9f..7d816bda 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -31,7 +31,7 @@ type URI struct { // Dialect represents a kind of database type Dialect interface { - Init(*core.DB, *URI, string, string) error + Init(*core.DB, *URI /*, string, string*/) error URI() *URI DB() *core.DB DBType() schemas.DBType @@ -39,9 +39,6 @@ type Dialect interface { FormatBytes(b []byte) string DefaultSchema() string - DriverName() string - DataSourceName() string - IsReserved(string) bool Quoter() schemas.Quoter @@ -77,17 +74,11 @@ type Dialect interface { SetParams(params map[string]string) } -func OpenDialect(dialect Dialect) (*core.DB, error) { - return core.Open(dialect.DriverName(), dialect.DataSourceName()) -} - // Base represents a basic dialect and all real dialects could embed this struct type Base struct { - db *core.DB - dialect Dialect - driverName string - dataSourceName string - uri *URI + db *core.DB + dialect Dialect + uri *URI } func (b *Base) DB() *core.DB { @@ -98,9 +89,8 @@ func (b *Base) DefaultSchema() string { return "" } -func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI, drivername, dataSourceName string) error { +func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI) error { b.db, b.dialect, b.uri = db, dialect, uri - b.driverName, b.dataSourceName = drivername, dataSourceName return nil } @@ -165,18 +155,10 @@ func (b *Base) FormatBytes(bs []byte) string { return fmt.Sprintf("0x%x", bs) } -func (b *Base) DriverName() string { - return b.driverName -} - func (b *Base) ShowCreateNull() bool { return true } -func (b *Base) DataSourceName() string { - return b.dataSourceName -} - func (db *Base) SupportDropIfExists() bool { return true } diff --git a/dialects/driver.go b/dialects/driver.go index 5343d594..89d21bfc 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -4,6 +4,12 @@ package dialects +import ( + "fmt" + + "xorm.io/xorm/core" +) + type Driver interface { Parse(string, string) (*URI, error) } @@ -29,3 +35,29 @@ func QueryDriver(driverName string) Driver { func RegisteredDriverSize() int { return len(drivers) } + +// OpenDialect opens a dialect via driver name and connection string +func OpenDialect(driverName, connstr string) (Dialect, error) { + driver := QueryDriver(driverName) + if driver == nil { + return nil, fmt.Errorf("Unsupported driver name: %v", driverName) + } + + uri, err := driver.Parse(driverName, connstr) + if err != nil { + return nil, err + } + + dialect := QueryDialect(uri.DBType) + if dialect == nil { + return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType) + } + + db, err := core.Open(driverName, connstr) + if err != nil { + return nil, err + } + dialect.Init(db, uri) + + return dialect, nil +} diff --git a/dialects/mssql.go b/dialects/mssql.go index 9963fc4f..3c95dd20 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -210,8 +210,8 @@ type mssql struct { Base } -func (db *mssql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *mssql) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { + return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) } func (db *mssql) SQLType(c *schemas.Column) string { diff --git a/dialects/mysql.go b/dialects/mysql.go index 5ed2d8f1..7c41ecf6 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -177,8 +177,8 @@ type mysql struct { rowFormat string } -func (db *mysql) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *mysql) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { + return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) } func (db *mysql) SetParams(params map[string]string) { diff --git a/dialects/oracle.go b/dialects/oracle.go index e5c438bc..49c65837 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -504,8 +504,8 @@ type oracle struct { Base } -func (db *oracle) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *oracle) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { + return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) } func (db *oracle) SQLType(c *schemas.Column) string { diff --git a/dialects/postgres.go b/dialects/postgres.go index 623b59ed..ad3c8c68 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -766,30 +766,27 @@ var ( "YES": true, "ZONE": true, } - - // DefaultPostgresSchema default postgres schema - DefaultPostgresSchema = "public" ) -const PostgresPublicSchema = "public" +const postgresPublicSchema = "public" type postgres struct { Base } -func (db *postgres) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - err := db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *postgres) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { + err := db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) if err != nil { return err } if db.uri.Schema == "" { - db.uri.Schema = DefaultPostgresSchema + db.uri.Schema = postgresPublicSchema } return nil } func (db *postgres) DefaultSchema() string { - return PostgresPublicSchema + return postgresPublicSchema } func (db *postgres) SQLType(c *schemas.Column) string { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 7dfa7fca..3b9cb97c 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -149,8 +149,8 @@ type sqlite3 struct { Base } -func (db *sqlite3) Init(d *core.DB, uri *URI, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *sqlite3) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { + return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) } func (db *sqlite3) SQLType(c *schemas.Column) string { diff --git a/engine.go b/engine.go index 8b4f3931..31d891dc 100644 --- a/engine.go +++ b/engine.go @@ -39,6 +39,9 @@ type Engine struct { logger log.ContextLogger tagParser *tags.Parser + driverName string + dataSourceName string + TZLocation *time.Location // The timezone of the application DatabaseTZ *time.Location // The timezone of the database } @@ -94,12 +97,12 @@ func (engine *Engine) SetDisableGlobalCache(disable bool) { // DriverName return the current sql driver's name func (engine *Engine) DriverName() string { - return engine.dialect.DriverName() + return engine.driverName } // DataSourceName return the current connection string func (engine *Engine) DataSourceName() string { - return engine.dialect.DataSourceName() + return engine.dataSourceName } // SetMapper set the name mapping rules @@ -210,7 +213,7 @@ func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error { // NewDB provides an interface to operate database directly func (engine *Engine) NewDB() (*core.DB, error) { - return dialects.OpenDialect(engine.dialect) + return core.Open(engine.driverName, engine.dataSourceName) } // DB return the wrapper of sql.DB @@ -364,7 +367,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch if dialect == nil { return errors.New("Unsupported database type") } - dialect.Init(nil, engine.dialect.URI(), "", "") + dialect.Init(nil, engine.dialect.URI()) distDBName = string(tp[0]) } diff --git a/interface.go b/interface.go index 13f1e12a..8d2402f0 100644 --- a/interface.go +++ b/interface.go @@ -82,6 +82,7 @@ type EngineInterface interface { CreateTables(...interface{}) error DBMetas() ([]*schemas.Table, error) Dialect() dialects.Dialect + DriverName() string DropTables(...interface{}) error DumpAllToFile(fp string, tp ...schemas.DBType) error GetCacher(string) caches.Cacher diff --git a/internal/statements/delete.go b/internal/statements/delete.go index de4f9f0f..2cb91f2a 100644 --- a/internal/statements/delete.go +++ b/internal/statements/delete.go @@ -22,14 +22,14 @@ var ( ) // GenDeleteSQL generated delete SQL according conditions -func (statement *Statement) GenDeleteSQL(bean interface{}) (string, string, []interface{}, error) { +func (statement *Statement) GenDeleteSQL(bean interface{}) (string, string, []interface{}, *time.Time, error) { condSQL, condArgs, err := statement.GenConds(bean) if err != nil { - return "", "", nil, err + return "", "", nil, nil, err } pLimitN := statement.LimitN if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { - return "", "", nil, ErrNeedDeletedCond + return "", "", nil, nil, ErrNeedDeletedCond } var tableNameNoQuote = statement.TableName() @@ -69,63 +69,57 @@ func (statement *Statement) GenDeleteSQL(bean interface{}) (string, string, []in } // TODO: how to handle delete limit on mssql? case schemas.MSSQL: - return "", "", nil, ErrNotImplemented + return "", "", nil, nil, ErrNotImplemented default: deleteSQL += orderSQL } } var realSQL string - argsForCache := make([]interface{}, 0, len(condArgs)*2) if statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled - realSQL = deleteSQL - copy(argsForCache, condArgs) - argsForCache = append(condArgs, argsForCache...) - } else { - // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches. - copy(argsForCache, condArgs) - argsForCache = append(condArgs, argsForCache...) - - deletedColumn := table.DeletedColumn() - realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", - statement.quote(statement.TableName()), - statement.quote(deletedColumn.Name), - condSQL) - - if len(orderSQL) > 0 { - switch statement.dialect.DBType() { - case schemas.POSTGRES: - inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - realSQL += " AND " + inSQL - } else { - realSQL += " WHERE " + inSQL - } - case schemas.SQLITE: - inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - realSQL += " AND " + inSQL - } else { - realSQL += " WHERE " + inSQL - } - // TODO: how to handle delete limit on mssql? - case schemas.MSSQL: - return "", "", nil, ErrNotImplemented - default: - realSQL += orderSQL - } - } - - // !oinume! Insert nowTime to the head of statement.Params - condArgs = append(condArgs, "") - paramsLen := len(condArgs) - copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) - - now := ColumnNow(deletedColumn, statement.defaultTimeZone) - val := dialects.FormatTime(statement.dialect, deletedColumn.SQLType.Name, now) - condArgs[0] = val + return deleteSQL, deleteSQL, condArgs, nil, nil } - return realSQL, deleteSQL, condArgs, nil + + deletedColumn := table.DeletedColumn() + realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", + statement.quote(statement.TableName()), + statement.quote(deletedColumn.Name), + condSQL) + + if len(orderSQL) > 0 { + switch statement.dialect.DBType() { + case schemas.POSTGRES: + inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + realSQL += " AND " + inSQL + } else { + realSQL += " WHERE " + inSQL + } + case schemas.SQLITE: + inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + realSQL += " AND " + inSQL + } else { + realSQL += " WHERE " + inSQL + } + // TODO: how to handle delete limit on mssql? + case schemas.MSSQL: + return "", "", nil, nil, ErrNotImplemented + default: + realSQL += orderSQL + } + } + + // !oinume! Insert nowTime to the head of statement.Params + condArgs = append(condArgs, "") + paramsLen := len(condArgs) + copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) + + now := ColumnNow(deletedColumn, statement.defaultTimeZone) + val := dialects.FormatTime(statement.dialect, deletedColumn.SQLType.Name, now) + condArgs[0] = val + + return realSQL, deleteSQL, condArgs, &now, nil } // ColumnNow returns the current time for a column diff --git a/internal/statements/statement_test.go b/internal/statements/statement_test.go index 3b6e3ae2..15f446f4 100644 --- a/internal/statements/statement_test.go +++ b/internal/statements/statement_test.go @@ -8,10 +8,37 @@ import ( "reflect" "strings" "testing" + "time" + "github.com/stretchr/testify/assert" + "xorm.io/xorm/caches" + "xorm.io/xorm/dialects" + "xorm.io/xorm/names" "xorm.io/xorm/schemas" + "xorm.io/xorm/tags" + + _ "github.com/mattn/go-sqlite3" ) +var ( + dialect dialects.Dialect + tagParser *tags.Parser +) + +func TestMain(m *testing.M) { + var err error + dialect, err = dialects.OpenDialect("sqlite3", "./test.db") + if err != nil { + panic("unknow dialect") + } + + tagParser = tags.NewParser("xorm", dialect, names.SnakeMapper{}, names.SnakeMapper{}, caches.NewManager()) + if tagParser == nil { + panic("tags parser is nil") + } + m.Run() +} + var colStrTests = []struct { omitColumn string onlyToDBColumnNdx int @@ -26,14 +53,9 @@ var colStrTests = []struct { } func TestColumnsStringGeneration(t *testing.T) { - if dbType == "postgres" || dbType == "mssql" { - return - } - - var statement *Statement - for ndx, testCase := range colStrTests { - statement = createTestStatement() + statement, err := createTestStatement() + assert.NoError(t, err) if testCase.omitColumn != "" { statement.Omit(testCase.omitColumn) @@ -55,33 +77,6 @@ func TestColumnsStringGeneration(t *testing.T) { } } -func BenchmarkColumnsStringGeneration(b *testing.B) { - b.StopTimer() - - statement := createTestStatement() - - testCase := colStrTests[0] - - if testCase.omitColumn != "" { - statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped - } - - if testCase.onlyToDBColumnNdx >= 0 { - columns := statement.RefTable.Columns() - columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped - } - - b.StartTimer() - - for i := 0; i < b.N; i++ { - actual := statement.genColumnStr() - - if actual != testCase.expected { - b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual) - } - } -} - func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { b.StopTimer() @@ -162,23 +157,40 @@ func (TestType) TableName() string { return "TestTable" } -func createTestStatement() *Statement { - if engine, ok := testEngine.(*Engine); ok { - statement := &Statement{} - statement.Reset() - statement.Engine = engine - statement.dialect = engine.dialect - statement.SetRefValue(reflect.ValueOf(TestType{})) - - return statement - } else if eg, ok := testEngine.(*EngineGroup); ok { - statement := &Statement{} - statement.Reset() - statement.Engine = eg.Engine - statement.dialect = eg.Engine.dialect - statement.SetRefValue(reflect.ValueOf(TestType{})) - - return statement +func createTestStatement() (*Statement, error) { + statement := NewStatement(dialect, tagParser, time.Local) + if err := statement.SetRefValue(reflect.ValueOf(TestType{})); err != nil { + return nil, err + } + return statement, nil +} + +func BenchmarkColumnsStringGeneration(b *testing.B) { + b.StopTimer() + + statement, err := createTestStatement() + if err != nil { + panic(err) + } + + testCase := colStrTests[0] + + if testCase.omitColumn != "" { + statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped + } + + if testCase.onlyToDBColumnNdx >= 0 { + columns := statement.RefTable.Columns() + columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + actual := statement.genColumnStr() + + if actual != testCase.expected { + b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual) + } } - return nil } diff --git a/session_delete.go b/session_delete.go index 3373d89e..16434bac 100644 --- a/session_delete.go +++ b/session_delete.go @@ -98,7 +98,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { processor.BeforeDelete() } - realSQL, deleteSQL, condArgs, err := session.statement.GenDeleteSQL(bean) + realSQL, deleteSQL, condArgs, now, err := session.statement.GenDeleteSQL(bean) if err != nil { return 0, err } @@ -110,12 +110,11 @@ func (session *Session) Delete(bean interface{}) (int64, error) { if !session.statement.GetUnscoped() && session.statement.RefTable.DeletedColumn() != nil { deletedColumn := session.statement.RefTable.DeletedColumn() - session.afterClosures = append(session.afterClosures, func(col *schemas.Column, tz *time.Location) func(interface{}) { + session.afterClosures = append(session.afterClosures, func(col *schemas.Column, t time.Time) func(interface{}) { return func(bean interface{}) { - t := time.Now().In(tz) setColumnTime(bean, col, t) } - }(deletedColumn, session.engine.TZLocation)) + }(deletedColumn, now.In(session.engine.TZLocation))) } var tableNameNoQuote = session.statement.TableName() diff --git a/session_get_test.go b/session_get_test.go index 5bac9cd7..7e10bf54 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -179,7 +179,7 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "1.5", valuesString["money"]) // for mymysql driver, interface{} will be []byte, so ignore it currently - if testEngine.Dialect().DriverName() != "mymysql" { + if testEngine.DriverName() != "mymysql" { var valuesInter = make(map[string]interface{}) has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter) assert.NoError(t, err) diff --git a/xorm.go b/xorm.go index 724a37cb..51915940 100644 --- a/xorm.go +++ b/xorm.go @@ -54,7 +54,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { return nil, err } - err = dialect.Init(db, uri, driverName, dataSourceName) + err = dialect.Init(db, uri /*, driverName, dataSourceName*/) if err != nil { return nil, err } @@ -70,6 +70,8 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { defaultContext: context.Background(), cacherMgr: cacherMgr, tagParser: tagParser, + driverName: driverName, + dataSourceName: dataSourceName, } if uri.DBType == schemas.SQLITE { -- 2.40.1 From aed961a5a9cbc330ae79bd89387e3a0a27af7995 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 4 Mar 2020 17:14:11 +0800 Subject: [PATCH 3/4] refactor new engine --- dialects/dialect.go | 2 +- dialects/mssql.go | 4 ++-- dialects/mysql.go | 4 ++-- dialects/oracle.go | 4 ++-- dialects/postgres.go | 4 ++-- dialects/sqlite3.go | 4 ++-- engine.go | 14 +++++++------- session.go | 2 +- session_tx.go | 4 ++-- xorm.go | 27 ++------------------------- 10 files changed, 23 insertions(+), 46 deletions(-) diff --git a/dialects/dialect.go b/dialects/dialect.go index 7d816bda..c591cc7b 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -31,7 +31,7 @@ type URI struct { // Dialect represents a kind of database type Dialect interface { - Init(*core.DB, *URI /*, string, string*/) error + Init(*core.DB, *URI) error URI() *URI DB() *core.DB DBType() schemas.DBType diff --git a/dialects/mssql.go b/dialects/mssql.go index 3c95dd20..558abdfc 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -210,8 +210,8 @@ type mssql struct { Base } -func (db *mssql) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { - return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) +func (db *mssql) Init(d *core.DB, uri *URI) error { + return db.Base.Init(d, db, uri) } func (db *mssql) SQLType(c *schemas.Column) string { diff --git a/dialects/mysql.go b/dialects/mysql.go index 7c41ecf6..939a7cf1 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -177,8 +177,8 @@ type mysql struct { rowFormat string } -func (db *mysql) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { - return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) +func (db *mysql) Init(d *core.DB, uri *URI) error { + return db.Base.Init(d, db, uri) } func (db *mysql) SetParams(params map[string]string) { diff --git a/dialects/oracle.go b/dialects/oracle.go index 49c65837..4a8162ac 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -504,8 +504,8 @@ type oracle struct { Base } -func (db *oracle) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { - return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) +func (db *oracle) Init(d *core.DB, uri *URI) error { + return db.Base.Init(d, db, uri) } func (db *oracle) SQLType(c *schemas.Column) string { diff --git a/dialects/postgres.go b/dialects/postgres.go index ad3c8c68..f92202cd 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -774,8 +774,8 @@ type postgres struct { Base } -func (db *postgres) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { - err := db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) +func (db *postgres) Init(d *core.DB, uri *URI) error { + err := db.Base.Init(d, db, uri) if err != nil { return err } diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 3b9cb97c..39138b13 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -149,8 +149,8 @@ type sqlite3 struct { Base } -func (db *sqlite3) Init(d *core.DB, uri *URI /*, drivername, dataSourceName string*/) error { - return db.Base.Init(d, db, uri /*, drivername, dataSourceName*/) +func (db *sqlite3) Init(d *core.DB, uri *URI) error { + return db.Base.Init(d, db, uri) } func (db *sqlite3) SQLType(c *schemas.Column) string { diff --git a/engine.go b/engine.go index 31d891dc..cc8a74a0 100644 --- a/engine.go +++ b/engine.go @@ -64,7 +64,7 @@ func (engine *Engine) BufferSize(size int) *Session { // ShowSQL show SQL statement or not on logger if log level is great than INFO func (engine *Engine) ShowSQL(show ...bool) { engine.logger.ShowSQL(show...) - engine.db.Logger = engine.logger + engine.DB().Logger = engine.logger } // Logger return the logger interface @@ -82,7 +82,7 @@ func (engine *Engine) SetLogger(logger interface{}) { realLogger = t } engine.logger = realLogger - engine.db.Logger = realLogger + engine.DB().Logger = realLogger } // SetLogLevel sets the logger level @@ -167,17 +167,17 @@ func (engine *Engine) AutoIncrStr() string { // SetConnMaxLifetime sets the maximum amount of time a connection may be reused. func (engine *Engine) SetConnMaxLifetime(d time.Duration) { - engine.db.SetConnMaxLifetime(d) + engine.DB().SetConnMaxLifetime(d) } // SetMaxOpenConns is only available for go 1.2+ func (engine *Engine) SetMaxOpenConns(conns int) { - engine.db.SetMaxOpenConns(conns) + engine.DB().SetMaxOpenConns(conns) } // SetMaxIdleConns set the max idle connections on pool, default is 2 func (engine *Engine) SetMaxIdleConns(conns int) { - engine.db.SetMaxIdleConns(conns) + engine.DB().SetMaxIdleConns(conns) } // SetDefaultCacher set the default cacher. Xorm's default not enable cacher. @@ -218,7 +218,7 @@ func (engine *Engine) NewDB() (*core.DB, error) { // DB return the wrapper of sql.DB func (engine *Engine) DB() *core.DB { - return engine.db + return engine.dialect.DB() } // Dialect return database dialect @@ -235,7 +235,7 @@ func (engine *Engine) NewSession() *Session { // Close the engine func (engine *Engine) Close() error { - return engine.db.Close() + return engine.DB().Close() } // Ping tests if database is alive diff --git a/session.go b/session.go index db990684..07b99594 100644 --- a/session.go +++ b/session.go @@ -284,7 +284,7 @@ func (session *Session) Having(conditions string) *Session { // DB db return the wrapper of sql.DB func (session *Session) DB() *core.DB { if session.db == nil { - session.db = session.engine.db + session.db = session.engine.DB() session.stmtCache = make(map[uint32]*core.Stmt, 0) } return session.db diff --git a/session_tx.go b/session_tx.go index 489489f3..cd23cf89 100644 --- a/session_tx.go +++ b/session_tx.go @@ -34,7 +34,7 @@ func (session *Session) Rollback() error { session.isAutoCommit = true start := time.Now() - needSQL := session.engine.db.NeedLogSQL(session.ctx) + needSQL := session.DB().NeedLogSQL(session.ctx) if needSQL { session.engine.logger.BeforeSQL(log.LogContext{ Ctx: session.ctx, @@ -63,7 +63,7 @@ func (session *Session) Commit() error { session.isAutoCommit = true start := time.Now() - needSQL := session.engine.db.NeedLogSQL(session.ctx) + needSQL := session.DB().NeedLogSQL(session.ctx) if needSQL { session.engine.logger.BeforeSQL(log.LogContext{ Ctx: session.ctx, diff --git a/xorm.go b/xorm.go index 51915940..3618b718 100644 --- a/xorm.go +++ b/xorm.go @@ -8,13 +8,11 @@ package xorm import ( "context" - "fmt" "os" "runtime" "time" "xorm.io/xorm/caches" - "xorm.io/xorm/core" "xorm.io/xorm/dialects" "xorm.io/xorm/log" "xorm.io/xorm/names" @@ -34,27 +32,7 @@ func close(engine *Engine) { // NewEngine new a db manager according to the parameter. Currently support four // drivers func NewEngine(driverName string, dataSourceName string) (*Engine, error) { - driver := dialects.QueryDriver(driverName) - if driver == nil { - return nil, fmt.Errorf("Unsupported driver name: %v", driverName) - } - - uri, err := driver.Parse(driverName, dataSourceName) - if err != nil { - return nil, err - } - - dialect := dialects.QueryDialect(uri.DBType) - if dialect == nil { - return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType) - } - - db, err := core.Open(driverName, dataSourceName) - if err != nil { - return nil, err - } - - err = dialect.Init(db, uri /*, driverName, dataSourceName*/) + dialect, err := dialects.OpenDialect(driverName, dataSourceName) if err != nil { return nil, err } @@ -64,7 +42,6 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr) engine := &Engine{ - db: db, dialect: dialect, TZLocation: time.Local, defaultContext: context.Background(), @@ -74,7 +51,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { dataSourceName: dataSourceName, } - if uri.DBType == schemas.SQLITE { + if dialect.URI().DBType == schemas.SQLITE { engine.DatabaseTZ = time.UTC } else { engine.DatabaseTZ = time.Local -- 2.40.1 From dd633142579c475e0297add705f71a46e97822de Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 6 Mar 2020 14:03:28 +0800 Subject: [PATCH 4/4] revert change for delete --- internal/statements/delete.go | 133 ---------------------------------- session_delete.go | 131 ++++++++++++++++++++++++++++----- 2 files changed, 114 insertions(+), 150 deletions(-) delete mode 100644 internal/statements/delete.go diff --git a/internal/statements/delete.go b/internal/statements/delete.go deleted file mode 100644 index 2cb91f2a..00000000 --- a/internal/statements/delete.go +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2020 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package statements - -import ( - "errors" - "fmt" - "time" - - "xorm.io/xorm/dialects" - "xorm.io/xorm/schemas" -) - -var ( - // ErrNeedDeletedCond delete needs less one condition error - ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") - - // ErrNotImplemented not implemented - ErrNotImplemented = errors.New("Not implemented") -) - -// GenDeleteSQL generated delete SQL according conditions -func (statement *Statement) GenDeleteSQL(bean interface{}) (string, string, []interface{}, *time.Time, error) { - condSQL, condArgs, err := statement.GenConds(bean) - if err != nil { - return "", "", nil, nil, err - } - pLimitN := statement.LimitN - if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { - return "", "", nil, nil, ErrNeedDeletedCond - } - - var tableNameNoQuote = statement.TableName() - var tableName = statement.quote(tableNameNoQuote) - var table = statement.RefTable - var deleteSQL string - if len(condSQL) > 0 { - deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL) - } else { - deleteSQL = fmt.Sprintf("DELETE FROM %v", tableName) - } - - var orderSQL string - if len(statement.OrderStr) > 0 { - orderSQL += fmt.Sprintf(" ORDER BY %s", statement.OrderStr) - } - if pLimitN != nil && *pLimitN > 0 { - limitNValue := *pLimitN - orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue) - } - - if len(orderSQL) > 0 { - switch statement.dialect.DBType() { - case schemas.POSTGRES: - inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - deleteSQL += " AND " + inSQL - } else { - deleteSQL += " WHERE " + inSQL - } - case schemas.SQLITE: - inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - deleteSQL += " AND " + inSQL - } else { - deleteSQL += " WHERE " + inSQL - } - // TODO: how to handle delete limit on mssql? - case schemas.MSSQL: - return "", "", nil, nil, ErrNotImplemented - default: - deleteSQL += orderSQL - } - } - - var realSQL string - if statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled - return deleteSQL, deleteSQL, condArgs, nil, nil - } - - deletedColumn := table.DeletedColumn() - realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", - statement.quote(statement.TableName()), - statement.quote(deletedColumn.Name), - condSQL) - - if len(orderSQL) > 0 { - switch statement.dialect.DBType() { - case schemas.POSTGRES: - inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - realSQL += " AND " + inSQL - } else { - realSQL += " WHERE " + inSQL - } - case schemas.SQLITE: - inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - realSQL += " AND " + inSQL - } else { - realSQL += " WHERE " + inSQL - } - // TODO: how to handle delete limit on mssql? - case schemas.MSSQL: - return "", "", nil, nil, ErrNotImplemented - default: - realSQL += orderSQL - } - } - - // !oinume! Insert nowTime to the head of statement.Params - condArgs = append(condArgs, "") - paramsLen := len(condArgs) - copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) - - now := ColumnNow(deletedColumn, statement.defaultTimeZone) - val := dialects.FormatTime(statement.dialect, deletedColumn.SQLType.Name, now) - condArgs[0] = val - - return realSQL, deleteSQL, condArgs, &now, nil -} - -// ColumnNow returns the current time for a column -func ColumnNow(col *schemas.Column, defaultTimeZone *time.Location) time.Time { - t := time.Now() - tz := defaultTimeZone - if !col.DisableTimeZone && col.TimeZone != nil { - tz = col.TimeZone - } - return t.In(tz) -} diff --git a/session_delete.go b/session_delete.go index 16434bac..eb5e2aea 100644 --- a/session_delete.go +++ b/session_delete.go @@ -6,13 +6,21 @@ package xorm import ( "errors" + "fmt" "strconv" - "time" "xorm.io/xorm/caches" "xorm.io/xorm/schemas" ) +var ( + // ErrNeedDeletedCond delete needs less one condition error + ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") + + // ErrNotImplemented not implemented + ErrNotImplemented = errors.New("Not implemented") +) + func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error { if table == nil || session.tx != nil { @@ -98,30 +106,119 @@ func (session *Session) Delete(bean interface{}) (int64, error) { processor.BeforeDelete() } - realSQL, deleteSQL, condArgs, now, err := session.statement.GenDeleteSQL(bean) + condSQL, condArgs, err := session.statement.GenConds(bean) if err != nil { return 0, err } - - argsForCache := make([]interface{}, 0, len(condArgs)*2) - copy(argsForCache, condArgs) - argsForCache = append(condArgs, argsForCache...) - - if !session.statement.GetUnscoped() && session.statement.RefTable.DeletedColumn() != nil { - deletedColumn := session.statement.RefTable.DeletedColumn() - - session.afterClosures = append(session.afterClosures, func(col *schemas.Column, t time.Time) func(interface{}) { - return func(bean interface{}) { - setColumnTime(bean, col, t) - } - }(deletedColumn, now.In(session.engine.TZLocation))) + pLimitN := session.statement.LimitN + if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { + return 0, ErrNeedDeletedCond } var tableNameNoQuote = session.statement.TableName() - if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { - session.cacheDelete(session.statement.RefTable, tableNameNoQuote, deleteSQL, argsForCache...) + var tableName = session.engine.Quote(tableNameNoQuote) + var table = session.statement.RefTable + var deleteSQL string + if len(condSQL) > 0 { + deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL) + } else { + deleteSQL = fmt.Sprintf("DELETE FROM %v", tableName) } + var orderSQL string + if len(session.statement.OrderStr) > 0 { + orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr) + } + if pLimitN != nil && *pLimitN > 0 { + limitNValue := *pLimitN + orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue) + } + + if len(orderSQL) > 0 { + switch session.engine.dialect.DBType() { + case schemas.POSTGRES: + inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + deleteSQL += " AND " + inSQL + } else { + deleteSQL += " WHERE " + inSQL + } + case schemas.SQLITE: + inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + deleteSQL += " AND " + inSQL + } else { + deleteSQL += " WHERE " + inSQL + } + // TODO: how to handle delete limit on mssql? + case schemas.MSSQL: + return 0, ErrNotImplemented + default: + deleteSQL += orderSQL + } + } + + var realSQL string + argsForCache := make([]interface{}, 0, len(condArgs)*2) + if session.statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled + realSQL = deleteSQL + copy(argsForCache, condArgs) + argsForCache = append(condArgs, argsForCache...) + } else { + // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches. + copy(argsForCache, condArgs) + argsForCache = append(condArgs, argsForCache...) + + deletedColumn := table.DeletedColumn() + realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", + session.engine.Quote(session.statement.TableName()), + session.engine.Quote(deletedColumn.Name), + condSQL) + + if len(orderSQL) > 0 { + switch session.engine.dialect.DBType() { + case schemas.POSTGRES: + inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + realSQL += " AND " + inSQL + } else { + realSQL += " WHERE " + inSQL + } + case schemas.SQLITE: + inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + realSQL += " AND " + inSQL + } else { + realSQL += " WHERE " + inSQL + } + // TODO: how to handle delete limit on mssql? + case schemas.MSSQL: + return 0, ErrNotImplemented + default: + realSQL += orderSQL + } + } + + // !oinume! Insert nowTime to the head of session.statement.Params + condArgs = append(condArgs, "") + paramsLen := len(condArgs) + copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) + + val, t := session.engine.nowTime(deletedColumn) + condArgs[0] = val + + var colName = deletedColumn.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) + } + + if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { + session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) + } + + session.statement.RefTable = table res, err := session.exec(realSQL, condArgs...) if err != nil { return 0, err -- 2.40.1