join support condition #2201

Merged
lunny merged 1 commits from lunny/join_condition into master 2022-12-09 15:37:26 +00:00
5 changed files with 39 additions and 16 deletions

View File

@ -330,7 +330,7 @@ func (engine *Engine) Ping() error {
// SQL method let's you manually write raw SQL and operate // SQL method let's you manually write raw SQL and operate
// For example: // For example:
// //
// engine.SQL("select * from user").Find(&users) // engine.SQL("select * from user").Find(&users)
// //
// This code will execute "select * from user" and set the records to users // This code will execute "select * from user" and set the records to users
func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session { func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session {
@ -996,9 +996,8 @@ func (engine *Engine) Desc(colNames ...string) *Session {
// Asc will generate "ORDER BY column1,column2 Asc" // Asc will generate "ORDER BY column1,column2 Asc"
// This method can chainable use. // This method can chainable use.
// //
// engine.Desc("name").Asc("age").Find(&users) // engine.Desc("name").Asc("age").Find(&users)
// // SELECT * FROM user ORDER BY name DESC, age ASC // // SELECT * FROM user ORDER BY name DESC, age ASC
//
func (engine *Engine) Asc(colNames ...string) *Session { func (engine *Engine) Asc(colNames ...string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.isAutoClose = true session.isAutoClose = true
@ -1020,7 +1019,7 @@ func (engine *Engine) Prepare() *Session {
} }
// Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (engine *Engine) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { func (engine *Engine) Join(joinOperator string, tablename interface{}, condition interface{}, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.isAutoClose = true session.isAutoClose = true
return session.Join(joinOperator, tablename, condition, args...) return session.Join(joinOperator, tablename, condition, args...)
@ -1220,9 +1219,10 @@ func (engine *Engine) InsertOne(bean interface{}) (int64, error) {
// Update records, bean's non-empty fields are updated contents, // Update records, bean's non-empty fields are updated contents,
// condiBean' non-empty filds are conditions // condiBean' non-empty filds are conditions
// CAUTION: // CAUTION:
// 1.bool will defaultly be updated content nor conditions //
// You should call UseBool if you have bool to use. // 1.bool will defaultly be updated content nor conditions
// 2.float32 & float64 may be not inexact as conditions // You should call UseBool if you have bool to use.
// 2.float32 & float64 may be not inexact as conditions
func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) { func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()

View File

@ -8,6 +8,7 @@ import (
"testing" "testing"
"time" "time"
"xorm.io/builder"
"xorm.io/xorm" "xorm.io/xorm"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/names" "xorm.io/xorm/names"
@ -965,6 +966,10 @@ func TestFindJoin(t *testing.T) {
scenes = make([]SceneItem, 0) scenes = make([]SceneItem, 0)
err = testEngine.Join("INNER", "order", "`scene_item`.`device_id`=`order`.`id`").Find(&scenes) err = testEngine.Join("INNER", "order", "`scene_item`.`device_id`=`order`.`id`").Find(&scenes)
assert.NoError(t, err) assert.NoError(t, err)
scenes = make([]SceneItem, 0)
err = testEngine.Join("INNER", "order", builder.Expr("`scene_item`.`device_id`=`order`.`id`")).Find(&scenes)
assert.NoError(t, err)
} }
func TestJoinFindLimit(t *testing.T) { func TestJoinFindLimit(t *testing.T) {

View File

@ -52,7 +52,7 @@ type Interface interface {
NoAutoCondition(...bool) *Session NoAutoCondition(...bool) *Session
NotIn(string, ...interface{}) *Session NotIn(string, ...interface{}) *Session
Nullable(...string) *Session Nullable(...string) *Session
Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session Join(joinOperator string, tablename interface{}, condition interface{}, args ...interface{}) *Session
Omit(columns ...string) *Session Omit(columns ...string) *Session
OrderBy(order interface{}, args ...interface{}) *Session OrderBy(order interface{}, args ...interface{}) *Session
Ping() error Ping() error

View File

@ -15,7 +15,7 @@ import (
) )
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { func (statement *Statement) Join(joinOP string, tablename interface{}, condition interface{}, args ...interface{}) *Statement {
var buf strings.Builder var buf strings.Builder
if len(statement.JoinStr) > 0 { if len(statement.JoinStr) > 0 {
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
@ -23,6 +23,23 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
fmt.Fprintf(&buf, "%v JOIN ", joinOP) fmt.Fprintf(&buf, "%v JOIN ", joinOP)
} }
condStr := ""
condArgs := []interface{}{}
switch condTp := condition.(type) {
case string:
condStr = condTp
case builder.Cond:
var err error
condStr, condArgs, err = builder.ToSQL(condTp)
if err != nil {
statement.LastError = err
return statement
}
default:
statement.LastError = fmt.Errorf("unsupported join condition type: %v", condTp)
return statement
}
switch tp := tablename.(type) { switch tp := tablename.(type) {
case builder.Builder: case builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL() subSQL, subQueryArgs, err := tp.ToSQL()
@ -35,8 +52,8 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName) aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition)) fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr))
statement.joinArgs = append(statement.joinArgs, subQueryArgs...) statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...)
case *builder.Builder: case *builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL() subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil { if err != nil {
@ -48,8 +65,8 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName) aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition)) fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr))
statement.joinArgs = append(statement.joinArgs, subQueryArgs...) statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...)
default: default:
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
if !utils.IsSubQuery(tbName) { if !utils.IsSubQuery(tbName) {
@ -59,7 +76,8 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
} else { } else {
tbName = statement.ReplaceQuote(tbName) tbName = statement.ReplaceQuote(tbName)
} }
fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition)) fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condStr))
statement.joinArgs = append(statement.joinArgs, condArgs...)
} }
statement.JoinStr = buf.String() statement.JoinStr = buf.String()

View File

@ -330,7 +330,7 @@ func (session *Session) NoCache() *Session {
} }
// Join join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (session *Session) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { func (session *Session) Join(joinOperator string, tablename interface{}, condition interface{}, args ...interface{}) *Session {
session.statement.Join(joinOperator, tablename, condition, args...) session.statement.Join(joinOperator, tablename, condition, args...)
return session return session
} }