From ad9580249f241dacc1ee8a9cbb3e6f74dcf784a2 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 14 Jul 2021 11:14:04 +0800 Subject: [PATCH] Exec with time arg now will obey time zone settings on engine --- dialects/dialect.go | 12 ++++-------- dialects/mssql.go | 13 +++++++++++++ dialects/mysql.go | 15 +++++++++++++++ dialects/oracle.go | 15 +++++++++++++++ dialects/postgres.go | 20 +++++++++++++++----- dialects/sqlite3.go | 15 +++++++++++++-- integrations/session_raw_test.go | 30 ++++++++++++++++++++++++++++++ internal/statements/statement.go | 17 +++++++++++++++-- scan.go | 22 ++++++++++++++++++---- session_query.go | 2 +- 10 files changed, 139 insertions(+), 22 deletions(-) diff --git a/dialects/dialect.go b/dialects/dialect.go index b3d374cc..df33155d 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -42,11 +42,12 @@ func (uri *URI) SetSchema(schema string) { type Dialect interface { Init(*URI) error URI() *URI - SQLType(*schemas.Column) string - Alias(string) string // return what a sql type's alias of - FormatBytes(b []byte) string Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) + SQLType(*schemas.Column) string + Alias(string) string // return what a sql type's alias of + ColumnTypeKind(string) int // database column type kind + IsReserved(string) bool Quoter() schemas.Quoter SetQuotePolicy(quotePolicy QuotePolicy) @@ -102,11 +103,6 @@ func (db *Base) URI() *URI { return db.uri } -// FormatBytes formats bytes -func (db *Base) FormatBytes(bs []byte) string { - return fmt.Sprintf("0x%x", bs) -} - // DropTableSQL returns drop table SQL func (db *Base) DropTableSQL(tableName string) (string, bool) { quote := db.dialect.Quoter().Quote diff --git a/dialects/mssql.go b/dialects/mssql.go index c3c15077..e708ba80 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -364,6 +364,19 @@ func (db *mssql) SQLType(c *schemas.Column) string { return res } +func (db *mssql) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATE", "DATETIME", "DATETIME2", "TIME": + return schemas.TIME_TYPE + case "VARCHAR", "TEXT", "CHAR", "NVARCHAR", "NCHAR", "NTEXT": + return schemas.TEXT_TYPE + case "FLOAT", "REAL", "BIGINT", "DATETIMEOFFSET", "TINYINT", "SMALLINT", "INT": + return schemas.NUMERIC_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + func (db *mssql) IsReserved(name string) bool { _, ok := mssqlReservedWords[strings.ToUpper(name)] return ok diff --git a/dialects/mysql.go b/dialects/mysql.go index da19b820..db45cd62 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -337,6 +337,21 @@ func (db *mysql) SQLType(c *schemas.Column) string { return res } +func (db *mysql) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATETIME": + return schemas.TIME_TYPE + case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET": + return schemas.TEXT_TYPE + case "BIGINT", "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "FLOAT", "REAL", "DOUBLE PRECISION", "DECIMAL", "NUMERIC", "BIT": + return schemas.NUMERIC_TYPE + case "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB": + return schemas.BLOB_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + func (db *mysql) IsReserved(name string) bool { _, ok := mysqlReservedWords[strings.ToUpper(name)] return ok diff --git a/dialects/oracle.go b/dialects/oracle.go index 7043972b..5dd92887 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -568,6 +568,21 @@ func (db *oracle) SQLType(c *schemas.Column) string { return res } +func (db *oracle) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATE": + return schemas.TIME_TYPE + case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB": + return schemas.TEXT_TYPE + case "NUMBER": + return schemas.NUMERIC_TYPE + case "BLOB": + return schemas.BLOB_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + func (db *oracle) AutoIncrStr() string { return "AUTO_INCREMENT" } diff --git a/dialects/postgres.go b/dialects/postgres.go index 9f3c7275..4ec780e8 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -873,11 +873,6 @@ func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { } } -// FormatBytes formats bytes -func (db *postgres) FormatBytes(bs []byte) string { - return fmt.Sprintf("E'\\x%x'", bs) -} - func (db *postgres) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { @@ -943,6 +938,21 @@ func (db *postgres) SQLType(c *schemas.Column) string { return res } +func (db *postgres) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATETIME", "TIMESTAMP": + return schemas.TIME_TYPE + case "VARCHAR", "TEXT": + return schemas.TEXT_TYPE + case "BIGINT", "BIGSERIAL", "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL", "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION": + return schemas.NUMERIC_TYPE + case "BOOL": + return schemas.BOOL_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + func (db *postgres) IsReserved(name string) bool { _, ok := postgresReservedWords[strings.ToUpper(name)] return ok diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 04e5b457..581272ad 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -233,8 +233,19 @@ func (db *sqlite3) SQLType(c *schemas.Column) string { } } -func (db *sqlite3) FormatBytes(bs []byte) string { - return fmt.Sprintf("X'%x'", bs) +func (db *sqlite3) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATETIME": + return schemas.TIME_TYPE + case "TEXT": + return schemas.TEXT_TYPE + case "INTEGER", "REAL", "NUMERIC", "DECIMAL": + return schemas.NUMERIC_TYPE + case "BLOB": + return schemas.BLOB_TYPE + default: + return schemas.UNKNOW_TYPE + } } func (db *sqlite3) IsReserved(name string) bool { diff --git a/integrations/session_raw_test.go b/integrations/session_raw_test.go index 8b9d6766..36677683 100644 --- a/integrations/session_raw_test.go +++ b/integrations/session_raw_test.go @@ -7,6 +7,7 @@ package integrations import ( "strconv" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -35,3 +36,32 @@ func TestExecAndQuery(t *testing.T) { assert.EqualValues(t, 1, id) assert.Equal(t, "user", string(results[0]["name"])) } + +func TestExecTime(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UserinfoExecTime struct { + Uid int + Name string + Created time.Time + } + + assert.NoError(t, testEngine.Sync2(new(UserinfoExecTime))) + now := time.Now() + res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_exec_time`", true)+" (uid, name, created) VALUES (?, ?, ?)", 1, "user", now) + assert.NoError(t, err) + cnt, err := res.RowsAffected() + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + results, err := testEngine.QueryString("SELECT * FROM " + testEngine.TableName("`userinfo_exec_time`", true)) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, now.In(testEngine.GetTZLocation()).Format("2006-01-02 15:04:05"), results[0]["created"]) + + var uet UserinfoExecTime + has, err := testEngine.Where("uid=?", 1).Get(&uet) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, now.In(testEngine.GetTZLocation()).Format("2006-01-02 15:04:05"), uet.Created.Format("2006-01-02 15:04:05")) +} diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 2d173b87..bfe9987f 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -942,16 +942,29 @@ func (statement *Statement) quoteColumnStr(columnStr string) string { // ConvertSQLOrArgs converts sql or args func (statement *Statement) ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { - sql, args, err := convertSQLOrArgs(sqlOrArgs...) + sql, args, err := statement.convertSQLOrArgs(sqlOrArgs...) if err != nil { return "", nil, err } return statement.ReplaceQuote(sql), args, nil } -func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { +func (statement *Statement) convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { switch sqlOrArgs[0].(type) { case string: + if len(sqlOrArgs) > 1 { + var newArgs = make([]interface{}, 0, len(sqlOrArgs)-1) + for _, arg := range sqlOrArgs[1:] { + if v, ok := arg.(*time.Time); ok { + newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) + } else if v, ok := arg.(time.Time); ok { + newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) + } else { + newArgs = append(newArgs, arg) + } + } + return sqlOrArgs[0].(string), newArgs, nil + } return sqlOrArgs[0].(string), sqlOrArgs[1:], nil case *builder.Builder: return sqlOrArgs[0].(*builder.Builder).ToSQL() diff --git a/scan.go b/scan.go index d668208a..2fedd415 100644 --- a/scan.go +++ b/scan.go @@ -14,6 +14,7 @@ import ( "xorm.io/xorm/convert" "xorm.io/xorm/core" "xorm.io/xorm/dialects" + "xorm.io/xorm/schemas" ) // genScanResultsByBeanNullabale generates scan result @@ -123,7 +124,7 @@ func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { } } -func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { +func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { var scanResults = make([]interface{}, len(fields)) for i := 0; i < len(fields); i++ { var s sql.NullString @@ -135,9 +136,22 @@ func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[ } result := make(map[string]string, len(fields)) - for ii, key := range fields { - s := scanResults[ii].(*sql.NullString) - result[key] = s.String + for i, key := range fields { + s := scanResults[i].(*sql.NullString) + if s.String == "" { + result[key] = "" + continue + } + + if schemas.TIME_TYPE == engine.dialect.ColumnTypeKind(types[i].DatabaseTypeName()) { + t, err := convert.String2Time(s.String, engine.DatabaseTZ, engine.TZLocation) + if err != nil { + return nil, err + } + result[key] = t.Format("2006-01-02 15:04:05") + } else { + result[key] = s.String + } } return result, nil } diff --git a/session_query.go b/session_query.go index fa33496d..d14c3908 100644 --- a/session_query.go +++ b/session_query.go @@ -33,7 +33,7 @@ func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string } for rows.Next() { - result, err := row2mapStr(rows, types, fields) + result, err := session.engine.row2mapStr(rows, types, fields) if err != nil { return nil, err } -- 2.40.1