diff --git a/dialect_mssql.go b/dialect_mssql.go index ce4dd00c..524d05a4 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -338,6 +338,7 @@ func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) { func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{} s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, + "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), replace(replace(isnull(c.text,''),'(',''),')','') as vdefault, ISNULL(i.is_primary_key, 0) from sys.columns a @@ -361,8 +362,8 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column for rows.Next() { var name, ctype, vdefault string var maxLen, precision, scale int - var nullable, isPK bool - err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &vdefault, &isPK) + var nullable, isPK, defaultIsNull bool + err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK) if err != nil { return nil, nil, err } @@ -371,7 +372,10 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column col.Indexes = make(map[string]int) col.Name = strings.Trim(name, "` ") col.Nullable = nullable - col.Default = vdefault + col.DefaultIsEmpty = defaultIsNull + if !defaultIsNull { + col.Default = vdefault + } col.IsPrimaryKey = isPK ct := strings.ToUpper(ctype) if ct == "DECIMAL" { @@ -395,15 +399,6 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column } } - if col.SQLType.IsText() || col.SQLType.IsTime() { - if col.Default != "" { - col.Default = "'" + col.Default + "'" - } else { - if col.DefaultIsEmpty { - col.Default = "''" - } - } - } cols[col.Name] = col colSeq = append(colSeq, col.Name) } diff --git a/dialect_mysql.go b/dialect_mysql.go index a108b81f..cf1dbb6f 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -345,9 +345,9 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column if colDefault != nil { col.Default = *colDefault - if col.Default == "" { - col.DefaultIsEmpty = true - } + col.DefaultIsEmpty = false + } else { + col.DefaultIsEmpty = true } cts := strings.Split(colType, "(") @@ -411,13 +411,11 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column col.IsAutoIncrement = true } - if col.SQLType.IsText() || col.SQLType.IsTime() { - if col.Default != "" { + if !col.DefaultIsEmpty { + if col.SQLType.IsText() { + col.Default = "'" + col.Default + "'" + } else if col.SQLType.IsTime() && col.Default != "CURRENT_TIMESTAMP" { col.Default = "'" + col.Default + "'" - } else { - if col.DefaultIsEmpty { - col.Default = "''" - } } } cols[col.Name] = col diff --git a/dialect_postgres.go b/dialect_postgres.go index 3df682e8..ccef3086 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -1005,16 +1005,18 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att col.Name = strings.Trim(colName, `" `) - if colDefault != nil || isPK { - if isPK { - col.IsPrimaryKey = true - } else { - col.Default = *colDefault + if colDefault != nil { + col.Default = *colDefault + col.DefaultIsEmpty = false + if strings.HasPrefix(col.Default, "nextval(") { + col.IsAutoIncrement = true } + } else { + col.DefaultIsEmpty = true } - if colDefault != nil && strings.HasPrefix(*colDefault, "nextval(") { - col.IsAutoIncrement = true + if isPK { + col.IsPrimaryKey = true } col.Nullable = (isNullable == "YES") @@ -1043,12 +1045,16 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att col.Length = maxLen - if col.SQLType.IsText() || col.SQLType.IsTime() { - if col.Default != "" { - col.Default = "'" + col.Default + "'" - } else { - if col.DefaultIsEmpty { - col.Default = "''" + if !col.DefaultIsEmpty { + if col.SQLType.IsText() { + if strings.HasSuffix(col.Default, "::character varying") { + col.Default = strings.TrimRight(col.Default, "::character varying") + } else if !strings.HasPrefix(col.Default, "'") { + col.Default = "'" + col.Default + "'" + } + } else if col.SQLType.IsTime() { + if strings.HasSuffix(col.Default, "::timestamp without time zone") { + col.Default = strings.TrimRight(col.Default, "::timestamp without time zone") } } } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index 60f07295..d1852e9b 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -270,6 +270,34 @@ func (db *sqlite3) IsColumnExist(tableName, colName string) (bool, error) { return false, nil } +// splitColStr splits a sqlite col strings as fields +func splitColStr(colStr string) []string { + colStr = strings.TrimSpace(colStr) + var results = make([]string, 0, 10) + var lastIdx int + var hasC, hasQuote bool + for i, c := range colStr { + if c == ' ' && !hasQuote { + if hasC { + results = append(results, colStr[lastIdx:i]) + hasC = false + } + } else { + if c == '\'' { + hasQuote = !hasQuote + } + if !hasC { + lastIdx = i + } + hasC = true + if i == len(colStr)-1 { + results = append(results, colStr[lastIdx:i+1]) + } + } + } + return results +} + func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" @@ -315,7 +343,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu continue } - fields := strings.Fields(strings.TrimSpace(colStr)) + fields := splitColStr(colStr) col := new(core.Column) col.Indexes = make(map[string]int) col.Nullable = true @@ -344,9 +372,6 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu col.DefaultIsEmpty = false } } - if !col.SQLType.IsNumeric() && !col.DefaultIsEmpty { - col.Default = "'" + col.Default + "'" - } cols[col.Name] = col colSeq = append(colSeq, col.Name) } diff --git a/dialect_sqlite3_test.go b/dialect_sqlite3_test.go new file mode 100644 index 00000000..a2036159 --- /dev/null +++ b/dialect_sqlite3_test.go @@ -0,0 +1,35 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSplitColStr(t *testing.T) { + var kases = []struct { + colStr string + fields []string + }{ + { + colStr: "`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL", + fields: []string{ + "`id`", "INTEGER", "PRIMARY", "KEY", "AUTOINCREMENT", "NOT", "NULL", + }, + }, + { + colStr: "`created` DATETIME DEFAULT '2006-01-02 15:04:05' NULL", + fields: []string{ + "`created`", "DATETIME", "DEFAULT", "'2006-01-02 15:04:05'", "NULL", + }, + }, + } + + for _, kase := range kases { + assert.EqualValues(t, kase.fields, splitColStr(kase.colStr)) + } +} diff --git a/engine.go b/engine.go index 649fd1e3..96100fce 100644 --- a/engine.go +++ b/engine.go @@ -907,8 +907,15 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { fieldType := fieldValue.Type() if ormTagStr != "" { - col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, - IsAutoIncrement: false, MapType: core.TWOSIDES, Indexes: make(map[string]int)} + col = &core.Column{ + FieldName: t.Field(i).Name, + Nullable: true, + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: core.TWOSIDES, + Indexes: make(map[string]int), + DefaultIsEmpty: true, + } tags := splitTag(ormTagStr) if len(tags) > 0 { diff --git a/helpers_test.go b/helpers_test.go index 7e317126..caf7b9f0 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -10,25 +10,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestSplitTag(t *testing.T) { - var cases = []struct { - tag string - tags []string - }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, - {"TEXT", []string{"TEXT"}}, - {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, - {"json binary", []string{"json", "binary"}}, - } - - for _, kase := range cases { - tags := splitTag(kase.tag) - if !sliceEq(tags, kase.tags) { - t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) - } - } -} - func TestEraseAny(t *testing.T) { raw := "SELECT * FROM `table`.[table_name]" assert.EqualValues(t, raw, eraseAny(raw)) diff --git a/tag.go b/tag.go index 6feb581a..ec8d5cf0 100644 --- a/tag.go +++ b/tag.go @@ -125,6 +125,7 @@ func DefaultTagHandler(ctx *tagContext) error { ctx.col.Default = ctx.nextTag ctx.ignoreNext = true } + ctx.col.DefaultIsEmpty = false return nil } diff --git a/tag_test.go b/tag_test.go index cfb16b3b..891c6ffc 100644 --- a/tag_test.go +++ b/tag_test.go @@ -5,7 +5,6 @@ package xorm import ( - "errors" "fmt" "strings" "testing" @@ -27,58 +26,27 @@ func TestCreatedAndUpdated(t *testing.T) { u := new(UserCU) err := testEngine.DropTables(u) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) err = testEngine.CreateTables(u) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) u.Name = "sss" cnt, err := testEngine.Insert(u) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) u.Name = "xxx" cnt, err = testEngine.ID(u.Id).Update(u) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update not returned 1") - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) u.Id = 0 u.Created = time.Now().Add(-time.Hour * 24 * 365) u.Updated = u.Created - fmt.Println(u) cnt, err = testEngine.NoAutoTime().Insert(u) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) } type StrangeName struct { @@ -90,25 +58,17 @@ func TestStrangeName(t *testing.T) { assert.NoError(t, prepareEngine()) err := testEngine.DropTables(new(StrangeName)) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) err = testEngine.CreateTables(new(StrangeName)) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(&StrangeName{Name: "sfsfdsfds"}) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) beans := make([]StrangeName, 0) err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) } func TestCreatedUpdated(t *testing.T) { @@ -179,29 +139,17 @@ func TestLowerCase(t *testing.T) { assert.NoError(t, prepareEngine()) err := testEngine.Sync2(&Lowercase{}) - _, err = testEngine.Where("(id) > 0").Delete(&Lowercase{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) + _, err = testEngine.Where("id > 0").Delete(&Lowercase{}) + assert.NoError(t, err) + _, err = testEngine.Insert(&Lowercase{ended: 1}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) ls := make([]Lowercase, 0) err = testEngine.Find(&ls) - if err != nil { - t.Error(err) - panic(err) - } - - if len(ls) != 1 { - err = errors.New("should be 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, len(ls)) } func TestAutoIncrTag(t *testing.T) { @@ -297,6 +245,24 @@ func TestTagDefault(t *testing.T) { assertSync(t, new(DefaultStruct)) + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("age") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + assert.EqualValues(t, "10", defaultVal) + cnt, err := testEngine.Omit("age").Insert(&DefaultStruct{ Name: "test", Age: 20, @@ -312,6 +278,163 @@ func TestTagDefault(t *testing.T) { assert.EqualValues(t, "test", s.Name) } +func TestTagDefault2(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type DefaultStruct2 struct { + Id int64 + Name string + } + + assertSync(t, new(DefaultStruct2)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct2") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("name") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.False(t, isDefaultExist, fmt.Sprintf("default value is --%v--", defaultVal)) + assert.EqualValues(t, "", defaultVal) +} + +func TestTagDefault3(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type DefaultStruct3 struct { + Id int64 + Name string `xorm:"default('myname')"` + } + + assertSync(t, new(DefaultStruct3)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct3") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("name") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + assert.EqualValues(t, "'myname'", defaultVal) +} + +func TestTagDefault4(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type DefaultStruct4 struct { + Id int64 + Created time.Time `xorm:"default(CURRENT_TIMESTAMP)"` + } + + assertSync(t, new(DefaultStruct4)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct4") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("created") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + assert.True(t, "CURRENT_TIMESTAMP" == defaultVal || + "now()" == defaultVal || + "getdate" == defaultVal, defaultVal) +} + +func TestTagDefault5(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type DefaultStruct5 struct { + Id int64 + Created time.Time `xorm:"default('2006-01-02 15:04:05')"` + } + + assertSync(t, new(DefaultStruct5)) + table := testEngine.TableInfo(new(DefaultStruct5)) + createdCol := table.GetColumn("created") + assert.NotNil(t, createdCol) + assert.EqualValues(t, "'2006-01-02 15:04:05'", createdCol.Default) + assert.False(t, createdCol.DefaultIsEmpty) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct5") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("created") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + assert.EqualValues(t, "'2006-01-02 15:04:05'", defaultVal) +} + +func TestTagDefault6(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type DefaultStruct6 struct { + Id int64 + IsMan bool `xorm:"default(true)"` + } + + assertSync(t, new(DefaultStruct6)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct6") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("is_man") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + if defaultVal == "1" { + defaultVal = "true" + } else if defaultVal == "0" { + defaultVal = "false" + } + assert.EqualValues(t, "true", defaultVal) +} + func TestTagsDirection(t *testing.T) { assert.NoError(t, prepareEngine()) @@ -407,3 +530,22 @@ func TestTagTime(t *testing.T) { assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1)) } + +func TestSplitTag(t *testing.T) { + var cases = []struct { + tag string + tags []string + }{ + {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, + {"TEXT", []string{"TEXT"}}, + {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, + {"json binary", []string{"json", "binary"}}, + } + + for _, kase := range cases { + tags := splitTag(kase.tag) + if !sliceEq(tags, kase.tags) { + t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) + } + } +} diff --git a/test_mssql.sh b/test_mssql.sh index e26e1641..7f060cff 100755 --- a/test_mssql.sh +++ b/test_mssql.sh @@ -1 +1 @@ -go test -db=mssql -conn_str="server=localhost;user id=sa;password=MwantsaSecurePassword1;database=xorm_test" \ No newline at end of file +go test -db=mssql -conn_str="server=localhost;user id=sa;password=yourStrong(!)Password;database=xorm_test" \ No newline at end of file