Fix bug on dumptable #1984

Merged
lunny merged 10 commits from lunny/fix_dumptable into master 2021-07-11 12:05:44 +00:00
6 changed files with 121 additions and 152 deletions

View File

@ -348,7 +348,6 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve
*d = cloneBytes(s) *d = cloneBytes(s)
return nil return nil
} }
case time.Time: case time.Time:
switch d := dest.(type) { switch d := dest.(type) {
case *string: case *string:

View File

@ -19,7 +19,14 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t
dt = dt.In(convertedLocation) dt = dt.In(convertedLocation)
return &dt, nil return &dt, nil
} else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' { } else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' {
dt, err := time.ParseInLocation("2006-01-02T15:04:05Z", s, originalLocation) dt, err := time.ParseInLocation(time.RFC3339, s, originalLocation)
if err != nil {
return nil, err
}
dt = dt.In(convertedLocation)
return &dt, nil
} else if len(s) == 25 && s[10] == 'T' && s[19] == '+' && s[22] == ':' {
dt, err := time.Parse(time.RFC3339, s)
if err != nil { if err != nil {
return nil, err return nil, err
} }

30
convert/time_test.go Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2021 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package convert
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestString2Time(t *testing.T) {
expectedLoc, err := time.LoadLocation("Asia/Shanghai")
assert.NoError(t, err)
var kases = map[string]time.Time{
"2021-06-06T22:58:20+08:00": time.Date(2021, 6, 6, 22, 58, 20, 0, expectedLoc),
"2021-07-11 10:44:00": time.Date(2021, 7, 11, 18, 44, 0, 0, expectedLoc),
"2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 04, 0, expectedLoc),
}
for layout, tm := range kases {
t.Run(layout, func(t *testing.T) {
target, err := String2Time(layout, time.UTC, expectedLoc)
assert.NoError(t, err)
assert.EqualValues(t, tm, *target)
})
}
}

212
engine.go
View File

@ -13,7 +13,6 @@ import (
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
"strconv"
"strings" "strings"
"time" "time"
@ -21,7 +20,6 @@ import (
"xorm.io/xorm/contexts" "xorm.io/xorm/contexts"
"xorm.io/xorm/core" "xorm.io/xorm/core"
"xorm.io/xorm/dialects" "xorm.io/xorm/dialects"
"xorm.io/xorm/internal/json"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/log" "xorm.io/xorm/log"
"xorm.io/xorm/names" "xorm.io/xorm/names"
@ -446,93 +444,14 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return engine.dumpTables(tables, w, tp...) return engine.dumpTables(tables, w, tp...)
} }
func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string { func formatBool(s string, dstDialect dialects.Dialect) string {
if d == nil { if dstDialect.URI().DBType == schemas.MSSQL {
return "NULL" switch s {
} case "true":
if dq, ok := d.(bool); ok && (dstDialect.URI().DBType == schemas.SQLITE ||
dstDialect.URI().DBType == schemas.MSSQL) {
if dq {
return "1" return "1"
case "false":
return "0"
} }
return "0"
}
if col.SQLType.IsText() {
var v string
switch reflect.TypeOf(d).Kind() {
case reflect.Struct, reflect.Array, reflect.Slice, reflect.Map:
bytes, err := json.DefaultJSONHandler.Marshal(d)
if err != nil {
v = fmt.Sprintf("%s", d)
} else {
v = string(bytes)
}
default:
v = fmt.Sprintf("%s", d)
}
return "'" + strings.Replace(v, "'", "''", -1) + "'"
} else if col.SQLType.IsTime() {
if t, ok := d.(time.Time); ok {
return "'" + t.In(dbLocation).Format("2006-01-02 15:04:05") + "'"
}
var v = fmt.Sprintf("%s", d)
if strings.HasSuffix(v, " +0000 UTC") {
return fmt.Sprintf("'%s'", v[0:len(v)-len(" +0000 UTC")])
} else if strings.HasSuffix(v, " +0000 +0000") {
return fmt.Sprintf("'%s'", v[0:len(v)-len(" +0000 +0000")])
}
return "'" + strings.Replace(v, "'", "''", -1) + "'"
} else if col.SQLType.IsBlob() {
if reflect.TypeOf(d).Kind() == reflect.Slice {
return fmt.Sprintf("%s", dstDialect.FormatBytes(d.([]byte)))
} else if reflect.TypeOf(d).Kind() == reflect.String {
return fmt.Sprintf("'%s'", d.(string))
}
} else if col.SQLType.IsNumeric() {
switch reflect.TypeOf(d).Kind() {
case reflect.Slice:
if col.SQLType.Name == schemas.Bool {
return fmt.Sprintf("%v", strconv.FormatBool(d.([]byte)[0] != byte('0')))
}
return fmt.Sprintf("%s", string(d.([]byte)))
case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int:
if col.SQLType.Name == schemas.Bool {
v := reflect.ValueOf(d).Int() > 0
if dstDialect.URI().DBType == schemas.SQLITE {
if v {
return "1"
}
return "0"
}
return fmt.Sprintf("%v", strconv.FormatBool(v))
}
return fmt.Sprintf("%d", d)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if col.SQLType.Name == schemas.Bool {
v := reflect.ValueOf(d).Uint() > 0
if dstDialect.URI().DBType == schemas.SQLITE {
if v {
return "1"
}
return "0"
}
return fmt.Sprintf("%v", strconv.FormatBool(v))
}
return fmt.Sprintf("%d", d)
default:
return fmt.Sprintf("%v", d)
}
}
s := fmt.Sprintf("%v", d)
if strings.Contains(s, ":") || strings.Contains(s, "-") {
if strings.HasSuffix(s, " +0000 UTC") {
return fmt.Sprintf("'%s'", s[0:len(s)-len(" +0000 UTC")])
}
return fmt.Sprintf("'%s'", s)
} }
return s return s
} }
@ -545,7 +464,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
} else { } else {
dstDialect = dialects.QueryDialect(tp[0]) dstDialect = dialects.QueryDialect(tp[0])
if dstDialect == nil { if dstDialect == nil {
return errors.New("Unsupported database type") return fmt.Errorf("unsupported database type %v", tp[0])
} }
uri := engine.dialect.URI() uri := engine.dialect.URI()
@ -619,73 +538,68 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
} }
defer rows.Close() defer rows.Close()
if table.Type != nil { types, err := rows.ColumnTypes()
sess := engine.NewSession() if err != nil {
defer sess.Close() return err
for rows.Next() { }
beanValue := reflect.New(table.Type)
bean := beanValue.Interface()
fields, err := rows.Columns()
if err != nil {
return err
}
scanResults, err := sess.row2Slice(rows, fields, bean)
if err != nil {
return err
}
dataStruct := utils.ReflectValue(bean) sess := engine.NewSession()
_, err = sess.slice2Bean(scanResults, fields, bean, &dataStruct, table) defer sess.Close()
if err != nil { for rows.Next() {
return err _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
} if err != nil {
return err
}
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") scanResults, err := sess.engine.scanStringInterface(rows, types)
if err != nil { if err != nil {
return err return err
} }
for i, scanResult := range scanResults {
var temp string stp := schemas.SQLType{Name: types[i].DatabaseTypeName()}
for _, d := range dstCols { if stp.IsNumeric() {
col := table.GetColumn(d) s := scanResult.(*sql.NullString)
if col == nil { if s.Valid {
return errors.New("unknown column error") if _, err = io.WriteString(w, formatBool(s.String, dstDialect)); err != nil {
return err
}
} else {
if _, err = io.WriteString(w, "NULL"); err != nil {
return err
}
}
} else if stp.IsBool() {
s := scanResult.(*sql.NullString)
if s.Valid {
if _, err = io.WriteString(w, formatBool(s.String, dstDialect)); err != nil {
return err
}
} else {
if _, err = io.WriteString(w, "NULL"); err != nil {
return err
}
}
} else {
s := scanResult.(*sql.NullString)
if s.Valid {
if _, err = io.WriteString(w, "'"+strings.ReplaceAll(s.String, "'", "''")+"'"); err != nil {
return err
}
} else {
if _, err = io.WriteString(w, "NULL"); err != nil {
return err
}
} }
field := dataStruct.FieldByIndex(col.FieldIndex)
temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, field.Interface(), col)
} }
_, err = io.WriteString(w, temp[1:]+");\n") if i < len(scanResults)-1 {
if err != nil { if _, err = io.WriteString(w, ","); err != nil {
return err return err
}
} }
} }
} else { _, err = io.WriteString(w, ");\n")
for rows.Next() { if err != nil {
dest := make([]interface{}, len(cols)) return err
err = rows.ScanSlice(&dest)
if err != nil {
return err
}
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
if err != nil {
return err
}
var temp string
for i, d := range dest {
col := table.GetColumn(cols[i])
if col == nil {
return errors.New("unknow column error")
}
temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col)
}
_, err = io.WriteString(w, temp[1:]+");\n")
if err != nil {
return err
}
} }
} }

View File

@ -172,8 +172,21 @@ func TestDumpTables(t *testing.T) {
name := fmt.Sprintf("dump_%v-table.sql", tp) name := fmt.Sprintf("dump_%v-table.sql", tp)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp)) assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp))
}) })
} }
assert.NoError(t, testEngine.DropTables(new(TestDumpTableStruct)))
importPath := fmt.Sprintf("dump_%v-table.sql", testEngine.Dialect().URI().DBType)
t.Run("import_"+importPath, func(t *testing.T) {
sess := testEngine.NewSession()
defer sess.Close()
assert.NoError(t, sess.Begin())
_, err = sess.ImportFile(importPath)
assert.NoError(t, err)
assert.NoError(t, sess.Commit())
})
} }
func TestDumpTables2(t *testing.T) { func TestDumpTables2(t *testing.T) {

View File

@ -39,6 +39,7 @@ const (
TIME_TYPE TIME_TYPE
NUMERIC_TYPE NUMERIC_TYPE
ARRAY_TYPE ARRAY_TYPE
BOOL_TYPE
) )
// IsType reutrns ture if the column type is the same as the parameter // IsType reutrns ture if the column type is the same as the parameter
@ -64,6 +65,10 @@ func (s *SQLType) IsTime() bool {
return s.IsType(TIME_TYPE) return s.IsType(TIME_TYPE)
} }
func (s *SQLType) IsBool() bool {
return s.IsType(BOOL_TYPE)
}
// IsNumeric returns true if column is a numeric type // IsNumeric returns true if column is a numeric type
func (s *SQLType) IsNumeric() bool { func (s *SQLType) IsNumeric() bool {
return s.IsType(NUMERIC_TYPE) return s.IsType(NUMERIC_TYPE)
@ -209,7 +214,8 @@ var (
Bytea: BLOB_TYPE, Bytea: BLOB_TYPE,
UniqueIdentifier: BLOB_TYPE, UniqueIdentifier: BLOB_TYPE,
Bool: NUMERIC_TYPE, Bool: BOOL_TYPE,
Boolean: BOOL_TYPE,
Serial: NUMERIC_TYPE, Serial: NUMERIC_TYPE,
BigSerial: NUMERIC_TYPE, BigSerial: NUMERIC_TYPE,