├── handler_legacy.go ├── .travis.yml ├── handler_go17.go ├── .gitignore ├── handler_legacy_test.go ├── utils.go ├── handler_go17_test.go ├── LICENSE ├── crypto.go ├── utils_test.go ├── context.go ├── testutils_test.go ├── token_test.go ├── crypto_test.go ├── context_legacy.go ├── token.go ├── context_legacy_test.go ├── exempt.go ├── README.md ├── exempt_test.go ├── handler.go └── handler_test.go /handler_legacy.go: -------------------------------------------------------------------------------- 1 | // +build !go1.7 2 | 3 | package nosurf 4 | 5 | import "net/http" 6 | 7 | func addNosurfContext(r *http.Request) *http.Request { 8 | return r 9 | } 10 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | install: 4 | - go get . 5 | 6 | script: 7 | - go test -v . 8 | 9 | go: 10 | - 1.1 11 | - 1.2 12 | - 1.3 13 | - 1.4 14 | - tip 15 | -------------------------------------------------------------------------------- /handler_go17.go: -------------------------------------------------------------------------------- 1 | // +build go1.7 2 | 3 | package nosurf 4 | 5 | import ( 6 | "context" 7 | "net/http" 8 | ) 9 | 10 | func addNosurfContext(r *http.Request) *http.Request { 11 | return r.WithContext(context.WithValue(r.Context(), nosurfKey, &csrfContext{})) 12 | } 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | 24 | # Vim stuff 25 | *.s[a-w][a-z] 26 | *.un~ 27 | Session.vim 28 | .netrwhist 29 | *~ 30 | -------------------------------------------------------------------------------- /handler_legacy_test.go: -------------------------------------------------------------------------------- 1 | // +build !go1.7 2 | 3 | package nosurf 4 | 5 | import ( 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | ) 10 | 11 | func TestClearsContextAfterTheRequest(t *testing.T) { 12 | hand := New(http.HandlerFunc(succHand)) 13 | writer := httptest.NewRecorder() 14 | req := dummyGet() 15 | 16 | hand.ServeHTTP(writer, req) 17 | 18 | if contextMap[req] != nil { 19 | t.Errorf("The context entry should have been cleared after the request.") 20 | t.Errorf("Instead, the context entry remains: %v", contextMap[req]) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package nosurf 2 | 3 | import ( 4 | "net/url" 5 | ) 6 | 7 | func sContains(slice []string, s string) bool { 8 | // checks if the given slice contains the given string 9 | for _, v := range slice { 10 | if v == s { 11 | return true 12 | } 13 | } 14 | return false 15 | } 16 | 17 | // Checks if the given URLs have the same origin 18 | // (that is, they share the host, the port and the scheme) 19 | func sameOrigin(u1, u2 *url.URL) bool { 20 | // we take pointers, as url.Parse() returns a pointer 21 | // and http.Request.URL is a pointer as well 22 | 23 | // Host is either host or host:port 24 | return (u1.Scheme == u2.Scheme && u1.Host == u2.Host) 25 | } 26 | -------------------------------------------------------------------------------- /handler_go17_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.7 2 | 3 | package nosurf 4 | 5 | import ( 6 | "context" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | ) 11 | 12 | // Confusing test name. Tests that nosurf's context is accessible 13 | // when a request with golang's context is passed into Token(). 14 | func TestContextIsAccessibleWithContext(t *testing.T) { 15 | succHand := func(w http.ResponseWriter, r *http.Request) { 16 | r = r.WithContext(context.WithValue(r.Context(), "dummykey", "dummyval")) 17 | token := Token(r) 18 | if token == "" { 19 | t.Errorf("Token is inaccessible in the success handler") 20 | } 21 | } 22 | 23 | hand := New(http.HandlerFunc(succHand)) 24 | 25 | // we need a request that passes. Let's just use a safe method for that. 26 | req := dummyGet() 27 | writer := httptest.NewRecorder() 28 | 29 | hand.ServeHTTP(writer, req) 30 | } 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2013 Justinas Stankevicius 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /crypto.go: -------------------------------------------------------------------------------- 1 | package nosurf 2 | 3 | import ( 4 | "crypto/rand" 5 | "io" 6 | ) 7 | 8 | // Masks/unmasks the given data *in place* 9 | // with the given key 10 | // Slices must be of the same length, or oneTimePad will panic 11 | func oneTimePad(data, key []byte) { 12 | n := len(data) 13 | if n != len(key) { 14 | panic("Lengths of slices are not equal") 15 | } 16 | 17 | for i := 0; i < n; i++ { 18 | data[i] ^= key[i] 19 | } 20 | } 21 | 22 | func maskToken(data []byte) []byte { 23 | if len(data) != tokenLength { 24 | return nil 25 | } 26 | 27 | // tokenLength*2 == len(enckey + token) 28 | result := make([]byte, 2*tokenLength) 29 | // the first half of the result is the OTP 30 | // the second half is the masked token itself 31 | key := result[:tokenLength] 32 | token := result[tokenLength:] 33 | copy(token, data) 34 | 35 | // generate the random token 36 | if _, err := io.ReadFull(rand.Reader, key); err != nil { 37 | panic(err) 38 | } 39 | 40 | oneTimePad(token, key) 41 | return result 42 | } 43 | 44 | func unmaskToken(data []byte) []byte { 45 | if len(data) != tokenLength*2 { 46 | return nil 47 | } 48 | 49 | key := data[:tokenLength] 50 | token := data[tokenLength:] 51 | oneTimePad(token, key) 52 | 53 | return token 54 | } 55 | -------------------------------------------------------------------------------- /utils_test.go: -------------------------------------------------------------------------------- 1 | package nosurf 2 | 3 | import ( 4 | "net/url" 5 | "testing" 6 | ) 7 | 8 | func TestsContains(t *testing.T) { 9 | slice := []string{"abc", "def", "ghi"} 10 | 11 | s1 := "abc" 12 | if !sContains(slice, s1) { 13 | t.Errorf("sContains said that %v doesn't contain %v, but it does.", slice, s1) 14 | } 15 | 16 | s2 := "xyz" 17 | if !sContains(slice, s2) { 18 | t.Errorf("sContains said that %v contains %v, but it doesn't.", slice, s2) 19 | } 20 | } 21 | 22 | func TestsameOrigin(t *testing.T) { 23 | // a little helper that saves us time 24 | p := func(rawurl string) *url.URL { 25 | u, err := url.Parse(rawurl) 26 | if err != nil { 27 | t.Fatal(err) 28 | } 29 | return u 30 | } 31 | 32 | truthy := [][]*url.URL{ 33 | {p("http://dummy.us/"), p("http://dummy.us/faq")}, 34 | {p("https://dummy.us/some/page"), p("https://dummy.us/faq")}, 35 | } 36 | 37 | falsy := [][]*url.URL{ 38 | // different ports 39 | {p("http://dummy.us/"), p("http://dummy.us:8080")}, 40 | // different scheme 41 | {p("https://dummy.us/"), p("http://dummy.us/")}, 42 | // different host 43 | {p("https://dummy.us/"), p("http://dummybook.us/")}, 44 | // slightly different host 45 | {p("https://beta.dummy.us/"), p("http://dummy.us/")}, 46 | } 47 | 48 | for _, v := range truthy { 49 | if !sameOrigin(v[0], v[1]) { 50 | t.Errorf("%v and %v have the same origin, but sameOrigin() said otherwise.", 51 | v[0], v[1]) 52 | } 53 | } 54 | 55 | for _, v := range falsy { 56 | if sameOrigin(v[0], v[1]) { 57 | t.Errorf("%v and %v don't have the same origin, but sameOrigin() said otherwise.", 58 | v[0], v[1]) 59 | } 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | // +build go1.7 2 | 3 | package nosurf 4 | 5 | import "net/http" 6 | 7 | type ctxKey int 8 | 9 | const ( 10 | nosurfKey ctxKey = iota 11 | ) 12 | 13 | type csrfContext struct { 14 | // The masked, base64 encoded token 15 | // That's suitable for use in form fields, etc. 16 | token string 17 | // reason for the failure of CSRF check 18 | reason error 19 | } 20 | 21 | // Token takes an HTTP request and returns 22 | // the CSRF token for that request 23 | // or an empty string if the token does not exist. 24 | // 25 | // Note that the token won't be available after 26 | // CSRFHandler finishes 27 | // (that is, in another handler that wraps it, 28 | // or after the request has been served) 29 | func Token(req *http.Request) string { 30 | ctx := req.Context().Value(nosurfKey).(*csrfContext) 31 | 32 | return ctx.token 33 | } 34 | 35 | // Reason takes an HTTP request and returns 36 | // the reason of failure of the CSRF check for that request 37 | // 38 | // Note that the same availability restrictions apply for Reason() as for Token(). 39 | func Reason(req *http.Request) error { 40 | ctx := req.Context().Value(nosurfKey).(*csrfContext) 41 | 42 | return ctx.reason 43 | } 44 | 45 | func ctxClear(_ *http.Request) { 46 | } 47 | 48 | func ctxSetToken(req *http.Request, token []byte) { 49 | ctx := req.Context().Value(nosurfKey).(*csrfContext) 50 | ctx.token = b64encode(maskToken(token)) 51 | } 52 | 53 | func ctxSetReason(req *http.Request, reason error) { 54 | ctx := req.Context().Value(nosurfKey).(*csrfContext) 55 | if ctx.token == "" { 56 | panic("Reason should never be set when there's no token in the context yet.") 57 | } 58 | 59 | ctx.reason = reason 60 | } 61 | -------------------------------------------------------------------------------- /testutils_test.go: -------------------------------------------------------------------------------- 1 | package nosurf 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/url" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | // A reader that always fails on Read() 12 | // Suitable for testing the case of crypto/rand unavailability 13 | type failReader struct{} 14 | 15 | func (f failReader) Read(p []byte) (n int, err error) { 16 | err = errors.New("dummy error") 17 | return 18 | } 19 | 20 | func dummyGet() *http.Request { 21 | req, err := http.NewRequest("GET", "http://dum.my/", nil) 22 | if err != nil { 23 | panic(err) 24 | } 25 | return req 26 | } 27 | 28 | func succHand(w http.ResponseWriter, r *http.Request) { 29 | w.WriteHeader(200) 30 | w.Write([]byte("success")) 31 | } 32 | 33 | // Returns a HandlerFunc 34 | // that tests for the correct failure reason 35 | func correctReason(t *testing.T, reason error) http.Handler { 36 | fn := func(w http.ResponseWriter, r *http.Request) { 37 | got := Reason(r) 38 | if got != reason { 39 | t.Errorf("CSRF check should have failed with the reason %#v,"+ 40 | " but it failed with the reason %#v", reason, got) 41 | } 42 | // Writes the default failure code 43 | http.Error(w, "", FailureCode) 44 | } 45 | 46 | return http.HandlerFunc(fn) 47 | } 48 | 49 | // Gets a cookie with the specified name from the Response 50 | // Returns nil on not finding a suitable cookie 51 | func getRespCookie(resp *http.Response, name string) *http.Cookie { 52 | for _, c := range resp.Cookies() { 53 | if c.Name == name { 54 | return c 55 | } 56 | } 57 | return nil 58 | } 59 | 60 | // Encodes a slice of key-value pairs to a form value string 61 | func formBody(pairs [][]string) string { 62 | vals := url.Values{} 63 | for _, pair := range pairs { 64 | vals.Add(pair[0], pair[1]) 65 | } 66 | 67 | return vals.Encode() 68 | } 69 | 70 | // The same as formBody(), but wraps the string in a Reader 71 | func formBodyR(pairs [][]string) *strings.Reader { 72 | return strings.NewReader(formBody(pairs)) 73 | } 74 | -------------------------------------------------------------------------------- /token_test.go: -------------------------------------------------------------------------------- 1 | package nosurf 2 | 3 | import ( 4 | "crypto/rand" 5 | "testing" 6 | ) 7 | 8 | func TestChecksForPRNG(t *testing.T) { 9 | // Monkeypatch crypto/rand with an always-failing reader 10 | oldReader := rand.Reader 11 | rand.Reader = failReader{} 12 | // Restore it later for other tests 13 | defer func() { 14 | rand.Reader = oldReader 15 | }() 16 | 17 | defer func() { 18 | r := recover() 19 | if r == nil { 20 | t.Errorf("Expected checkForPRNG() to panic") 21 | } 22 | }() 23 | 24 | checkForPRNG() 25 | } 26 | 27 | func TestGeneratesAValidToken(t *testing.T) { 28 | // We can't test much with any certainity here, 29 | // since we generate tokens randomly 30 | // Basically we check that the length of the 31 | // token is what it should be 32 | 33 | token := generateToken() 34 | l := len(token) 35 | 36 | if l != tokenLength { 37 | t.Errorf("Bad decoded token length: expected %d, got %d", tokenLength, l) 38 | } 39 | } 40 | 41 | func TestVerifyTokenChecksLengthCorrectly(t *testing.T) { 42 | for i := 0; i < 64; i++ { 43 | slice := make([]byte, i) 44 | result := verifyToken(slice, slice) 45 | if result != false { 46 | t.Errorf("VerifyToken should've returned false with slices of length %d", i) 47 | } 48 | } 49 | 50 | slice := make([]byte, 64) 51 | result := verifyToken(slice[:32], slice) 52 | if result != true { 53 | t.Errorf("VerifyToken should've returned true on a zeroed slice of length 64") 54 | } 55 | } 56 | 57 | func TestVerifiesMaskedTokenCorrectly(t *testing.T) { 58 | realToken := []byte("qwertyuiopasdfghjklzxcvbnm123456") 59 | sentToken := []byte("qwertyuiopasdfghjklzxcvbnm123456" + 60 | "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + 61 | "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") 62 | 63 | if !verifyToken(realToken, sentToken) { 64 | t.Errorf("VerifyToken returned a false negative") 65 | } 66 | 67 | realToken[0] = 'x' 68 | 69 | if verifyToken(realToken, sentToken) { 70 | t.Errorf("VerifyToken returned a false positive") 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /crypto_test.go: -------------------------------------------------------------------------------- 1 | package nosurf 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestOtpPanicsOnLengthMismatch(t *testing.T) { 9 | data := make([]byte, 1) 10 | key := make([]byte, 2) 11 | 12 | defer func() { 13 | if r := recover(); r == nil { 14 | t.Error("One time pad should've panicked on receiving slices" + 15 | "of different length, but it didn't") 16 | } 17 | }() 18 | oneTimePad(data, key) 19 | } 20 | func TestOtpMasksCorrectly(t *testing.T) { 21 | data := []byte("Inventors of the shish-kebab") 22 | key := []byte("They stop Cthulhu eating ye.") 23 | // precalculated 24 | expected := []byte("\x1d\x06\x13\x1cN\x07\x1b\x1d\x03\x00,\x12H\x01\x04" + 25 | "\rUS\r\x08\x07\x01C\x0cE\x1b\x04L") 26 | 27 | oneTimePad(data, key) 28 | 29 | if !bytes.Equal(data, expected) { 30 | t.Errorf("oneTimePad masked the data incorrectly: expected %#v, got %#v", 31 | expected, data) 32 | } 33 | } 34 | 35 | func TestOtpUnmasksCorrectly(t *testing.T) { 36 | orig := []byte("a very secret message") 37 | data := make([]byte, len(orig)) 38 | copy(data, orig) 39 | if !bytes.Equal(orig, data) { 40 | t.Fatal("copy failed") 41 | } 42 | 43 | key := []byte("even more secret key!") 44 | 45 | oneTimePad(data, key) 46 | oneTimePad(data, key) 47 | 48 | if !bytes.Equal(orig, data) { 49 | t.Errorf("2x oneTimePad didn't return the original data:"+ 50 | " expected %#v, got %#v", orig, data) 51 | } 52 | } 53 | 54 | func TestMasksTokenCorrectly(t *testing.T) { 55 | // needs to be of tokenLength 56 | token := []byte("12345678901234567890123456789012") 57 | fullToken := maskToken(token) 58 | 59 | if len(fullToken) != 2*tokenLength { 60 | t.Errorf("len(fullToken) is not %d, but %d", 2*tokenLength, len(fullToken)) 61 | } 62 | 63 | key := fullToken[:tokenLength] 64 | encToken := fullToken[tokenLength:] 65 | 66 | // perform unmasking 67 | oneTimePad(encToken, key) 68 | 69 | if !bytes.Equal(encToken, token) { 70 | t.Errorf("Unmasked token is invalid: expected %v, got %v", token, encToken) 71 | } 72 | } 73 | 74 | func TestUnmasksTokenCorrectly(t *testing.T) { 75 | token := []byte("12345678901234567890123456789012") 76 | fullToken := maskToken(token) 77 | 78 | decToken := unmaskToken(fullToken) 79 | 80 | if !bytes.Equal(decToken, token) { 81 | t.Errorf("Unmasked token is invalid: expected %v, got %v", token, decToken) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /context_legacy.go: -------------------------------------------------------------------------------- 1 | // +build !go1.7 2 | 3 | package nosurf 4 | 5 | import ( 6 | "net/http" 7 | "sync" 8 | ) 9 | 10 | // This file implements a context similar to one found 11 | // in gorilla/context, but tailored specifically for our use case 12 | // and not using gorilla's package just because. 13 | 14 | type csrfContext struct { 15 | // The masked, base64 encoded token 16 | // That's suitable for use in form fields, etc. 17 | token string 18 | // reason for the failure of CSRF check 19 | reason error 20 | } 21 | 22 | var ( 23 | contextMap = make(map[*http.Request]*csrfContext) 24 | cmMutex = new(sync.RWMutex) 25 | ) 26 | 27 | // Token() takes an HTTP request and returns 28 | // the CSRF token for that request 29 | // or an empty string if the token does not exist. 30 | // 31 | // Note that the token won't be available after 32 | // CSRFHandler finishes 33 | // (that is, in another handler that wraps it, 34 | // or after the request has been served) 35 | func Token(req *http.Request) string { 36 | cmMutex.RLock() 37 | defer cmMutex.RUnlock() 38 | 39 | ctx, ok := contextMap[req] 40 | 41 | if !ok { 42 | return "" 43 | } 44 | 45 | return ctx.token 46 | } 47 | 48 | // Reason() takes an HTTP request and returns 49 | // the reason of failure of the CSRF check for that request 50 | // 51 | // Note that the same availability restrictions apply for Reason() as for Token(). 52 | func Reason(req *http.Request) error { 53 | cmMutex.RLock() 54 | defer cmMutex.RUnlock() 55 | 56 | ctx, ok := contextMap[req] 57 | 58 | if !ok { 59 | return nil 60 | } 61 | 62 | return ctx.reason 63 | } 64 | 65 | // Takes a raw token, masks it with a per-request key, 66 | // encodes in base64 and makes it available to the wrapped handler 67 | func ctxSetToken(req *http.Request, token []byte) *http.Request { 68 | cmMutex.Lock() 69 | defer cmMutex.Unlock() 70 | 71 | ctx, ok := contextMap[req] 72 | if !ok { 73 | ctx = new(csrfContext) 74 | contextMap[req] = ctx 75 | } 76 | 77 | ctx.token = b64encode(maskToken(token)) 78 | 79 | return req 80 | } 81 | 82 | func ctxSetReason(req *http.Request, reason error) *http.Request { 83 | cmMutex.Lock() 84 | defer cmMutex.Unlock() 85 | 86 | ctx, ok := contextMap[req] 87 | if !ok { 88 | panic("Reason should never be set when there's no token" + 89 | " (context) yet.") 90 | } 91 | 92 | ctx.reason = reason 93 | return req 94 | } 95 | 96 | func ctxClear(req *http.Request) { 97 | cmMutex.Lock() 98 | defer cmMutex.Unlock() 99 | 100 | delete(contextMap, req) 101 | } 102 | -------------------------------------------------------------------------------- /token.go: -------------------------------------------------------------------------------- 1 | package nosurf 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/subtle" 6 | "encoding/base64" 7 | "fmt" 8 | "io" 9 | ) 10 | 11 | const ( 12 | tokenLength = 32 13 | ) 14 | 15 | /* 16 | There are two types of tokens. 17 | 18 | * The unmasked "real" token consists of 32 random bytes. 19 | It is stored in a cookie (base64-encoded) and it's the 20 | "reference" value that sent tokens get compared to. 21 | 22 | * The masked "sent" token consists of 64 bytes: 23 | 32 byte key used for one-time pad masking and 24 | 32 byte "real" token masked with the said key. 25 | It is used as a value (base64-encoded as well) 26 | in forms and/or headers. 27 | 28 | Upon processing, both tokens are base64-decoded 29 | and then treated as 32/64 byte slices. 30 | */ 31 | 32 | // A token is generated by returning tokenLength bytes 33 | // from crypto/rand 34 | func generateToken() []byte { 35 | bytes := make([]byte, tokenLength) 36 | 37 | if _, err := io.ReadFull(rand.Reader, bytes); err != nil { 38 | panic(err) 39 | } 40 | 41 | return bytes 42 | } 43 | 44 | func b64encode(data []byte) string { 45 | return base64.StdEncoding.EncodeToString(data) 46 | } 47 | 48 | func b64decode(data string) []byte { 49 | decoded, err := base64.StdEncoding.DecodeString(data) 50 | if err != nil { 51 | return nil 52 | } 53 | return decoded 54 | } 55 | 56 | // VerifyToken verifies the sent token equals the real one 57 | // and returns a bool value indicating if tokens are equal. 58 | // Supports masked tokens. realToken comes from Token(r) and 59 | // sentToken is token sent unusual way. 60 | func VerifyToken(realToken, sentToken string) bool { 61 | r := b64decode(realToken) 62 | if len(r) == 2*tokenLength { 63 | r = unmaskToken(r) 64 | } 65 | s := b64decode(sentToken) 66 | if len(s) == 2*tokenLength { 67 | s = unmaskToken(s) 68 | } 69 | return subtle.ConstantTimeCompare(r, s) == 1 70 | } 71 | 72 | func verifyToken(realToken, sentToken []byte) bool { 73 | realN := len(realToken) 74 | sentN := len(sentToken) 75 | 76 | // sentN == tokenLength means the token is unmasked 77 | // sentN == 2*tokenLength means the token is masked. 78 | 79 | if realN == tokenLength && sentN == 2*tokenLength { 80 | return verifyMasked(realToken, sentToken) 81 | } else { 82 | return false 83 | } 84 | } 85 | 86 | // Verifies the masked token 87 | func verifyMasked(realToken, sentToken []byte) bool { 88 | sentPlain := unmaskToken(sentToken) 89 | return subtle.ConstantTimeCompare(realToken, sentPlain) == 1 90 | } 91 | 92 | func checkForPRNG() { 93 | // Check that cryptographically secure PRNG is available 94 | // In case it's not, panic. 95 | buf := make([]byte, 1) 96 | _, err := io.ReadFull(rand.Reader, buf) 97 | 98 | if err != nil { 99 | panic(fmt.Sprintf("crypto/rand is unavailable: Read() failed with %#v", err)) 100 | } 101 | } 102 | 103 | func init() { 104 | checkForPRNG() 105 | } 106 | -------------------------------------------------------------------------------- /context_legacy_test.go: -------------------------------------------------------------------------------- 1 | // +build !go1.7 2 | 3 | package nosurf 4 | 5 | import ( 6 | "bytes" 7 | "errors" 8 | "testing" 9 | ) 10 | 11 | func TestSetsReasonCorrectly(t *testing.T) { 12 | req := dummyGet() 13 | 14 | // set token first, as it's required for ctxSetReason 15 | ctxSetToken(req, []byte("abcdef")) 16 | 17 | err := errors.New("universe imploded") 18 | ctxSetReason(req, err) 19 | 20 | got := contextMap[req].reason 21 | 22 | if got != err { 23 | t.Errorf("Reason set incorrectly: expected %v, got %v", err, got) 24 | } 25 | } 26 | 27 | func TestSettingReasonFailsWithoutContext(t *testing.T) { 28 | req := dummyGet() 29 | err := errors.New("universe imploded") 30 | 31 | defer func() { 32 | r := recover() 33 | if r == nil { 34 | t.Error("ctxSetReason() didn't panic on no context") 35 | } 36 | }() 37 | 38 | ctxSetReason(req, err) 39 | } 40 | 41 | func TestSetsTokenCorrectly(t *testing.T) { 42 | req := dummyGet() 43 | token := []byte("12345678901234567890123456789012") 44 | ctxSetToken(req, token) 45 | 46 | got := contextMap[req].token 47 | 48 | if !bytes.Equal(token, unmaskToken(b64decode(got))) { 49 | t.Errorf("Token set incorrectly: expected %v, got %v", token, got) 50 | } 51 | } 52 | 53 | func TestGetsTokenCorrectly(t *testing.T) { 54 | req := dummyGet() 55 | token := Token(req) 56 | 57 | if len(token) != 0 { 58 | t.Errorf("Token hasn't been set yet, but it's not an empty slice, it's %v", token) 59 | } 60 | 61 | intended := []byte("12345678901234567890123456789012") 62 | ctxSetToken(req, intended) 63 | 64 | token = Token(req) 65 | decToken := unmaskToken(b64decode(token)) 66 | if !bytes.Equal(intended, decToken) { 67 | t.Errorf("Token has been set to %v, but it's %v", intended, token) 68 | } 69 | } 70 | 71 | func TestGetsReasonCorrectly(t *testing.T) { 72 | req := dummyGet() 73 | 74 | reason := Reason(req) 75 | if reason != nil { 76 | t.Errorf("Reason hasn't been set yet, but it's not nil, it's %v", reason) 77 | } 78 | 79 | // again, needed for ctxSetReason() to work 80 | ctxSetToken(req, []byte("dummy")) 81 | 82 | intended := errors.New("universe imploded") 83 | ctxSetReason(req, intended) 84 | 85 | reason = Reason(req) 86 | if reason != intended { 87 | t.Errorf("Reason has been set to %v, but it's %v", intended, reason) 88 | } 89 | } 90 | 91 | func TestClearsContextEntry(t *testing.T) { 92 | req := dummyGet() 93 | 94 | ctxSetToken(req, []byte("dummy")) 95 | ctxSetReason(req, errors.New("some error")) 96 | 97 | ctxClear(req) 98 | 99 | entry, found := contextMap[req] 100 | 101 | if found { 102 | t.Errorf("Context entry %v found for the request %v, even though"+ 103 | " it should have been cleared.", entry, req) 104 | } 105 | } 106 | 107 | func TestClearsContextEntryEvenIfNotSet(t *testing.T) { 108 | r := dummyGet() 109 | defer func() { 110 | if r := recover(); r != nil { 111 | t.Errorf("ctxClear(r) panicked with %v", r) 112 | } 113 | }() 114 | ctxClear(r) 115 | } 116 | -------------------------------------------------------------------------------- /exempt.go: -------------------------------------------------------------------------------- 1 | package nosurf 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | pathModule "path" 7 | "reflect" 8 | "regexp" 9 | ) 10 | 11 | // Checks if the given request is exempt from CSRF checks. 12 | // It checks the ExemptFunc first, then the exact paths, 13 | // then the globs and finally the regexps. 14 | func (h *CSRFHandler) IsExempt(r *http.Request) bool { 15 | if h.exemptFunc != nil && h.exemptFunc(r) { 16 | return true 17 | } 18 | 19 | path := r.URL.Path 20 | if sContains(h.exemptPaths, path) { 21 | return true 22 | } 23 | 24 | // then the globs 25 | for _, glob := range h.exemptGlobs { 26 | matched, err := pathModule.Match(glob, path) 27 | if matched && err == nil { 28 | return true 29 | } 30 | } 31 | 32 | // finally, the regexps 33 | for _, re := range h.exemptRegexps { 34 | if re.MatchString(path) { 35 | return true 36 | } 37 | } 38 | 39 | return false 40 | } 41 | 42 | // Exempts an exact path from CSRF checks 43 | // With this (and other Exempt* methods) 44 | // you should take note that Go's paths 45 | // include a leading slash. 46 | func (h *CSRFHandler) ExemptPath(path string) { 47 | h.exemptPaths = append(h.exemptPaths, path) 48 | } 49 | 50 | // A variadic argument version of ExemptPath() 51 | func (h *CSRFHandler) ExemptPaths(paths ...string) { 52 | for _, v := range paths { 53 | h.ExemptPath(v) 54 | } 55 | } 56 | 57 | // Exempts URLs that match the specified glob pattern 58 | // (as used by filepath.Match()) from CSRF checks 59 | // 60 | // Note that ExemptGlob() is unable to detect syntax errors, 61 | // because it doesn't have a path to check it against 62 | // and filepath.Match() doesn't report an error 63 | // if the path is empty. 64 | // If we find a way to check the syntax, ExemptGlob 65 | // MIGHT PANIC on a syntax error in the future. 66 | // ALWAYS check your globs for syntax errors. 67 | func (h *CSRFHandler) ExemptGlob(pattern string) { 68 | h.exemptGlobs = append(h.exemptGlobs, pattern) 69 | } 70 | 71 | // A variadic argument version of ExemptGlob() 72 | func (h *CSRFHandler) ExemptGlobs(patterns ...string) { 73 | for _, v := range patterns { 74 | h.ExemptGlob(v) 75 | } 76 | } 77 | 78 | // Accepts a regular expression string or a compiled *regexp.Regexp 79 | // and exempts URLs that match it from CSRF checks. 80 | // 81 | // If the given argument is neither of the accepted values, 82 | // or the given string fails to compile, ExemptRegexp() panics. 83 | func (h *CSRFHandler) ExemptRegexp(re interface{}) { 84 | var compiled *regexp.Regexp 85 | 86 | switch re.(type) { 87 | case string: 88 | compiled = regexp.MustCompile(re.(string)) 89 | case *regexp.Regexp: 90 | compiled = re.(*regexp.Regexp) 91 | default: 92 | err := fmt.Sprintf("%v isn't a valid type for ExemptRegexp()", reflect.TypeOf(re)) 93 | panic(err) 94 | } 95 | 96 | h.exemptRegexps = append(h.exemptRegexps, compiled) 97 | } 98 | 99 | // A variadic argument version of ExemptRegexp() 100 | func (h *CSRFHandler) ExemptRegexps(res ...interface{}) { 101 | for _, v := range res { 102 | h.ExemptRegexp(v) 103 | } 104 | } 105 | 106 | func (h *CSRFHandler) ExemptFunc(fn func(r *http.Request) bool) { 107 | h.exemptFunc = fn 108 | } 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nosurf 2 | 3 | [![Build Status](https://travis-ci.org/justinas/nosurf.svg?branch=master)](https://travis-ci.org/justinas/nosurf) 4 | [![GoDoc](http://godoc.org/github.com/justinas/nosurf?status.png)](http://godoc.org/github.com/justinas/nosurf) 5 | 6 | `nosurf` is an HTTP package for Go 7 | that helps you prevent Cross-Site Request Forgery attacks. 8 | It acts like a middleware and therefore 9 | is compatible with basically any Go HTTP application. 10 | 11 | ### Why? 12 | Even though CSRF is a prominent vulnerability, 13 | Go's web-related package infrastructure mostly consists of 14 | micro-frameworks that neither do implement CSRF checks, 15 | nor should they. 16 | 17 | `nosurf` solves this problem by providing a `CSRFHandler` 18 | that wraps your `http.Handler` and checks for CSRF attacks 19 | on every non-safe (non-GET/HEAD/OPTIONS/TRACE) method. 20 | 21 | `nosurf` requires Go 1.1 or later. 22 | 23 | ### Features 24 | 25 | * Supports any `http.Handler` (frameworks, your own handlers, etc.) 26 | and acts like one itself. 27 | * Allows exempting specific endpoints from CSRF checks by 28 | an exact URL, a glob, or a regular expression. 29 | * Allows specifying your own failure handler. 30 | Want to present the hacker with an ASCII middle finger 31 | instead of the plain old `HTTP 400`? No problem. 32 | * Uses masked tokens to mitigate the BREACH attack. 33 | * Has no dependencies outside the Go standard library. 34 | 35 | ### Example 36 | ```go 37 | package main 38 | 39 | import ( 40 | "fmt" 41 | "github.com/justinas/nosurf" 42 | "html/template" 43 | "net/http" 44 | ) 45 | 46 | var templateString string = ` 47 | 48 | 49 | 50 | {{ if .name }} 51 |

Your name: {{ .name }}

52 | {{ end }} 53 |
54 | 55 | 56 | 58 | 59 | 60 |
61 | 62 | 63 | ` 64 | var templ = template.Must(template.New("t1").Parse(templateString)) 65 | 66 | func myFunc(w http.ResponseWriter, r *http.Request) { 67 | context := make(map[string]string) 68 | context["token"] = nosurf.Token(r) 69 | if r.Method == "POST" { 70 | context["name"] = r.FormValue("name") 71 | } 72 | 73 | templ.Execute(w, context) 74 | } 75 | 76 | func main() { 77 | myHandler := http.HandlerFunc(myFunc) 78 | fmt.Println("Listening on http://127.0.0.1:8000/") 79 | http.ListenAndServe(":8000", nosurf.New(myHandler)) 80 | } 81 | ``` 82 | 83 | ### Manual token verification 84 | In some cases the CSRF token may be send through a non standard way, 85 | e.g. a body or request is a JSON encoded message with one of the fields 86 | being a token. 87 | 88 | In such case the handler(path) should be excluded from an automatic 89 | verification by using one of the exemption methods: 90 | 91 | ```go 92 | func (h *CSRFHandler) ExemptFunc(fn func(r *http.Request) bool) 93 | func (h *CSRFHandler) ExemptGlob(pattern string) 94 | func (h *CSRFHandler) ExemptGlobs(patterns ...string) 95 | func (h *CSRFHandler) ExemptPath(path string) 96 | func (h *CSRFHandler) ExemptPaths(paths ...string) 97 | func (h *CSRFHandler) ExemptRegexp(re interface{}) 98 | func (h *CSRFHandler) ExemptRegexps(res ...interface{}) 99 | ``` 100 | 101 | Later on, the token **must** be verified by manually getting the token from the cookie 102 | and providing the token sent in body through: `VerifyToken(tkn, tkn2 string) bool`. 103 | 104 | Example: 105 | ```go 106 | func HandleJson(w http.ResponseWriter, r *http.Request) { 107 | d := struct{ 108 | X,Y int 109 | Tkn string 110 | }{} 111 | json.Unmarshal(ioutil.ReadAll(r.Body), &d) 112 | if !nosurf.VerifyToken(Token(r), d.Tkn) { 113 | http.Errorf(w, "CSRF token incorrect", http.StatusBadRequest) 114 | return 115 | } 116 | // do smth cool 117 | } 118 | ``` 119 | 120 | ### Contributing 121 | 122 | 0. Find an issue that bugs you / open a new one. 123 | 1. Discuss. 124 | 2. Branch off, commit, test. 125 | 3. Make a pull request / attach the commits to the issue. 126 | -------------------------------------------------------------------------------- /exempt_test.go: -------------------------------------------------------------------------------- 1 | package nosurf 2 | 3 | import ( 4 | "net/http" 5 | "regexp" 6 | "testing" 7 | ) 8 | 9 | func TestExemptPath(t *testing.T) { 10 | // the handler doesn't matter here, let's use nil 11 | hand := New(nil) 12 | path := "/home" 13 | exempt, _ := http.NewRequest("GET", path, nil) 14 | 15 | hand.ExemptPath(path) 16 | if !hand.IsExempt(exempt) { 17 | t.Errorf("%v is not exempt, but it should be", exempt.URL.Path) 18 | } 19 | 20 | other, _ := http.NewRequest("GET", "/faq", nil) 21 | if hand.IsExempt(other) { 22 | t.Errorf("%v is exempt, but it shouldn't be", other.URL.Path) 23 | } 24 | } 25 | 26 | func TestExemptPaths(t *testing.T) { 27 | hand := New(nil) 28 | paths := []string{"/home", "/news", "/help"} 29 | hand.ExemptPaths(paths...) 30 | 31 | for _, v := range paths { 32 | request, _ := http.NewRequest("GET", v, nil) 33 | if !hand.IsExempt(request) { 34 | t.Errorf("%v should be exempt, but it isn't", v) 35 | } 36 | } 37 | 38 | other, _ := http.NewRequest("GET", "/accounts", nil) 39 | if hand.IsExempt(other) { 40 | t.Errorf("%v is exempt, but it shouldn't be", other) 41 | } 42 | } 43 | 44 | func TestExemptGlob(t *testing.T) { 45 | hand := New(nil) 46 | glob := "/[m-n]ail" 47 | 48 | hand.ExemptGlob(glob) 49 | 50 | test, _ := http.NewRequest("GET", "/mail", nil) 51 | if !hand.IsExempt(test) { 52 | t.Errorf("%v should be exempt, but it isn't.", test) 53 | } 54 | 55 | test, _ = http.NewRequest("GET", "/nail", nil) 56 | if !hand.IsExempt(test) { 57 | t.Errorf("%v should be exempt, but it isn't.", test) 58 | } 59 | 60 | test, _ = http.NewRequest("GET", "/snail", nil) 61 | if hand.IsExempt(test) { 62 | t.Errorf("%v should not be exempt, but it is.", test) 63 | } 64 | 65 | test, _ = http.NewRequest("GET", "/mail/outbox", nil) 66 | if hand.IsExempt(test) { 67 | t.Errorf("%v should not be exempt, but it is.", test) 68 | } 69 | } 70 | 71 | func TestExemptGlobs(t *testing.T) { 72 | slice := []string{"/", "/accounts/*", "/post/?*"} 73 | matching := []string{"/", "/accounts/", "/accounts/johndoe", "/post/1", "/post/123"} 74 | 75 | nonMatching := []string{"", "/accounts", 76 | // glob's * and ? don't match a forward slash 77 | "/accounts/johndoe/posts", 78 | "/post/", 79 | } 80 | 81 | hand := New(nil) 82 | hand.ExemptGlobs(slice...) 83 | 84 | for _, v := range matching { 85 | test, _ := http.NewRequest("GET", v, nil) 86 | if !hand.IsExempt(test) { 87 | t.Errorf("%v should be exempt, but it isn't.", v) 88 | } 89 | } 90 | 91 | for _, v := range nonMatching { 92 | test, _ := http.NewRequest("GET", v, nil) 93 | if hand.IsExempt(test) { 94 | t.Errorf("%v shouldn't be exempt, but it is", v) 95 | } 96 | } 97 | } 98 | 99 | // This only tests that ExemptRegexp handles the argument correctly 100 | // The matching itself is tested by TestExemptRegexpMatching 101 | func TestExemptRegexpCall(t *testing.T) { 102 | pattern := "^/[rd]ope$" 103 | 104 | // case 1: a string 105 | hand := New(nil) 106 | hand.ExemptRegexp(pattern) 107 | 108 | // String() returns the original pattern 109 | got := hand.exemptRegexps[0].String() 110 | 111 | if pattern != got { 112 | t.Errorf("The compiled pattern has changed: expected %v, got %v", 113 | pattern, got) 114 | } 115 | 116 | // case 2: a compiled *Regexp 117 | hand = New(nil) 118 | re := regexp.MustCompile(pattern) 119 | hand.ExemptRegexp(re) 120 | 121 | got_re := hand.exemptRegexps[0] 122 | 123 | if re != got_re { 124 | t.Errorf("The compiled pattern is not what we gave: expected %v, got %v", 125 | re, got_re) 126 | } 127 | 128 | } 129 | 130 | func TestExemptRegexpInvalidType(t *testing.T) { 131 | arg := 123 132 | 133 | defer func() { 134 | r := recover() 135 | if r == nil { 136 | t.Error("The function didn't panic on an invalid argument type") 137 | } 138 | }() 139 | 140 | hand := New(nil) 141 | hand.ExemptRegexp(arg) 142 | } 143 | 144 | func TestExemptRegexpInvalidPattern(t *testing.T) { 145 | // an unclosed group 146 | pattern := "a(b" 147 | 148 | defer func() { 149 | r := recover() 150 | if r == nil { 151 | t.Error("The function didn't panic on an invalid regular expression") 152 | } 153 | }() 154 | 155 | hand := New(nil) 156 | hand.ExemptRegexp(pattern) 157 | } 158 | 159 | // The same as TestExemptRegexCall, but for the variadic function 160 | func TestExemptRegexpsCall(t *testing.T) { 161 | // case 1: a slice of strings 162 | hand := New(nil) 163 | slice1 := []interface{}{"^/$", "^/accounts$"} 164 | hand.ExemptRegexps(slice1...) 165 | 166 | for i := range slice1 { 167 | pat := hand.exemptRegexps[i].String() 168 | got := slice1[i] 169 | if pat != got { 170 | t.Errorf("The compiled pattern has changed: expected %v, got %v", pat, got) 171 | } 172 | } 173 | 174 | // case 2: a mixed slice 175 | hand = New(nil) 176 | slice2 := []interface{}{"^/$", regexp.MustCompile("^/accounts$")} 177 | hand.ExemptRegexps(slice2...) 178 | 179 | pat := slice2[0].(string) 180 | got := hand.exemptRegexps[0].String() 181 | if pat != got { 182 | t.Errorf("The compiled pattern has changed: expected %v, got %v", pat, got) 183 | } 184 | 185 | pat = slice2[1].(*regexp.Regexp).String() 186 | got = hand.exemptRegexps[1].String() 187 | if pat != got { 188 | t.Errorf("The compiled pattern has changed: expected %v, got %v", pat, got) 189 | } 190 | } 191 | 192 | func TestExemptRegexpMatching(t *testing.T) { 193 | hand := New(nil) 194 | re := `^/[mn]ail$` 195 | hand.ExemptRegexp(re) 196 | 197 | // valid 198 | test, _ := http.NewRequest("GET", "/mail", nil) 199 | if !hand.IsExempt(test) { 200 | t.Errorf("%v should be exempt, but it isn't.", test) 201 | } 202 | 203 | test, _ = http.NewRequest("GET", "/nail", nil) 204 | if !hand.IsExempt(test) { 205 | t.Errorf("%v should be exempt, but it isn't.", test) 206 | } 207 | 208 | test, _ = http.NewRequest("GET", "/mail/outbox", nil) 209 | if hand.IsExempt(test) { 210 | t.Errorf("%v shouldn't be exempt, but it is.", test) 211 | } 212 | 213 | test, _ = http.NewRequest("GET", "/snail", nil) 214 | if hand.IsExempt(test) { 215 | t.Errorf("%v shouldn't be exempt, but it is.", test) 216 | } 217 | } 218 | 219 | func TestExemptFunc(t *testing.T) { 220 | // the handler doesn't matter here, let's use nil 221 | hand := New(nil) 222 | hand.ExemptFunc(func(r *http.Request) bool { 223 | return r.Method == "GET" 224 | }) 225 | 226 | test, _ := http.NewRequest("GET", "/path", nil) 227 | if !hand.IsExempt(test) { 228 | t.Errorf("%v is not exempt, but it should be", test) 229 | } 230 | 231 | other, _ := http.NewRequest("POST", "/path", nil) 232 | if hand.IsExempt(other) { 233 | t.Errorf("%v is exempt, but it shouldn't be", other) 234 | } 235 | } 236 | -------------------------------------------------------------------------------- /handler.go: -------------------------------------------------------------------------------- 1 | // Package nosurf implements an HTTP handler that 2 | // mitigates Cross-Site Request Forgery Attacks. 3 | package nosurf 4 | 5 | import ( 6 | "errors" 7 | "net/http" 8 | "net/url" 9 | "regexp" 10 | ) 11 | 12 | const ( 13 | // the name of CSRF cookie 14 | CookieName = "csrf_token" 15 | // the name of the form field 16 | FormFieldName = "csrf_token" 17 | // the name of CSRF header 18 | HeaderName = "X-CSRF-Token" 19 | // the HTTP status code for the default failure handler 20 | FailureCode = 400 21 | 22 | // Max-Age in seconds for the default base cookie. 365 days. 23 | MaxAge = 365 * 24 * 60 * 60 24 | ) 25 | 26 | var safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"} 27 | 28 | // reasons for CSRF check failures 29 | var ( 30 | ErrNoReferer = errors.New("A secure request contained no Referer or its value was malformed") 31 | ErrBadReferer = errors.New("A secure request's Referer comes from a different Origin" + 32 | " from the request's URL") 33 | ErrBadToken = errors.New("The CSRF token in the cookie doesn't match the one" + 34 | " received in a form/header.") 35 | ) 36 | 37 | type CSRFHandler struct { 38 | // Handlers that CSRFHandler wraps. 39 | successHandler http.Handler 40 | failureHandler http.Handler 41 | 42 | // The base cookie that CSRF cookies will be built upon. 43 | // This should be a better solution of customizing the options 44 | // than a bunch of methods SetCookieExpiration(), etc. 45 | baseCookie http.Cookie 46 | 47 | // Slices of paths that are exempt from CSRF checks. 48 | // They can be specified by... 49 | // ...an exact path, 50 | exemptPaths []string 51 | // ...a regexp, 52 | exemptRegexps []*regexp.Regexp 53 | // ...or a glob (as used by path.Match()). 54 | exemptGlobs []string 55 | // ...or a custom matcher function 56 | exemptFunc func(r *http.Request) bool 57 | 58 | // All of those will be matched against Request.URL.Path, 59 | // So they should take the leading slash into account 60 | } 61 | 62 | func defaultFailureHandler(w http.ResponseWriter, r *http.Request) { 63 | http.Error(w, "", FailureCode) 64 | } 65 | 66 | // Extracts the "sent" token from the request 67 | // and returns an unmasked version of it 68 | func extractToken(r *http.Request) []byte { 69 | var sentToken string 70 | 71 | // Prefer the header over form value 72 | sentToken = r.Header.Get(HeaderName) 73 | 74 | // Then POST values 75 | if len(sentToken) == 0 { 76 | sentToken = r.PostFormValue(FormFieldName) 77 | } 78 | 79 | // If all else fails, try a multipart value. 80 | // PostFormValue() will already have called ParseMultipartForm() 81 | if len(sentToken) == 0 && r.MultipartForm != nil { 82 | vals := r.MultipartForm.Value[FormFieldName] 83 | if len(vals) != 0 { 84 | sentToken = vals[0] 85 | } 86 | } 87 | 88 | return b64decode(sentToken) 89 | } 90 | 91 | // Constructs a new CSRFHandler that calls 92 | // the specified handler if the CSRF check succeeds. 93 | func New(handler http.Handler) *CSRFHandler { 94 | baseCookie := http.Cookie{} 95 | baseCookie.MaxAge = MaxAge 96 | 97 | csrf := &CSRFHandler{successHandler: handler, 98 | failureHandler: http.HandlerFunc(defaultFailureHandler), 99 | baseCookie: baseCookie, 100 | } 101 | 102 | return csrf 103 | } 104 | 105 | // The same as New(), but has an interface return type. 106 | func NewPure(handler http.Handler) http.Handler { 107 | return New(handler) 108 | } 109 | 110 | func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 111 | r = addNosurfContext(r) 112 | defer ctxClear(r) 113 | w.Header().Add("Vary", "Cookie") 114 | 115 | var realToken []byte 116 | 117 | tokenCookie, err := r.Cookie(CookieName) 118 | if err == nil { 119 | realToken = b64decode(tokenCookie.Value) 120 | } 121 | 122 | // If the length of the real token isn't what it should be, 123 | // it has either been tampered with, 124 | // or we're migrating onto a new algorithm for generating tokens, 125 | // or it hasn't ever been set so far. 126 | // In any case of those, we should regenerate it. 127 | // 128 | // As a consequence, CSRF check will fail when comparing the tokens later on, 129 | // so we don't have to fail it just yet. 130 | if len(realToken) != tokenLength { 131 | h.RegenerateToken(w, r) 132 | } else { 133 | ctxSetToken(r, realToken) 134 | } 135 | 136 | if sContains(safeMethods, r.Method) || h.IsExempt(r) { 137 | // short-circuit with a success for safe methods 138 | h.handleSuccess(w, r) 139 | return 140 | } 141 | 142 | // if the request is secure, we enforce origin check 143 | // for referer to prevent MITM of http->https requests 144 | if r.URL.Scheme == "https" { 145 | referer, err := url.Parse(r.Header.Get("Referer")) 146 | 147 | // if we can't parse the referer or it's empty, 148 | // we assume it's not specified 149 | if err != nil || referer.String() == "" { 150 | ctxSetReason(r, ErrNoReferer) 151 | h.handleFailure(w, r) 152 | return 153 | } 154 | 155 | // if the referer doesn't share origin with the request URL, 156 | // we have another error for that 157 | if !sameOrigin(referer, r.URL) { 158 | ctxSetReason(r, ErrBadReferer) 159 | h.handleFailure(w, r) 160 | return 161 | } 162 | } 163 | 164 | // Finally, we check the token itself. 165 | sentToken := extractToken(r) 166 | 167 | if !verifyToken(realToken, sentToken) { 168 | ctxSetReason(r, ErrBadToken) 169 | h.handleFailure(w, r) 170 | return 171 | } 172 | 173 | // Everything else passed, handle the success. 174 | h.handleSuccess(w, r) 175 | } 176 | 177 | // handleSuccess simply calls the successHandler. 178 | // Everything else, like setting a token in the context 179 | // is taken care of by h.ServeHTTP() 180 | func (h *CSRFHandler) handleSuccess(w http.ResponseWriter, r *http.Request) { 181 | h.successHandler.ServeHTTP(w, r) 182 | } 183 | 184 | // Same applies here: h.ServeHTTP() sets the failure reason, the token, 185 | // and only then calls handleFailure() 186 | func (h *CSRFHandler) handleFailure(w http.ResponseWriter, r *http.Request) { 187 | h.failureHandler.ServeHTTP(w, r) 188 | } 189 | 190 | // Generates a new token, sets it on the given request and returns it 191 | func (h *CSRFHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) string { 192 | token := generateToken() 193 | h.setTokenCookie(w, r, token) 194 | 195 | return Token(r) 196 | } 197 | 198 | func (h *CSRFHandler) setTokenCookie(w http.ResponseWriter, r *http.Request, token []byte) { 199 | // ctxSetToken() does the masking for us 200 | ctxSetToken(r, token) 201 | 202 | cookie := h.baseCookie 203 | cookie.Name = CookieName 204 | cookie.Value = b64encode(token) 205 | 206 | http.SetCookie(w, &cookie) 207 | 208 | } 209 | 210 | // Sets the handler to call in case the CSRF check 211 | // fails. By default it's defaultFailureHandler. 212 | func (h *CSRFHandler) SetFailureHandler(handler http.Handler) { 213 | h.failureHandler = handler 214 | } 215 | 216 | // Sets the base cookie to use when building a CSRF token cookie 217 | // This way you can specify the Domain, Path, HttpOnly, Secure, etc. 218 | func (h *CSRFHandler) SetBaseCookie(cookie http.Cookie) { 219 | h.baseCookie = cookie 220 | } 221 | -------------------------------------------------------------------------------- /handler_test.go: -------------------------------------------------------------------------------- 1 | package nosurf 2 | 3 | import ( 4 | "io" 5 | "mime/multipart" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | func TestDefaultFailureHandler(t *testing.T) { 13 | writer := httptest.NewRecorder() 14 | req := dummyGet() 15 | 16 | defaultFailureHandler(writer, req) 17 | 18 | if writer.Code != FailureCode { 19 | t.Errorf("Wrong status code for defaultFailure Handler: "+ 20 | "expected %d, got %d", FailureCode, writer.Code) 21 | } 22 | } 23 | 24 | func TestSafeMethodsPass(t *testing.T) { 25 | handler := New(http.HandlerFunc(succHand)) 26 | 27 | for _, method := range safeMethods { 28 | req, err := http.NewRequest(method, "http://dummy.us", nil) 29 | 30 | if err != nil { 31 | t.Fatal(err) 32 | } 33 | 34 | writer := httptest.NewRecorder() 35 | handler.ServeHTTP(writer, req) 36 | 37 | expected := 200 38 | 39 | if writer.Code != expected { 40 | t.Errorf("A safe method didn't pass the CSRF check."+ 41 | "Expected HTTP status %d, got %d", expected, writer.Code) 42 | } 43 | 44 | writer.Flush() 45 | } 46 | } 47 | 48 | func TestExemptedPass(t *testing.T) { 49 | handler := New(http.HandlerFunc(succHand)) 50 | handler.ExemptPath("/faq") 51 | 52 | req, err := http.NewRequest("POST", "http://dummy.us/faq", strings.NewReader("a=b")) 53 | if err != nil { 54 | t.Fatal(err) 55 | } 56 | 57 | writer := httptest.NewRecorder() 58 | handler.ServeHTTP(writer, req) 59 | 60 | expected := 200 61 | 62 | if writer.Code != expected { 63 | t.Errorf("An exempted URL didn't pass the CSRF check."+ 64 | "Expected HTTP status %d, got %d", expected, writer.Code) 65 | } 66 | 67 | writer.Flush() 68 | } 69 | 70 | func TestManualVerify(t *testing.T) { 71 | var keepToken string 72 | hand := New(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 73 | if r.Method == "POST" { 74 | if !VerifyToken(Token(r), keepToken) { 75 | http.Error(w, "error", http.StatusBadRequest) 76 | } 77 | } else { 78 | keepToken = Token(r) 79 | } 80 | })) 81 | hand.ExemptPath("/") 82 | hand.SetFailureHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 83 | t.Errorf("Test failed. Reason: %v", Reason(r)) 84 | })) 85 | 86 | server := httptest.NewServer(hand) 87 | defer server.Close() 88 | 89 | // issue the first request to get the token 90 | resp, err := http.Get(server.URL) 91 | if err != nil { 92 | t.Fatal(err) 93 | } 94 | 95 | cookie := getRespCookie(resp, CookieName) 96 | if cookie == nil { 97 | t.Fatal("Cookie was not found in the response.") 98 | } 99 | 100 | // finalToken := b64encode(maskToken(b64decode(cookie.Value))) 101 | 102 | vals := [][]string{ 103 | {"name", "Jolene"}, 104 | } 105 | 106 | // Test usual POST 107 | { 108 | req, err := http.NewRequest("POST", server.URL, formBodyR(vals)) 109 | if err != nil { 110 | t.Fatal(err) 111 | } 112 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 113 | req.AddCookie(cookie) 114 | 115 | resp, err = http.DefaultClient.Do(req) 116 | 117 | if err != nil { 118 | t.Fatal(err) 119 | } 120 | if resp.StatusCode != 200 { 121 | t.Errorf("The request should have succeeded, but it didn't. Instead, the code was %d", 122 | resp.StatusCode) 123 | } 124 | } 125 | } 126 | 127 | // Tests that the token/reason context is accessible 128 | // in the success/failure handlers 129 | func TestContextIsAccessible(t *testing.T) { 130 | // case 1: success 131 | succHand := func(w http.ResponseWriter, r *http.Request) { 132 | token := Token(r) 133 | if token == "" { 134 | t.Errorf("Token is inaccessible in the success handler") 135 | } 136 | } 137 | 138 | hand := New(http.HandlerFunc(succHand)) 139 | 140 | // we need a request that passes. Let's just use a safe method for that. 141 | req := dummyGet() 142 | writer := httptest.NewRecorder() 143 | 144 | hand.ServeHTTP(writer, req) 145 | } 146 | 147 | func TestEmptyRefererFails(t *testing.T) { 148 | hand := New(http.HandlerFunc(succHand)) 149 | fhand := correctReason(t, ErrNoReferer) 150 | hand.SetFailureHandler(fhand) 151 | 152 | req, err := http.NewRequest("POST", "https://dummy.us/", strings.NewReader("a=b")) 153 | if err != nil { 154 | t.Fatal(err) 155 | } 156 | writer := httptest.NewRecorder() 157 | 158 | hand.ServeHTTP(writer, req) 159 | 160 | if writer.Code != FailureCode { 161 | t.Errorf("A POST request with no Referer should have failed with the code %d, but it didn't.", 162 | writer.Code) 163 | } 164 | } 165 | 166 | func TestDifferentOriginRefererFails(t *testing.T) { 167 | hand := New(http.HandlerFunc(succHand)) 168 | fhand := correctReason(t, ErrBadReferer) 169 | hand.SetFailureHandler(fhand) 170 | 171 | req, err := http.NewRequest("POST", "https://dummy.us/", strings.NewReader("a=b")) 172 | if err != nil { 173 | t.Fatal(err) 174 | } 175 | req.Header.Set("Referer", "http://attack-on-golang.com") 176 | writer := httptest.NewRecorder() 177 | 178 | hand.ServeHTTP(writer, req) 179 | 180 | if writer.Code != FailureCode { 181 | t.Errorf("A POST request with a Referer from a different origin"+ 182 | "should have failed with the code %d, but it didn't.", writer.Code) 183 | } 184 | } 185 | 186 | func TestNoTokenFails(t *testing.T) { 187 | hand := New(http.HandlerFunc(succHand)) 188 | fhand := correctReason(t, ErrBadToken) 189 | hand.SetFailureHandler(fhand) 190 | 191 | vals := [][]string{ 192 | {"name", "Jolene"}, 193 | } 194 | 195 | req, err := http.NewRequest("POST", "http://dummy.us", formBodyR(vals)) 196 | if err != nil { 197 | panic(err) 198 | } 199 | writer := httptest.NewRecorder() 200 | 201 | hand.ServeHTTP(writer, req) 202 | 203 | if writer.Code != FailureCode { 204 | t.Errorf("The check should've failed with the code %d, but instead, it"+ 205 | " returned code %d", FailureCode, writer.Code) 206 | } 207 | 208 | expectedContentType := "text/plain; charset=utf-8" 209 | actualContentType := writer.Header().Get("Content-Type") 210 | if actualContentType != expectedContentType { 211 | t.Errorf("The check should've failed with content type %s, but instead, it"+ 212 | " returned content type %s", expectedContentType, actualContentType) 213 | } 214 | } 215 | 216 | func TestWrongTokenFails(t *testing.T) { 217 | hand := New(http.HandlerFunc(succHand)) 218 | fhand := correctReason(t, ErrBadToken) 219 | hand.SetFailureHandler(fhand) 220 | 221 | vals := [][]string{ 222 | {"name", "Jolene"}, 223 | // this won't EVER be a valid value with the current scheme 224 | {FormFieldName, "$#%^&"}, 225 | } 226 | 227 | req, err := http.NewRequest("POST", "http://dummy.us", formBodyR(vals)) 228 | if err != nil { 229 | panic(err) 230 | } 231 | writer := httptest.NewRecorder() 232 | 233 | hand.ServeHTTP(writer, req) 234 | 235 | if writer.Code != FailureCode { 236 | t.Errorf("The check should've failed with the code %d, but instead, it"+ 237 | " returned code %d", FailureCode, writer.Code) 238 | } 239 | 240 | expectedContentType := "text/plain; charset=utf-8" 241 | actualContentType := writer.Header().Get("Content-Type") 242 | if actualContentType != expectedContentType { 243 | t.Errorf("The check should've failed with content type %s, but instead, it"+ 244 | " returned content type %s", expectedContentType, actualContentType) 245 | } 246 | } 247 | 248 | // For this and similar tests we start a test server 249 | // Since it's much easier to get the cookie 250 | // from a normal http.Response than from the recorder 251 | func TestCorrectTokenPasses(t *testing.T) { 252 | hand := New(http.HandlerFunc(succHand)) 253 | hand.SetFailureHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 254 | t.Errorf("Test failed. Reason: %v", Reason(r)) 255 | })) 256 | 257 | server := httptest.NewServer(hand) 258 | defer server.Close() 259 | 260 | // issue the first request to get the token 261 | resp, err := http.Get(server.URL) 262 | if err != nil { 263 | t.Fatal(err) 264 | } 265 | 266 | cookie := getRespCookie(resp, CookieName) 267 | if cookie == nil { 268 | t.Fatal("Cookie was not found in the response.") 269 | } 270 | 271 | finalToken := b64encode(maskToken(b64decode(cookie.Value))) 272 | 273 | vals := [][]string{ 274 | {"name", "Jolene"}, 275 | {FormFieldName, finalToken}, 276 | } 277 | 278 | // Test usual POST 279 | /* 280 | { 281 | req, err := http.NewRequest("POST", server.URL, formBodyR(vals)) 282 | if err != nil { 283 | t.Fatal(err) 284 | } 285 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 286 | req.AddCookie(cookie) 287 | 288 | resp, err = http.DefaultClient.Do(req) 289 | 290 | if err != nil { 291 | t.Fatal(err) 292 | } 293 | if resp.StatusCode != 200 { 294 | t.Errorf("The request should have succeeded, but it didn't. Instead, the code was %d", 295 | resp.StatusCode) 296 | } 297 | } 298 | */ 299 | 300 | // Test multipart 301 | { 302 | prd, pwr := io.Pipe() 303 | wr := multipart.NewWriter(pwr) 304 | go func() { 305 | 306 | for _, v := range vals { 307 | wr.WriteField(v[0], v[1]) 308 | } 309 | 310 | err := wr.Close() 311 | if err != nil { 312 | t.Fatal(err) 313 | } 314 | err = pwr.Close() 315 | if err != nil { 316 | t.Fatal(err) 317 | } 318 | }() 319 | 320 | // Prepare a multipart request 321 | req, err := http.NewRequest("POST", server.URL, prd) 322 | if err != nil { 323 | t.Fatal(err) 324 | } 325 | 326 | req.Header.Add("Content-Type", wr.FormDataContentType()) 327 | req.AddCookie(cookie) 328 | 329 | resp, err := http.DefaultClient.Do(req) 330 | if err != nil { 331 | t.Fatal(err) 332 | } 333 | if resp.StatusCode != 200 { 334 | t.Errorf("The request should have succeeded, but it didn't. Instead, the code was %d", 335 | resp.StatusCode) 336 | } 337 | } 338 | } 339 | 340 | func TestPrefersHeaderOverFormValue(t *testing.T) { 341 | // Let's do a nice trick to find out this: 342 | // We'll set the correct token in the header 343 | // And a wrong one in the form. 344 | // That way, if it succeeds, 345 | // it will mean that it prefered the header. 346 | 347 | hand := New(http.HandlerFunc(succHand)) 348 | 349 | server := httptest.NewServer(hand) 350 | defer server.Close() 351 | 352 | resp, err := http.Get(server.URL) 353 | if err != nil { 354 | t.Fatal(err) 355 | } 356 | 357 | cookie := getRespCookie(resp, CookieName) 358 | if cookie == nil { 359 | t.Fatal("Cookie was not found in the response.") 360 | } 361 | 362 | finalToken := b64encode(maskToken(b64decode(cookie.Value))) 363 | 364 | vals := [][]string{ 365 | {"name", "Jolene"}, 366 | {FormFieldName, "a very wrong value"}, 367 | } 368 | 369 | req, err := http.NewRequest("POST", server.URL, formBodyR(vals)) 370 | if err != nil { 371 | t.Fatal(err) 372 | } 373 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 374 | req.Header.Set(HeaderName, finalToken) 375 | req.AddCookie(cookie) 376 | 377 | resp, err = http.DefaultClient.Do(req) 378 | 379 | if err != nil { 380 | t.Fatal(err) 381 | } 382 | if resp.StatusCode != 200 { 383 | t.Errorf("The request should have succeeded, but it didn't. Instead, the code was %d", 384 | resp.StatusCode) 385 | } 386 | } 387 | 388 | func TestAddsVaryCookieHeader(t *testing.T) { 389 | hand := New(http.HandlerFunc(succHand)) 390 | writer := httptest.NewRecorder() 391 | req := dummyGet() 392 | 393 | hand.ServeHTTP(writer, req) 394 | 395 | if !sContains(writer.Header()["Vary"], "Cookie") { 396 | t.Errorf("CSRFHandler didn't add a `Vary: Cookie` header.") 397 | } 398 | } 399 | --------------------------------------------------------------------------------