From c434bb80559f0b1e22dd89b72f2bcab8ef89a050 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 19 Oct 2021 20:33:31 +0800 Subject: [PATCH 1/2] New Prepare useage --- integrations/session_get_test.go | 41 ++++++++++++++++++++++++++++++++ interface.go | 1 + session.go | 19 ++++++++++++++- session_raw.go | 21 +++++++--------- 4 files changed, 69 insertions(+), 13 deletions(-) diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 601d4a26..c73f9ea6 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -915,3 +915,44 @@ func TestGetVars(t *testing.T) { assert.EqualValues(t, "xlw", name) assert.EqualValues(t, 42, age) } + +func TestGetWithPrepare(t *testing.T) { + type GetVarsWithPrepare struct { + Id int64 + Name string + Age int + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetVarsWithPrepare)) + + _, err := testEngine.Insert(&GetVarsWithPrepare{ + Name: "xlw", + Age: 42, + }) + assert.NoError(t, err) + + var v1 GetVarsWithPrepare + has, err := testEngine.Prepare().Get(&v1) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "xlw", v1.Name) + assert.EqualValues(t, 42, v1.Age) + + sess := testEngine.NewSession() + defer sess.Close() + + var v2 GetVarsWithPrepare + has, err = sess.Prepare().Get(&v2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "xlw", v2.Name) + assert.EqualValues(t, 42, v2.Age) + + var v3 GetVarsWithPrepare + has, err = sess.Prepare().Get(&v3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "xlw", v3.Name) + assert.EqualValues(t, 42, v3.Age) +} diff --git a/interface.go b/interface.go index 42dc9a0a..b9e88505 100644 --- a/interface.go +++ b/interface.go @@ -99,6 +99,7 @@ type EngineInterface interface { MapCacher(interface{}, caches.Cacher) error NewSession() *Session NoAutoTime() *Session + Prepare() *Session Quote(string) string SetCacher(string, caches.Cacher) SetConnMaxLifetime(time.Duration) diff --git a/session.go b/session.go index f51fd41b..da4576c8 100644 --- a/session.go +++ b/session.go @@ -79,7 +79,8 @@ type Session struct { afterClosures []func(interface{}) afterProcessors []executedProcessor - stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) + stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) + txStmtCache map[uint32]*core.Stmt // for tx statement lastSQL string lastSQLArgs []interface{} @@ -200,6 +201,7 @@ func (session *Session) IsClosed() bool { func (session *Session) resetStatement() { if session.autoResetStatement { session.statement.Reset() + session.prepareStmt = false } } @@ -370,6 +372,21 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, return } +func (session *Session) doPrepareTx(sqlStr string) (stmt *core.Stmt, err error) { + crc := crc32.ChecksumIEEE([]byte(sqlStr)) + // TODO try hash(sqlStr+len(sqlStr)) + var has bool + stmt, has = session.txStmtCache[crc] + if !has { + stmt, err = session.tx.PrepareContext(session.ctx, sqlStr) + if err != nil { + return nil, err + } + session.txStmtCache[crc] = stmt + } + return +} + func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) { var col = table.GetColumnIdx(colName, idx) if col == nil { diff --git a/session_raw.go b/session_raw.go index cee29fc7..bce1f575 100644 --- a/session_raw.go +++ b/session_raw.go @@ -46,25 +46,22 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row return nil, err } - rows, err := stmt.QueryContext(session.ctx, args...) - if err != nil { - return nil, err - } - return rows, nil + return stmt.QueryContext(session.ctx, args...) } - rows, err := db.QueryContext(session.ctx, sqlStr, args...) + return db.QueryContext(session.ctx, sqlStr, args...) + } + + if session.prepareStmt { + stmt, err := session.doPrepareTx(sqlStr) if err != nil { return nil, err } - return rows, nil + + return stmt.QueryContext(session.ctx, args...) } - rows, err := session.tx.QueryContext(session.ctx, sqlStr, args...) - if err != nil { - return nil, err - } - return rows, nil + return session.tx.QueryContext(session.ctx, sqlStr, args...) } func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row { -- 2.40.1 From ccc5c0abd42df230609f0faa5b15aea358fb02db Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 19 Oct 2021 21:35:04 +0800 Subject: [PATCH 2/2] Also fix prepare with exec --- core/stmt.go | 2 +- integrations/session_get_test.go | 26 +++++++++++++++++++++++--- session.go | 8 ++++++++ session_raw.go | 14 ++++++++------ 4 files changed, 40 insertions(+), 10 deletions(-) diff --git a/core/stmt.go b/core/stmt.go index 260843d5..3247efed 100644 --- a/core/stmt.go +++ b/core/stmt.go @@ -93,7 +93,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result if err != nil { return nil, err } - res, err := s.Stmt.ExecContext(ctx, args) + res, err := s.Stmt.ExecContext(ctx, args...) hookCtx.End(ctx, res, err) if err := s.db.afterProcess(hookCtx); err != nil { return nil, err diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index c73f9ea6..5d1558f4 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -933,7 +933,7 @@ func TestGetWithPrepare(t *testing.T) { assert.NoError(t, err) var v1 GetVarsWithPrepare - has, err := testEngine.Prepare().Get(&v1) + has, err := testEngine.Prepare().ID(1).Get(&v1) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "xlw", v1.Name) @@ -943,16 +943,36 @@ func TestGetWithPrepare(t *testing.T) { defer sess.Close() var v2 GetVarsWithPrepare - has, err = sess.Prepare().Get(&v2) + has, err = sess.Prepare().ID(1).Get(&v2) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "xlw", v2.Name) assert.EqualValues(t, 42, v2.Age) var v3 GetVarsWithPrepare - has, err = sess.Prepare().Get(&v3) + has, err = sess.Prepare().ID(1).Get(&v3) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "xlw", v3.Name) assert.EqualValues(t, 42, v3.Age) + + err = sess.Begin() + assert.NoError(t, err) + + cnt, err := sess.Prepare().Insert(&GetVarsWithPrepare{ + Name: "xlw2", + Age: 12, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = sess.Prepare().Insert(&GetVarsWithPrepare{ + Name: "xlw3", + Age: 13, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + err = sess.Commit() + assert.NoError(t, err) } diff --git a/session.go b/session.go index da4576c8..2c916335 100644 --- a/session.go +++ b/session.go @@ -131,6 +131,7 @@ func newSession(engine *Engine) *Session { afterClosures: make([]func(interface{}), 0), afterProcessors: make([]executedProcessor, 0), stmtCache: make(map[uint32]*core.Stmt), + txStmtCache: make(map[uint32]*core.Stmt), lastSQL: "", lastSQLArgs: make([]interface{}, 0), @@ -151,6 +152,12 @@ func (session *Session) Close() error { } } + for _, v := range session.txStmtCache { + if err := v.Close(); err != nil { + return err + } + } + if !session.isClosed { // When Close be called, if session is a transaction and do not call // Commit or Rollback, then call Rollback. @@ -161,6 +168,7 @@ func (session *Session) Close() error { } session.tx = nil session.stmtCache = nil + session.txStmtCache = nil session.isClosed = true } return nil diff --git a/session_raw.go b/session_raw.go index bce1f575..acb106a5 100644 --- a/session_raw.go +++ b/session_raw.go @@ -157,6 +157,13 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er session.lastSQLArgs = args if !session.isAutoCommit { + if session.prepareStmt { + stmt, err := session.doPrepareTx(sqlStr) + if err != nil { + return nil, err + } + return stmt.ExecContext(session.ctx, args...) + } return session.tx.ExecContext(session.ctx, sqlStr, args...) } @@ -165,12 +172,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er if err != nil { return nil, err } - - res, err := stmt.ExecContext(session.ctx, args...) - if err != nil { - return nil, err - } - return res, nil + return stmt.ExecContext(session.ctx, args...) } return session.DB().ExecContext(session.ctx, sqlStr, args...) -- 2.40.1