Fix bug on dumptable #1984

Merged
lunny merged 10 commits from lunny/fix_dumptable into master 7 months ago
  1. 1
      convert.go
  2. 9
      convert/time.go
  3. 30
      convert/time_test.go
  4. 212
      engine.go
  5. 13
      integrations/engine_test.go
  6. 8
      schemas/type.go

1
convert.go

@ -348,7 +348,6 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve @@ -348,7 +348,6 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve
*d = cloneBytes(s)
return nil
}
case time.Time:
switch d := dest.(type) {
case *string:

9
convert/time.go

@ -19,7 +19,14 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t @@ -19,7 +19,14 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t
dt = dt.In(convertedLocation)
return &dt, nil
} else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' {
dt, err := time.ParseInLocation("2006-01-02T15:04:05Z", s, originalLocation)
dt, err := time.ParseInLocation(time.RFC3339, s, originalLocation)
if err != nil {
return nil, err
}
dt = dt.In(convertedLocation)
return &dt, nil
} else if len(s) == 25 && s[10] == 'T' && s[19] == '+' && s[22] == ':' {
dt, err := time.Parse(time.RFC3339, s)
if err != nil {
return nil, err
}

30
convert/time_test.go

@ -0,0 +1,30 @@ @@ -0,0 +1,30 @@
// Copyright 2021 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package convert
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestString2Time(t *testing.T) {
expectedLoc, err := time.LoadLocation("Asia/Shanghai")
assert.NoError(t, err)
var kases = map[string]time.Time{
"2021-06-06T22:58:20+08:00": time.Date(2021, 6, 6, 22, 58, 20, 0, expectedLoc),
"2021-07-11 10:44:00": time.Date(2021, 7, 11, 18, 44, 0, 0, expectedLoc),
"2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 04, 0, expectedLoc),
}
for layout, tm := range kases {
t.Run(layout, func(t *testing.T) {
target, err := String2Time(layout, time.UTC, expectedLoc)
assert.NoError(t, err)
assert.EqualValues(t, tm, *target)
})
}
}

212
engine.go

@ -13,7 +13,6 @@ import ( @@ -13,7 +13,6 @@ import (
"os"
"reflect"
"runtime"
"strconv"
"strings"
"time"
@ -21,7 +20,6 @@ import ( @@ -21,7 +20,6 @@ import (
"xorm.io/xorm/contexts"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/json"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
@ -446,93 +444,14 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch @@ -446,93 +444,14 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return engine.dumpTables(tables, w, tp...)
}
func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string {
if d == nil {
return "NULL"
}
if dq, ok := d.(bool); ok && (dstDialect.URI().DBType == schemas.SQLITE ||
dstDialect.URI().DBType == schemas.MSSQL) {
if dq {
func formatBool(s string, dstDialect dialects.Dialect) string {
if dstDialect.URI().DBType == schemas.MSSQL {
switch s {
case "true":
return "1"
case "false":
return "0"
}
return "0"
}
if col.SQLType.IsText() {
var v string
switch reflect.TypeOf(d).Kind() {
case reflect.Struct, reflect.Array, reflect.Slice, reflect.Map:
bytes, err := json.DefaultJSONHandler.Marshal(d)
if err != nil {
v = fmt.Sprintf("%s", d)
} else {
v = string(bytes)
}
default:
v = fmt.Sprintf("%s", d)
}
return "'" + strings.Replace(v, "'", "''", -1) + "'"
} else if col.SQLType.IsTime() {
if t, ok := d.(time.Time); ok {
return "'" + t.In(dbLocation).Format("2006-01-02 15:04:05") + "'"
}
var v = fmt.Sprintf("%s", d)
if strings.HasSuffix(v, " +0000 UTC") {
return fmt.Sprintf("'%s'", v[0:len(v)-len(" +0000 UTC")])
} else if strings.HasSuffix(v, " +0000 +0000") {
return fmt.Sprintf("'%s'", v[0:len(v)-len(" +0000 +0000")])
}
return "'" + strings.Replace(v, "'", "''", -1) + "'"
} else if col.SQLType.IsBlob() {
if reflect.TypeOf(d).Kind() == reflect.Slice {
return fmt.Sprintf("%s", dstDialect.FormatBytes(d.([]byte)))
} else if reflect.TypeOf(d).Kind() == reflect.String {
return fmt.Sprintf("'%s'", d.(string))
}
} else if col.SQLType.IsNumeric() {
switch reflect.TypeOf(d).Kind() {
case reflect.Slice:
if col.SQLType.Name == schemas.Bool {
return fmt.Sprintf("%v", strconv.FormatBool(d.([]byte)[0] != byte('0')))
}
return fmt.Sprintf("%s", string(d.([]byte)))
case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int:
if col.SQLType.Name == schemas.Bool {
v := reflect.ValueOf(d).Int() > 0
if dstDialect.URI().DBType == schemas.SQLITE {
if v {
return "1"
}
return "0"
}
return fmt.Sprintf("%v", strconv.FormatBool(v))
}
return fmt.Sprintf("%d", d)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if col.SQLType.Name == schemas.Bool {
v := reflect.ValueOf(d).Uint() > 0
if dstDialect.URI().DBType == schemas.SQLITE {
if v {
return "1"
}
return "0"
}
return fmt.Sprintf("%v", strconv.FormatBool(v))
}
return fmt.Sprintf("%d", d)
default:
return fmt.Sprintf("%v", d)
}
}
s := fmt.Sprintf("%v", d)
if strings.Contains(s, ":") || strings.Contains(s, "-") {
if strings.HasSuffix(s, " +0000 UTC") {
return fmt.Sprintf("'%s'", s[0:len(s)-len(" +0000 UTC")])
}
return fmt.Sprintf("'%s'", s)
}
return s
}
@ -545,7 +464,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch @@ -545,7 +464,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
} else {
dstDialect = dialects.QueryDialect(tp[0])
if dstDialect == nil {
return errors.New("Unsupported database type")
return fmt.Errorf("unsupported database type %v", tp[0])
}
uri := engine.dialect.URI()
@ -619,74 +538,69 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch @@ -619,74 +538,69 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
}
defer rows.Close()
if table.Type != nil {
sess := engine.NewSession()
defer sess.Close()
for rows.Next() {
beanValue := reflect.New(table.Type)
bean := beanValue.Interface()
fields, err := rows.Columns()
if err != nil {
return err
}
scanResults, err := sess.row2Slice(rows, fields, bean)
if err != nil {
return err
}
dataStruct := utils.ReflectValue(bean)
_, err = sess.slice2Bean(scanResults, fields, bean, &dataStruct, table)
if err != nil {
return err
}
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
if err != nil {
return err
}
var temp string
for _, d := range dstCols {
col := table.GetColumn(d)
if col == nil {
return errors.New("unknown column error")
}
types, err := rows.ColumnTypes()
if err != nil {
return err
}
field := dataStruct.FieldByIndex(col.FieldIndex)
temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, field.Interface(), col)
}
_, err = io.WriteString(w, temp[1:]+");\n")
if err != nil {
return err
}
sess := engine.NewSession()
defer sess.Close()
for rows.Next() {
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
if err != nil {
return err
}
} else {
for rows.Next() {
dest := make([]interface{}, len(cols))
err = rows.ScanSlice(&dest)
if err != nil {
return err
}
_, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
if err != nil {
return err
}
var temp string
for i, d := range dest {
col := table.GetColumn(cols[i])
if col == nil {
return errors.New("unknow column error")
scanResults, err := sess.engine.scanStringInterface(rows, types)
if err != nil {
return err
}
for i, scanResult := range scanResults {
stp := schemas.SQLType{Name: types[i].DatabaseTypeName()}
if stp.IsNumeric() {
s := scanResult.(*sql.NullString)
if s.Valid {
if _, err = io.WriteString(w, formatBool(s.String, dstDialect)); err != nil {
return err
}
} else {
if _, err = io.WriteString(w, "NULL"); err != nil {
return err
}
}
} else if stp.IsBool() {
s := scanResult.(*sql.NullString)
if s.Valid {
if _, err = io.WriteString(w, formatBool(s.String, dstDialect)); err != nil {
return err
}
} else {
if _, err = io.WriteString(w, "NULL"); err != nil {
return err
}
}
} else {
s := scanResult.(*sql.NullString)
if s.Valid {
if _, err = io.WriteString(w, "'"+strings.ReplaceAll(s.String, "'", "''")+"'"); err != nil {
return err
}
} else {
if _, err = io.WriteString(w, "NULL"); err != nil {
return err
}
}
temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col)
}
_, err = io.WriteString(w, temp[1:]+");\n")
if err != nil {
return err
if i < len(scanResults)-1 {
if _, err = io.WriteString(w, ","); err != nil {
return err
}
}
}
_, err = io.WriteString(w, ");\n")
if err != nil {
return err
}
}
// FIXME: Hack for postgres

13
integrations/engine_test.go

@ -172,8 +172,21 @@ func TestDumpTables(t *testing.T) { @@ -172,8 +172,21 @@ func TestDumpTables(t *testing.T) {
name := fmt.Sprintf("dump_%v-table.sql", tp)
t.Run(name, func(t *testing.T) {
assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp))
})
}
assert.NoError(t, testEngine.DropTables(new(TestDumpTableStruct)))
importPath := fmt.Sprintf("dump_%v-table.sql", testEngine.Dialect().URI().DBType)
t.Run("import_"+importPath, func(t *testing.T) {
sess := testEngine.NewSession()
defer sess.Close()
assert.NoError(t, sess.Begin())
_, err = sess.ImportFile(importPath)
assert.NoError(t, err)
assert.NoError(t, sess.Commit())
})
}
func TestDumpTables2(t *testing.T) {

8
schemas/type.go

@ -39,6 +39,7 @@ const ( @@ -39,6 +39,7 @@ const (
TIME_TYPE
NUMERIC_TYPE
ARRAY_TYPE
BOOL_TYPE
)
// IsType reutrns ture if the column type is the same as the parameter
@ -64,6 +65,10 @@ func (s *SQLType) IsTime() bool { @@ -64,6 +65,10 @@ func (s *SQLType) IsTime() bool {
return s.IsType(TIME_TYPE)
}
func (s *SQLType) IsBool() bool {
return s.IsType(BOOL_TYPE)
}
// IsNumeric returns true if column is a numeric type
func (s *SQLType) IsNumeric() bool {
return s.IsType(NUMERIC_TYPE)
@ -209,7 +214,8 @@ var ( @@ -209,7 +214,8 @@ var (
Bytea: BLOB_TYPE,
UniqueIdentifier: BLOB_TYPE,
Bool: NUMERIC_TYPE,
Bool: BOOL_TYPE,
Boolean: BOOL_TYPE,
Serial: NUMERIC_TYPE,
BigSerial: NUMERIC_TYPE,

Loading…
Cancel
Save