From f2dfaded1e5f5344b6a71577de7db05863cf1b76 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 9 Dec 2022 17:40:23 +0800 Subject: [PATCH] join support condition --- engine.go | 16 ++++++++-------- integrations/session_find_test.go | 5 +++++ interface.go | 2 +- internal/statements/join.go | 30 ++++++++++++++++++++++++------ session.go | 2 +- 5 files changed, 39 insertions(+), 16 deletions(-) diff --git a/engine.go b/engine.go index 81cfc7a9..a42519ee 100644 --- a/engine.go +++ b/engine.go @@ -330,7 +330,7 @@ func (engine *Engine) Ping() error { // SQL method let's you manually write raw SQL and operate // For example: // -// engine.SQL("select * from user").Find(&users) +// engine.SQL("select * from user").Find(&users) // // This code will execute "select * from user" and set the records to users func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session { @@ -996,9 +996,8 @@ func (engine *Engine) Desc(colNames ...string) *Session { // Asc will generate "ORDER BY column1,column2 Asc" // This method can chainable use. // -// engine.Desc("name").Asc("age").Find(&users) -// // SELECT * FROM user ORDER BY name DESC, age ASC -// +// engine.Desc("name").Asc("age").Find(&users) +// // SELECT * FROM user ORDER BY name DESC, age ASC func (engine *Engine) Asc(colNames ...string) *Session { session := engine.NewSession() session.isAutoClose = true @@ -1020,7 +1019,7 @@ func (engine *Engine) Prepare() *Session { } // Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (engine *Engine) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { +func (engine *Engine) Join(joinOperator string, tablename interface{}, condition interface{}, args ...interface{}) *Session { session := engine.NewSession() session.isAutoClose = true return session.Join(joinOperator, tablename, condition, args...) @@ -1220,9 +1219,10 @@ func (engine *Engine) InsertOne(bean interface{}) (int64, error) { // Update records, bean's non-empty fields are updated contents, // condiBean' non-empty filds are conditions // CAUTION: -// 1.bool will defaultly be updated content nor conditions -// You should call UseBool if you have bool to use. -// 2.float32 & float64 may be not inexact as conditions +// +// 1.bool will defaultly be updated content nor conditions +// You should call UseBool if you have bool to use. +// 2.float32 & float64 may be not inexact as conditions func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) { session := engine.NewSession() defer session.Close() diff --git a/integrations/session_find_test.go b/integrations/session_find_test.go index 6701b1b5..5c2a4c68 100644 --- a/integrations/session_find_test.go +++ b/integrations/session_find_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "xorm.io/builder" "xorm.io/xorm" "xorm.io/xorm/internal/utils" "xorm.io/xorm/names" @@ -965,6 +966,10 @@ func TestFindJoin(t *testing.T) { scenes = make([]SceneItem, 0) err = testEngine.Join("INNER", "order", "`scene_item`.`device_id`=`order`.`id`").Find(&scenes) assert.NoError(t, err) + + scenes = make([]SceneItem, 0) + err = testEngine.Join("INNER", "order", builder.Expr("`scene_item`.`device_id`=`order`.`id`")).Find(&scenes) + assert.NoError(t, err) } func TestJoinFindLimit(t *testing.T) { diff --git a/interface.go b/interface.go index 55ffebe4..6ad0577a 100644 --- a/interface.go +++ b/interface.go @@ -52,7 +52,7 @@ type Interface interface { NoAutoCondition(...bool) *Session NotIn(string, ...interface{}) *Session Nullable(...string) *Session - Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session + Join(joinOperator string, tablename interface{}, condition interface{}, args ...interface{}) *Session Omit(columns ...string) *Session OrderBy(order interface{}, args ...interface{}) *Session Ping() error diff --git a/internal/statements/join.go b/internal/statements/join.go index 45fc2441..adf349e7 100644 --- a/internal/statements/join.go +++ b/internal/statements/join.go @@ -15,7 +15,7 @@ import ( ) // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { +func (statement *Statement) Join(joinOP string, tablename interface{}, condition interface{}, args ...interface{}) *Statement { var buf strings.Builder if len(statement.JoinStr) > 0 { fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) @@ -23,6 +23,23 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition fmt.Fprintf(&buf, "%v JOIN ", joinOP) } + condStr := "" + condArgs := []interface{}{} + switch condTp := condition.(type) { + case string: + condStr = condTp + case builder.Cond: + var err error + condStr, condArgs, err = builder.ToSQL(condTp) + if err != nil { + statement.LastError = err + return statement + } + default: + statement.LastError = fmt.Errorf("unsupported join condition type: %v", condTp) + return statement + } + switch tp := tablename.(type) { case builder.Builder: subSQL, subQueryArgs, err := tp.ToSQL() @@ -35,8 +52,8 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) aliasName = schemas.CommonQuoter.Trim(aliasName) - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition)) - statement.joinArgs = append(statement.joinArgs, subQueryArgs...) + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr)) + statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...) case *builder.Builder: subSQL, subQueryArgs, err := tp.ToSQL() if err != nil { @@ -48,8 +65,8 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) aliasName = schemas.CommonQuoter.Trim(aliasName) - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condition)) - statement.joinArgs = append(statement.joinArgs, subQueryArgs...) + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr)) + statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...) default: tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) if !utils.IsSubQuery(tbName) { @@ -59,7 +76,8 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition } else { tbName = statement.ReplaceQuote(tbName) } - fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition)) + fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condStr)) + statement.joinArgs = append(statement.joinArgs, condArgs...) } statement.JoinStr = buf.String() diff --git a/session.go b/session.go index 854a125a..e1a16e5b 100644 --- a/session.go +++ b/session.go @@ -330,7 +330,7 @@ func (session *Session) NoCache() *Session { } // Join join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (session *Session) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { +func (session *Session) Join(joinOperator string, tablename interface{}, condition interface{}, args ...interface{}) *Session { session.statement.Join(joinOperator, tablename, condition, args...) return session } -- 2.40.1