Fix exist #1921

Merged
lunny merged 2 commits from lunny/fix_exist2 into master 2021-06-07 05:45:35 +00:00
2 changed files with 29 additions and 23 deletions

View File

@ -75,6 +75,10 @@ func TestExistStruct(t *testing.T) {
has, err = testEngine.Table("record_exist").Where("name = ?", "test2").Exist() has, err = testEngine.Table("record_exist").Where("name = ?", "test2").Exist()
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, has) assert.False(t, has)
has, err = testEngine.Table(new(RecordExist)).ID(1).Cols("id").Exist()
assert.NoError(t, err)
assert.True(t, has)
} }
func TestExistStructForJoin(t *testing.T) { func TestExistStructForJoin(t *testing.T) {

View File

@ -106,10 +106,13 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
// GenGetSQL generates Get SQL // GenGetSQL generates Get SQL
func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) { func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) {
v := rValue(bean) var isStruct bool
isStruct := v.Kind() == reflect.Struct if bean != nil {
if isStruct { v := rValue(bean)
statement.SetRefBean(bean) isStruct = v.Kind() == reflect.Struct
if isStruct {
statement.SetRefBean(bean)
}
} }
var columnStr = statement.ColumnStr() var columnStr = statement.ColumnStr()
@ -340,12 +343,25 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
var args []interface{} var args []interface{}
var joinStr string var joinStr string
var err error var err error
if len(bean) == 0 { var b interface{} = nil
tableName := statement.TableName() if len(bean) > 0 {
if len(tableName) <= 0 { b = bean[0]
return "", nil, ErrTableNotFound beanValue := reflect.ValueOf(bean[0])
if beanValue.Kind() != reflect.Ptr {
return "", nil, errors.New("needs a pointer")
} }
if beanValue.Elem().Kind() == reflect.Struct {
if err := statement.SetRefBean(bean[0]); err != nil {
return "", nil, err
}
}
}
tableName := statement.TableName()
if len(tableName) <= 0 {
return "", nil, ErrTableNotFound
}
if statement.RefTable == nil {
tableName = statement.quote(tableName) tableName = statement.quote(tableName)
if len(statement.JoinStr) > 0 { if len(statement.JoinStr) > 0 {
joinStr = statement.JoinStr joinStr = statement.JoinStr
@ -376,22 +392,8 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
args = []interface{}{} args = []interface{}{}
} }
} else { } else {
beanValue := reflect.ValueOf(bean[0])
if beanValue.Kind() != reflect.Ptr {
return "", nil, errors.New("needs a pointer")
}
if beanValue.Elem().Kind() == reflect.Struct {
if err := statement.SetRefBean(bean[0]); err != nil {
return "", nil, err
}
}
if len(statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}
statement.Limit(1) statement.Limit(1)
sqlStr, args, err = statement.GenGetSQL(bean[0]) sqlStr, args, err = statement.GenGetSQL(b)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }