some improvement #2136

Merged
lunny merged 2 commits from lunny/fix_bug into master 2022-04-22 02:16:35 +00:00
5 changed files with 238 additions and 67 deletions

View File

@ -70,7 +70,7 @@ func TestRows(t *testing.T) {
} }
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var tbName = testEngine.Quote(testEngine.TableName(user, true)) tbName := testEngine.Quote(testEngine.TableName(user, true))
rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows)) rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows))
assert.NoError(t, err) assert.NoError(t, err)
defer rows2.Close() defer rows2.Close()
@ -92,7 +92,7 @@ func TestRowsMyTableName(t *testing.T) {
IsMan bool IsMan bool
} }
var tableName = "user_rows_my_table_name" tableName := "user_rows_my_table_name"
assert.NoError(t, testEngine.Table(tableName).Sync(new(UserRowsMyTable))) assert.NoError(t, testEngine.Table(tableName).Sync(new(UserRowsMyTable)))
@ -206,3 +206,75 @@ func TestRowsScanVars(t *testing.T) {
assert.NoError(t, rows.Err()) assert.NoError(t, rows.Err())
assert.EqualValues(t, 2, cnt) assert.EqualValues(t, 2, cnt)
} }
func TestRowsScanBytes(t *testing.T) {
type RowsScanBytes struct {
Id int64
Bytes1 []byte
Bytes2 []byte
}
assert.NoError(t, PrepareEngine())
assert.NoError(t, testEngine.Sync(new(RowsScanBytes)))
cnt, err := testEngine.Insert(&RowsScanBytes{
Bytes1: []byte("bytes1"),
Bytes2: []byte("bytes2"),
}, &RowsScanBytes{
Bytes1: []byte("bytes1-1"),
Bytes2: []byte("bytes2-2"),
})
assert.NoError(t, err)
assert.EqualValues(t, 2, cnt)
{
rows, err := testEngine.Cols("bytes1, bytes2").Rows(new(RowsScanBytes))
assert.NoError(t, err)
defer rows.Close()
cnt = 0
var bytes1 []byte
var bytes2 []byte
for rows.Next() {
err = rows.Scan(&bytes1, &bytes2)
assert.NoError(t, err)
if cnt == 0 {
assert.EqualValues(t, []byte("bytes1"), bytes1)
assert.EqualValues(t, []byte("bytes2"), bytes2)
} else if cnt == 1 {
// bytes1 now should be `bytes1` but will be override
assert.EqualValues(t, []byte("bytes1-1"), bytes1)
assert.EqualValues(t, []byte("bytes2-2"), bytes2)
}
cnt++
}
assert.NoError(t, rows.Err())
assert.EqualValues(t, 2, cnt)
rows.Close()
}
{
rows, err := testEngine.Cols("bytes1, bytes2").Rows(new(RowsScanBytes))
assert.NoError(t, err)
defer rows.Close()
cnt = 0
var rsb RowsScanBytes
for rows.Next() {
err = rows.Scan(&rsb)
assert.NoError(t, err)
if cnt == 0 {
assert.EqualValues(t, []byte("bytes1"), rsb.Bytes1)
assert.EqualValues(t, []byte("bytes2"), rsb.Bytes2)
} else if cnt == 1 {
// bytes1 now should be `bytes1` but will be override
assert.EqualValues(t, []byte("bytes1-1"), rsb.Bytes1)
assert.EqualValues(t, []byte("bytes2-2"), rsb.Bytes2)
}
cnt++
}
assert.NoError(t, rows.Err())
assert.EqualValues(t, 2, cnt)
rows.Close()
}
}

View File

@ -40,14 +40,14 @@ func TestJoinLimit(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var checklist = CheckList{ checklist := CheckList{
Eid: emp.Id, Eid: emp.Id,
} }
cnt, err = testEngine.Insert(&checklist) cnt, err = testEngine.Insert(&checklist)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var salary = Salary{ salary := Salary{
Lid: checklist.Id, Lid: checklist.Id,
} }
cnt, err = testEngine.Insert(&salary) cnt, err = testEngine.Insert(&salary)
@ -89,7 +89,7 @@ func TestFind(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
users2 := make([]Userinfo, 0) users2 := make([]Userinfo, 0)
var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true)) tbName := testEngine.Quote(testEngine.TableName(new(Userinfo), true))
err = testEngine.SQL("select * from " + tbName).Find(&users2) err = testEngine.SQL("select * from " + tbName).Find(&users2)
assert.NoError(t, err) assert.NoError(t, err)
} }
@ -119,7 +119,7 @@ func (TeamUser) TableName() string {
} }
func TestFind3(t *testing.T) { func TestFind3(t *testing.T) {
var teamUser = new(TeamUser) teamUser := new(TeamUser)
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
err := testEngine.Sync(new(Team), teamUser) err := testEngine.Sync(new(Team), teamUser)
assert.NoError(t, err) assert.NoError(t, err)
@ -426,7 +426,7 @@ func TestFindBool(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, cnt) assert.EqualValues(t, 2, cnt)
var results = make([]FindBoolStruct, 0, 2) results := make([]FindBoolStruct, 0, 2)
err = testEngine.Find(&results) err = testEngine.Find(&results)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, len(results)) assert.EqualValues(t, 2, len(results))
@ -457,7 +457,7 @@ func TestFindMark(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, cnt) assert.EqualValues(t, 2, cnt)
var results = make([]Mark, 0, 2) results := make([]Mark, 0, 2)
err = testEngine.Find(&results) err = testEngine.Find(&results)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, len(results)) assert.EqualValues(t, 2, len(results))
@ -486,7 +486,7 @@ func TestFindAndCountOneFunc(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, cnt) assert.EqualValues(t, 2, cnt)
var results = make([]FindAndCountStruct, 0, 2) results := make([]FindAndCountStruct, 0, 2)
cnt, err = testEngine.Limit(1).FindAndCount(&results) cnt, err = testEngine.Limit(1).FindAndCount(&results)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, len(results))
@ -611,14 +611,14 @@ func TestFindAndCount2(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
assertSync(t, new(TestFindAndCountUser), new(TestFindAndCountHotel)) assertSync(t, new(TestFindAndCountUser), new(TestFindAndCountHotel))
var u = TestFindAndCountUser{ u := TestFindAndCountUser{
Name: "myname", Name: "myname",
} }
cnt, err := testEngine.Insert(&u) cnt, err := testEngine.Insert(&u)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var hotel = TestFindAndCountHotel{ hotel := TestFindAndCountHotel{
Name: "myhotel", Name: "myhotel",
Code: "111", Code: "111",
Region: "222", Region: "222",
@ -1063,7 +1063,7 @@ func TestUpdateFind(t *testing.T) {
session := testEngine.NewSession() session := testEngine.NewSession()
defer session.Close() defer session.Close()
var tuf = TestUpdateFind{ tuf := TestUpdateFind{
Name: "test", Name: "test",
} }
_, err := session.Insert(&tuf) _, err := session.Insert(&tuf)
@ -1095,7 +1095,7 @@ func TestFindAnonymousStruct(t *testing.T) {
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
assert.NoError(t, err) assert.NoError(t, err)
var findRes = make([]struct { findRes := make([]struct {
Id int64 Id int64
Name string Name string
}, 0) }, 0)
@ -1115,3 +1115,47 @@ func TestFindAnonymousStruct(t *testing.T) {
assert.EqualValues(t, 1, findRes[0].Id) assert.EqualValues(t, 1, findRes[0].Id)
assert.EqualValues(t, "xlw", findRes[0].Name) assert.EqualValues(t, "xlw", findRes[0].Name)
} }
func TestFindBytesVars(t *testing.T) {
type FindBytesVars struct {
Id int64
Bytes1 []byte
Bytes2 []byte
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(FindBytesVars))
_, err := testEngine.Insert([]FindBytesVars{
{
Bytes1: []byte("bytes1"),
Bytes2: []byte("bytes2"),
},
{
Bytes1: []byte("bytes1-1"),
Bytes2: []byte("bytes2-2"),
},
})
assert.NoError(t, err)
var gbv []FindBytesVars
err = testEngine.Find(&gbv)
assert.NoError(t, err)
assert.EqualValues(t, 2, len(gbv))
assert.EqualValues(t, []byte("bytes1"), gbv[0].Bytes1)
assert.EqualValues(t, []byte("bytes2"), gbv[0].Bytes2)
assert.EqualValues(t, []byte("bytes1-1"), gbv[1].Bytes1)
assert.EqualValues(t, []byte("bytes2-2"), gbv[1].Bytes2)
err = testEngine.Find(&gbv)
assert.NoError(t, err)
assert.EqualValues(t, 4, len(gbv))
assert.EqualValues(t, []byte("bytes1"), gbv[0].Bytes1)
assert.EqualValues(t, []byte("bytes2"), gbv[0].Bytes2)
assert.EqualValues(t, []byte("bytes1-1"), gbv[1].Bytes1)
assert.EqualValues(t, []byte("bytes2-2"), gbv[1].Bytes2)
assert.EqualValues(t, []byte("bytes1"), gbv[2].Bytes1)
assert.EqualValues(t, []byte("bytes2"), gbv[2].Bytes2)
assert.EqualValues(t, []byte("bytes1-1"), gbv[3].Bytes1)
assert.EqualValues(t, []byte("bytes2-2"), gbv[3].Bytes2)
}

View File

@ -35,7 +35,7 @@ func TestGetVar(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVar))) assert.NoError(t, testEngine.Sync(new(GetVar)))
var data = GetVar{ data := GetVar{
Msg: "hi", Msg: "hi",
Age: 28, Age: 28,
Money: 1.5, Money: 1.5,
@ -175,7 +175,7 @@ func TestGetVar(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, false, has) assert.Equal(t, false, has)
var valuesString = make(map[string]string) valuesString := make(map[string]string)
has, err = testEngine.Table("get_var").Get(&valuesString) has, err = testEngine.Table("get_var").Get(&valuesString)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
@ -187,7 +187,7 @@ func TestGetVar(t *testing.T) {
// for mymysql driver, interface{} will be []byte, so ignore it currently // for mymysql driver, interface{} will be []byte, so ignore it currently
if testEngine.DriverName() != "mymysql" { if testEngine.DriverName() != "mymysql" {
var valuesInter = make(map[string]interface{}) valuesInter := make(map[string]interface{})
has, err = testEngine.Table("get_var").Where("`id` = ?", 1).Select("*").Get(&valuesInter) has, err = testEngine.Table("get_var").Where("`id` = ?", 1).Select("*").Get(&valuesInter)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
@ -198,7 +198,7 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"])) assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"]))
} }
var valuesSliceString = make([]string, 5) valuesSliceString := make([]string, 5)
has, err = testEngine.Table("get_var").Get(&valuesSliceString) has, err = testEngine.Table("get_var").Get(&valuesSliceString)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
@ -207,7 +207,7 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "28", valuesSliceString[2]) assert.Equal(t, "28", valuesSliceString[2])
assert.Equal(t, "1.5", valuesSliceString[3]) assert.Equal(t, "1.5", valuesSliceString[3])
var valuesSliceInter = make([]interface{}, 5) valuesSliceInter := make([]interface{}, 5)
has, err = testEngine.Table("get_var").Get(&valuesSliceInter) has, err = testEngine.Table("get_var").Get(&valuesSliceInter)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
@ -317,7 +317,7 @@ func TestGetMap(t *testing.T) {
_, err := testEngine.Exec(fmt.Sprintf("INSERT INTO %s (`is_man`) VALUES (NULL)", tableName)) _, err := testEngine.Exec(fmt.Sprintf("INSERT INTO %s (`is_man`) VALUES (NULL)", tableName))
assert.NoError(t, err) assert.NoError(t, err)
var valuesString = make(map[string]string) valuesString := make(map[string]string)
has, err := testEngine.Table("userinfo_map").Get(&valuesString) has, err := testEngine.Table("userinfo_map").Get(&valuesString)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
@ -336,7 +336,7 @@ func TestGetError(t *testing.T) {
assertSync(t, new(GetError)) assertSync(t, new(GetError))
var info = new(GetError) info := new(GetError)
has, err := testEngine.Get(&info) has, err := testEngine.Get(&info)
assert.False(t, has) assert.False(t, has)
assert.Error(t, err) assert.Error(t, err)
@ -456,7 +456,7 @@ func TestGetActionMapping(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
var valuesSlice = make([]string, 2) valuesSlice := make([]string, 2)
has, err := testEngine.Table(new(ActionMapping)). has, err := testEngine.Table(new(ActionMapping)).
Cols("script_id", "rollback_id"). Cols("script_id", "rollback_id").
ID("1").Get(&valuesSlice) ID("1").Get(&valuesSlice)
@ -483,7 +483,7 @@ func TestGetStructId(t *testing.T) {
Id int64 Id int64
} }
//var id int64 // var id int64
var maxid maxidst var maxid maxidst
sql := "select max(`id`) as id from " + testEngine.Quote(testEngine.TableName(&TestGetStruct{}, true)) sql := "select max(`id`) as id from " + testEngine.Quote(testEngine.TableName(&TestGetStruct{}, true))
has, err := testEngine.SQL(sql).Get(&maxid) has, err := testEngine.SQL(sql).Get(&maxid)
@ -693,7 +693,7 @@ func TestCustomTypes(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
assertSync(t, new(TestCustomizeStruct)) assertSync(t, new(TestCustomizeStruct))
var s = TestCustomizeStruct{ s := TestCustomizeStruct{
Name: "test", Name: "test",
Age: 32, Age: 32,
} }
@ -763,7 +763,7 @@ func TestGetBigFloat(t *testing.T) {
assertSync(t, new(GetBigFloat)) assertSync(t, new(GetBigFloat))
{ {
var gf = GetBigFloat{ gf := GetBigFloat{
Money: big.NewFloat(999999.99), Money: big.NewFloat(999999.99),
} }
_, err := testEngine.Insert(&gf) _, err := testEngine.Insert(&gf)
@ -774,8 +774,8 @@ func TestGetBigFloat(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String())
//fmt.Println(m.Cmp(gf.Money)) // fmt.Println(m.Cmp(gf.Money))
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) // assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
} }
type GetBigFloat2 struct { type GetBigFloat2 struct {
@ -788,7 +788,7 @@ func TestGetBigFloat(t *testing.T) {
assertSync(t, new(GetBigFloat2)) assertSync(t, new(GetBigFloat2))
{ {
var gf2 = GetBigFloat2{ gf2 := GetBigFloat2{
Money: big.NewFloat(9999999.99), Money: big.NewFloat(9999999.99),
Money2: *big.NewFloat(99.99), Money2: *big.NewFloat(99.99),
} }
@ -800,8 +800,8 @@ func TestGetBigFloat(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
assert.True(t, m2.String() == gf2.Money.String(), "%v != %v", m2.String(), gf2.Money.String()) assert.True(t, m2.String() == gf2.Money.String(), "%v != %v", m2.String(), gf2.Money.String())
//fmt.Println(m.Cmp(gf.Money)) // fmt.Println(m.Cmp(gf.Money))
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) // assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
var gf3 GetBigFloat2 var gf3 GetBigFloat2
has, err = testEngine.ID(gf2.Id).Get(&gf3) has, err = testEngine.ID(gf2.Id).Get(&gf3)
@ -829,7 +829,7 @@ func TestGetDecimal(t *testing.T) {
assertSync(t, new(GetDecimal)) assertSync(t, new(GetDecimal))
{ {
var gf = GetDecimal{ gf := GetDecimal{
Money: decimal.NewFromFloat(999999.99), Money: decimal.NewFromFloat(999999.99),
} }
_, err := testEngine.Insert(&gf) _, err := testEngine.Insert(&gf)
@ -840,8 +840,8 @@ func TestGetDecimal(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String())
//fmt.Println(m.Cmp(gf.Money)) // fmt.Println(m.Cmp(gf.Money))
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) // assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
} }
type GetDecimal2 struct { type GetDecimal2 struct {
@ -854,7 +854,7 @@ func TestGetDecimal(t *testing.T) {
{ {
v := decimal.NewFromFloat(999999.99) v := decimal.NewFromFloat(999999.99)
var gf = GetDecimal2{ gf := GetDecimal2{
Money: &v, Money: &v,
} }
_, err := testEngine.Insert(&gf) _, err := testEngine.Insert(&gf)
@ -865,10 +865,11 @@ func TestGetDecimal(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String())
//fmt.Println(m.Cmp(gf.Money)) // fmt.Println(m.Cmp(gf.Money))
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) // assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
} }
} }
func TestGetTime(t *testing.T) { func TestGetTime(t *testing.T) {
type GetTimeStruct struct { type GetTimeStruct struct {
Id int64 Id int64
@ -878,7 +879,7 @@ func TestGetTime(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
assertSync(t, new(GetTimeStruct)) assertSync(t, new(GetTimeStruct))
var gts = GetTimeStruct{ gts := GetTimeStruct{
CreateTime: time.Now().In(testEngine.GetTZLocation()), CreateTime: time.Now().In(testEngine.GetTZLocation()),
} }
_, err := testEngine.Insert(&gts) _, err := testEngine.Insert(&gts)
@ -976,3 +977,39 @@ func TestGetWithPrepare(t *testing.T) {
err = sess.Commit() err = sess.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestGetBytesVars(t *testing.T) {
type GetBytesVars struct {
Id int64
Bytes1 []byte
Bytes2 []byte
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(GetBytesVars))
_, err := testEngine.Insert([]GetBytesVars{
{
Bytes1: []byte("bytes1"),
Bytes2: []byte("bytes2"),
},
{
Bytes1: []byte("bytes1-1"),
Bytes2: []byte("bytes2-2"),
},
})
assert.NoError(t, err)
var gbv GetBytesVars
has, err := testEngine.Asc("id").Get(&gbv)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, []byte("bytes1"), gbv.Bytes1)
assert.EqualValues(t, []byte("bytes2"), gbv.Bytes2)
has, err = testEngine.Desc("id").NoAutoCondition().Get(&gbv)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, []byte("bytes1-1"), gbv.Bytes1)
assert.EqualValues(t, []byte("bytes2-2"), gbv.Bytes2)
}

24
scan.go
View File

@ -22,7 +22,7 @@ func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) {
switch t := bean.(type) { switch t := bean.(type) {
case *interface{}: case *interface{}:
return t, false, nil return t, false, nil
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes: case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes, *[]byte:
return t, false, nil return t, false, nil
case *time.Time: case *time.Time:
return &sql.NullString{}, true, nil return &sql.NullString{}, true, nil
@ -67,7 +67,7 @@ func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) {
case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8: case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8:
return &convert.NullUint32{}, true, nil return &convert.NullUint32{}, true, nil
default: default:
return nil, false, fmt.Errorf("unsupported type: %#v", bean) return nil, false, fmt.Errorf("genScanResultsByBeanNullable: unsupported type: %#v", bean)
} }
} }
@ -125,12 +125,12 @@ func genScanResultsByBean(bean interface{}) (interface{}, bool, error) {
case reflect.Float64: case reflect.Float64:
return new(float64), true, nil return new(float64), true, nil
default: default:
return nil, false, fmt.Errorf("unsupported type: %#v", bean) return nil, false, fmt.Errorf("genScanResultsByBean: unsupported type: %#v", bean)
} }
} }
func (engine *Engine) scanStringInterface(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) { func (engine *Engine) scanStringInterface(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) {
var scanResults = make([]interface{}, len(types)) scanResults := make([]interface{}, len(types))
for i := 0; i < len(types); i++ { for i := 0; i < len(types); i++ {
var s sql.NullString var s sql.NullString
scanResults[i] = &s scanResults[i] = &s
@ -144,8 +144,8 @@ func (engine *Engine) scanStringInterface(rows *core.Rows, fields []string, type
// scan is a wrap of driver.Scan but will automatically change the input values according requirements // scan is a wrap of driver.Scan but will automatically change the input values according requirements
func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error { func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error {
var scanResults = make([]interface{}, 0, len(types)) scanResults := make([]interface{}, 0, len(types))
var replaces = make([]bool, 0, len(types)) replaces := make([]bool, 0, len(types))
var err error var err error
for _, v := range vv { for _, v := range vv {
var replaced bool var replaced bool
@ -194,7 +194,7 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
} }
func (engine *Engine) scanInterfaces(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) { func (engine *Engine) scanInterfaces(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) {
var scanResultContainers = make([]interface{}, len(types)) scanResultContainers := make([]interface{}, len(types))
for i := 0; i < len(types); i++ { for i := 0; i < len(types); i++ {
scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName())
if err != nil { if err != nil {
@ -212,8 +212,8 @@ func (engine *Engine) scanInterfaces(rows *core.Rows, fields []string, types []*
// row -> map[string]interface{} // row -> map[string]interface{}
func (engine *Engine) row2mapInterface(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]interface{}, error) { func (engine *Engine) row2mapInterface(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]interface{}, error) {
var resultsMap = make(map[string]interface{}, len(fields)) resultsMap := make(map[string]interface{}, len(fields))
var scanResultContainers = make([]interface{}, len(fields)) scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ { for i := 0; i < len(fields); i++ {
scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName())
if err != nil { if err != nil {
@ -277,7 +277,7 @@ func (engine *Engine) ScanInterfaceMaps(rows *core.Rows) (resultsSlice []map[str
// row -> map[string]string // row -> map[string]string
func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) {
var scanResults = make([]interface{}, len(fields)) scanResults := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ { for i := 0; i < len(fields); i++ {
var s sql.NullString var s sql.NullString
scanResults[i] = &s scanResults[i] = &s
@ -353,7 +353,7 @@ func (engine *Engine) ScanStringMaps(rows *core.Rows) (resultsSlice []map[string
// row -> map[string][]byte // row -> map[string][]byte
func convertMapStr2Bytes(m map[string]string) map[string][]byte { func convertMapStr2Bytes(m map[string]string) map[string][]byte {
var r = make(map[string][]byte, len(m)) r := make(map[string][]byte, len(m))
for k, v := range m { for k, v := range m {
r[k] = []byte(v) r[k] = []byte(v)
} }
@ -392,7 +392,7 @@ func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fie
return nil, err return nil, err
} }
var results = make([]string, 0, len(fields)) results := make([]string, 0, len(fields))
for i := 0; i < len(fields); i++ { for i := 0; i < len(fields); i++ {
results = append(results, scanResults[i].(*sql.NullString).String) results = append(results, scanResults[i].(*sql.NullString).String)
} }

View File

@ -79,7 +79,7 @@ type Session struct {
afterClosures []func(interface{}) afterClosures []func(interface{})
afterProcessors []executedProcessor afterProcessors []executedProcessor
stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) stmtCache map[uint32]*core.Stmt // key: hash.Hash32 of (queryStr, len(queryStr))
txStmtCache map[uint32]*core.Stmt // for tx statement txStmtCache map[uint32]*core.Stmt // for tx statement
lastSQL string lastSQL string
@ -314,7 +314,7 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session {
// MustLogSQL means record SQL or not and don't follow engine's setting // MustLogSQL means record SQL or not and don't follow engine's setting
func (session *Session) MustLogSQL(logs ...bool) *Session { func (session *Session) MustLogSQL(logs ...bool) *Session {
var showSQL = true showSQL := true
if len(logs) > 0 { if len(logs) > 0 {
showSQL = logs[0] showSQL = logs[0]
} }
@ -396,7 +396,7 @@ func (session *Session) doPrepareTx(sqlStr string) (stmt *core.Stmt, err error)
} }
func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) { func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) {
var col = table.GetColumnIdx(colName, idx) col := table.GetColumnIdx(colName, idx)
if col == nil { if col == nil {
return nil, nil, ErrFieldIsNotExist{colName, table.Name} return nil, nil, ErrFieldIsNotExist{colName, table.Name}
} }
@ -420,9 +420,10 @@ type Cell *interface{}
func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sql.ColumnType, func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sql.ColumnType,
table *schemas.Table, newElemFunc func([]string) reflect.Value, table *schemas.Table, newElemFunc func([]string) reflect.Value,
sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { sliceValueSetFunc func(*reflect.Value, schemas.PK) error,
) error {
for rows.Next() { for rows.Next() {
var newValue = newElemFunc(fields) newValue := newElemFunc(fields)
bean := newValue.Interface() bean := newValue.Interface()
dataStruct := newValue.Elem() dataStruct := newValue.Elem()
@ -533,8 +534,11 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv)
} }
var uint8ZeroValue = reflect.ValueOf(uint8(0))
func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value,
scanResult interface{}, table *schemas.Table) error { scanResult interface{}, table *schemas.Table,
) error {
v, ok := scanResult.(*interface{}) v, ok := scanResult.(*interface{})
if ok { if ok {
scanResult = *v scanResult = *v
@ -596,7 +600,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
return nil return nil
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
return setJSON(fieldValue, fieldType, scanResult) return setJSON(fieldValue, fieldType, scanResult)
case reflect.Slice, reflect.Array: case reflect.Slice:
bs, ok := convert.AsBytes(scanResult) bs, ok := convert.AsBytes(scanResult)
if ok && fieldType.Elem().Kind() == reflect.Uint8 { if ok && fieldType.Elem().Kind() == reflect.Uint8 {
if col.SQLType.IsText() { if col.SQLType.IsText() {
@ -607,15 +611,29 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
} }
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
} else { } else {
if fieldValue.Len() > 0 { fieldValue.Set(reflect.ValueOf(bs))
for i := 0; i < fieldValue.Len(); i++ { }
if i < vv.Len() { return nil
fieldValue.Index(i).Set(vv.Index(i)) }
} case reflect.Array:
} bs, ok := convert.AsBytes(scanResult)
} else { if ok && fieldType.Elem().Kind() == reflect.Uint8 {
for i := 0; i < vv.Len(); i++ { if col.SQLType.IsText() {
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil {
return err
}
fieldValue.Set(x.Elem())
} else {
if fieldValue.Len() < vv.Len() {
return fmt.Errorf("Set field %s[Array] failed because of data too long", col.Name)
}
for i := 0; i < fieldValue.Len(); i++ {
if i < vv.Len() {
fieldValue.Index(i).Set(vv.Index(i))
} else {
fieldValue.Index(i).Set(uint8ZeroValue)
} }
} }
} }
@ -659,7 +677,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
if len(table.PrimaryKeys) != 1 { if len(table.PrimaryKeys) != 1 {
return errors.New("unsupported non or composited primary key cascade") return errors.New("unsupported non or composited primary key cascade")
} }
var pk = make(schemas.PK, len(table.PrimaryKeys)) pk := make(schemas.PK, len(table.PrimaryKeys))
pk[0], err = asKind(vv, reflect.TypeOf(scanResult)) pk[0], err = asKind(vv, reflect.TypeOf(scanResult))
if err != nil { if err != nil {
return err return err
@ -694,11 +712,11 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
buildAfterProcessors(session, bean) buildAfterProcessors(session, bean)
var tempMap = make(map[string]int) tempMap := make(map[string]int)
var pk schemas.PK var pk schemas.PK
for i, colName := range fields { for i, colName := range fields {
var idx int var idx int
var lKey = strings.ToLower(colName) lKey := strings.ToLower(colName)
var ok bool var ok bool
if idx, ok = tempMap[lKey]; !ok { if idx, ok = tempMap[lKey]; !ok {