├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── context.go ├── context_test.go ├── default.go ├── default_test.go ├── doc.go ├── go.mod ├── go.sum ├── session.go ├── session_test.go ├── store.go ├── store_test.go └── util.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.dll 4 | *.so 5 | *.dylib 6 | 7 | # Test binary, build with `go test -c` 8 | *.test 9 | 10 | # Output of the go coverage tool, specifically when used with LiteIDE 11 | *.out 12 | 13 | # Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 14 | .glide/ 15 | 16 | .vscode 17 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | sudo: false 3 | go_import_path: github.com/go-session/session 4 | go: 5 | - 1.13 6 | before_install: 7 | - go get -t -v ./... 8 | 9 | script: 10 | - go test -race -coverprofile=coverage.txt -covermode=atomic 11 | 12 | after_success: 13 | - bash <(curl -s https://codecov.io/bash) 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Lyric 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # session 2 | 3 | > A efficient, safely and easy-to-use session library for Go. 4 | 5 | [![Build][build-status-image]][build-status-url] [![Codecov][codecov-image]][codecov-url] [![ReportCard][reportcard-image]][reportcard-url] [![GoDoc][godoc-image]][godoc-url] [![License][license-image]][license-url] 6 | 7 | ## Quick Start 8 | 9 | ### Download and install 10 | 11 | ```bash 12 | go get -v github.com/go-session/session/v3 13 | ``` 14 | 15 | ### Create file `server.go` 16 | 17 | ```go 18 | package main 19 | 20 | import ( 21 | "context" 22 | "fmt" 23 | "net/http" 24 | 25 | session "github.com/go-session/session/v3" 26 | ) 27 | 28 | func main() { 29 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 30 | store, err := session.Start(context.Background(), w, r) 31 | if err != nil { 32 | fmt.Fprint(w, err) 33 | return 34 | } 35 | 36 | store.Set("foo", "bar") 37 | err = store.Save() 38 | if err != nil { 39 | fmt.Fprint(w, err) 40 | return 41 | } 42 | 43 | http.Redirect(w, r, "/foo", 302) 44 | }) 45 | 46 | http.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { 47 | store, err := session.Start(context.Background(), w, r) 48 | if err != nil { 49 | fmt.Fprint(w, err) 50 | return 51 | } 52 | 53 | foo, ok := store.Get("foo") 54 | if ok { 55 | fmt.Fprintf(w, "foo:%s", foo) 56 | return 57 | } 58 | fmt.Fprint(w, "does not exist") 59 | }) 60 | 61 | http.ListenAndServe(":8080", nil) 62 | } 63 | ``` 64 | 65 | ### Build and run 66 | 67 | ```bash 68 | go build server.go 69 | ./server 70 | ``` 71 | 72 | ### Open in your web browser 73 | 74 | 75 | 76 | ```text 77 | foo:bar 78 | ``` 79 | 80 | ## Features 81 | 82 | - Easy to use 83 | - Multi-storage support 84 | - Multi-middleware support 85 | - More secure, signature-based tamper-proof 86 | - Context support 87 | - Support request header and query parameters 88 | 89 | ## Store Implementations 90 | 91 | - [https://github.com/go-session/redis](https://github.com/go-session/redis) - Redis 92 | - [https://github.com/go-session/mongo](https://github.com/go-session/mongo) - MongoDB 93 | - [https://github.com/go-session/gorm](https://github.com/go-session/gorm) - [GORM](https://github.com/jinzhu/gorm) 94 | - [https://github.com/go-session/mysql](https://github.com/go-session/mysql) - MySQL 95 | - [https://github.com/go-session/buntdb](https://github.com/go-session/buntdb) - [BuntDB](https://github.com/tidwall/buntdb) 96 | - [https://github.com/go-session/cookie](https://github.com/go-session/cookie) - Cookie 97 | 98 | ## Middlewares 99 | 100 | - [https://github.com/go-session/gin-session](https://github.com/go-session/gin-session) - [Gin](https://github.com/gin-gonic/gin) 101 | - [https://github.com/go-session/beego-session](https://github.com/go-session/beego-session) - [Beego](https://github.com/astaxie/beego) 102 | - [https://github.com/go-session/gear-session](https://github.com/go-session/gear-session) - [Gear](https://github.com/teambition/gear) 103 | - [https://github.com/go-session/echo-session](https://github.com/go-session/echo-session) - [Echo](https://github.com/labstack/echo) 104 | 105 | ## MIT License 106 | 107 | Copyright (c) 2021 Lyric 108 | 109 | [build-status-url]: https://travis-ci.org/go-session/session 110 | [build-status-image]: https://travis-ci.org/go-session/session.svg?branch=master 111 | [codecov-url]: https://codecov.io/gh/go-session/session 112 | [codecov-image]: https://codecov.io/gh/go-session/session/branch/master/graph/badge.svg 113 | [reportcard-url]: https://goreportcard.com/report/github.com/go-session/session 114 | [reportcard-image]: https://goreportcard.com/badge/github.com/go-session/session 115 | [godoc-url]: https://godoc.org/github.com/go-session/session 116 | [godoc-image]: https://godoc.org/github.com/go-session/session?status.svg 117 | [license-url]: http://opensource.org/licenses/MIT 118 | [license-image]: https://img.shields.io/npm/l/express.svg 119 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | ) 7 | 8 | // Define the keys in the context 9 | type ( 10 | ctxResKey struct{} 11 | ctxReqKey struct{} 12 | ) 13 | 14 | // returns a new Context that carries value res. 15 | func newResContext(ctx context.Context, res http.ResponseWriter) context.Context { 16 | return context.WithValue(ctx, ctxResKey{}, res) 17 | } 18 | 19 | // FromResContext returns the ResponseWriter value stored in ctx, if any. 20 | func FromResContext(ctx context.Context) (http.ResponseWriter, bool) { 21 | res, ok := ctx.Value(ctxResKey{}).(http.ResponseWriter) 22 | return res, ok 23 | } 24 | 25 | // returns a new Context that carries value req. 26 | func newReqContext(ctx context.Context, req *http.Request) context.Context { 27 | return context.WithValue(ctx, ctxReqKey{}, req) 28 | } 29 | 30 | // FromReqContext returns the Request value stored in ctx, if any. 31 | func FromReqContext(ctx context.Context) (*http.Request, bool) { 32 | req, ok := ctx.Value(ctxReqKey{}).(*http.Request) 33 | return req, ok 34 | } 35 | -------------------------------------------------------------------------------- /context_test.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "reflect" 10 | "testing" 11 | ) 12 | 13 | func TestContext(t *testing.T) { 14 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 15 | type ctxKey struct{} 16 | ctx := context.WithValue(context.Background(), ctxKey{}, "bar") 17 | store, err := Start(ctx, w, r) 18 | if err != nil { 19 | t.Error(err) 20 | return 21 | } 22 | 23 | ctxValue := store.Context().Value(ctxKey{}) 24 | if !reflect.DeepEqual(ctxValue, "bar") { 25 | t.Error("Not expected value:", ctxValue) 26 | return 27 | } 28 | 29 | req, ok := FromReqContext(store.Context()) 30 | if !ok || req.URL.Query().Get("foo") != "bar" { 31 | t.Error("Not expected value:", req.URL.Query().Get("foo")) 32 | return 33 | } 34 | 35 | res, ok := FromResContext(store.Context()) 36 | if !ok { 37 | t.Error("Not expected value") 38 | return 39 | } 40 | 41 | fmt.Fprint(res, "ok") 42 | })) 43 | defer ts.Close() 44 | 45 | res, err := http.Get(ts.URL + "?foo=bar") 46 | if err != nil { 47 | t.Error(err) 48 | return 49 | } 50 | 51 | buf, _ := io.ReadAll(res.Body) 52 | res.Body.Close() 53 | if string(buf) != "ok" { 54 | t.Error("Not expected value:", string(buf)) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /default.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "sync" 7 | ) 8 | 9 | var ( 10 | internalManager *Manager 11 | once sync.Once 12 | ) 13 | 14 | func manager(opt ...Option) *Manager { 15 | once.Do(func() { 16 | internalManager = NewManager(opt...) 17 | }) 18 | return internalManager 19 | } 20 | 21 | // Initialize the global session management instance 22 | func InitManager(opt ...Option) { 23 | manager(opt...) 24 | } 25 | 26 | // Start a session and return to session storage 27 | func Start(ctx context.Context, w http.ResponseWriter, r *http.Request) (Store, error) { 28 | return manager().Start(ctx, w, r) 29 | } 30 | 31 | // Destroy a session 32 | func Destroy(ctx context.Context, w http.ResponseWriter, r *http.Request) error { 33 | return manager().Destroy(ctx, w, r) 34 | } 35 | 36 | // Refresh and return session storage 37 | func Refresh(ctx context.Context, w http.ResponseWriter, r *http.Request) (Store, error) { 38 | return manager().Refresh(ctx, w, r) 39 | } 40 | -------------------------------------------------------------------------------- /default_test.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | . "github.com/smartystreets/goconvey/convey" 11 | ) 12 | 13 | var defaultCookieName = "test_default_start" 14 | 15 | func init() { 16 | InitManager( 17 | SetCookieName(defaultCookieName), 18 | ) 19 | } 20 | 21 | func TestDefaultStart(t *testing.T) { 22 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 23 | store, err := Start(r.Context(), w, r) 24 | if err != nil { 25 | t.Error(err) 26 | return 27 | } 28 | 29 | if r.URL.Query().Get("login") == "1" { 30 | foo, ok := store.Get("foo") 31 | fmt.Fprintf(w, "%v:%v", foo, ok) 32 | return 33 | } 34 | 35 | store.Set("foo", "bar") 36 | err = store.Save() 37 | if err != nil { 38 | t.Error(err) 39 | return 40 | } 41 | fmt.Fprint(w, "ok") 42 | })) 43 | defer ts.Close() 44 | 45 | Convey("Test default start", t, func() { 46 | res, err := http.Get(ts.URL) 47 | So(err, ShouldBeNil) 48 | So(res, ShouldNotBeNil) 49 | So(len(res.Cookies()), ShouldBeGreaterThan, 0) 50 | 51 | cookie := res.Cookies()[0] 52 | So(cookie.Name, ShouldEqual, defaultCookieName) 53 | 54 | req, err := http.NewRequest("GET", fmt.Sprintf("%s?login=1", ts.URL), nil) 55 | So(err, ShouldBeNil) 56 | req.AddCookie(cookie) 57 | 58 | res, err = http.DefaultClient.Do(req) 59 | So(err, ShouldBeNil) 60 | So(res, ShouldNotBeNil) 61 | 62 | buf, err := io.ReadAll(res.Body) 63 | So(err, ShouldBeNil) 64 | res.Body.Close() 65 | So(string(buf), ShouldEqual, "bar:true") 66 | }) 67 | } 68 | 69 | func TestDefaultDestroy(t *testing.T) { 70 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 71 | if r.URL.Query().Get("logout") == "1" { 72 | err := Destroy(r.Context(), w, r) 73 | if err != nil { 74 | t.Error(err) 75 | return 76 | } 77 | fmt.Fprint(w, "ok") 78 | return 79 | } 80 | 81 | store, err := Start(r.Context(), w, r) 82 | if err != nil { 83 | t.Error(err) 84 | return 85 | } 86 | 87 | if r.URL.Query().Get("check") == "1" { 88 | foo, ok := store.Get("foo") 89 | fmt.Fprintf(w, "%v:%v", foo, ok) 90 | return 91 | } 92 | 93 | store.Set("foo", "bar") 94 | err = store.Save() 95 | if err != nil { 96 | t.Error(err) 97 | return 98 | } 99 | fmt.Fprint(w, "ok") 100 | })) 101 | defer ts.Close() 102 | 103 | Convey("Test default destroy", t, func() { 104 | res, err := http.Get(ts.URL) 105 | So(err, ShouldBeNil) 106 | So(res, ShouldNotBeNil) 107 | So(len(res.Cookies()), ShouldBeGreaterThan, 0) 108 | 109 | cookie := res.Cookies()[0] 110 | So(cookie.Name, ShouldEqual, defaultCookieName) 111 | 112 | req, err := http.NewRequest("GET", fmt.Sprintf("%s?logout=1", ts.URL), nil) 113 | So(err, ShouldBeNil) 114 | 115 | req.AddCookie(cookie) 116 | res, err = http.DefaultClient.Do(req) 117 | So(err, ShouldBeNil) 118 | So(res, ShouldNotBeNil) 119 | 120 | req, err = http.NewRequest("GET", fmt.Sprintf("%s?check=1", ts.URL), nil) 121 | So(err, ShouldBeNil) 122 | req.AddCookie(cookie) 123 | res, err = http.DefaultClient.Do(req) 124 | So(err, ShouldBeNil) 125 | So(res, ShouldNotBeNil) 126 | 127 | buf, err := io.ReadAll(res.Body) 128 | So(err, ShouldBeNil) 129 | res.Body.Close() 130 | So(string(buf), ShouldEqual, ":false") 131 | }) 132 | } 133 | 134 | func TestDefaultRefresh(t *testing.T) { 135 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 136 | store, err := Start(r.Context(), w, r) 137 | if err != nil { 138 | t.Error(err) 139 | return 140 | } 141 | 142 | if r.URL.Query().Get("refresh") == "1" { 143 | vstore, verr := Refresh(r.Context(), w, r) 144 | if verr != nil { 145 | t.Error(err) 146 | return 147 | } 148 | 149 | if vstore.SessionID() == store.SessionID() { 150 | t.Errorf("Not expected value") 151 | return 152 | } 153 | 154 | foo, ok := vstore.Get("foo") 155 | fmt.Fprintf(w, "%s:%v", foo, ok) 156 | return 157 | } 158 | 159 | store.Set("foo", "bar") 160 | err = store.Save() 161 | if err != nil { 162 | t.Error(err) 163 | return 164 | } 165 | fmt.Fprint(w, "ok") 166 | })) 167 | defer ts.Close() 168 | 169 | Convey("Test default refresh", t, func() { 170 | res, err := http.Get(ts.URL) 171 | So(err, ShouldBeNil) 172 | So(res, ShouldNotBeNil) 173 | So(len(res.Cookies()), ShouldBeGreaterThan, 0) 174 | 175 | cookie := res.Cookies()[0] 176 | So(cookie.Name, ShouldEqual, defaultCookieName) 177 | 178 | req, err := http.NewRequest("GET", fmt.Sprintf("%s?refresh=1", ts.URL), nil) 179 | So(err, ShouldBeNil) 180 | 181 | req.AddCookie(cookie) 182 | res, err = http.DefaultClient.Do(req) 183 | So(err, ShouldBeNil) 184 | So(res, ShouldNotBeNil) 185 | So(len(res.Cookies()), ShouldBeGreaterThan, 0) 186 | So(res.Cookies()[0].Value, ShouldNotEqual, cookie.Value) 187 | 188 | buf, err := io.ReadAll(res.Body) 189 | So(err, ShouldBeNil) 190 | res.Body.Close() 191 | So(string(buf), ShouldEqual, "bar:true") 192 | }) 193 | } 194 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package session implements a efficient, safely and easy-to-use session library for Go. 3 | 4 | 5 | Example: 6 | 7 | package main 8 | 9 | import ( 10 | "context" 11 | "fmt" 12 | "net/http" 13 | 14 | session "github.com/go-session/session/v3" 15 | ) 16 | 17 | func main() { 18 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 19 | store, err := session.Start(context.Background(), w, r) 20 | if err != nil { 21 | fmt.Fprint(w, err) 22 | return 23 | } 24 | 25 | store.Set("foo", "bar") 26 | err = store.Save() 27 | if err != nil { 28 | fmt.Fprint(w, err) 29 | return 30 | } 31 | 32 | http.Redirect(w, r, "/foo", 302) 33 | }) 34 | 35 | http.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { 36 | store, err := session.Start(context.Background(), w, r) 37 | if err != nil { 38 | fmt.Fprint(w, err) 39 | return 40 | } 41 | 42 | foo, ok := store.Get("foo") 43 | if ok { 44 | fmt.Fprintf(w, "foo:%s", foo) 45 | return 46 | } 47 | fmt.Fprint(w, "does not exist") 48 | }) 49 | 50 | http.ListenAndServe(":8080", nil) 51 | } 52 | 53 | Open in your web browser at http://localhost:8080 54 | 55 | Output: 56 | foo:bar 57 | 58 | Learn more at https://github.com/go-session/session 59 | */ 60 | package session 61 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-session/session/v3 2 | 3 | go 1.17 4 | 5 | require ( 6 | github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 7 | github.com/smartystreets/goconvey v1.6.4 8 | ) 9 | 10 | require ( 11 | github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect 12 | github.com/jtolds/gls v4.20.0+incompatible // indirect 13 | github.com/smartystreets/assertions v1.1.0 // indirect 14 | ) 15 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 h1:FCLDGi1EmB7JzjVVYNZiqc/zAJj2BQ5M0lfkVOxbfs8= 2 | github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6/go.mod h1:5FoAH5xUHHCMDvQPy1rnj8moqLkLHFaDVBjHhcFwEi0= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= 5 | github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0= 6 | github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= 7 | github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= 8 | github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= 9 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 10 | github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= 11 | github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0= 12 | github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= 13 | github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= 14 | github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= 15 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 16 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 17 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 18 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 19 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 20 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 21 | golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 22 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 23 | golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 24 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 25 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 26 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "crypto/hmac" 6 | "crypto/sha1" 7 | "encoding/base64" 8 | "errors" 9 | "fmt" 10 | "net" 11 | "net/http" 12 | "net/url" 13 | "strings" 14 | "time" 15 | ) 16 | 17 | // Version # of session 18 | const Version = "3.1.4" 19 | 20 | var ( 21 | ErrInvalidSessionID = errors.New("Invalid session id") 22 | ) 23 | 24 | // Define the handler to get the session id 25 | type IDHandlerFunc func(context.Context) string 26 | 27 | // Define default options 28 | var defaultOptions = options{ 29 | cookieName: "go_session_id", 30 | cookieLifeTime: 3600 * 24 * 7, 31 | expired: 7200, 32 | secure: true, 33 | sameSite: http.SameSiteDefaultMode, 34 | sessionID: func(_ context.Context) string { 35 | return newUUID() 36 | }, 37 | enableSetCookie: true, 38 | enableSIDInURLQuery: true, 39 | } 40 | 41 | type options struct { 42 | sign []byte 43 | cookieName string 44 | cookieLifeTime int 45 | secure bool 46 | domain string 47 | sameSite http.SameSite 48 | expired int64 49 | sessionID IDHandlerFunc 50 | enableSetCookie bool 51 | enableSIDInURLQuery bool 52 | enableSIDInHTTPHeader bool 53 | sessionNameInHTTPHeader string 54 | store ManagerStore 55 | } 56 | 57 | type Option func(*options) 58 | 59 | // Set the session id signature value 60 | func SetSign(sign []byte) Option { 61 | return func(o *options) { 62 | o.sign = sign 63 | } 64 | } 65 | 66 | // Set the cookie name 67 | func SetCookieName(cookieName string) Option { 68 | return func(o *options) { 69 | o.cookieName = cookieName 70 | } 71 | } 72 | 73 | // Set the cookie expiration time (in seconds) 74 | func SetCookieLifeTime(cookieLifeTime int) Option { 75 | return func(o *options) { 76 | o.cookieLifeTime = cookieLifeTime 77 | } 78 | } 79 | 80 | // Set the domain name of the cookie 81 | func SetDomain(domain string) Option { 82 | return func(o *options) { 83 | o.domain = domain 84 | } 85 | } 86 | 87 | // Set cookie security 88 | func SetSecure(secure bool) Option { 89 | return func(o *options) { 90 | o.secure = secure 91 | } 92 | } 93 | 94 | // Set SameSite attribute of the cookie 95 | func SetSameSite(sameSite http.SameSite) Option { 96 | return func(o *options) { 97 | o.sameSite = sameSite 98 | } 99 | } 100 | 101 | // Set session expiration time (in seconds) 102 | func SetExpired(expired int64) Option { 103 | return func(o *options) { 104 | o.expired = expired 105 | } 106 | } 107 | 108 | // Set callback function to generate session id 109 | func SetSessionID(handler IDHandlerFunc) Option { 110 | return func(o *options) { 111 | o.sessionID = handler 112 | } 113 | } 114 | 115 | // Enable writing session id to cookie 116 | // (enabled by default, can be turned off if no cookie is written) 117 | func SetEnableSetCookie(enableSetCookie bool) Option { 118 | return func(o *options) { 119 | o.enableSetCookie = enableSetCookie 120 | } 121 | } 122 | 123 | // Allow session id from URL query parameters (enabled by default) 124 | func SetEnableSIDInURLQuery(enableSIDInURLQuery bool) Option { 125 | return func(o *options) { 126 | o.enableSIDInURLQuery = enableSIDInURLQuery 127 | } 128 | } 129 | 130 | // Allow session id to be obtained from the request header 131 | func SetEnableSIDInHTTPHeader(enableSIDInHTTPHeader bool) Option { 132 | return func(o *options) { 133 | o.enableSIDInHTTPHeader = enableSIDInHTTPHeader 134 | } 135 | } 136 | 137 | // The key name in the request header where the session ID is stored 138 | // (if it is empty, the default is the cookie name) 139 | func SetSessionNameInHTTPHeader(sessionNameInHTTPHeader string) Option { 140 | return func(o *options) { 141 | o.sessionNameInHTTPHeader = sessionNameInHTTPHeader 142 | } 143 | } 144 | 145 | // Set session management storage 146 | func SetStore(store ManagerStore) Option { 147 | return func(o *options) { 148 | o.store = store 149 | } 150 | } 151 | 152 | // Create a session management instance 153 | func NewManager(opt ...Option) *Manager { 154 | opts := defaultOptions 155 | for _, o := range opt { 156 | o(&opts) 157 | } 158 | 159 | if opts.enableSIDInHTTPHeader && opts.sessionNameInHTTPHeader == "" { 160 | opts.sessionNameInHTTPHeader = opts.cookieName 161 | } 162 | 163 | if opts.store == nil { 164 | opts.store = NewMemoryStore() 165 | } 166 | return &Manager{opts: &opts} 167 | } 168 | 169 | // A session management instance, including start and destroy operations 170 | type Manager struct { 171 | opts *options 172 | } 173 | 174 | func (m *Manager) getContext(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 175 | if ctx == nil { 176 | ctx = context.Background() 177 | } 178 | ctx = newReqContext(ctx, r) 179 | ctx = newResContext(ctx, w) 180 | return ctx 181 | } 182 | 183 | func (m *Manager) signature(sid string) string { 184 | h := hmac.New(sha1.New, m.opts.sign) 185 | h.Write([]byte(sid)) 186 | return fmt.Sprintf("%x", h.Sum(nil)) 187 | } 188 | 189 | func (m *Manager) decodeSessionID(value string) (string, error) { 190 | value, err := url.QueryUnescape(value) 191 | if err != nil { 192 | return "", err 193 | } 194 | 195 | vals := strings.Split(value, ".") 196 | if len(vals) != 2 { 197 | return "", ErrInvalidSessionID 198 | } 199 | 200 | bsid, err := base64.StdEncoding.DecodeString(vals[0]) 201 | if err != nil { 202 | return "", err 203 | } 204 | sid := string(bsid) 205 | 206 | sign := m.signature(sid) 207 | if sign != vals[1] { 208 | return "", ErrInvalidSessionID 209 | } 210 | return sid, nil 211 | } 212 | 213 | func (m *Manager) sessionID(r *http.Request) (string, error) { 214 | var cookieValue string 215 | 216 | if m.opts.enableSetCookie { 217 | cookie, err := r.Cookie(m.opts.cookieName) 218 | if err == nil && cookie.Value != "" { 219 | cookieValue = cookie.Value 220 | } 221 | } 222 | 223 | if m.opts.enableSIDInURLQuery && cookieValue == "" { 224 | err := r.ParseForm() 225 | if err != nil { 226 | return "", err 227 | } 228 | cookieValue = r.FormValue(m.opts.cookieName) 229 | } 230 | 231 | if m.opts.enableSIDInHTTPHeader && cookieValue == "" { 232 | cookieValue = r.Header.Get(m.opts.sessionNameInHTTPHeader) 233 | } 234 | 235 | if cookieValue != "" { 236 | return m.decodeSessionID(cookieValue) 237 | } 238 | 239 | return "", nil 240 | } 241 | 242 | func (m *Manager) encodeSessionID(sid string) string { 243 | b := base64.StdEncoding.EncodeToString([]byte(sid)) 244 | s := fmt.Sprintf("%s.%s", b, m.signature(sid)) 245 | return url.QueryEscape(s) 246 | } 247 | 248 | func (m *Manager) isSecure(r *http.Request) bool { 249 | if !m.opts.secure { 250 | return false 251 | } 252 | host, _, _ := net.SplitHostPort(r.RemoteAddr) 253 | ip := net.ParseIP(host) 254 | if ip.IsLoopback() || ip.IsPrivate() { 255 | return true 256 | } 257 | if r.URL.Scheme != "" { 258 | return r.URL.Scheme == "https" 259 | } 260 | if r.TLS == nil { 261 | return false 262 | } 263 | return true 264 | } 265 | 266 | func (m *Manager) setCookie(sessionID string, w http.ResponseWriter, r *http.Request) { 267 | cookieValue := m.encodeSessionID(sessionID) 268 | 269 | if m.opts.enableSetCookie { 270 | cookie := &http.Cookie{ 271 | Name: m.opts.cookieName, 272 | Value: cookieValue, 273 | Path: "/", 274 | HttpOnly: true, 275 | Secure: m.isSecure(r), 276 | Domain: m.opts.domain, 277 | SameSite: m.opts.sameSite, 278 | } 279 | 280 | if v := m.opts.cookieLifeTime; v > 0 { 281 | cookie.MaxAge = v 282 | cookie.Expires = time.Now().Add(time.Duration(v) * time.Second) 283 | } 284 | 285 | http.SetCookie(w, cookie) 286 | r.AddCookie(cookie) 287 | } 288 | 289 | if m.opts.enableSIDInHTTPHeader { 290 | key := m.opts.sessionNameInHTTPHeader 291 | r.Header.Set(key, cookieValue) 292 | w.Header().Set(key, cookieValue) 293 | } 294 | } 295 | 296 | // Start a session and return to session storage 297 | func (m *Manager) Start(ctx context.Context, w http.ResponseWriter, r *http.Request) (Store, error) { 298 | ctx = m.getContext(ctx, w, r) 299 | 300 | sid, err := m.sessionID(r) 301 | if err != nil { 302 | return nil, err 303 | } 304 | 305 | if sid != "" { 306 | if exists, err := m.opts.store.Check(ctx, sid); err != nil { 307 | return nil, err 308 | } else if exists { 309 | return m.opts.store.Update(ctx, sid, m.opts.expired) 310 | } 311 | } 312 | 313 | sid = m.opts.sessionID(ctx) 314 | store, err := m.opts.store.Create(ctx, sid, m.opts.expired) 315 | if err != nil { 316 | return nil, err 317 | } 318 | 319 | m.setCookie(store.SessionID(), w, r) 320 | return store, nil 321 | } 322 | 323 | // Refresh and return session storage 324 | func (m *Manager) Refresh(ctx context.Context, w http.ResponseWriter, r *http.Request) (Store, error) { 325 | ctx = m.getContext(ctx, w, r) 326 | 327 | oldSID, err := m.sessionID(r) 328 | if err != nil { 329 | return nil, err 330 | } else if oldSID == "" { 331 | oldSID = m.opts.sessionID(ctx) 332 | } 333 | 334 | sid := m.opts.sessionID(ctx) 335 | store, err := m.opts.store.Refresh(ctx, oldSID, sid, m.opts.expired) 336 | if err != nil { 337 | return nil, err 338 | } 339 | 340 | m.setCookie(store.SessionID(), w, r) 341 | return store, nil 342 | } 343 | 344 | // Destroy a session 345 | func (m *Manager) Destroy(ctx context.Context, w http.ResponseWriter, r *http.Request) error { 346 | ctx = m.getContext(ctx, w, r) 347 | 348 | sid, err := m.sessionID(r) 349 | if err != nil { 350 | return err 351 | } else if sid == "" { 352 | return nil 353 | } 354 | 355 | if exists, err := m.opts.store.Check(ctx, sid); err != nil { 356 | return err 357 | } else if !exists { 358 | return nil 359 | } 360 | 361 | err = m.opts.store.Delete(ctx, sid) 362 | if err != nil { 363 | return err 364 | } 365 | 366 | if m.opts.enableSetCookie { 367 | cookie := &http.Cookie{ 368 | Name: m.opts.cookieName, 369 | Path: "/", 370 | HttpOnly: true, 371 | Expires: time.Now(), 372 | MaxAge: -1, 373 | } 374 | 375 | http.SetCookie(w, cookie) 376 | } 377 | 378 | if m.opts.enableSIDInHTTPHeader { 379 | key := m.opts.sessionNameInHTTPHeader 380 | r.Header.Del(key) 381 | w.Header().Del(key) 382 | } 383 | 384 | return nil 385 | } 386 | -------------------------------------------------------------------------------- /session_test.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | . "github.com/smartystreets/goconvey/convey" 11 | ) 12 | 13 | func TestSessionStart(t *testing.T) { 14 | cookieName := "test_session_start" 15 | manager := NewManager( 16 | SetCookieName(cookieName), 17 | ) 18 | 19 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 20 | store, err := manager.Start(r.Context(), w, r) 21 | if err != nil { 22 | t.Error(err) 23 | return 24 | } 25 | 26 | if r.URL.Query().Get("login") == "1" { 27 | foo, ok := store.Get("foo") 28 | fmt.Fprintf(w, "%v:%v", foo, ok) 29 | return 30 | } 31 | 32 | store.Set("foo", "bar") 33 | err = store.Save() 34 | if err != nil { 35 | t.Error(err) 36 | return 37 | } 38 | fmt.Fprint(w, "ok") 39 | })) 40 | defer ts.Close() 41 | 42 | Convey("Test session start", t, func() { 43 | res, err := http.Get(ts.URL) 44 | So(err, ShouldBeNil) 45 | So(res, ShouldNotBeNil) 46 | So(len(res.Cookies()), ShouldBeGreaterThan, 0) 47 | 48 | cookie := res.Cookies()[0] 49 | So(cookie.Name, ShouldEqual, cookieName) 50 | 51 | req, err := http.NewRequest("GET", fmt.Sprintf("%s?login=1", ts.URL), nil) 52 | So(err, ShouldBeNil) 53 | req.AddCookie(cookie) 54 | 55 | res, err = http.DefaultClient.Do(req) 56 | So(err, ShouldBeNil) 57 | So(res, ShouldNotBeNil) 58 | 59 | buf, err := io.ReadAll(res.Body) 60 | So(err, ShouldBeNil) 61 | res.Body.Close() 62 | So(string(buf), ShouldEqual, "bar:true") 63 | }) 64 | } 65 | 66 | func TestSessionDestroy(t *testing.T) { 67 | cookieName := "test_session_destroy" 68 | 69 | manager := NewManager( 70 | SetCookieName(cookieName), 71 | ) 72 | 73 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 74 | if r.URL.Query().Get("logout") == "1" { 75 | err := manager.Destroy(r.Context(), w, r) 76 | if err != nil { 77 | t.Error(err) 78 | return 79 | } 80 | fmt.Fprint(w, "ok") 81 | return 82 | } 83 | 84 | store, err := manager.Start(r.Context(), w, r) 85 | if err != nil { 86 | t.Error(err) 87 | return 88 | } 89 | 90 | if r.URL.Query().Get("check") == "1" { 91 | foo, ok := store.Get("foo") 92 | fmt.Fprintf(w, "%v:%v", foo, ok) 93 | return 94 | } 95 | 96 | store.Set("foo", "bar") 97 | err = store.Save() 98 | if err != nil { 99 | t.Error(err) 100 | return 101 | } 102 | fmt.Fprint(w, "ok") 103 | })) 104 | defer ts.Close() 105 | 106 | Convey("Test session destroy", t, func() { 107 | res, err := http.Get(ts.URL) 108 | So(err, ShouldBeNil) 109 | So(res, ShouldNotBeNil) 110 | So(len(res.Cookies()), ShouldBeGreaterThan, 0) 111 | 112 | cookie := res.Cookies()[0] 113 | So(cookie.Name, ShouldEqual, cookieName) 114 | 115 | req, err := http.NewRequest("GET", fmt.Sprintf("%s?logout=1", ts.URL), nil) 116 | So(err, ShouldBeNil) 117 | 118 | req.AddCookie(cookie) 119 | res, err = http.DefaultClient.Do(req) 120 | So(err, ShouldBeNil) 121 | So(res, ShouldNotBeNil) 122 | 123 | req, err = http.NewRequest("GET", fmt.Sprintf("%s?check=1", ts.URL), nil) 124 | So(err, ShouldBeNil) 125 | req.AddCookie(cookie) 126 | res, err = http.DefaultClient.Do(req) 127 | So(err, ShouldBeNil) 128 | So(res, ShouldNotBeNil) 129 | 130 | buf, err := io.ReadAll(res.Body) 131 | So(err, ShouldBeNil) 132 | res.Body.Close() 133 | So(string(buf), ShouldEqual, ":false") 134 | }) 135 | } 136 | 137 | func TestSessionRefresh(t *testing.T) { 138 | cookieName := "test_session_refresh" 139 | 140 | manager := NewManager( 141 | SetCookieName(cookieName), 142 | ) 143 | 144 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 145 | store, err := manager.Start(r.Context(), w, r) 146 | if err != nil { 147 | t.Error(err) 148 | return 149 | } 150 | 151 | if r.URL.Query().Get("refresh") == "1" { 152 | vstore, verr := manager.Refresh(r.Context(), w, r) 153 | if verr != nil { 154 | t.Error(err) 155 | return 156 | } 157 | 158 | if vstore.SessionID() == store.SessionID() { 159 | t.Errorf("Not expected value") 160 | return 161 | } 162 | 163 | foo, ok := vstore.Get("foo") 164 | fmt.Fprintf(w, "%s:%v", foo, ok) 165 | return 166 | } 167 | 168 | store.Set("foo", "bar") 169 | err = store.Save() 170 | if err != nil { 171 | t.Error(err) 172 | return 173 | } 174 | fmt.Fprint(w, "ok") 175 | })) 176 | defer ts.Close() 177 | 178 | Convey("Test session refresh", t, func() { 179 | res, err := http.Get(ts.URL) 180 | So(err, ShouldBeNil) 181 | So(res, ShouldNotBeNil) 182 | So(len(res.Cookies()), ShouldBeGreaterThan, 0) 183 | 184 | cookie := res.Cookies()[0] 185 | So(cookie.Name, ShouldEqual, cookieName) 186 | 187 | req, err := http.NewRequest("GET", fmt.Sprintf("%s?refresh=1", ts.URL), nil) 188 | So(err, ShouldBeNil) 189 | 190 | req.AddCookie(cookie) 191 | res, err = http.DefaultClient.Do(req) 192 | So(err, ShouldBeNil) 193 | So(res, ShouldNotBeNil) 194 | So(len(res.Cookies()), ShouldBeGreaterThan, 0) 195 | So(res.Cookies()[0].Value, ShouldNotEqual, cookie.Value) 196 | 197 | buf, err := io.ReadAll(res.Body) 198 | So(err, ShouldBeNil) 199 | res.Body.Close() 200 | So(string(buf), ShouldEqual, "bar:true") 201 | }) 202 | } 203 | -------------------------------------------------------------------------------- /store.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "time" 7 | 8 | "github.com/bytedance/gopkg/collection/skipmap" 9 | ) 10 | 11 | var ( 12 | _ ManagerStore = &memoryStore{} 13 | _ Store = &store{} 14 | now = time.Now 15 | ) 16 | 17 | // Management of session storage, including creation, update, and delete operations 18 | type ManagerStore interface { 19 | // Check the session store exists 20 | Check(ctx context.Context, sid string) (bool, error) 21 | // Create a session store and specify the expiration time (in seconds) 22 | Create(ctx context.Context, sid string, expired int64) (Store, error) 23 | // Update a session store and specify the expiration time (in seconds) 24 | Update(ctx context.Context, sid string, expired int64) (Store, error) 25 | // Delete a session store 26 | Delete(ctx context.Context, sid string) error 27 | // Use sid to replace old sid and return session store 28 | Refresh(ctx context.Context, oldsid, sid string, expired int64) (Store, error) 29 | // Close storage, release resources 30 | Close() error 31 | } 32 | 33 | // A session id storage operation 34 | type Store interface { 35 | // Get a session storage context 36 | Context() context.Context 37 | // Get the current session id 38 | SessionID() string 39 | // Set session value, call save function to take effect 40 | Set(key string, value interface{}) 41 | // Get session value 42 | Get(key string) (interface{}, bool) 43 | // Delete session value, call save function to take effect 44 | Delete(key string) interface{} 45 | // Save session data 46 | Save() error 47 | // Clear all session data 48 | Flush() error 49 | } 50 | 51 | // Create a new session storage (memory) 52 | func NewMemoryStore() ManagerStore { 53 | mstore := &memoryStore{ 54 | ticker: time.NewTicker(time.Second), 55 | data: skipmap.NewString(), 56 | } 57 | 58 | go mstore.gc() 59 | return mstore 60 | } 61 | 62 | type dataItem struct { 63 | sid string 64 | expiredAt time.Time 65 | values map[string]interface{} 66 | } 67 | 68 | func newDataItem(sid string, values map[string]interface{}, expired int64) *dataItem { 69 | return &dataItem{ 70 | sid: sid, 71 | expiredAt: now().Add(time.Duration(expired) * time.Second), 72 | values: values, 73 | } 74 | } 75 | 76 | type memoryStore struct { 77 | ticker *time.Ticker 78 | data *skipmap.StringMap 79 | } 80 | 81 | func (s *memoryStore) gc() { 82 | for range s.ticker.C { 83 | s.data.Range(func(key string, value interface{}) bool { 84 | if item, ok := value.(*dataItem); ok && item.expiredAt.Before(now()) { 85 | s.data.Delete(key) 86 | } 87 | return true 88 | }) 89 | } 90 | } 91 | 92 | func (s *memoryStore) save(sid string, values map[string]interface{}, expired int64) { 93 | if dt, ok := s.data.Load(sid); ok { 94 | dt.(*dataItem).values = values 95 | return 96 | } 97 | 98 | s.data.Store(sid, newDataItem(sid, values, expired)) 99 | } 100 | 101 | func (s *memoryStore) Check(ctx context.Context, sid string) (bool, error) { 102 | dt, ok := s.data.Load(sid) 103 | if !ok { 104 | return false, nil 105 | } 106 | 107 | if item, ok := dt.(*dataItem); ok && item.expiredAt.After(now()) { 108 | return true, nil 109 | } 110 | return false, nil 111 | } 112 | 113 | func (s *memoryStore) Create(ctx context.Context, sid string, expired int64) (Store, error) { 114 | return newStore(ctx, s, sid, expired, nil), nil 115 | } 116 | 117 | func (s *memoryStore) Update(ctx context.Context, sid string, expired int64) (Store, error) { 118 | dt, ok := s.data.Load(sid) 119 | if !ok { 120 | return newStore(ctx, s, sid, expired, nil), nil 121 | } 122 | 123 | item := dt.(*dataItem) 124 | item.expiredAt = now().Add(time.Duration(expired) * time.Second) 125 | s.data.Store(sid, item) 126 | return newStore(ctx, s, sid, expired, item.values), nil 127 | } 128 | 129 | func (s *memoryStore) delete(sid string) { 130 | s.data.Delete(sid) 131 | } 132 | 133 | func (s *memoryStore) Delete(_ context.Context, sid string) error { 134 | s.delete(sid) 135 | return nil 136 | } 137 | 138 | func (s *memoryStore) Refresh(ctx context.Context, oldsid, sid string, expired int64) (Store, error) { 139 | dt, ok := s.data.Load(oldsid) 140 | if !ok { 141 | return newStore(ctx, s, sid, expired, nil), nil 142 | } 143 | 144 | item := dt.(*dataItem) 145 | newItem := newDataItem(sid, item.values, expired) 146 | s.data.Store(sid, newItem) 147 | s.delete(oldsid) 148 | return newStore(ctx, s, sid, expired, newItem.values), nil 149 | } 150 | 151 | func (s *memoryStore) Close() error { 152 | s.ticker.Stop() 153 | return nil 154 | } 155 | 156 | func newStore(ctx context.Context, mstore *memoryStore, sid string, expired int64, values map[string]interface{}) *store { 157 | if values == nil { 158 | values = make(map[string]interface{}) 159 | } 160 | 161 | return &store{ 162 | mstore: mstore, 163 | ctx: ctx, 164 | sid: sid, 165 | expired: expired, 166 | values: values, 167 | } 168 | } 169 | 170 | type store struct { 171 | sync.RWMutex 172 | mstore *memoryStore 173 | ctx context.Context 174 | sid string 175 | expired int64 176 | values map[string]interface{} 177 | } 178 | 179 | func (s *store) Context() context.Context { 180 | return s.ctx 181 | } 182 | 183 | func (s *store) SessionID() string { 184 | return s.sid 185 | } 186 | 187 | func (s *store) Set(key string, value interface{}) { 188 | s.Lock() 189 | s.values[key] = value 190 | s.Unlock() 191 | } 192 | 193 | func (s *store) Get(key string) (interface{}, bool) { 194 | s.RLock() 195 | val, ok := s.values[key] 196 | s.RUnlock() 197 | return val, ok 198 | } 199 | 200 | func (s *store) Delete(key string) interface{} { 201 | s.RLock() 202 | v, ok := s.values[key] 203 | s.RUnlock() 204 | 205 | if ok { 206 | s.Lock() 207 | delete(s.values, key) 208 | s.Unlock() 209 | } 210 | return v 211 | } 212 | 213 | func (s *store) Flush() error { 214 | s.Lock() 215 | s.values = make(map[string]interface{}) 216 | s.Unlock() 217 | 218 | return s.Save() 219 | } 220 | 221 | func (s *store) Save() error { 222 | s.RLock() 223 | values := s.values 224 | s.RUnlock() 225 | 226 | s.mstore.save(s.sid, values, s.expired) 227 | return nil 228 | } 229 | -------------------------------------------------------------------------------- /store_test.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | . "github.com/smartystreets/goconvey/convey" 9 | ) 10 | 11 | func testStore(store Store) { 12 | foo, ok := store.Get("foo") 13 | So(ok, ShouldBeFalse) 14 | So(foo, ShouldBeNil) 15 | 16 | store.Set("foo", "bar") 17 | store.Set("foo2", "bar2") 18 | err := store.Save() 19 | So(err, ShouldBeNil) 20 | 21 | foo, ok = store.Get("foo") 22 | So(ok, ShouldBeTrue) 23 | So(foo, ShouldEqual, "bar") 24 | 25 | foo = store.Delete("foo") 26 | So(foo, ShouldEqual, "bar") 27 | 28 | foo, ok = store.Get("foo") 29 | So(ok, ShouldBeFalse) 30 | So(foo, ShouldBeNil) 31 | 32 | foo2, ok := store.Get("foo2") 33 | So(ok, ShouldBeTrue) 34 | So(foo2, ShouldEqual, "bar2") 35 | 36 | err = store.Flush() 37 | So(err, ShouldBeNil) 38 | 39 | foo2, ok = store.Get("foo2") 40 | So(ok, ShouldBeFalse) 41 | So(foo2, ShouldBeNil) 42 | } 43 | 44 | func TestStore(t *testing.T) { 45 | mstore := NewMemoryStore() 46 | 47 | Convey("Test memory storage operation", t, func() { 48 | store, err := mstore.Create(context.Background(), "test_memory_store", 10) 49 | if err != nil { 50 | So(err, ShouldBeNil) 51 | } 52 | testStore(store) 53 | }) 54 | } 55 | 56 | func testManagerStore(mstore ManagerStore) { 57 | sid := "test_manager_store" 58 | store, err := mstore.Create(context.Background(), sid, 10) 59 | So(store, ShouldNotBeNil) 60 | So(err, ShouldBeNil) 61 | 62 | store.Set("foo", "bar") 63 | err = store.Save() 64 | So(err, ShouldBeNil) 65 | 66 | store, err = mstore.Update(context.Background(), sid, 10) 67 | So(store, ShouldNotBeNil) 68 | So(err, ShouldBeNil) 69 | 70 | foo, ok := store.Get("foo") 71 | So(ok, ShouldBeTrue) 72 | So(foo, ShouldEqual, "bar") 73 | 74 | newsid := "test_manager_store2" 75 | store, err = mstore.Refresh(context.Background(), sid, newsid, 10) 76 | So(store, ShouldNotBeNil) 77 | So(err, ShouldBeNil) 78 | 79 | foo, ok = store.Get("foo") 80 | So(ok, ShouldBeTrue) 81 | So(foo, ShouldEqual, "bar") 82 | 83 | exists, err := mstore.Check(context.Background(), sid) 84 | So(exists, ShouldBeFalse) 85 | So(err, ShouldBeNil) 86 | 87 | err = mstore.Delete(context.Background(), newsid) 88 | So(err, ShouldBeNil) 89 | 90 | exists, err = mstore.Check(context.Background(), newsid) 91 | So(exists, ShouldBeFalse) 92 | So(err, ShouldBeNil) 93 | } 94 | 95 | func TestManagerMemoryStore(t *testing.T) { 96 | mstore := NewMemoryStore() 97 | 98 | Convey("Test memory-based storage management operations", t, func() { 99 | testManagerStore(mstore) 100 | }) 101 | } 102 | 103 | func testStoreWithExpired(mstore ManagerStore) { 104 | sid := "test_store_expired" 105 | store, err := mstore.Create(context.Background(), sid, 1) 106 | So(store, ShouldNotBeNil) 107 | So(err, ShouldBeNil) 108 | 109 | store.Set("foo", "bar") 110 | err = store.Save() 111 | So(err, ShouldBeNil) 112 | 113 | store, err = mstore.Update(context.Background(), sid, 1) 114 | So(store, ShouldNotBeNil) 115 | So(err, ShouldBeNil) 116 | 117 | foo, ok := store.Get("foo") 118 | So(foo, ShouldEqual, "bar") 119 | So(ok, ShouldBeTrue) 120 | 121 | time.Sleep(time.Second * 3) 122 | 123 | exists, err := mstore.Check(context.Background(), sid) 124 | So(err, ShouldBeNil) 125 | So(exists, ShouldBeFalse) 126 | } 127 | 128 | func TestMemoryStoreWithExpired(t *testing.T) { 129 | mstore := NewMemoryStore() 130 | 131 | Convey("Test memory store expiration", t, func() { 132 | testStoreWithExpired(mstore) 133 | }) 134 | } 135 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/hex" 6 | "io" 7 | ) 8 | 9 | // create a UUID, reference: https://github.com/google/uuid 10 | func newUUID() string { 11 | var buf [16]byte 12 | _, _ = io.ReadFull(rand.Reader, buf[:]) 13 | buf[6] = (buf[6] & 0x0f) | 0x40 14 | buf[8] = (buf[8] & 0x3f) | 0x80 15 | 16 | dst := make([]byte, 36) 17 | hex.Encode(dst, buf[:4]) 18 | dst[8] = '-' 19 | hex.Encode(dst[9:13], buf[4:6]) 20 | dst[13] = '-' 21 | hex.Encode(dst[14:18], buf[6:8]) 22 | dst[18] = '-' 23 | hex.Encode(dst[19:23], buf[8:10]) 24 | dst[23] = '-' 25 | hex.Encode(dst[24:], buf[10:]) 26 | 27 | return string(dst) 28 | } 29 | --------------------------------------------------------------------------------