diff --git a/session_exist.go b/session_exist.go index 660cc47e..bce2758d 100644 --- a/session_exist.go +++ b/session_exist.go @@ -25,8 +25,8 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { var sqlStr string var args []interface{} + var joinStr string var err error - if session.statement.RawSQL == "" { if len(bean) == 0 { tableName := session.statement.TableName() @@ -35,6 +35,9 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { } tableName = session.statement.Engine.Quote(tableName) + if len(session.statement.JoinStr) > 0 { + joinStr = session.statement.JoinStr + } if session.statement.cond.IsValid() { condSQL, condArgs, err := builder.ToSQL(session.statement.cond) @@ -43,20 +46,20 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { } if session.engine.dialect.DBType() == core.MSSQL { - sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s WHERE %s", tableName, condSQL) + sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL) } else if session.engine.dialect.DBType() == core.ORACLE { - sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) AND ROWNUM=1", tableName, condSQL) + sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL) } else { - sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL) + sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL) } args = condArgs } else { if session.engine.dialect.DBType() == core.MSSQL { - sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s", tableName) + sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr) } else if session.engine.dialect.DBType() == core.ORACLE { - sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE ROWNUM=1", tableName) + sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr) } else { - sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName) + sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr) } args = []interface{}{} } diff --git a/session_exist_test.go b/session_exist_test.go index 9d985771..52d39f32 100644 --- a/session_exist_test.go +++ b/session_exist_test.go @@ -74,3 +74,106 @@ func TestExistStruct(t *testing.T) { assert.NoError(t, err) assert.False(t, has) } + +func TestExistStructForJoin(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Number struct { + Id int64 + Lid int64 + } + + type OrderList struct { + Id int64 + Eid int64 + } + + type Player struct { + Id int64 + Name string + } + + assert.NoError(t, testEngine.Sync2(new(Number), new(OrderList), new(Player))) + + var ply Player + cnt, err := testEngine.Insert(&ply) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var orderlist = OrderList{ + Eid: ply.Id, + } + cnt, err = testEngine.Insert(&orderlist) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var um = Number{ + Lid: orderlist.Id, + } + cnt, err = testEngine.Insert(&um) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + session := testEngine.NewSession() + defer session.Close() + + session.Table("number"). + Join("INNER", "order_list", "order_list.id = number.lid"). + Join("LEFT", "player", "player.id = order_list.eid"). + Where("number.lid = ?", 1) + has, err := session.Exist() + assert.NoError(t, err) + assert.True(t, has) + + session.Table("number"). + Join("INNER", "order_list", "order_list.id = number.lid"). + Join("LEFT", "player", "player.id = order_list.eid"). + Where("number.lid = ?", 2) + has, err = session.Exist() + assert.NoError(t, err) + assert.False(t, has) + + session.Table("number"). + Select("order_list.id"). + Join("INNER", "order_list", "order_list.id = number.lid"). + Join("LEFT", "player", "player.id = order_list.eid"). + Where("order_list.id = ?", 1) + has, err = session.Exist() + assert.NoError(t, err) + assert.True(t, has) + + session.Table("number"). + Select("player.id"). + Join("INNER", "order_list", "order_list.id = number.lid"). + Join("LEFT", "player", "player.id = order_list.eid"). + Where("player.id = ?", 2) + has, err = session.Exist() + assert.NoError(t, err) + assert.False(t, has) + + session.Table("number"). + Select("player.id"). + Join("INNER", "order_list", "order_list.id = number.lid"). + Join("LEFT", "player", "player.id = order_list.eid") + has, err = session.Exist() + assert.NoError(t, err) + assert.True(t, has) + + err = session.DropTable("order_list") + assert.NoError(t, err) + + session.Table("number"). + Select("player.id"). + Join("INNER", "order_list", "order_list.id = number.lid"). + Join("LEFT", "player", "player.id = order_list.eid") + has, err = session.Exist() + assert.Error(t, err) + assert.False(t, has) + + session.Table("number"). + Select("player.id"). + Join("LEFT", "player", "player.id = number.lid") + has, err = session.Exist() + assert.NoError(t, err) + assert.True(t, has) +}