Provide function to regenerate RawStore within session #4

Merged
zeripath merged 2 commits from zeripath/session:regenerate-session-function into master 2021-12-18 22:16:15 +00:00
2 changed files with 36 additions and 5 deletions

View File

@ -260,7 +260,7 @@ func Sessioner(options ...Options) func(next http.Handler) http.Handler {
return return
} }
if err = sess.Release(); err != nil { if err = s.RawStore.Release(); err != nil {
panic("session(release): " + err.Error()) panic("session(release): " + err.Error())
} }
}) })
@ -274,6 +274,26 @@ func GetSession(req *http.Request) Store {
return sess return sess
} }
// RegenerateSession
func RegenerateSession(resp http.ResponseWriter, req *http.Request) (Store, error) {
sess, ok := GetSession(req).(*store)
if !ok {
return nil, fmt.Errorf("no session in request context")
}
oldRawStore := sess.RawStore
if err := oldRawStore.Release(); err != nil {
return nil, err
}
store, err := sess.RegenerateID(resp, req)
if err != nil {
return nil, err
}
sess.RawStore = store
return sess, nil
}
// Provider is the interface that provides session manipulations. // Provider is the interface that provides session manipulations.
type Provider interface { type Provider interface {
// Init initializes session provider. // Init initializes session provider.

View File

@ -71,20 +71,31 @@ func testProvider(opt Options) {
Convey("Basic operation", func() { Convey("Basic operation", func() {
c := chi.NewRouter() c := chi.NewRouter()
c.Use(Sessioner(opt)) c.Use(Sessioner(opt))
var initialSid string
c.Get("/", func(resp http.ResponseWriter, req *http.Request) { c.Get("/", func(resp http.ResponseWriter, req *http.Request) {
sess := GetSession(req) sess := GetSession(req)
sess.Set("uname", "unknwon") sess.Set("uname", "unknwon")
initialSid = sess.ID()
}) })
c.Get("/reg", func(resp http.ResponseWriter, req *http.Request) { c.Get("/reg", func(resp http.ResponseWriter, req *http.Request) {
sess := GetSession(req) sess := GetSession(req)
raw, err := sess.RegenerateID(resp, req) So(initialSid, ShouldEqual, sess.ID())
raw, err := RegenerateSession(resp, req)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(raw, ShouldNotBeNil) So(sess, ShouldNotBeNil)
So(sess, ShouldEqual, raw)
uname := raw.Get("uname") So(initialSid, ShouldNotEqual, sess.ID())
uname := sess.Get("uname")
So(uname, ShouldNotBeNil) So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon") So(uname, ShouldEqual, "unknwon")
sess.Set("uname", "lunny")
uname = sess.Get("uname")
So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "lunny")
}) })
c.Get("/get", func(resp http.ResponseWriter, req *http.Request) { c.Get("/get", func(resp http.ResponseWriter, req *http.Request) {
sess := GetSession(req) sess := GetSession(req)
@ -97,7 +108,7 @@ func testProvider(opt Options) {
uname := sess.Get("uname") uname := sess.Get("uname")
So(uname, ShouldNotBeNil) So(uname, ShouldNotBeNil)
So(uname, ShouldEqual, "unknwon") So(uname, ShouldEqual, "lunny")
So(sess.Delete("uname"), ShouldBeNil) So(sess.Delete("uname"), ShouldBeNil)
So(sess.Get("uname"), ShouldBeNil) So(sess.Get("uname"), ShouldBeNil)