New Prepare useage #2061

Merged
lunny merged 2 commits from lunny/fix_prepare_tx into master 2021-10-20 00:53:30 +00:00
4 changed files with 40 additions and 10 deletions
Showing only changes of commit ccc5c0abd4 - Show all commits

View File

@ -93,7 +93,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result
if err != nil {
return nil, err
}
res, err := s.Stmt.ExecContext(ctx, args)
res, err := s.Stmt.ExecContext(ctx, args...)
hookCtx.End(ctx, res, err)
if err := s.db.afterProcess(hookCtx); err != nil {
return nil, err

View File

@ -933,7 +933,7 @@ func TestGetWithPrepare(t *testing.T) {
assert.NoError(t, err)
var v1 GetVarsWithPrepare
has, err := testEngine.Prepare().Get(&v1)
has, err := testEngine.Prepare().ID(1).Get(&v1)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "xlw", v1.Name)
@ -943,16 +943,36 @@ func TestGetWithPrepare(t *testing.T) {
defer sess.Close()
var v2 GetVarsWithPrepare
has, err = sess.Prepare().Get(&v2)
has, err = sess.Prepare().ID(1).Get(&v2)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "xlw", v2.Name)
assert.EqualValues(t, 42, v2.Age)
var v3 GetVarsWithPrepare
has, err = sess.Prepare().Get(&v3)
has, err = sess.Prepare().ID(1).Get(&v3)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "xlw", v3.Name)
assert.EqualValues(t, 42, v3.Age)
err = sess.Begin()
assert.NoError(t, err)
cnt, err := sess.Prepare().Insert(&GetVarsWithPrepare{
Name: "xlw2",
Age: 12,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
cnt, err = sess.Prepare().Insert(&GetVarsWithPrepare{
Name: "xlw3",
Age: 13,
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
err = sess.Commit()
assert.NoError(t, err)
}

View File

@ -131,6 +131,7 @@ func newSession(engine *Engine) *Session {
afterClosures: make([]func(interface{}), 0),
afterProcessors: make([]executedProcessor, 0),
stmtCache: make(map[uint32]*core.Stmt),
txStmtCache: make(map[uint32]*core.Stmt),
lastSQL: "",
lastSQLArgs: make([]interface{}, 0),
@ -151,6 +152,12 @@ func (session *Session) Close() error {
}
}
for _, v := range session.txStmtCache {
if err := v.Close(); err != nil {
return err
}
}
if !session.isClosed {
// When Close be called, if session is a transaction and do not call
// Commit or Rollback, then call Rollback.
@ -161,6 +168,7 @@ func (session *Session) Close() error {
}
session.tx = nil
session.stmtCache = nil
session.txStmtCache = nil
session.isClosed = true
}
return nil

View File

@ -157,6 +157,13 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
session.lastSQLArgs = args
if !session.isAutoCommit {
if session.prepareStmt {
stmt, err := session.doPrepareTx(sqlStr)
if err != nil {
return nil, err
}
return stmt.ExecContext(session.ctx, args...)
}
return session.tx.ExecContext(session.ctx, sqlStr, args...)
}
@ -165,12 +172,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
if err != nil {
return nil, err
}
res, err := stmt.ExecContext(session.ctx, args...)
if err != nil {
return nil, err
}
return res, nil
return stmt.ExecContext(session.ctx, args...)
}
return session.DB().ExecContext(session.ctx, sqlStr, args...)