Exec with time arg now will obey time zone settings on engine #1989

Merged
lunny merged 1 commits from lunny/exec_time into master 2021-07-14 04:20:27 +00:00
10 changed files with 139 additions and 22 deletions

View File

@ -42,11 +42,12 @@ func (uri *URI) SetSchema(schema string) {
type Dialect interface {
Init(*URI) error
URI() *URI
SQLType(*schemas.Column) string
Alias(string) string // return what a sql type's alias of
FormatBytes(b []byte) string
Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error)
SQLType(*schemas.Column) string
Alias(string) string // return what a sql type's alias of
ColumnTypeKind(string) int // database column type kind
IsReserved(string) bool
Quoter() schemas.Quoter
SetQuotePolicy(quotePolicy QuotePolicy)
@ -102,11 +103,6 @@ func (db *Base) URI() *URI {
return db.uri
}
// FormatBytes formats bytes
func (db *Base) FormatBytes(bs []byte) string {
return fmt.Sprintf("0x%x", bs)
}
// DropTableSQL returns drop table SQL
func (db *Base) DropTableSQL(tableName string) (string, bool) {
quote := db.dialect.Quoter().Quote

View File

@ -364,6 +364,19 @@ func (db *mssql) SQLType(c *schemas.Column) string {
return res
}
func (db *mssql) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) {
case "DATE", "DATETIME", "DATETIME2", "TIME":
return schemas.TIME_TYPE
case "VARCHAR", "TEXT", "CHAR", "NVARCHAR", "NCHAR", "NTEXT":
return schemas.TEXT_TYPE
case "FLOAT", "REAL", "BIGINT", "DATETIMEOFFSET", "TINYINT", "SMALLINT", "INT":
return schemas.NUMERIC_TYPE
default:
return schemas.UNKNOW_TYPE
}
}
func (db *mssql) IsReserved(name string) bool {
_, ok := mssqlReservedWords[strings.ToUpper(name)]
return ok

View File

@ -337,6 +337,21 @@ func (db *mysql) SQLType(c *schemas.Column) string {
return res
}
func (db *mysql) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) {
case "DATETIME":
return schemas.TIME_TYPE
case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET":
return schemas.TEXT_TYPE
case "BIGINT", "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "FLOAT", "REAL", "DOUBLE PRECISION", "DECIMAL", "NUMERIC", "BIT":
return schemas.NUMERIC_TYPE
case "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB":
return schemas.BLOB_TYPE
default:
return schemas.UNKNOW_TYPE
}
}
func (db *mysql) IsReserved(name string) bool {
_, ok := mysqlReservedWords[strings.ToUpper(name)]
return ok

View File

@ -568,6 +568,21 @@ func (db *oracle) SQLType(c *schemas.Column) string {
return res
}
func (db *oracle) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) {
case "DATE":
return schemas.TIME_TYPE
case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB":
return schemas.TEXT_TYPE
case "NUMBER":
return schemas.NUMERIC_TYPE
case "BLOB":
return schemas.BLOB_TYPE
default:
return schemas.UNKNOW_TYPE
}
}
func (db *oracle) AutoIncrStr() string {
return "AUTO_INCREMENT"
}

View File

@ -873,11 +873,6 @@ func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) {
}
}
// FormatBytes formats bytes
func (db *postgres) FormatBytes(bs []byte) string {
return fmt.Sprintf("E'\\x%x'", bs)
}
func (db *postgres) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {
@ -943,6 +938,21 @@ func (db *postgres) SQLType(c *schemas.Column) string {
return res
}
func (db *postgres) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) {
case "DATETIME", "TIMESTAMP":
return schemas.TIME_TYPE
case "VARCHAR", "TEXT":
return schemas.TEXT_TYPE
case "BIGINT", "BIGSERIAL", "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL", "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION":
return schemas.NUMERIC_TYPE
case "BOOL":
return schemas.BOOL_TYPE
default:
return schemas.UNKNOW_TYPE
}
}
func (db *postgres) IsReserved(name string) bool {
_, ok := postgresReservedWords[strings.ToUpper(name)]
return ok

View File

@ -233,8 +233,19 @@ func (db *sqlite3) SQLType(c *schemas.Column) string {
}
}
func (db *sqlite3) FormatBytes(bs []byte) string {
return fmt.Sprintf("X'%x'", bs)
func (db *sqlite3) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) {
case "DATETIME":
return schemas.TIME_TYPE
case "TEXT":
return schemas.TEXT_TYPE
case "INTEGER", "REAL", "NUMERIC", "DECIMAL":
return schemas.NUMERIC_TYPE
case "BLOB":
return schemas.BLOB_TYPE
default:
return schemas.UNKNOW_TYPE
}
}
func (db *sqlite3) IsReserved(name string) bool {

View File

@ -7,6 +7,7 @@ package integrations
import (
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
@ -35,3 +36,32 @@ func TestExecAndQuery(t *testing.T) {
assert.EqualValues(t, 1, id)
assert.Equal(t, "user", string(results[0]["name"]))
}
func TestExecTime(t *testing.T) {
assert.NoError(t, PrepareEngine())
type UserinfoExecTime struct {
Uid int
Name string
Created time.Time
}
assert.NoError(t, testEngine.Sync2(new(UserinfoExecTime)))
now := time.Now()
res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_exec_time`", true)+" (uid, name, created) VALUES (?, ?, ?)", 1, "user", now)
assert.NoError(t, err)
cnt, err := res.RowsAffected()
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
results, err := testEngine.QueryString("SELECT * FROM " + testEngine.TableName("`userinfo_exec_time`", true))
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, now.In(testEngine.GetTZLocation()).Format("2006-01-02 15:04:05"), results[0]["created"])
var uet UserinfoExecTime
has, err := testEngine.Where("uid=?", 1).Get(&uet)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, now.In(testEngine.GetTZLocation()).Format("2006-01-02 15:04:05"), uet.Created.Format("2006-01-02 15:04:05"))
}

View File

@ -942,16 +942,29 @@ func (statement *Statement) quoteColumnStr(columnStr string) string {
// ConvertSQLOrArgs converts sql or args
func (statement *Statement) ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) {
sql, args, err := convertSQLOrArgs(sqlOrArgs...)
sql, args, err := statement.convertSQLOrArgs(sqlOrArgs...)
if err != nil {
return "", nil, err
}
return statement.ReplaceQuote(sql), args, nil
}
func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) {
func (statement *Statement) convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) {
switch sqlOrArgs[0].(type) {
case string:
if len(sqlOrArgs) > 1 {
var newArgs = make([]interface{}, 0, len(sqlOrArgs)-1)
for _, arg := range sqlOrArgs[1:] {
if v, ok := arg.(*time.Time); ok {
newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05"))
} else if v, ok := arg.(time.Time); ok {
newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05"))
} else {
newArgs = append(newArgs, arg)
}
}
return sqlOrArgs[0].(string), newArgs, nil
}
return sqlOrArgs[0].(string), sqlOrArgs[1:], nil
case *builder.Builder:
return sqlOrArgs[0].(*builder.Builder).ToSQL()

22
scan.go
View File

@ -14,6 +14,7 @@ import (
"xorm.io/xorm/convert"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/schemas"
)
// genScanResultsByBeanNullabale generates scan result
@ -123,7 +124,7 @@ func genScanResultsByBean(bean interface{}) (interface{}, bool, error) {
}
}
func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) {
func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) {
var scanResults = make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ {
var s sql.NullString
@ -135,9 +136,22 @@ func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[
}
result := make(map[string]string, len(fields))
for ii, key := range fields {
s := scanResults[ii].(*sql.NullString)
result[key] = s.String
for i, key := range fields {
s := scanResults[i].(*sql.NullString)
if s.String == "" {
result[key] = ""
continue
}
if schemas.TIME_TYPE == engine.dialect.ColumnTypeKind(types[i].DatabaseTypeName()) {
t, err := convert.String2Time(s.String, engine.DatabaseTZ, engine.TZLocation)
if err != nil {
return nil, err
}
result[key] = t.Format("2006-01-02 15:04:05")
} else {
result[key] = s.String
}
}
return result, nil
}

View File

@ -33,7 +33,7 @@ func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string
}
for rows.Next() {
result, err := row2mapStr(rows, types, fields)
result, err := session.engine.row2mapStr(rows, types, fields)
if err != nil {
return nil, err
}