From 6f142ad620719abf09f417e43209d0a7e1264e08 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 15 Nov 2017 11:10:56 +0800 Subject: [PATCH] Query now could work with Where, In, SQL and other condition methods --- engine.go | 4 +-- interface.go | 2 +- session_query.go | 62 ++++++++++++++++++++++++++++++++++++++++++- session_query_test.go | 49 ++++++++++++++++++++++++++++++++++ session_raw_test.go | 2 +- 5 files changed, 114 insertions(+), 5 deletions(-) diff --git a/engine.go b/engine.go index 2b986966..4cf98b95 100644 --- a/engine.go +++ b/engine.go @@ -1369,10 +1369,10 @@ func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) } // Query a raw sql and return records as []map[string][]byte -func (engine *Engine) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { +func (engine *Engine) Query(sqlorArgs ...interface{}) (resultsSlice []map[string][]byte, err error) { session := engine.NewSession() defer session.Close() - return session.Query(sql, paramStr...) + return session.Query(sqlorArgs...) } // QueryString runs a raw sql and return records as []map[string]string diff --git a/interface.go b/interface.go index 4f94750b..70907b20 100644 --- a/interface.go +++ b/interface.go @@ -47,7 +47,7 @@ type Interface interface { Omit(columns ...string) *Session OrderBy(order string) *Session Ping() error - Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) + Query(sqlOrAgrs ...interface{}) (resultsSlice []map[string][]byte, err error) QueryInterface(sqlStr string, args ...interface{}) ([]map[string]interface{}, error) QueryString(sqlStr string, args ...interface{}) ([]map[string]string, error) Rows(bean interface{}) (*Rows, error) diff --git a/session_query.go b/session_query.go index a693bace..260b52d8 100644 --- a/session_query.go +++ b/session_query.go @@ -8,17 +8,77 @@ import ( "fmt" "reflect" "strconv" + "strings" "time" + "github.com/go-xorm/builder" "github.com/go-xorm/core" ) // Query runs a raw sql and return records as []map[string][]byte -func (session *Session) Query(sqlStr string, args ...interface{}) ([]map[string][]byte, error) { +func (session *Session) Query(sqlorArgs ...interface{}) ([]map[string][]byte, error) { if session.isAutoClose { defer session.Close() } + var sqlStr string + var args []interface{} + if len(sqlorArgs) == 0 { + if session.statement.RawSQL != "" { + sqlStr = session.statement.RawSQL + args = session.statement.RawParams + } else { + if len(session.statement.TableName()) <= 0 { + return nil, ErrTableNotFound + } + + var columnStr = session.statement.ColumnStr + if len(session.statement.selectStr) > 0 { + columnStr = session.statement.selectStr + } else { + if session.statement.JoinStr == "" { + if columnStr == "" { + if session.statement.GroupByStr != "" { + columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) + } else { + columnStr = session.statement.genColumnStr() + } + } + } else { + if columnStr == "" { + if session.statement.GroupByStr != "" { + columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) + } else { + columnStr = "*" + } + } + } + if columnStr == "" { + columnStr = "*" + } + } + + condSQL, condArgs, err := builder.ToSQL(session.statement.cond) + if err != nil { + return nil, err + } + + args = append(session.statement.joinArgs, condArgs...) + sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL) + if err != nil { + return nil, err + } + // for mssql and use limit + qs := strings.Count(sqlStr, "?") + if len(args)*2 == qs { + args = append(args, args...) + } + } + } else { + sqlStr = sqlorArgs[0].(string) + args = sqlorArgs[1:] + } + return session.queryBytes(sqlStr, args...) } diff --git a/session_query_test.go b/session_query_test.go index 4bb4598b..e84a7142 100644 --- a/session_query_test.go +++ b/session_query_test.go @@ -134,3 +134,52 @@ func TestQueryInterface(t *testing.T) { assert.EqualValues(t, 28, toInt64(records[0]["age"])) assert.EqualValues(t, 1.5, toFloat64(records[0]["money"])) } + +func TestQueryNoParams(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type QueryNoParams struct { + Id int64 `xorm:"autoincr pk"` + Msg string `xorm:"varchar(255)"` + Age int + Money float32 + Created time.Time `xorm:"created"` + } + + testEngine.ShowSQL(true) + + assert.NoError(t, testEngine.Sync2(new(QueryNoParams))) + + var q = QueryNoParams{ + Msg: "message", + Age: 20, + Money: 3000, + } + cnt, err := testEngine.Insert(&q) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + assertResult := func(t *testing.T, results []map[string][]byte) { + assert.EqualValues(t, 1, len(results)) + id, err := strconv.ParseInt(string(results[0]["id"]), 10, 64) + assert.NoError(t, err) + assert.EqualValues(t, 1, id) + assert.Equal(t, "message", string(results[0]["msg"])) + + age, err := strconv.Atoi(string(results[0]["age"])) + assert.NoError(t, err) + assert.EqualValues(t, 20, age) + + money, err := strconv.ParseFloat(string(results[0]["money"]), 32) + assert.NoError(t, err) + assert.EqualValues(t, 3000, money) + } + + results, err := testEngine.Table("query_no_params").Limit(10).Query() + assert.NoError(t, err) + assertResult(t, results) + + results, err = testEngine.SQL("select * from query_no_params").Query() + assert.NoError(t, err) + assertResult(t, results) +} diff --git a/session_raw_test.go b/session_raw_test.go index f52db7d3..32e8037c 100644 --- a/session_raw_test.go +++ b/session_raw_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestQuery(t *testing.T) { +func TestExecAndQuery(t *testing.T) { assert.NoError(t, prepareEngine()) type UserinfoQuery struct { -- 2.40.1