diff --git a/dialects/mssql.go b/dialects/mssql.go index 8ef924b8..f766950c 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -231,6 +231,7 @@ func (db *mssql) SQLType(c *schemas.Column) string { } else if strings.EqualFold(c.Default, "false") { c.Default = "0" } + return res case schemas.Serial: c.IsAutoIncrement = true c.IsPrimaryKey = true @@ -270,8 +271,8 @@ func (db *mssql) SQLType(c *schemas.Column) string { res = t } - if res == schemas.Int { - return schemas.Int + if res == schemas.Int || res == schemas.Bit || res == schemas.DateTime { + return res } hasLen1 := (c.Length > 0) diff --git a/engine.go b/engine.go index d5e599d6..8b137b27 100644 --- a/engine.go +++ b/engine.go @@ -412,6 +412,82 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return engine.dumpTables(tables, w, tp...) } +func formatColumnValue(dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string { + if d == nil { + return "NULL" + } + + if dq, ok := d.(bool); ok && (dstDialect.URI().DBType == schemas.SQLITE || + dstDialect.URI().DBType == schemas.MSSQL) { + if dq { + return "1" + } + return "0" + } + + if col.SQLType.IsText() { + var v = fmt.Sprintf("%s", d) + return "'" + strings.Replace(v, "'", "''", -1) + "'" + } else if col.SQLType.IsTime() { + var v = fmt.Sprintf("%s", d) + if strings.HasSuffix(v, " +0000 UTC") { + return fmt.Sprintf("'%s'", v[0:len(v)-len(" +0000 UTC")]) + } else if strings.HasSuffix(v, " +0000 +0000") { + return fmt.Sprintf("'%s'", v[0:len(v)-len(" +0000 +0000")]) + } + return "'" + strings.Replace(v, "'", "''", -1) + "'" + } else if col.SQLType.IsBlob() { + if reflect.TypeOf(d).Kind() == reflect.Slice { + return fmt.Sprintf("%s", dstDialect.FormatBytes(d.([]byte))) + } else if reflect.TypeOf(d).Kind() == reflect.String { + return fmt.Sprintf("'%s'", d.(string)) + } + } else if col.SQLType.IsNumeric() { + switch reflect.TypeOf(d).Kind() { + case reflect.Slice: + if col.SQLType.Name == schemas.Bool { + return fmt.Sprintf("%v", strconv.FormatBool(d.([]byte)[0] != byte('0'))) + } + return fmt.Sprintf("%s", string(d.([]byte))) + case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int: + if col.SQLType.Name == schemas.Bool { + v := reflect.ValueOf(d).Int() > 0 + if dstDialect.URI().DBType == schemas.SQLITE { + if v { + return "1" + } + return "0" + } + return fmt.Sprintf("%v", strconv.FormatBool(v)) + } + return fmt.Sprintf("%v", d) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if col.SQLType.Name == schemas.Bool { + v := reflect.ValueOf(d).Uint() > 0 + if dstDialect.URI().DBType == schemas.SQLITE { + if v { + return "1" + } + return "0" + } + return fmt.Sprintf("%v", strconv.FormatBool(v)) + } + return fmt.Sprintf("%v", d) + default: + return fmt.Sprintf("%v", d) + } + } + + s := fmt.Sprintf("%v", d) + if strings.Contains(s, ":") || strings.Contains(s, "-") { + if strings.HasSuffix(s, " +0000 UTC") { + return fmt.Sprintf("'%s'", s[0:len(s)-len(" +0000 UTC")]) + } + return fmt.Sprintf("'%s'", s) + } + return s +} + // dumpTables dump database all table structs and data to w with specify db type func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { var dstDialect dialects.Dialect @@ -424,7 +500,10 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } uri := engine.dialect.URI() - destURI := *uri + destURI := dialects.URI{ + DBType: tp[0], + DBName: uri.DBName, + } dstDialect.Init(&destURI) } @@ -495,59 +574,9 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch if col == nil { return errors.New("unknow column error") } - - if d == nil { - temp += ", NULL" - } else if col.SQLType.IsText() || col.SQLType.IsTime() { - var v = fmt.Sprintf("%s", d) - if strings.HasSuffix(v, " +0000 UTC") { - temp += fmt.Sprintf(", '%s'", v[0:len(v)-len(" +0000 UTC")]) - } else { - temp += ", '" + strings.Replace(v, "'", "''", -1) + "'" - } - } else if col.SQLType.IsBlob() { - if reflect.TypeOf(d).Kind() == reflect.Slice { - temp += fmt.Sprintf(", %s", dstDialect.FormatBytes(d.([]byte))) - } else if reflect.TypeOf(d).Kind() == reflect.String { - temp += fmt.Sprintf(", '%s'", d.(string)) - } - } else if col.SQLType.IsNumeric() { - switch reflect.TypeOf(d).Kind() { - case reflect.Slice: - if col.SQLType.Name == schemas.Bool { - temp += fmt.Sprintf(", %v", strconv.FormatBool(d.([]byte)[0] != byte('0'))) - } else { - temp += fmt.Sprintf(", %s", string(d.([]byte))) - } - case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int: - if col.SQLType.Name == schemas.Bool { - temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Int() > 0)) - } else { - temp += fmt.Sprintf(", %v", d) - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if col.SQLType.Name == schemas.Bool { - temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Uint() > 0)) - } else { - temp += fmt.Sprintf(", %v", d) - } - default: - temp += fmt.Sprintf(", %v", d) - } - } else { - s := fmt.Sprintf("%v", d) - if strings.Contains(s, ":") || strings.Contains(s, "-") { - if strings.HasSuffix(s, " +0000 UTC") { - temp += fmt.Sprintf(", '%s'", s[0:len(s)-len(" +0000 UTC")]) - } else { - temp += fmt.Sprintf(", '%s'", s) - } - } else { - temp += fmt.Sprintf(", %s", s) - } - } + temp += "," + formatColumnValue(dstDialect, d, col) } - _, err = io.WriteString(w, temp[2:]+");\n") + _, err = io.WriteString(w, temp[1:]+");\n") if err != nil { return err } diff --git a/integrations/engine_test.go b/integrations/engine_test.go index 19c5285d..0e5d3424 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -93,19 +93,23 @@ func TestDump(t *testing.T) { assert.NoError(t, PrepareEngine()) type TestDumpStruct struct { - Id int64 - Name string + Id int64 + Name string + IsMan bool + Created time.Time `xorm:"created"` } assertSync(t, new(TestDumpStruct)) - testEngine.Insert([]TestDumpStruct{ - {Name: "1"}, + cnt, err := testEngine.Insert([]TestDumpStruct{ + {Name: "1", IsMan: true}, {Name: "2\n"}, {Name: "3;"}, {Name: "4\n;\n''"}, {Name: "5'\n"}, }) + assert.NoError(t, err) + assert.EqualValues(t, 5, cnt) fp := fmt.Sprintf("%v.sql", testEngine.Dialect().URI().DBType) os.Remove(fp) @@ -116,7 +120,7 @@ func TestDump(t *testing.T) { sess := testEngine.NewSession() defer sess.Close() assert.NoError(t, sess.Begin()) - _, err := sess.ImportFile(fp) + _, err = sess.ImportFile(fp) assert.NoError(t, err) assert.NoError(t, sess.Commit()) @@ -128,6 +132,49 @@ func TestDump(t *testing.T) { } } +func TestDumpTables(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestDumpTableStruct struct { + Id int64 + Name string + IsMan bool + Created time.Time `xorm:"created"` + } + + assertSync(t, new(TestDumpTableStruct)) + + testEngine.Insert([]TestDumpTableStruct{ + {Name: "1", IsMan: true}, + {Name: "2\n"}, + {Name: "3;"}, + {Name: "4\n;\n''"}, + {Name: "5'\n"}, + }) + + fp := fmt.Sprintf("%v-table.sql", testEngine.Dialect().URI().DBType) + os.Remove(fp) + tb, err := testEngine.TableInfo(new(TestDumpTableStruct)) + assert.NoError(t, err) + assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, fp)) + + assert.NoError(t, PrepareEngine()) + + sess := testEngine.NewSession() + defer sess.Close() + assert.NoError(t, sess.Begin()) + _, err = sess.ImportFile(fp) + assert.NoError(t, err) + assert.NoError(t, sess.Commit()) + + for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL} { + name := fmt.Sprintf("dump_%v-table.sql", tp) + t.Run(name, func(t *testing.T) { + assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp)) + }) + } +} + func TestSetSchema(t *testing.T) { assert.NoError(t, PrepareEngine())