├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── csrf.go ├── csrf_test.go ├── helpers.go ├── helpers_test.go ├── options.go ├── options_test.go ├── store.go └── store_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | *.DS_Store 27 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | sudo: false 3 | 4 | matrix: 5 | include: 6 | - go: 1.2 7 | - go: 1.3 8 | - go: 1.4 9 | - go: 1.5 10 | - go: tip 11 | 12 | install: 13 | - go get golang.org/x/tools/cmd/vet 14 | 15 | script: 16 | - go get -t -v ./... 17 | - diff -u <(echo -n) <(gofmt -d .) 18 | - go tool vet . 19 | - go test -v -race ./... 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Matt Silverlock (matt@eatsleeprepeat.net) All rights 2 | reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation and/or 12 | other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software without 16 | specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # goji/csrf 2 | [](https://godoc.org/github.com/goji/csrf) [](https://travis-ci.org/goji/csrf) 3 | 4 | **Goji v2 users**: A new version with support for [Goji 5 | v2](https://github.com/goji/goji) and any other `context.Context` aware 6 | muxes/applications (i.e. not just Goji!) is available in the 7 | [goji/ctx-csrf](https://github.com/goji/ctx-csrf) repo. 8 | 9 | goji/csrf is a HTTP middleware library that provides [cross-site request 10 | forgery](http://blog.codinghorror.com/preventing-csrf-and-xsrf-attacks/) (CSRF) 11 | protection. It includes: 12 | 13 | * The `csrf.Protect` middleware/handler that can be used with `goji.Use` to 14 | provide CSRF protection on routes attached to a router or a sub-router. 15 | * A `csrf.Token` function that provides the token to pass into your response, 16 | whether that be a HTML form or a JSON response body. 17 | * ... and a `csrf.TemplateField` helper that you can pass into your `html/template` 18 | templates to replace a `{{ .csrfField }}` template tag with a hidden input 19 | field. 20 | 21 | This library is designed to work with the [Goji](https://github.com/zenazn/goji) 22 | micro-framework, which is a simple web framework for Go that is broadly 23 | compatible with other parts of the Go ecosystem. It makes use of Goji's `web.C` 24 | request context, which doesn't rely on a global map, and is therefore safe to 25 | attach to your top-level router (if you so wish). 26 | 27 | The library also assumes HTTPS by default: sending cookies over vanilla HTTP 28 | is risky and you're likely to get hurt. 29 | 30 | ## Examples 31 | 32 | goji/csrf is easy to use: add the middleware to your stack with the below: 33 | 34 | ```go 35 | goji.Use(csrf.Protect([]byte("32-byte-long-auth-key"))) 36 | ``` 37 | 38 | ... and then collect the token with `csrf.Token(c, r)` before passing it to the 39 | template, JSON body or HTTP header (you pick!). goji/csrf inspects HTTP headers 40 | (first) and the form body (second) on subsequent POST/PUT/PATCH/DELETE/etc. 41 | requests for the token. 42 | 43 | ### HTML Forms 44 | 45 | Here's the common use-case: HTML forms you want to provide CSRF protection for, 46 | in order to protect malicious POST requests being made: 47 | 48 | ```go 49 | package main 50 | 51 | import ( 52 | "html/template" 53 | "net/http" 54 | 55 | "github.com/goji/csrf" 56 | "github.com/zenazn/goji" 57 | ) 58 | 59 | func main() { 60 | // Add the middleware to your router. 61 | goji.Use(csrf.Protect([]byte("32-byte-long-auth-key"))) 62 | goji.Get("/signup", ShowSignupForm) 63 | // POST requests without a valid token will return a HTTP 403 Forbidden. 64 | goji.Post("/signup/post", SubmitSignupForm) 65 | 66 | goji.Serve() 67 | } 68 | 69 | func ShowSignupForm(c web.C, w http.ResponseWriter, r *http.Request) { 70 | // signup_form.tmpl just needs a {{ .csrfField }} template tag for 71 | // csrf.TemplateField to inject the CSRF token into. Easy! 72 | t.ExecuteTemplate(w, "signup_form.tmpl", map[string]interface{ 73 | csrf.TemplateTag: csrf.TemplateField(c, r), 74 | }) 75 | // We could also retrieve the token directly from csrf.Token(c, r) and 76 | // set it in the request header - w.Header.Set("X-CSRF-Token", token) 77 | // This is useful if your sending JSON to clients or a front-end JavaScript 78 | // framework. 79 | } 80 | 81 | func SubmitSignupForm(c web.C, w http.ResponseWriter, r *http.Request) { 82 | // We can trust that requests making it this far have satisfied 83 | // our CSRF protection requirements. 84 | } 85 | ``` 86 | 87 | ### JSON Responses 88 | 89 | This approach is useful if you're using a front-end JavaScript framework like 90 | Ember or Angular, or are providing a JSON API. 91 | 92 | We'll also look at applying selective CSRF protection using Goji's sub-routers, 93 | as we don't handle any POST/PUT/DELETE requests with our top-level router. 94 | 95 | ```go 96 | package main 97 | 98 | import ( 99 | "github.com/goji/csrf" 100 | "github.com/zenazn/goji/graceful" 101 | "github.com/zenazn/goji/web" 102 | ) 103 | 104 | func main() { 105 | r := web.New() 106 | // Our top-level router doesn't need CSRF protection: it's simple. 107 | r.Get("/", ShowIndex) 108 | 109 | api := web.New() 110 | r.Handle("/api/*", s) 111 | // ... but our /api/* routes do, so we add it to the sub-router only. 112 | s.Use(csrf.Protect([]byte("32-byte-long-auth-key"))) 113 | 114 | s.Get("/api/user/:id", GetUser) 115 | s.Post("/api/user", PostUser) 116 | 117 | graceful.ListenAndServe(":8000", r) 118 | } 119 | 120 | func GetUser(c web.C, w http.ResponseWriter, r *http.Request) { 121 | // Authenticate the request, get the :id from the route params, 122 | // and fetch the user from the DB, etc. 123 | 124 | // Get the token and pass it in the CSRF header. Our JSON-speaking client 125 | // or JavaScript framework can now read the header and return the token in 126 | // in its own "X-CSRF-Token" request header on the subsequent POST. 127 | w.Header().Set("X-CSRF-Token", csrf.Token(c, r)) 128 | b, err := json.Marshal(user) 129 | if err != nil { 130 | http.Error(...) 131 | return 132 | } 133 | 134 | w.Write(b) 135 | } 136 | ``` 137 | 138 | ### Setting Options 139 | 140 | What about providing your own error handler and changing the HTTP header the 141 | package inspects on requests? (i.e. an existing API you're porting to Go). Well, 142 | goji/csrf provides options for changing these as you see fit: 143 | 144 | ```go 145 | func main() { 146 | CSRF := csrf.Protect( 147 | []byte("a-32-byte-long-key-goes-here"), 148 | csrf.RequestHeader("Authenticity-Token"), 149 | csrf.FieldName("authenticity_token"), 150 | // Note that csrf.ErrorHandler takes a Goji web.Handler type, else 151 | // your error handler can't retrieve the error reason from the context. 152 | // The signature `func UnauthHandler(c web.C, w http.ResponseWriter, r *http.Request)` 153 | // is a web.Handler, and the simplest to use if you'd like to serve 154 | // "pretty" error pages (who doesn't?). 155 | csrf.ErrorHandler(web.HandlerFunc(serverError(403))), 156 | ) 157 | 158 | goji.Use(CSRF) 159 | goji.Get("/signup", GetSignupForm) 160 | goji.Post("/signup", PostSignupForm) 161 | 162 | goji.Serve() 163 | } 164 | ``` 165 | 166 | Not too bad, right? 167 | 168 | If there's something you're confused about or a feature you would like to see 169 | added, open an issue with your code so far. 170 | 171 | ## Design Notes 172 | 173 | Getting CSRF protection right is important, so here's some background: 174 | 175 | * This library generates unique-per-request (masked) tokens as a mitigation 176 | against the [BREACH attack](http://breachattack.com/). 177 | * The 'base' (unmasked) token is stored in the session, which means that 178 | multiple browser tabs won't cause a user problems as their per-request token 179 | is compared with the base token. 180 | * Operates on a "whitelist only" approach where safe (non-mutating) HTTP methods 181 | (GET, HEAD, OPTIONS, TRACE) are the *only* methods where token validation is not 182 | enforced. 183 | * The design is based on the battle-tested 184 | [Django](https://docs.djangoproject.com/en/1.8/ref/csrf/) and [Ruby on 185 | Rails](http://api.rubyonrails.org/classes/ActionController/RequestForgeryProtection.html) 186 | approaches. 187 | * Cookies are authenticated and based on the [securecookie](https://github.com/gorilla/securecookie) 188 | library. They're also Secure (issued over HTTPS only) and are HttpOnly 189 | by default, because sane defaults are important. 190 | * Go's `crypto/rand` library is used to generate the 32 byte (256 bit) tokens 191 | and the one-time-pad used for masking them. 192 | 193 | This library does not seek to be adventurous. 194 | 195 | ## License 196 | 197 | BSD licensed. See the LICENSE file for details. 198 | 199 | -------------------------------------------------------------------------------- /csrf.go: -------------------------------------------------------------------------------- 1 | // Package csrf (goji/csrf) provides Cross Site Request Forgery 2 | // protection middleware for the Goji microframework (https://goji.io). 3 | package csrf 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | "net/http" 9 | "net/url" 10 | 11 | "github.com/gorilla/securecookie" 12 | "github.com/zenazn/goji/web" 13 | ) 14 | 15 | // CSRF token length in bytes. 16 | const tokenLength = 32 17 | 18 | // Context/session keys & prefixes 19 | const ( 20 | tokenKey string = "goji.csrf.Token" 21 | formKey string = "goji.csrf.Form" 22 | errorKey string = "goji.csrf.Error" 23 | cookieName string = "_goji_csrf" 24 | errorPrefix string = "goji/csrf: " 25 | ) 26 | 27 | var ( 28 | // The name value used in form fields. 29 | fieldName = tokenKey 30 | // The default HTTP request header to inspect 31 | headerName = "X-CSRF-Token" 32 | // Idempotent (safe) methods as defined by RFC7231 section 4.2.2. 33 | safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"} 34 | ) 35 | 36 | // TemplateTag provides a default template tag - e.g. {{ .csrfField }} - for use 37 | // with the TemplateField function. 38 | var TemplateTag = "csrfField" 39 | 40 | var ( 41 | // ErrNoReferer is returned when a HTTPS request provides an empty Referer 42 | // header. 43 | ErrNoReferer = errors.New("referer not supplied") 44 | // ErrBadReferer is returned when the scheme & host in the URL do not match 45 | // the supplied Referer header. 46 | ErrBadReferer = errors.New("referer invalid") 47 | // ErrNoToken is returned if no CSRF token is supplied in the request. 48 | ErrNoToken = errors.New("CSRF token not found in request") 49 | // ErrBadToken is returned if the CSRF token in the request does not match 50 | // the token in the session, or is otherwise malformed. 51 | ErrBadToken = errors.New("CSRF token invalid") 52 | ) 53 | 54 | type csrf struct { 55 | c *web.C 56 | h http.Handler 57 | sc *securecookie.SecureCookie 58 | st store 59 | opts options 60 | } 61 | 62 | // options contains the optional settings for the CSRF middleware. 63 | type options struct { 64 | MaxAge int 65 | Domain string 66 | Path string 67 | // Note that the function and field names match the case of the associated 68 | // http.Cookie field instead of the "correct" HTTPOnly name that golint suggests. 69 | HttpOnly bool 70 | Secure bool 71 | RequestHeader string 72 | FieldName string 73 | ErrorHandler web.Handler 74 | CookieName string 75 | } 76 | 77 | // Protect is HTTP middleware that provides Cross-Site Request Forgery 78 | // protection. 79 | // 80 | // It securely generates a masked (unique-per-request) token that 81 | // can be embedded in the HTTP response (e.g. form field or HTTP header). 82 | // The original (unmasked) token is stored in the session, which is inaccessible 83 | // by an attacker (provided you are using HTTPS). Subsequent requests are 84 | // expected to include this token, which is compared against the session token. 85 | // Requests that do not provide a matching token are served with a HTTP 403 86 | // 'Forbidden' error response. 87 | // 88 | // Example: 89 | // package main 90 | // 91 | // import ( 92 | // "github.com/goji/csrf" 93 | // "github.com/zenazn/goji" 94 | // ) 95 | // 96 | // func main() { 97 | // // Add the middleware to your router. 98 | // goji.Use(csrf.Protect([]byte("32-byte-long-auth-key"))) 99 | // goji.Get("/signup", GetSignupForm) 100 | // // POST requests without a valid token will return a HTTP 403 Forbidden. 101 | // goji.Post("/signup/post", PostSignupForm) 102 | // 103 | // goji.Serve() 104 | // } 105 | // 106 | // func GetSignupForm(c web.C, w http.ResponseWriter, r *http.Request) { 107 | // // signup_form.tmpl just needs a {{ .csrfField }} template tag for 108 | // // csrf.TemplateField to inject the CSRF token into. Easy! 109 | // t.ExecuteTemplate(w, "signup_form.tmpl", map[string]interface{ 110 | // csrf.TemplateTag: csrf.TemplateField(c, r), 111 | // }) 112 | // // We could also retrieve the token directly from csrf.Token(c, r) and 113 | // // set it in the request header - w.Header.Set("X-CSRF-Token", token) 114 | // // This is useful if your sending JSON to clients or a front-end JavaScript 115 | // // framework. 116 | // } 117 | // 118 | func Protect(authKey []byte, opts ...Option) func(*web.C, http.Handler) http.Handler { 119 | return func(c *web.C, h http.Handler) http.Handler { 120 | cs := parseOptions(h, opts...) 121 | 122 | // Set the defaults if no options have been specified 123 | if cs.opts.ErrorHandler == nil { 124 | cs.opts.ErrorHandler = web.HandlerFunc(unauthorizedHandler) 125 | } 126 | 127 | if cs.opts.MaxAge < 1 { 128 | // Default of 12 hours 129 | cs.opts.MaxAge = 3600 * 12 130 | } 131 | 132 | if cs.opts.FieldName == "" { 133 | cs.opts.FieldName = fieldName 134 | } 135 | 136 | if cs.opts.CookieName == "" { 137 | cs.opts.CookieName = cookieName 138 | } 139 | 140 | if cs.opts.RequestHeader == "" { 141 | cs.opts.RequestHeader = headerName 142 | } 143 | 144 | // Create an authenticated securecookie instance. 145 | if cs.sc == nil { 146 | cs.sc = securecookie.New(authKey, nil) 147 | // Use JSON serialization (faster than one-off gob encoding) 148 | cs.sc.SetSerializer(securecookie.JSONEncoder{}) 149 | // Set the MaxAge of the underlying securecookie. 150 | cs.sc.MaxAge(cs.opts.MaxAge) 151 | } 152 | 153 | if cs.st == nil { 154 | // Default to the cookieStore 155 | cs.st = &cookieStore{ 156 | name: cs.opts.CookieName, 157 | maxAge: cs.opts.MaxAge, 158 | secure: cs.opts.Secure, 159 | httpOnly: cs.opts.HttpOnly, 160 | path: cs.opts.Path, 161 | domain: cs.opts.Domain, 162 | sc: cs.sc, 163 | } 164 | } 165 | 166 | // Initialize Goji's request context 167 | cs.c = c 168 | 169 | return *cs 170 | } 171 | } 172 | 173 | // Implements http.Handler for the csrf type. 174 | func (cs csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { 175 | // Create our request context if it does not already exist. 176 | if cs.c.Env == nil { 177 | cs.c.Env = make(map[interface{}]interface{}) 178 | } 179 | 180 | // Retrieve the token from the session. 181 | // An error represents either a cookie that failed HMAC validation 182 | // or that doesn't exist. 183 | realToken, err := cs.st.Get(cs.c, r) 184 | if err != nil || len(realToken) != tokenLength { 185 | // If there was an error retrieving the token, the token doesn't exist 186 | // yet, or it's the wrong length, generate a new token. 187 | // Note that the new token will (correctly) fail validation downstream 188 | // as it will no longer match the request token. 189 | realToken, err = generateRandomBytes(tokenLength) 190 | if err != nil { 191 | envError(cs.c, err) 192 | cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r) 193 | return 194 | } 195 | 196 | // Save the new (real) token in the session store. 197 | err = cs.st.Save(realToken, w) 198 | if err != nil { 199 | envError(cs.c, err) 200 | cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r) 201 | return 202 | } 203 | } 204 | 205 | // Save the masked token to the request context 206 | cs.c.Env[tokenKey] = mask(realToken, cs.c, r) 207 | // Save the field name to the request context 208 | cs.c.Env[formKey] = cs.opts.FieldName 209 | 210 | // HTTP methods not defined as idempotent ("safe") under RFC7231 require 211 | // inspection. 212 | if !contains(safeMethods, r.Method) { 213 | // Enforce an origin check for HTTPS connections. As per the Django CSRF 214 | // implementation (https://goo.gl/vKA7GE) the Referer header is almost 215 | // always present for same-domain HTTP requests. 216 | if r.URL.Scheme == "https" { 217 | // Fetch the Referer value. Call the error handler if it's empty or 218 | // otherwise fails to parse. 219 | referer, err := url.Parse(r.Referer()) 220 | if err != nil || referer.String() == "" { 221 | envError(cs.c, ErrNoReferer) 222 | cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r) 223 | return 224 | } 225 | 226 | if sameOrigin(r.URL, referer) == false { 227 | envError(cs.c, ErrBadReferer) 228 | cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r) 229 | return 230 | } 231 | } 232 | 233 | // If the token returned from the session store is nil for non-idempotent 234 | // ("unsafe") methods, call the error handler. 235 | if realToken == nil { 236 | envError(cs.c, ErrNoToken) 237 | cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r) 238 | return 239 | } 240 | 241 | // Retrieve the combined token (pad + masked) token and unmask it. 242 | requestToken := unmask(cs.requestToken(r)) 243 | 244 | // Compare the request token against the real token 245 | if !compareTokens(requestToken, realToken) { 246 | envError(cs.c, ErrBadToken) 247 | cs.opts.ErrorHandler.ServeHTTPC(*cs.c, w, r) 248 | return 249 | } 250 | 251 | } 252 | 253 | // Set the Vary: Cookie header to protect clients from caching the response. 254 | w.Header().Add("Vary", "Cookie") 255 | 256 | // Call the wrapped handler/router on success 257 | cs.h.ServeHTTP(w, r) 258 | } 259 | 260 | // unauthorizedhandler sets a HTTP 403 Forbidden status and writes the 261 | // CSRF failure reason to the response. 262 | func unauthorizedHandler(c web.C, w http.ResponseWriter, r *http.Request) { 263 | http.Error(w, fmt.Sprintf("%s - %s", 264 | http.StatusText(http.StatusForbidden), FailureReason(c, r)), 265 | http.StatusForbidden) 266 | return 267 | } 268 | -------------------------------------------------------------------------------- /csrf_test.go: -------------------------------------------------------------------------------- 1 | package csrf 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/zenazn/goji/web" 10 | ) 11 | 12 | var testKey = []byte("keep-it-secret-keep-it-safe-----") 13 | var testHandler = web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {}) 14 | 15 | // TestProtect is a high-level test to make sure the middleware returns the 16 | // wrapped handler with a 200 OK status. 17 | func TestProtect(t *testing.T) { 18 | s := web.New() 19 | s.Use(Protect(testKey)) 20 | 21 | s.Get("/", testHandler) 22 | 23 | r, err := http.NewRequest("GET", "/", nil) 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | 28 | rr := httptest.NewRecorder() 29 | s.ServeHTTP(rr, r) 30 | 31 | if rr.Code != http.StatusOK { 32 | t.Fatalf("middleware failed to pass to the next handler: got %v want %v", 33 | rr.Code, http.StatusOK) 34 | } 35 | 36 | if rr.Header().Get("Set-Cookie") == "" { 37 | t.Fatalf("cookie not set: got %q", rr.Header().Get("Set-Cookie")) 38 | } 39 | } 40 | 41 | // Test that idempotent methods return a 200 OK status and that non-idempotent 42 | // methods return a 403 Forbidden status when a CSRF cookie is not present. 43 | func TestMethods(t *testing.T) { 44 | s := web.New() 45 | s.Use(Protect(testKey)) 46 | 47 | s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { 48 | })) 49 | 50 | // Test idempontent ("safe") methods 51 | for _, method := range safeMethods { 52 | r, err := http.NewRequest(method, "/", nil) 53 | if err != nil { 54 | t.Fatal(err) 55 | } 56 | 57 | rr := httptest.NewRecorder() 58 | s.ServeHTTP(rr, r) 59 | 60 | if rr.Code != http.StatusOK { 61 | t.Fatalf("middleware failed to pass to the next handler: got %v want %v", 62 | rr.Code, http.StatusOK) 63 | } 64 | 65 | if rr.Header().Get("Set-Cookie") == "" { 66 | t.Fatalf("cookie not set: got %q", rr.Header().Get("Set-Cookie")) 67 | } 68 | } 69 | 70 | // Test non-idempotent methods (should return a 403 without a cookie set) 71 | nonIdempotent := []string{"POST", "PUT", "DELETE", "PATCH"} 72 | for _, method := range nonIdempotent { 73 | r, err := http.NewRequest(method, "/", nil) 74 | if err != nil { 75 | t.Fatal(err) 76 | } 77 | 78 | rr := httptest.NewRecorder() 79 | s.ServeHTTP(rr, r) 80 | 81 | if rr.Code != http.StatusForbidden { 82 | t.Fatalf("middleware failed to pass to the next handler: got %v want %v", 83 | rr.Code, http.StatusOK) 84 | } 85 | 86 | if rr.Header().Get("Set-Cookie") == "" { 87 | t.Fatalf("cookie not set: got %q", rr.Header().Get("Set-Cookie")) 88 | } 89 | } 90 | 91 | } 92 | 93 | // Tests for failure if the cookie containing the session is removed from the 94 | // request. 95 | func TestNoCookie(t *testing.T) { 96 | 97 | } 98 | 99 | // TestBadCookie tests for failure when a cookie header is modified (malformed). 100 | func TestBadCookie(t *testing.T) { 101 | s := web.New() 102 | CSRF := Protect(testKey) 103 | s.Use(CSRF) 104 | 105 | var token string 106 | s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { 107 | token = Token(c, r) 108 | })) 109 | 110 | // Obtain a CSRF cookie via a GET request. 111 | r, err := http.NewRequest("GET", "http://www.gorillatoolkit.org/", nil) 112 | if err != nil { 113 | t.Fatal(err) 114 | } 115 | 116 | rr := httptest.NewRecorder() 117 | s.ServeHTTP(rr, r) 118 | 119 | // POST the token back in the header. 120 | r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil) 121 | if err != nil { 122 | t.Fatal(err) 123 | } 124 | 125 | // Replace the cookie prefix 126 | badHeader := strings.Replace("_csrfToken=", rr.Header().Get("Set-Cookie"), "_badCookie", -1) 127 | r.Header.Set("Cookie", badHeader) 128 | r.Header.Set("X-CSRF-Token", token) 129 | r.Header.Set("Referer", "http://www.gorillatoolkit.org/") 130 | 131 | rr = httptest.NewRecorder() 132 | s.ServeHTTP(rr, r) 133 | 134 | if rr.Code != http.StatusForbidden { 135 | t.Fatalf("middleware failed to reject a bad cookie: got %v want %v", 136 | rr.Code, http.StatusForbidden) 137 | } 138 | 139 | } 140 | 141 | // Responses should set a "Vary: Cookie" header to protect client/proxy caching. 142 | func TestVaryHeader(t *testing.T) { 143 | 144 | s := web.New() 145 | s.Use(Protect(testKey)) 146 | s.Get("/", testHandler) 147 | 148 | r, err := http.NewRequest("HEAD", "https://www.golang.org/", nil) 149 | if err != nil { 150 | t.Fatal(err) 151 | } 152 | 153 | rr := httptest.NewRecorder() 154 | s.ServeHTTP(rr, r) 155 | 156 | if rr.Code != http.StatusOK { 157 | t.Fatalf("middleware failed to pass to the next handler: got %v want %v", 158 | rr.Code, http.StatusOK) 159 | } 160 | 161 | if rr.Header().Get("Vary") != "Cookie" { 162 | t.Fatalf("vary header not set: got %q want %q", rr.Header().Get("Vary"), "Cookie") 163 | } 164 | } 165 | 166 | // Requests with no Referer header should fail. 167 | func TestNoReferer(t *testing.T) { 168 | 169 | s := web.New() 170 | s.Use(Protect(testKey)) 171 | s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) {})) 172 | 173 | r, err := http.NewRequest("POST", "https://golang.org/", nil) 174 | if err != nil { 175 | t.Fatal(err) 176 | } 177 | 178 | rr := httptest.NewRecorder() 179 | s.ServeHTTP(rr, r) 180 | 181 | if rr.Code != http.StatusForbidden { 182 | t.Fatalf("middleware failed to pass to the next handler: got %v want %v", 183 | rr.Code, http.StatusForbidden) 184 | } 185 | } 186 | 187 | // TestBadReferer checks that HTTPS requests with a Referer that does not 188 | // match the request URL correctly fail CSRF validation. 189 | func TestBadReferer(t *testing.T) { 190 | 191 | s := web.New() 192 | CSRF := Protect(testKey) 193 | s.Use(CSRF) 194 | 195 | var token string 196 | s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { 197 | token = Token(c, r) 198 | })) 199 | 200 | // Obtain a CSRF cookie via a GET request. 201 | r, err := http.NewRequest("GET", "https://www.gorillatoolkit.org/", nil) 202 | if err != nil { 203 | t.Fatal(err) 204 | } 205 | 206 | rr := httptest.NewRecorder() 207 | s.ServeHTTP(rr, r) 208 | 209 | // POST the token back in the header. 210 | r, err = http.NewRequest("POST", "https://www.gorillatoolkit.org/", nil) 211 | if err != nil { 212 | t.Fatal(err) 213 | } 214 | 215 | setCookie(rr, r) 216 | r.Header.Set("X-CSRF-Token", token) 217 | 218 | // Set a non-matching Referer header. 219 | r.Header.Set("Referer", "http://goji.io") 220 | 221 | rr = httptest.NewRecorder() 222 | s.ServeHTTP(rr, r) 223 | 224 | if rr.Code != http.StatusForbidden { 225 | t.Fatalf("middleware failed to pass to the next handler: got %v want %v", 226 | rr.Code, http.StatusForbidden) 227 | } 228 | } 229 | 230 | // Requests with a valid Referer should pass. 231 | func TestWithReferer(t *testing.T) { 232 | s := web.New() 233 | CSRF := Protect(testKey) 234 | s.Use(CSRF) 235 | 236 | var token string 237 | s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { 238 | token = Token(c, r) 239 | })) 240 | 241 | // Obtain a CSRF cookie via a GET request. 242 | r, err := http.NewRequest("GET", "http://www.gorillatoolkit.org/", nil) 243 | if err != nil { 244 | t.Fatal(err) 245 | } 246 | 247 | rr := httptest.NewRecorder() 248 | s.ServeHTTP(rr, r) 249 | 250 | // POST the token back in the header. 251 | r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil) 252 | if err != nil { 253 | t.Fatal(err) 254 | } 255 | 256 | setCookie(rr, r) 257 | r.Header.Set("X-CSRF-Token", token) 258 | r.Header.Set("Referer", "http://www.gorillatoolkit.org/") 259 | 260 | rr = httptest.NewRecorder() 261 | s.ServeHTTP(rr, r) 262 | 263 | if rr.Code != http.StatusOK { 264 | t.Fatalf("middleware failed to pass to the next handler: got %v want %v", 265 | rr.Code, http.StatusOK) 266 | } 267 | } 268 | 269 | // TestFormField tests that a token in the form field takes precedence over a 270 | // token in the HTTP header. 271 | func TestFormField(t *testing.T) { 272 | 273 | } 274 | 275 | func setCookie(rr *httptest.ResponseRecorder, r *http.Request) { 276 | r.Header.Set("Cookie", rr.Header().Get("Set-Cookie")) 277 | } 278 | -------------------------------------------------------------------------------- /helpers.go: -------------------------------------------------------------------------------- 1 | package csrf 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/subtle" 6 | "encoding/base64" 7 | "fmt" 8 | "html/template" 9 | "net/http" 10 | "net/url" 11 | 12 | "github.com/zenazn/goji/web" 13 | ) 14 | 15 | // Token returns a masked CSRF token ready for passing into HTML template or 16 | // a JSON response body. An empty token will be returned if the middleware 17 | // has not been applied (which will fail subsequent validation). 18 | func Token(c web.C, r *http.Request) string { 19 | if maskedToken, ok := c.Env[tokenKey].(string); ok { 20 | return maskedToken 21 | } 22 | 23 | return "" 24 | } 25 | 26 | // FailureReason makes CSRF validation errors available in Goji's request 27 | // context. 28 | // This is useful when you want to log the cause of the error or report it to 29 | // client. 30 | func FailureReason(c web.C, r *http.Request) error { 31 | if err, ok := c.Env[errorKey].(error); ok { 32 | return err 33 | } 34 | 35 | return nil 36 | } 37 | 38 | // TemplateField is a template helper for html/template that provides an field 39 | // populated with a CSRF token. 40 | // 41 | // Example: 42 | // 43 | // // The following tag in our form.tmpl template: 44 | // {{ .csrfField }} 45 | // 46 | // // ... becomes: 47 | // 48 | // 49 | func TemplateField(c web.C, r *http.Request) template.HTML { 50 | fragment := fmt.Sprintf(``, 51 | c.Env[formKey], Token(c, r)) 52 | 53 | return template.HTML(fragment) 54 | } 55 | 56 | // mask returns a unique-per-request token to mitigate the BREACH attack 57 | // as per http://breachattack.com/#mitigations 58 | // 59 | // The token is generated by XOR'ing a one-time-pad and the base (session) CSRF 60 | // token and returning them together as a 64-byte slice. This effectively 61 | // randomises the token on a per-request basis without breaking multiple browser 62 | // tabs/windows. 63 | func mask(realToken []byte, c *web.C, r *http.Request) string { 64 | otp, err := generateRandomBytes(tokenLength) 65 | if err != nil { 66 | return "" 67 | } 68 | 69 | // XOR the OTP with the real token to generate a masked token. Append the 70 | // OTP to the front of the masked token to allow unmasking in the subsequent 71 | // request. 72 | return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...)) 73 | } 74 | 75 | // unmask splits the issued token (one-time-pad + masked token) and returns the 76 | // unmasked request token for comparison. 77 | func unmask(issued []byte) []byte { 78 | // Issued tokens are always masked and combined with the pad. 79 | if len(issued) != tokenLength*2 { 80 | return nil 81 | } 82 | 83 | // We now know the length of the byte slice. 84 | otp := issued[tokenLength:] 85 | masked := issued[:tokenLength] 86 | 87 | // Unmask the token by XOR'ing it against the OTP used to mask it. 88 | return xorToken(otp, masked) 89 | } 90 | 91 | // requestToken returns the issued token (pad + masked token) from the HTTP POST 92 | // body or HTTP header. It will return nil if the token fails to decode. 93 | func (cs *csrf) requestToken(r *http.Request) []byte { 94 | // 1. Check the HTTP header first. 95 | issued := r.Header.Get(cs.opts.RequestHeader) 96 | 97 | // 2. Fall back to the POST (form) value. 98 | if issued == "" { 99 | issued = r.PostFormValue(cs.opts.FieldName) 100 | } 101 | 102 | // 3. Finally, fall back to the multipart form (if set). 103 | if issued == "" && r.MultipartForm != nil { 104 | vals := r.MultipartForm.Value[cs.opts.FieldName] 105 | 106 | if len(vals) > 0 { 107 | issued = vals[0] 108 | } 109 | } 110 | 111 | // Decode the "issued" (pad + masked) token sent in the request. Return a 112 | // nil byte slice on a decoding error (this will fail upstream). 113 | decoded, err := base64.StdEncoding.DecodeString(issued) 114 | if err != nil { 115 | return nil 116 | } 117 | 118 | return decoded 119 | } 120 | 121 | // generateRandomBytes returns securely generated random bytes. 122 | // It will return an error if the system's secure random number generator 123 | // fails to function correctly. 124 | func generateRandomBytes(n int) ([]byte, error) { 125 | b := make([]byte, n) 126 | _, err := rand.Read(b) 127 | // err == nil only if len(b) == n 128 | if err != nil { 129 | return nil, err 130 | } 131 | 132 | return b, nil 133 | 134 | } 135 | 136 | // sameOrigin returns true if URLs a and b share the same origin. The same 137 | // origin is defined as host (which includes the port) and scheme. 138 | func sameOrigin(a, b *url.URL) bool { 139 | return (a.Scheme == b.Scheme && a.Host == b.Host) 140 | } 141 | 142 | // compare securely (constant-time) compares the unmasked token from the request 143 | // against the real token from the session. 144 | func compareTokens(a, b []byte) bool { 145 | // This is required as subtle.ConstantTimeCompare does not check for equal 146 | // lengths in Go versions prior to 1.3. 147 | if len(a) != len(b) { 148 | return false 149 | } 150 | 151 | return subtle.ConstantTimeCompare(a, b) == 1 152 | } 153 | 154 | // xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It 155 | // will return a masked token if the base token is XOR'ed with a one-time-pad. 156 | // An unmasked token will be returned if a masked token is XOR'ed with the 157 | // one-time-pad used to mask it. 158 | func xorToken(a, b []byte) []byte { 159 | n := len(a) 160 | if len(b) < n { 161 | n = len(b) 162 | } 163 | 164 | res := make([]byte, n) 165 | 166 | for i := 0; i < n; i++ { 167 | res[i] = a[i] ^ b[i] 168 | } 169 | 170 | return res 171 | } 172 | 173 | // contains is a helper function to check if a string exists in a slice - e.g. 174 | // whether a HTTP method exists in a list of safe methods. 175 | func contains(vals []string, s string) bool { 176 | for _, v := range vals { 177 | if v == s { 178 | return true 179 | } 180 | } 181 | 182 | return false 183 | } 184 | 185 | // envError stores a CSRF error in the request context. 186 | func envError(c *web.C, err error) { 187 | c.Env[errorKey] = err 188 | } 189 | -------------------------------------------------------------------------------- /helpers_test.go: -------------------------------------------------------------------------------- 1 | package csrf 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "encoding/base64" 7 | "fmt" 8 | "io" 9 | "mime/multipart" 10 | "net/http" 11 | "net/http/httptest" 12 | "net/url" 13 | "strings" 14 | "testing" 15 | "text/template" 16 | 17 | "github.com/zenazn/goji/web" 18 | ) 19 | 20 | var testTemplate = ` 21 | 22 |
23 | 26 | 27 | 28 | ` 29 | var testFieldName = "custom_csrf_field_name" 30 | var testTemplateField = `` 31 | 32 | // Test that our form helpers correctly inject a token into the response body. 33 | func TestFormToken(t *testing.T) { 34 | s := web.New() 35 | s.Use(Protect(testKey)) 36 | 37 | // Make the token available outside of the handler for comparison. 38 | var token string 39 | s.Get("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { 40 | token = Token(c, r) 41 | t := template.Must((template.New("base").Parse(testTemplate))) 42 | t.Execute(w, map[string]interface{}{ 43 | TemplateTag: TemplateField(c, r), 44 | }) 45 | })) 46 | 47 | r, err := http.NewRequest("GET", "/", nil) 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | 52 | rr := httptest.NewRecorder() 53 | s.ServeHTTP(rr, r) 54 | 55 | if rr.Code != http.StatusOK { 56 | t.Fatalf("middleware failed to pass to the next handler: got %v want %v", 57 | rr.Code, http.StatusOK) 58 | } 59 | 60 | if len(token) != base64.StdEncoding.EncodedLen(tokenLength*2) { 61 | t.Fatalf("token length invalid: got %v want %v", len(token), base64.StdEncoding.EncodedLen(tokenLength*2)) 62 | } 63 | 64 | if !strings.Contains(rr.Body.String(), token) { 65 | t.Fatalf("token not in response body: got %v want %v", rr.Body.String(), token) 66 | } 67 | } 68 | 69 | // Test that we can extract a CSRF token from a multipart form. 70 | func TestMultipartFormToken(t *testing.T) { 71 | s := web.New() 72 | s.Use(Protect(testKey)) 73 | 74 | // Make the token available outside of the handler for comparison. 75 | var token string 76 | s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { 77 | token = Token(c, r) 78 | t := template.Must((template.New("base").Parse(testTemplate))) 79 | t.Execute(w, map[string]interface{}{ 80 | TemplateTag: TemplateField(c, r), 81 | }) 82 | })) 83 | 84 | r, err := http.NewRequest("GET", "/", nil) 85 | if err != nil { 86 | t.Fatal(err) 87 | } 88 | 89 | rr := httptest.NewRecorder() 90 | s.ServeHTTP(rr, r) 91 | 92 | // Set up our multipart form 93 | var b bytes.Buffer 94 | mp := multipart.NewWriter(&b) 95 | wr, err := mp.CreateFormField(fieldName) 96 | if err != nil { 97 | t.Fatal(err) 98 | } 99 | 100 | wr.Write([]byte(token)) 101 | mp.Close() 102 | 103 | r, err = http.NewRequest("POST", "/", &b) 104 | if err != nil { 105 | t.Fatal(err) 106 | } 107 | 108 | // Add the multipart header. 109 | r.Header.Set("Content-Type", mp.FormDataContentType()) 110 | 111 | // Send back the issued cookie. 112 | setCookie(rr, r) 113 | 114 | rr = httptest.NewRecorder() 115 | s.ServeHTTP(rr, r) 116 | 117 | if rr.Code != http.StatusOK { 118 | t.Fatalf("middleware failed to pass to the next handler: got %v want %v", 119 | rr.Code, http.StatusOK) 120 | } 121 | 122 | if body := rr.Body.String(); !strings.Contains(body, token) { 123 | t.Fatalf("token not in response body: got %v want %v", body, token) 124 | } 125 | } 126 | 127 | // TestMaskUnmaskTokens tests that a token traversing the mask -> unmask process 128 | // is correctly unmasked to the original 'real' token. 129 | func TestMaskUnmaskTokens(t *testing.T) { 130 | t.Parallel() 131 | 132 | realToken, err := generateRandomBytes(tokenLength) 133 | if err != nil { 134 | t.Fatal(err) 135 | } 136 | 137 | issued := mask(realToken, nil, nil) 138 | decoded, err := base64.StdEncoding.DecodeString(issued) 139 | if err != nil { 140 | t.Fatal(err) 141 | } 142 | 143 | unmasked := unmask(decoded) 144 | if !compareTokens(unmasked, realToken) { 145 | t.Fatalf("tokens do not match: got %x want %x", unmasked, realToken) 146 | } 147 | } 148 | 149 | // Tests domains that should (or should not) return true for a 150 | // same-origin check. 151 | func TestSameOrigin(t *testing.T) { 152 | var originTests = []struct { 153 | o1 string 154 | o2 string 155 | expected bool 156 | }{ 157 | {"https://goji.io/", "https://goji.io", true}, 158 | {"http://golang.org/", "http://golang.org/pkg/net/http", true}, 159 | {"https://goji.io/", "http://goji.io", false}, 160 | {"https://goji.io:3333/", "http://goji.io:4444", false}, 161 | } 162 | 163 | for _, origins := range originTests { 164 | a, err := url.Parse(origins.o1) 165 | if err != nil { 166 | t.Fatal(err) 167 | } 168 | 169 | b, err := url.Parse(origins.o2) 170 | if err != nil { 171 | t.Fatal(err) 172 | } 173 | 174 | if sameOrigin(a, b) != origins.expected { 175 | t.Fatalf("origin checking failed: %v and %v, expected %v", 176 | origins.o1, origins.o2, origins.expected) 177 | } 178 | } 179 | } 180 | 181 | func TestXOR(t *testing.T) { 182 | testTokens := []struct { 183 | a []byte 184 | b []byte 185 | expected []byte 186 | }{ 187 | {[]byte("goodbye"), []byte("hello"), []byte{15, 10, 3, 8, 13}}, 188 | {[]byte("gophers"), []byte("clojure"), []byte{4, 3, 31, 2, 16, 0, 22}}, 189 | {nil, []byte("requestToken"), nil}, 190 | } 191 | 192 | for _, token := range testTokens { 193 | if res := xorToken(token.a, token.b); res != nil { 194 | if bytes.Compare(res, token.expected) != 0 { 195 | t.Fatalf("xorBytes failed to return the expected result: got %v want %v", 196 | res, token.expected) 197 | } 198 | } 199 | } 200 | 201 | } 202 | 203 | // shortReader provides a broken implementation of io.Reader for testing. 204 | type shortReader struct{} 205 | 206 | func (sr shortReader) Read(p []byte) (int, error) { 207 | return len(p) % 2, io.ErrUnexpectedEOF 208 | } 209 | 210 | // TestGenerateRandomBytes tests the (extremely rare) case that crypto/rand does 211 | // not return the expected number of bytes. 212 | func TestGenerateRandomBytes(t *testing.T) { 213 | // Pioneered from https://github.com/justinas/nosurf 214 | original := rand.Reader 215 | rand.Reader = shortReader{} 216 | defer func() { 217 | rand.Reader = original 218 | }() 219 | 220 | b, err := generateRandomBytes(tokenLength) 221 | if err == nil { 222 | t.Fatalf("generateRandomBytes did not report a short read: only read %d bytes", len(b)) 223 | } 224 | } 225 | 226 | func TestTemplateField(t *testing.T) { 227 | s := web.New() 228 | CSRF := Protect( 229 | testKey, 230 | FieldName(testFieldName), 231 | ) 232 | s.Use(CSRF) 233 | 234 | var token string 235 | var customTemplateField string 236 | s.Handle("/", web.HandlerFunc(func(c web.C, w http.ResponseWriter, r *http.Request) { 237 | token = Token(c, r) 238 | customTemplateField = string(TemplateField(c, r)) 239 | })) 240 | 241 | r, err := http.NewRequest("GET", "/", nil) 242 | if err != nil { 243 | t.Fatal(err) 244 | } 245 | 246 | rr := httptest.NewRecorder() 247 | s.ServeHTTP(rr, r) 248 | 249 | expectedTemplateField := fmt.Sprintf(testTemplateField, testFieldName, token) 250 | 251 | if customTemplateField != expectedTemplateField { 252 | t.Fatalf("templateField not set correctly: got %v want %v", 253 | customTemplateField, expectedTemplateField) 254 | } 255 | } 256 | 257 | func TestCompareTokens(t *testing.T) { 258 | // Go's subtle.ConstantTimeCompare prior to 1.3 did not check for matching 259 | // lengths. 260 | a := []byte("") 261 | b := []byte("an-actual-token") 262 | 263 | if v := compareTokens(a, b); v == true { 264 | t.Fatalf("compareTokens failed on different tokens: got %v want %v", v, !v) 265 | } 266 | } 267 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package csrf 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/zenazn/goji/web" 7 | ) 8 | 9 | // Option describes a functional option for configuring the CSRF handler. 10 | type Option func(*csrf) error 11 | 12 | // MaxAge sets the maximum age (in seconds) of a CSRF token's underlying cookie. 13 | // Defaults to 12 hours. 14 | func MaxAge(age int) Option { 15 | return func(cs *csrf) error { 16 | cs.opts.MaxAge = age 17 | return nil 18 | } 19 | } 20 | 21 | // Domain sets the cookie domain. Defaults to the current domain of the request 22 | // only (recommended). 23 | // 24 | // This should be a hostname and not a URL. If set, the domain is treated as 25 | // being prefixed with a '.' - e.g. "example.com" becomes ".example.com" and 26 | // matches "www.example.com" and "secure.example.com". 27 | func Domain(domain string) Option { 28 | return func(cs *csrf) error { 29 | cs.opts.Domain = domain 30 | return nil 31 | } 32 | } 33 | 34 | // Path sets the cookie path. Defaults to the path the cookie was issued from 35 | // (recommended). 36 | // 37 | // This instructs clients to only respond with cookie for that path and its 38 | // subpaths - i.e. a cookie issued from "/register" would be included in requests 39 | // to "/register/step2" and "/register/submit". 40 | func Path(p string) Option { 41 | return func(cs *csrf) error { 42 | cs.opts.Path = p 43 | return nil 44 | } 45 | } 46 | 47 | // Secure sets the 'Secure' flag on the cookie. Defaults to true (recommended). 48 | func Secure(s bool) Option { 49 | return func(cs *csrf) error { 50 | cs.opts.Secure = s 51 | return nil 52 | } 53 | } 54 | 55 | // HttpOnly sets the 'HttpOnly' flag on the cookie. Defaults to true (recommended). 56 | func HttpOnly(h bool) Option { 57 | return func(cs *csrf) error { 58 | // Note that the function and field names match the case of the 59 | // related http.Cookie field instead of the "correct" HTTPOnly name 60 | // that golint suggests. 61 | cs.opts.HttpOnly = h 62 | return nil 63 | } 64 | } 65 | 66 | // ErrorHandler allows you to change the handler called when CSRF request 67 | // processing encounters an invalid token or request. A typical use would be to 68 | // provide a handler that returns a static HTML file with a HTTP 403 status. By 69 | // default a HTTP 403 status and a plain text CSRF failure reason are served. 70 | // 71 | // Note that a custom error handler can also access the csrf.Failure(c, r) 72 | // function to retrieve the CSRF validation reason from Goji's request context. 73 | func ErrorHandler(h web.Handler) Option { 74 | return func(cs *csrf) error { 75 | cs.opts.ErrorHandler = h 76 | return nil 77 | } 78 | } 79 | 80 | // RequestHeader allows you to change the request header the CSRF middleware 81 | // inspects. The default is X-CSRF-Token. 82 | func RequestHeader(header string) Option { 83 | return func(cs *csrf) error { 84 | cs.opts.RequestHeader = header 85 | return nil 86 | } 87 | } 88 | 89 | // FieldName allows you to change the name value of the hidden field 90 | // generated by csrf.TemplateField. The default is {{ .csrfToken }} 91 | func FieldName(name string) Option { 92 | return func(cs *csrf) error { 93 | cs.opts.FieldName = name 94 | return nil 95 | } 96 | } 97 | 98 | // CookieName changes the name of the CSRF cookie issued to clients. 99 | // 100 | // Note that cookie names should not contain whitespace, commas, semicolons, 101 | // backslashes or control characters as per RFC6265. 102 | func CookieName(name string) Option { 103 | return func(cs *csrf) error { 104 | cs.opts.CookieName = name 105 | return nil 106 | } 107 | } 108 | 109 | // setStore sets the store used by the CSRF middleware. 110 | // Note: this is private (for now) to allow for internal API changes. 111 | func setStore(s store) Option { 112 | return func(cs *csrf) error { 113 | cs.st = s 114 | return nil 115 | } 116 | } 117 | 118 | // parseOptions parses the supplied options functions and returns a configured 119 | // csrf handler. 120 | func parseOptions(h http.Handler, opts ...Option) *csrf { 121 | // Set the handler to call after processing. 122 | cs := &csrf{ 123 | h: h, 124 | } 125 | 126 | // Default to true. See Secure & HttpOnly function comments for rationale. 127 | // Set here to allow package users to override the default. 128 | cs.opts.Secure = true 129 | cs.opts.HttpOnly = true 130 | 131 | // Range over each options function and apply it 132 | // to our csrf type to configure it. Options functions are 133 | // applied in order, with any conflicting options overriding 134 | // earlier calls. 135 | for _, option := range opts { 136 | option(cs) 137 | } 138 | 139 | return cs 140 | } 141 | -------------------------------------------------------------------------------- /options_test.go: -------------------------------------------------------------------------------- 1 | package csrf 2 | 3 | import ( 4 | "net/http" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/zenazn/goji/web" 9 | ) 10 | 11 | // Tests that options functions are applied to the middleware. 12 | func TestOptions(t *testing.T) { 13 | var h http.Handler 14 | 15 | age := 86400 16 | domain := "goji.io" 17 | path := "/forms/" 18 | header := "X-AUTH-TOKEN" 19 | field := "authenticity_token" 20 | errorHandler := unauthorizedHandler 21 | name := "_goji_goji_goji" 22 | 23 | testOpts := []Option{ 24 | MaxAge(age), 25 | Domain(domain), 26 | Path(path), 27 | HttpOnly(false), 28 | Secure(false), 29 | RequestHeader(header), 30 | FieldName(field), 31 | ErrorHandler(web.HandlerFunc(errorHandler)), 32 | CookieName(name), 33 | } 34 | 35 | // Parse our test options and check that they set the related struct fields. 36 | cs := parseOptions(h, testOpts...) 37 | 38 | if cs.opts.MaxAge != age { 39 | t.Errorf("MaxAge not set correctly: got %v want %v", cs.opts.MaxAge, age) 40 | } 41 | 42 | if cs.opts.Domain != domain { 43 | t.Errorf("Domain not set correctly: got %v want %v", cs.opts.Domain, domain) 44 | } 45 | 46 | if cs.opts.Path != path { 47 | t.Errorf("Path not set correctly: got %v want %v", cs.opts.Path, path) 48 | } 49 | 50 | if cs.opts.HttpOnly != false { 51 | t.Errorf("HttpOnly not set correctly: got %v want %v", cs.opts.HttpOnly, false) 52 | } 53 | 54 | if cs.opts.Secure != false { 55 | t.Errorf("Secure not set correctly: got %v want %v", cs.opts.Secure, false) 56 | } 57 | 58 | if cs.opts.RequestHeader != header { 59 | t.Errorf("RequestHeader not set correctly: got %v want %v", cs.opts.RequestHeader, header) 60 | } 61 | 62 | if cs.opts.FieldName != field { 63 | t.Errorf("FieldName not set correctly: got %v want %v", cs.opts.FieldName, field) 64 | } 65 | 66 | if !reflect.ValueOf(cs.opts.ErrorHandler).IsValid() { 67 | t.Errorf("ErrorHandler not set correctly: got %v want %v", 68 | reflect.ValueOf(cs.opts.ErrorHandler).IsValid(), reflect.ValueOf(errorHandler).IsValid()) 69 | } 70 | 71 | if cs.opts.CookieName != name { 72 | t.Errorf("CookieName not set correctly: got %v want %v", 73 | cs.opts.CookieName, name) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /store.go: -------------------------------------------------------------------------------- 1 | package csrf 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | 7 | "github.com/gorilla/securecookie" 8 | "github.com/zenazn/goji/web" 9 | ) 10 | 11 | // store represents the session storage used for CSRF tokens. 12 | type store interface { 13 | // Get returns the real CSRF token from the store. 14 | Get(c *web.C, r *http.Request) ([]byte, error) 15 | // Save stores the real CSRF token in the store and writes a 16 | // cookie to the http.ResponseWriter. 17 | // For non-cookie stores, the cookie should contain a unique (256 bit) ID 18 | // or key that references the token in the backend store. 19 | // csrf.GenerateRandomBytes is a helper function for generating secure IDs. 20 | Save(token []byte, w http.ResponseWriter) error 21 | } 22 | 23 | // cookieStore is a signed cookie session store for CSRF tokens. 24 | type cookieStore struct { 25 | name string 26 | maxAge int 27 | secure bool 28 | httpOnly bool 29 | path string 30 | domain string 31 | sc *securecookie.SecureCookie 32 | } 33 | 34 | // Get retrieves a CSRF token from the session cookie. It returns an empty token 35 | // if decoding fails (e.g. HMAC validation fails or the named cookie doesn't exist). 36 | func (cs *cookieStore) Get(c *web.C, r *http.Request) ([]byte, error) { 37 | // Retrieve the cookie from the request 38 | cookie, err := r.Cookie(cs.name) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | token := make([]byte, tokenLength) 44 | // Decode the HMAC authenticated cookie. 45 | err = cs.sc.Decode(cs.name, cookie.Value, &token) 46 | if err != nil { 47 | return nil, err 48 | } 49 | 50 | return token, nil 51 | } 52 | 53 | // Save stores the CSRF token in the session cookie. 54 | func (cs *cookieStore) Save(token []byte, w http.ResponseWriter) error { 55 | // Generate an encoded cookie value with the CSRF token. 56 | encoded, err := cs.sc.Encode(cs.name, token) 57 | if err != nil { 58 | return err 59 | } 60 | 61 | cookie := &http.Cookie{ 62 | Name: cs.name, 63 | Value: encoded, 64 | MaxAge: cs.maxAge, 65 | HttpOnly: cs.httpOnly, 66 | Secure: cs.secure, 67 | Path: cs.path, 68 | Domain: cs.domain, 69 | } 70 | 71 | // Set the Expires field on the cookie based on the MaxAge 72 | if cs.maxAge > 0 { 73 | cookie.Expires = time.Now().Add( 74 | time.Duration(cs.maxAge) * time.Second) 75 | } else { 76 | cookie.Expires = time.Unix(1, 0) 77 | } 78 | 79 | // Write the authenticated cookie to the response. 80 | http.SetCookie(w, cookie) 81 | 82 | return nil 83 | } 84 | -------------------------------------------------------------------------------- /store_test.go: -------------------------------------------------------------------------------- 1 | package csrf 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/gorilla/securecookie" 11 | "github.com/zenazn/goji/web" 12 | ) 13 | 14 | // Check Store implementations 15 | var _ store = &cookieStore{} 16 | 17 | // brokenSaveStore is a CSRF store that cannot, well, save. 18 | type brokenSaveStore struct { 19 | store 20 | } 21 | 22 | func (bs *brokenSaveStore) Get(*web.C, *http.Request) ([]byte, error) { 23 | // Generate an invalid token so we can progress to our Save method 24 | return generateRandomBytes(24) 25 | } 26 | 27 | func (bs *brokenSaveStore) Save(realToken []byte, w http.ResponseWriter) error { 28 | return errors.New("test error") 29 | } 30 | 31 | // Tests for failure if the middleware can't save to the Store. 32 | func TestStoreCannotSave(t *testing.T) { 33 | s := web.New() 34 | bs := &brokenSaveStore{} 35 | s.Use(Protect(testKey, setStore(bs))) 36 | s.Get("/", testHandler) 37 | 38 | r, err := http.NewRequest("GET", "/", nil) 39 | if err != nil { 40 | t.Fatal(err) 41 | } 42 | 43 | rr := httptest.NewRecorder() 44 | s.ServeHTTP(rr, r) 45 | 46 | if rr.Code != http.StatusForbidden { 47 | t.Fatalf("broken store did not set an error status: got %v want %v", 48 | rr.Code, http.StatusForbidden) 49 | } 50 | 51 | if c := rr.Header().Get("Set-Cookie"); c != "" { 52 | t.Fatalf("broken store incorrectly set a cookie: got %v want %v", 53 | c, "") 54 | } 55 | 56 | } 57 | 58 | // TestCookieDecode tests that an invalid cookie store returns a decoding error. 59 | func TestCookieDecode(t *testing.T) { 60 | r, err := http.NewRequest("GET", "/", nil) 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | var age = 3600 66 | 67 | // Test with a nil hash key 68 | sc := securecookie.New(nil, nil) 69 | sc.MaxAge(age) 70 | st := &cookieStore{cookieName, age, true, true, "", "", sc} 71 | 72 | // Set a fake cookie value so r.Cookie passes. 73 | r.Header.Set("Cookie", fmt.Sprintf("%s=%s", cookieName, "notacookie")) 74 | 75 | _, err = st.Get(&web.C{}, r) 76 | if err == nil { 77 | t.Fatal("cookiestore did not report an invalid hashkey on decode") 78 | } 79 | } 80 | 81 | // TestCookieEncode tests that an invalid cookie store returns an encoding error. 82 | func TestCookieEncode(t *testing.T) { 83 | var age = 3600 84 | 85 | // Test with a nil hash key 86 | sc := securecookie.New(nil, nil) 87 | sc.MaxAge(age) 88 | st := &cookieStore{cookieName, age, true, true, "", "", sc} 89 | 90 | rr := httptest.NewRecorder() 91 | 92 | err := st.Save(nil, rr) 93 | if err == nil { 94 | t.Fatal("cookiestore did not report an invalid hashkey on encode") 95 | } 96 | } 97 | --------------------------------------------------------------------------------