From efb8f9292e8a178cc0bae307a5225256b0f126b4 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 6 Jun 2021 21:13:21 +0800 Subject: [PATCH] Fix exist --- integrations/session_exist_test.go | 4 +++ internal/statements/query.go | 48 ++++++++++++++++-------------- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/integrations/session_exist_test.go b/integrations/session_exist_test.go index 6247c91a..29546376 100644 --- a/integrations/session_exist_test.go +++ b/integrations/session_exist_test.go @@ -75,6 +75,10 @@ func TestExistStruct(t *testing.T) { has, err = testEngine.Table("record_exist").Where("name = ?", "test2").Exist() assert.NoError(t, err) assert.False(t, has) + + has, err = testEngine.Table(new(RecordExist)).ID(1).Cols("id").Exist() + assert.NoError(t, err) + assert.True(t, has) } func TestExistStructForJoin(t *testing.T) { diff --git a/internal/statements/query.go b/internal/statements/query.go index 8b4cd919..e1091e9f 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -106,10 +106,13 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri // GenGetSQL generates Get SQL func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) { - v := rValue(bean) - isStruct := v.Kind() == reflect.Struct - if isStruct { - statement.SetRefBean(bean) + var isStruct bool + if bean != nil { + v := rValue(bean) + isStruct = v.Kind() == reflect.Struct + if isStruct { + statement.SetRefBean(bean) + } } var columnStr = statement.ColumnStr() @@ -340,12 +343,25 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac var args []interface{} var joinStr string var err error - if len(bean) == 0 { - tableName := statement.TableName() - if len(tableName) <= 0 { - return "", nil, ErrTableNotFound + var b interface{} = nil + if len(bean) > 0 { + b = bean[0] + beanValue := reflect.ValueOf(bean[0]) + if beanValue.Kind() != reflect.Ptr { + return "", nil, errors.New("needs a pointer") } + if beanValue.Elem().Kind() == reflect.Struct { + if err := statement.SetRefBean(bean[0]); err != nil { + return "", nil, err + } + } + } + tableName := statement.TableName() + if len(tableName) <= 0 { + return "", nil, ErrTableNotFound + } + if statement.RefTable == nil { tableName = statement.quote(tableName) if len(statement.JoinStr) > 0 { joinStr = statement.JoinStr @@ -376,22 +392,8 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac args = []interface{}{} } } else { - beanValue := reflect.ValueOf(bean[0]) - if beanValue.Kind() != reflect.Ptr { - return "", nil, errors.New("needs a pointer") - } - - if beanValue.Elem().Kind() == reflect.Struct { - if err := statement.SetRefBean(bean[0]); err != nil { - return "", nil, err - } - } - - if len(statement.TableName()) <= 0 { - return "", nil, ErrTableNotFound - } statement.Limit(1) - sqlStr, args, err = statement.GenGetSQL(bean[0]) + sqlStr, args, err = statement.GenGetSQL(b) if err != nil { return "", nil, err } -- 2.40.1