diff --git a/session.go b/session.go index 5c2afed..be10674 100644 --- a/session.go +++ b/session.go @@ -260,7 +260,7 @@ func Sessioner(options ...Options) func(next http.Handler) http.Handler { return } - if err = sess.Release(); err != nil { + if err = s.RawStore.Release(); err != nil { panic("session(release): " + err.Error()) } }) @@ -274,6 +274,26 @@ func GetSession(req *http.Request) Store { return sess } +// RegenerateSession +func RegenerateSession(resp http.ResponseWriter, req *http.Request) (Store, error) { + sess, ok := GetSession(req).(*store) + if !ok { + return nil, fmt.Errorf("no session in request context") + } + + oldRawStore := sess.RawStore + if err := oldRawStore.Release(); err != nil { + return nil, err + } + + store, err := sess.RegenerateID(resp, req) + if err != nil { + return nil, err + } + sess.RawStore = store + return sess, nil +} + // Provider is the interface that provides session manipulations. type Provider interface { // Init initializes session provider. diff --git a/session_test.go b/session_test.go index c602759..13ca1d0 100644 --- a/session_test.go +++ b/session_test.go @@ -71,20 +71,31 @@ func testProvider(opt Options) { Convey("Basic operation", func() { c := chi.NewRouter() c.Use(Sessioner(opt)) + var initialSid string c.Get("/", func(resp http.ResponseWriter, req *http.Request) { sess := GetSession(req) sess.Set("uname", "unknwon") + initialSid = sess.ID() }) c.Get("/reg", func(resp http.ResponseWriter, req *http.Request) { sess := GetSession(req) - raw, err := sess.RegenerateID(resp, req) + So(initialSid, ShouldEqual, sess.ID()) + raw, err := RegenerateSession(resp, req) So(err, ShouldBeNil) - So(raw, ShouldNotBeNil) + So(sess, ShouldNotBeNil) + So(sess, ShouldEqual, raw) - uname := raw.Get("uname") + So(initialSid, ShouldNotEqual, sess.ID()) + + uname := sess.Get("uname") So(uname, ShouldNotBeNil) So(uname, ShouldEqual, "unknwon") + + sess.Set("uname", "lunny") + uname = sess.Get("uname") + So(uname, ShouldNotBeNil) + So(uname, ShouldEqual, "lunny") }) c.Get("/get", func(resp http.ResponseWriter, req *http.Request) { sess := GetSession(req) @@ -97,7 +108,7 @@ func testProvider(opt Options) { uname := sess.Get("uname") So(uname, ShouldNotBeNil) - So(uname, ShouldEqual, "unknwon") + So(uname, ShouldEqual, "lunny") So(sess.Delete("uname"), ShouldBeNil) So(sess.Get("uname"), ShouldBeNil)