├── handlers_pre18.go ├── .travis.yml ├── doc.go ├── handlers_go18.go ├── handlers_go18_test.go ├── LICENSE ├── recovery_test.go ├── canonical.go ├── recovery.go ├── README.md ├── canonical_test.go ├── proxy_headers_test.go ├── compress.go ├── proxy_headers.go ├── compress_test.go ├── cors.go ├── cors_test.go ├── handlers_test.go └── handlers.go /handlers_pre18.go: -------------------------------------------------------------------------------- 1 | // +build !go1.8 2 | 3 | package handlers 4 | 5 | type loggingResponseWriter interface { 6 | commonLoggingResponseWriter 7 | } 8 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | sudo: false 3 | 4 | matrix: 5 | include: 6 | - go: 1.4 7 | - go: 1.5 8 | - go: 1.6 9 | - go: 1.7 10 | - go: tip 11 | allow_failures: 12 | - go: tip 13 | 14 | script: 15 | - go get -t -v ./... 16 | - diff -u <(echo -n) <(gofmt -d .) 17 | - go vet $(go list ./... | grep -v /vendor/) 18 | - go test -v -race ./... 19 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package handlers is a collection of handlers (aka "HTTP middleware") for use 3 | with Go's net/http package (or any framework supporting http.Handler). 4 | 5 | The package includes handlers for logging in standardised formats, compressing 6 | HTTP responses, validating content types and other useful tools for manipulating 7 | requests and responses. 8 | */ 9 | package handlers 10 | -------------------------------------------------------------------------------- /handlers_go18.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package handlers 4 | 5 | import ( 6 | "fmt" 7 | "net/http" 8 | ) 9 | 10 | type loggingResponseWriter interface { 11 | commonLoggingResponseWriter 12 | http.Pusher 13 | } 14 | 15 | func (l *responseLogger) Push(target string, opts *http.PushOptions) error { 16 | p, ok := l.w.(http.Pusher) 17 | if !ok { 18 | return fmt.Errorf("responseLogger does not implement http.Pusher") 19 | } 20 | return p.Push(target, opts) 21 | } 22 | -------------------------------------------------------------------------------- /handlers_go18_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package handlers 4 | 5 | import ( 6 | "io/ioutil" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | ) 11 | 12 | func TestLoggingHandlerWithPush(t *testing.T) { 13 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 14 | if _, ok := w.(http.Pusher); !ok { 15 | t.Fatalf("%T from LoggingHandler does not satisfy http.Pusher interface when built with Go >=1.8", w) 16 | } 17 | w.WriteHeader(200) 18 | }) 19 | 20 | logger := LoggingHandler(ioutil.Discard, handler) 21 | logger.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/")) 22 | } 23 | 24 | func TestCombinedLoggingHandlerWithPush(t *testing.T) { 25 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 26 | if _, ok := w.(http.Pusher); !ok { 27 | t.Fatalf("%T from CombinedLoggingHandler does not satisfy http.Pusher interface when built with Go >=1.8", w) 28 | } 29 | w.WriteHeader(200) 30 | }) 31 | 32 | logger := CombinedLoggingHandler(ioutil.Discard, handler) 33 | logger.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/")) 34 | } 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013 The Gorilla Handlers Authors. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 17 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 18 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 19 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 20 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 21 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 22 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | -------------------------------------------------------------------------------- /recovery_test.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "bytes" 5 | "log" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | func TestRecoveryLoggerWithDefaultOptions(t *testing.T) { 13 | var buf bytes.Buffer 14 | log.SetOutput(&buf) 15 | 16 | handler := RecoveryHandler() 17 | handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 18 | panic("Unexpected error!") 19 | }) 20 | 21 | recovery := handler(handlerFunc) 22 | recovery.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/subdir/asdf")) 23 | 24 | if !strings.Contains(buf.String(), "Unexpected error!") { 25 | t.Fatalf("Got log %#v, wanted substring %#v", buf.String(), "Unexpected error!") 26 | } 27 | } 28 | 29 | func TestRecoveryLoggerWithCustomLogger(t *testing.T) { 30 | var buf bytes.Buffer 31 | var logger = log.New(&buf, "", log.LstdFlags) 32 | 33 | handler := RecoveryHandler(RecoveryLogger(logger), PrintRecoveryStack(false)) 34 | handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 35 | panic("Unexpected error!") 36 | }) 37 | 38 | recovery := handler(handlerFunc) 39 | recovery.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/subdir/asdf")) 40 | 41 | if !strings.Contains(buf.String(), "Unexpected error!") { 42 | t.Fatalf("Got log %#v, wanted substring %#v", buf.String(), "Unexpected error!") 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /canonical.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "strings" 7 | ) 8 | 9 | type canonical struct { 10 | h http.Handler 11 | domain string 12 | code int 13 | } 14 | 15 | // CanonicalHost is HTTP middleware that re-directs requests to the canonical 16 | // domain. It accepts a domain and a status code (e.g. 301 or 302) and 17 | // re-directs clients to this domain. The existing request path is maintained. 18 | // 19 | // Note: If the provided domain is considered invalid by url.Parse or otherwise 20 | // returns an empty scheme or host, clients are not re-directed. 21 | // 22 | // Example: 23 | // 24 | // r := mux.NewRouter() 25 | // canonical := handlers.CanonicalHost("http://www.gorillatoolkit.org", 302) 26 | // r.HandleFunc("/route", YourHandler) 27 | // 28 | // log.Fatal(http.ListenAndServe(":7000", canonical(r))) 29 | // 30 | func CanonicalHost(domain string, code int) func(h http.Handler) http.Handler { 31 | fn := func(h http.Handler) http.Handler { 32 | return canonical{h, domain, code} 33 | } 34 | 35 | return fn 36 | } 37 | 38 | func (c canonical) ServeHTTP(w http.ResponseWriter, r *http.Request) { 39 | dest, err := url.Parse(c.domain) 40 | if err != nil { 41 | // Call the next handler if the provided domain fails to parse. 42 | c.h.ServeHTTP(w, r) 43 | return 44 | } 45 | 46 | if dest.Scheme == "" || dest.Host == "" { 47 | // Call the next handler if the scheme or host are empty. 48 | // Note that url.Parse won't fail on in this case. 49 | c.h.ServeHTTP(w, r) 50 | return 51 | } 52 | 53 | if !strings.EqualFold(cleanHost(r.Host), dest.Host) { 54 | // Re-build the destination URL 55 | dest := dest.Scheme + "://" + dest.Host + r.URL.Path 56 | if r.URL.RawQuery != "" { 57 | dest += "?" + r.URL.RawQuery 58 | } 59 | http.Redirect(w, r, dest, c.code) 60 | return 61 | } 62 | 63 | c.h.ServeHTTP(w, r) 64 | } 65 | 66 | // cleanHost cleans invalid Host headers by stripping anything after '/' or ' '. 67 | // This is backported from Go 1.5 (in response to issue #11206) and attempts to 68 | // mitigate malformed Host headers that do not match the format in RFC7230. 69 | func cleanHost(in string) string { 70 | if i := strings.IndexAny(in, " /"); i != -1 { 71 | return in[:i] 72 | } 73 | return in 74 | } 75 | -------------------------------------------------------------------------------- /recovery.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "runtime/debug" 7 | ) 8 | 9 | // RecoveryHandlerLogger is an interface used by the recovering handler to print logs. 10 | type RecoveryHandlerLogger interface { 11 | Println(...interface{}) 12 | } 13 | 14 | type recoveryHandler struct { 15 | handler http.Handler 16 | logger RecoveryHandlerLogger 17 | printStack bool 18 | } 19 | 20 | // RecoveryOption provides a functional approach to define 21 | // configuration for a handler; such as setting the logging 22 | // whether or not to print strack traces on panic. 23 | type RecoveryOption func(http.Handler) 24 | 25 | func parseRecoveryOptions(h http.Handler, opts ...RecoveryOption) http.Handler { 26 | for _, option := range opts { 27 | option(h) 28 | } 29 | 30 | return h 31 | } 32 | 33 | // RecoveryHandler is HTTP middleware that recovers from a panic, 34 | // logs the panic, writes http.StatusInternalServerError, and 35 | // continues to the next handler. 36 | // 37 | // Example: 38 | // 39 | // r := mux.NewRouter() 40 | // r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 41 | // panic("Unexpected error!") 42 | // }) 43 | // 44 | // http.ListenAndServe(":1123", handlers.RecoveryHandler()(r)) 45 | func RecoveryHandler(opts ...RecoveryOption) func(h http.Handler) http.Handler { 46 | return func(h http.Handler) http.Handler { 47 | r := &recoveryHandler{handler: h} 48 | return parseRecoveryOptions(r, opts...) 49 | } 50 | } 51 | 52 | // RecoveryLogger is a functional option to override 53 | // the default logger 54 | func RecoveryLogger(logger RecoveryHandlerLogger) RecoveryOption { 55 | return func(h http.Handler) { 56 | r := h.(*recoveryHandler) 57 | r.logger = logger 58 | } 59 | } 60 | 61 | // PrintRecoveryStack is a functional option to enable 62 | // or disable printing stack traces on panic. 63 | func PrintRecoveryStack(print bool) RecoveryOption { 64 | return func(h http.Handler) { 65 | r := h.(*recoveryHandler) 66 | r.printStack = print 67 | } 68 | } 69 | 70 | func (h recoveryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { 71 | defer func() { 72 | if err := recover(); err != nil { 73 | w.WriteHeader(http.StatusInternalServerError) 74 | h.log(err) 75 | } 76 | }() 77 | 78 | h.handler.ServeHTTP(w, req) 79 | } 80 | 81 | func (h recoveryHandler) log(v ...interface{}) { 82 | if h.logger != nil { 83 | h.logger.Println(v...) 84 | } else { 85 | log.Println(v...) 86 | } 87 | 88 | if h.printStack { 89 | debug.PrintStack() 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | gorilla/handlers 2 | ================ 3 | [![GoDoc](https://godoc.org/github.com/gorilla/handlers?status.svg)](https://godoc.org/github.com/gorilla/handlers) [![Build Status](https://travis-ci.org/gorilla/handlers.svg?branch=master)](https://travis-ci.org/gorilla/handlers) 4 | [![Sourcegraph](https://sourcegraph.com/github.com/gorilla/handlers/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/handlers?badge) 5 | 6 | 7 | Package handlers is a collection of handlers (aka "HTTP middleware") for use 8 | with Go's `net/http` package (or any framework supporting `http.Handler`), including: 9 | 10 | * [**LoggingHandler**](https://godoc.org/github.com/gorilla/handlers#LoggingHandler) for logging HTTP requests in the Apache [Common Log 11 | Format](http://httpd.apache.org/docs/2.2/logs.html#common). 12 | * [**CombinedLoggingHandler**](https://godoc.org/github.com/gorilla/handlers#CombinedLoggingHandler) for logging HTTP requests in the Apache [Combined Log 13 | Format](http://httpd.apache.org/docs/2.2/logs.html#combined) commonly used by 14 | both Apache and nginx. 15 | * [**CompressHandler**](https://godoc.org/github.com/gorilla/handlers#CompressHandler) for gzipping responses. 16 | * [**ContentTypeHandler**](https://godoc.org/github.com/gorilla/handlers#ContentTypeHandler) for validating requests against a list of accepted 17 | content types. 18 | * [**MethodHandler**](https://godoc.org/github.com/gorilla/handlers#MethodHandler) for matching HTTP methods against handlers in a 19 | `map[string]http.Handler` 20 | * [**ProxyHeaders**](https://godoc.org/github.com/gorilla/handlers#ProxyHeaders) for populating `r.RemoteAddr` and `r.URL.Scheme` based on the 21 | `X-Forwarded-For`, `X-Real-IP`, `X-Forwarded-Proto` and RFC7239 `Forwarded` 22 | headers when running a Go server behind a HTTP reverse proxy. 23 | * [**CanonicalHost**](https://godoc.org/github.com/gorilla/handlers#CanonicalHost) for re-directing to the preferred host when handling multiple 24 | domains (i.e. multiple CNAME aliases). 25 | * [**RecoveryHandler**](https://godoc.org/github.com/gorilla/handlers#RecoveryHandler) for recovering from unexpected panics. 26 | 27 | Other handlers are documented [on the Gorilla 28 | website](http://www.gorillatoolkit.org/pkg/handlers). 29 | 30 | ## Example 31 | 32 | A simple example using `handlers.LoggingHandler` and `handlers.CompressHandler`: 33 | 34 | ```go 35 | import ( 36 | "net/http" 37 | "github.com/gorilla/handlers" 38 | ) 39 | 40 | func main() { 41 | r := http.NewServeMux() 42 | 43 | // Only log requests to our admin dashboard to stdout 44 | r.Handle("/admin", handlers.LoggingHandler(os.Stdout, http.HandlerFunc(ShowAdminDashboard))) 45 | r.HandleFunc("/", ShowIndex) 46 | 47 | // Wrap our server with our gzip handler to gzip compress all responses. 48 | http.ListenAndServe(":8000", handlers.CompressHandler(r)) 49 | } 50 | ``` 51 | 52 | ## License 53 | 54 | BSD licensed. See the included LICENSE file for details. 55 | 56 | -------------------------------------------------------------------------------- /canonical_test.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "log" 7 | "net/http" 8 | "net/http/httptest" 9 | "net/url" 10 | "strings" 11 | "testing" 12 | ) 13 | 14 | func TestCleanHost(t *testing.T) { 15 | tests := []struct { 16 | in, want string 17 | }{ 18 | {"www.google.com", "www.google.com"}, 19 | {"www.google.com foo", "www.google.com"}, 20 | {"www.google.com/foo", "www.google.com"}, 21 | {" first character is a space", ""}, 22 | } 23 | for _, tt := range tests { 24 | got := cleanHost(tt.in) 25 | if tt.want != got { 26 | t.Errorf("cleanHost(%q) = %q, want %q", tt.in, got, tt.want) 27 | } 28 | } 29 | } 30 | 31 | func TestCanonicalHost(t *testing.T) { 32 | gorilla := "http://www.gorillatoolkit.org" 33 | 34 | rr := httptest.NewRecorder() 35 | r := newRequest("GET", "http://www.example.com/") 36 | 37 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 38 | 39 | // Test a re-direct: should return a 302 Found. 40 | CanonicalHost(gorilla, http.StatusFound)(testHandler).ServeHTTP(rr, r) 41 | 42 | if rr.Code != http.StatusFound { 43 | t.Fatalf("bad status: got %v want %v", rr.Code, http.StatusFound) 44 | } 45 | 46 | if rr.Header().Get("Location") != gorilla+r.URL.Path { 47 | t.Fatalf("bad re-direct: got %q want %q", rr.Header().Get("Location"), gorilla+r.URL.Path) 48 | } 49 | 50 | } 51 | 52 | func TestKeepsQueryString(t *testing.T) { 53 | google := "https://www.google.com" 54 | 55 | rr := httptest.NewRecorder() 56 | querystring := url.Values{"q": {"golang"}, "format": {"json"}}.Encode() 57 | r := newRequest("GET", "http://www.example.com/search?"+querystring) 58 | 59 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 60 | CanonicalHost(google, http.StatusFound)(testHandler).ServeHTTP(rr, r) 61 | 62 | want := google + r.URL.Path + "?" + querystring 63 | if rr.Header().Get("Location") != want { 64 | t.Fatalf("bad re-direct: got %q want %q", rr.Header().Get("Location"), want) 65 | } 66 | } 67 | 68 | func TestBadDomain(t *testing.T) { 69 | rr := httptest.NewRecorder() 70 | r := newRequest("GET", "http://www.example.com/") 71 | 72 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 73 | 74 | // Test a bad domain - should return 200 OK. 75 | CanonicalHost("%", http.StatusFound)(testHandler).ServeHTTP(rr, r) 76 | 77 | if rr.Code != http.StatusOK { 78 | t.Fatalf("bad status: got %v want %v", rr.Code, http.StatusOK) 79 | } 80 | } 81 | 82 | func TestEmptyHost(t *testing.T) { 83 | rr := httptest.NewRecorder() 84 | r := newRequest("GET", "http://www.example.com/") 85 | 86 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 87 | 88 | // Test a domain that returns an empty url.Host from url.Parse. 89 | CanonicalHost("hello.com", http.StatusFound)(testHandler).ServeHTTP(rr, r) 90 | 91 | if rr.Code != http.StatusOK { 92 | t.Fatalf("bad status: got %v want %v", rr.Code, http.StatusOK) 93 | } 94 | } 95 | 96 | func TestHeaderWrites(t *testing.T) { 97 | gorilla := "http://www.gorillatoolkit.org" 98 | 99 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 100 | w.WriteHeader(200) 101 | }) 102 | 103 | // Catch the log output to ensure we don't write multiple headers. 104 | var b bytes.Buffer 105 | buf := bufio.NewWriter(&b) 106 | tl := log.New(buf, "test: ", log.Lshortfile) 107 | 108 | srv := httptest.NewServer( 109 | CanonicalHost(gorilla, http.StatusFound)(testHandler)) 110 | defer srv.Close() 111 | srv.Config.ErrorLog = tl 112 | 113 | _, err := http.Get(srv.URL) 114 | if err != nil { 115 | t.Fatal(err) 116 | } 117 | 118 | err = buf.Flush() 119 | if err != nil { 120 | t.Fatal(err) 121 | } 122 | 123 | // We rely on the error not changing: net/http does not export it. 124 | if strings.Contains(b.String(), "multiple response.WriteHeader calls") { 125 | t.Fatalf("re-direct did not return early: multiple header writes") 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /proxy_headers_test.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | ) 8 | 9 | type headerTable struct { 10 | key string // header key 11 | val string // header val 12 | expected string // expected result 13 | } 14 | 15 | func TestGetIP(t *testing.T) { 16 | headers := []headerTable{ 17 | {xForwardedFor, "8.8.8.8", "8.8.8.8"}, // Single address 18 | {xForwardedFor, "8.8.8.8, 8.8.4.4", "8.8.8.8"}, // Multiple 19 | {xForwardedFor, "[2001:db8:cafe::17]:4711", "[2001:db8:cafe::17]:4711"}, // IPv6 address 20 | {xForwardedFor, "", ""}, // None 21 | {xRealIP, "8.8.8.8", "8.8.8.8"}, // Single address 22 | {xRealIP, "8.8.8.8, 8.8.4.4", "8.8.8.8, 8.8.4.4"}, // Multiple 23 | {xRealIP, "[2001:db8:cafe::17]:4711", "[2001:db8:cafe::17]:4711"}, // IPv6 address 24 | {xRealIP, "", ""}, // None 25 | {forwarded, `for="_gazonk"`, "_gazonk"}, // Hostname 26 | {forwarded, `For="[2001:db8:cafe::17]:4711`, `[2001:db8:cafe::17]:4711`}, // IPv6 address 27 | {forwarded, `for=192.0.2.60;proto=http;by=203.0.113.43`, `192.0.2.60`}, // Multiple params 28 | {forwarded, `for=192.0.2.43, for=198.51.100.17`, "192.0.2.43"}, // Multiple params 29 | {forwarded, `for="workstation.local",for=198.51.100.17`, "workstation.local"}, // Hostname 30 | } 31 | 32 | for _, v := range headers { 33 | req := &http.Request{ 34 | Header: http.Header{ 35 | v.key: []string{v.val}, 36 | }} 37 | res := getIP(req) 38 | if res != v.expected { 39 | t.Fatalf("wrong header for %s: got %s want %s", v.key, res, 40 | v.expected) 41 | } 42 | } 43 | } 44 | 45 | func TestGetScheme(t *testing.T) { 46 | headers := []headerTable{ 47 | {xForwardedProto, "https", "https"}, 48 | {xForwardedProto, "http", "http"}, 49 | {xForwardedProto, "HTTP", "http"}, 50 | {xForwardedScheme, "https", "https"}, 51 | {xForwardedScheme, "http", "http"}, 52 | {xForwardedScheme, "HTTP", "http"}, 53 | {forwarded, `For="[2001:db8:cafe::17]:4711`, ""}, // No proto 54 | {forwarded, `for=192.0.2.43, for=198.51.100.17;proto=https`, "https"}, // Multiple params before proto 55 | {forwarded, `for=172.32.10.15; proto=https;by=127.0.0.1`, "https"}, // Space before proto 56 | {forwarded, `for=192.0.2.60;proto=http;by=203.0.113.43`, "http"}, // Multiple params 57 | } 58 | 59 | for _, v := range headers { 60 | req := &http.Request{ 61 | Header: http.Header{ 62 | v.key: []string{v.val}, 63 | }, 64 | } 65 | res := getScheme(req) 66 | if res != v.expected { 67 | t.Fatalf("wrong header for %s: got %s want %s", v.key, res, 68 | v.expected) 69 | } 70 | } 71 | } 72 | 73 | // Test the middleware end-to-end 74 | func TestProxyHeaders(t *testing.T) { 75 | rr := httptest.NewRecorder() 76 | r := newRequest("GET", "/") 77 | 78 | r.Header.Set(xForwardedFor, "8.8.8.8") 79 | r.Header.Set(xForwardedProto, "https") 80 | r.Header.Set(xForwardedHost, "google.com") 81 | var ( 82 | addr string 83 | proto string 84 | host string 85 | ) 86 | ProxyHeaders(http.HandlerFunc( 87 | func(w http.ResponseWriter, r *http.Request) { 88 | addr = r.RemoteAddr 89 | proto = r.URL.Scheme 90 | host = r.Host 91 | })).ServeHTTP(rr, r) 92 | 93 | if rr.Code != http.StatusOK { 94 | t.Fatalf("bad status: got %d want %d", rr.Code, http.StatusOK) 95 | } 96 | 97 | if addr != r.Header.Get(xForwardedFor) { 98 | t.Fatalf("wrong address: got %s want %s", addr, 99 | r.Header.Get(xForwardedFor)) 100 | } 101 | 102 | if proto != r.Header.Get(xForwardedProto) { 103 | t.Fatalf("wrong address: got %s want %s", proto, 104 | r.Header.Get(xForwardedProto)) 105 | } 106 | if host != r.Header.Get(xForwardedHost) { 107 | t.Fatalf("wrong address: got %s want %s", host, 108 | r.Header.Get(xForwardedHost)) 109 | } 110 | 111 | } 112 | -------------------------------------------------------------------------------- /compress.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package handlers 6 | 7 | import ( 8 | "compress/flate" 9 | "compress/gzip" 10 | "io" 11 | "net/http" 12 | "strings" 13 | ) 14 | 15 | type compressResponseWriter struct { 16 | io.Writer 17 | http.ResponseWriter 18 | http.Hijacker 19 | http.Flusher 20 | http.CloseNotifier 21 | } 22 | 23 | func (w *compressResponseWriter) WriteHeader(c int) { 24 | w.ResponseWriter.Header().Del("Content-Length") 25 | w.ResponseWriter.WriteHeader(c) 26 | } 27 | 28 | func (w *compressResponseWriter) Header() http.Header { 29 | return w.ResponseWriter.Header() 30 | } 31 | 32 | func (w *compressResponseWriter) Write(b []byte) (int, error) { 33 | h := w.ResponseWriter.Header() 34 | if h.Get("Content-Type") == "" { 35 | h.Set("Content-Type", http.DetectContentType(b)) 36 | } 37 | h.Del("Content-Length") 38 | 39 | return w.Writer.Write(b) 40 | } 41 | 42 | type flusher interface { 43 | Flush() error 44 | } 45 | 46 | func (w *compressResponseWriter) Flush() { 47 | // Flush compressed data if compressor supports it. 48 | if f, ok := w.Writer.(flusher); ok { 49 | f.Flush() 50 | } 51 | // Flush HTTP response. 52 | if w.Flusher != nil { 53 | w.Flusher.Flush() 54 | } 55 | } 56 | 57 | // CompressHandler gzip compresses HTTP responses for clients that support it 58 | // via the 'Accept-Encoding' header. 59 | // 60 | // Compressing TLS traffic may leak the page contents to an attacker if the 61 | // page contains user input: http://security.stackexchange.com/a/102015/12208 62 | func CompressHandler(h http.Handler) http.Handler { 63 | return CompressHandlerLevel(h, gzip.DefaultCompression) 64 | } 65 | 66 | // CompressHandlerLevel gzip compresses HTTP responses with specified compression level 67 | // for clients that support it via the 'Accept-Encoding' header. 68 | // 69 | // The compression level should be gzip.DefaultCompression, gzip.NoCompression, 70 | // or any integer value between gzip.BestSpeed and gzip.BestCompression inclusive. 71 | // gzip.DefaultCompression is used in case of invalid compression level. 72 | func CompressHandlerLevel(h http.Handler, level int) http.Handler { 73 | if level < gzip.DefaultCompression || level > gzip.BestCompression { 74 | level = gzip.DefaultCompression 75 | } 76 | 77 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 78 | L: 79 | for _, enc := range strings.Split(r.Header.Get("Accept-Encoding"), ",") { 80 | switch strings.TrimSpace(enc) { 81 | case "gzip": 82 | w.Header().Set("Content-Encoding", "gzip") 83 | w.Header().Add("Vary", "Accept-Encoding") 84 | 85 | gw, _ := gzip.NewWriterLevel(w, level) 86 | defer gw.Close() 87 | 88 | h, hok := w.(http.Hijacker) 89 | if !hok { /* w is not Hijacker... oh well... */ 90 | h = nil 91 | } 92 | 93 | f, fok := w.(http.Flusher) 94 | if !fok { 95 | f = nil 96 | } 97 | 98 | cn, cnok := w.(http.CloseNotifier) 99 | if !cnok { 100 | cn = nil 101 | } 102 | 103 | w = &compressResponseWriter{ 104 | Writer: gw, 105 | ResponseWriter: w, 106 | Hijacker: h, 107 | Flusher: f, 108 | CloseNotifier: cn, 109 | } 110 | 111 | break L 112 | case "deflate": 113 | w.Header().Set("Content-Encoding", "deflate") 114 | w.Header().Add("Vary", "Accept-Encoding") 115 | 116 | fw, _ := flate.NewWriter(w, level) 117 | defer fw.Close() 118 | 119 | h, hok := w.(http.Hijacker) 120 | if !hok { /* w is not Hijacker... oh well... */ 121 | h = nil 122 | } 123 | 124 | f, fok := w.(http.Flusher) 125 | if !fok { 126 | f = nil 127 | } 128 | 129 | cn, cnok := w.(http.CloseNotifier) 130 | if !cnok { 131 | cn = nil 132 | } 133 | 134 | w = &compressResponseWriter{ 135 | Writer: fw, 136 | ResponseWriter: w, 137 | Hijacker: h, 138 | Flusher: f, 139 | CloseNotifier: cn, 140 | } 141 | 142 | break L 143 | } 144 | } 145 | 146 | h.ServeHTTP(w, r) 147 | }) 148 | } 149 | -------------------------------------------------------------------------------- /proxy_headers.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "net/http" 5 | "regexp" 6 | "strings" 7 | ) 8 | 9 | var ( 10 | // De-facto standard header keys. 11 | xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") 12 | xForwardedHost = http.CanonicalHeaderKey("X-Forwarded-Host") 13 | xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto") 14 | xForwardedScheme = http.CanonicalHeaderKey("X-Forwarded-Scheme") 15 | xRealIP = http.CanonicalHeaderKey("X-Real-IP") 16 | ) 17 | 18 | var ( 19 | // RFC7239 defines a new "Forwarded: " header designed to replace the 20 | // existing use of X-Forwarded-* headers. 21 | // e.g. Forwarded: for=192.0.2.60;proto=https;by=203.0.113.43 22 | forwarded = http.CanonicalHeaderKey("Forwarded") 23 | // Allows for a sub-match of the first value after 'for=' to the next 24 | // comma, semi-colon or space. The match is case-insensitive. 25 | forRegex = regexp.MustCompile(`(?i)(?:for=)([^(;|,| )]+)`) 26 | // Allows for a sub-match for the first instance of scheme (http|https) 27 | // prefixed by 'proto='. The match is case-insensitive. 28 | protoRegex = regexp.MustCompile(`(?i)(?:proto=)(https|http)`) 29 | ) 30 | 31 | // ProxyHeaders inspects common reverse proxy headers and sets the corresponding 32 | // fields in the HTTP request struct. These are X-Forwarded-For and X-Real-IP 33 | // for the remote (client) IP address, X-Forwarded-Proto or X-Forwarded-Scheme 34 | // for the scheme (http|https) and the RFC7239 Forwarded header, which may 35 | // include both client IPs and schemes. 36 | // 37 | // NOTE: This middleware should only be used when behind a reverse 38 | // proxy like nginx, HAProxy or Apache. Reverse proxies that don't (or are 39 | // configured not to) strip these headers from client requests, or where these 40 | // headers are accepted "as is" from a remote client (e.g. when Go is not behind 41 | // a proxy), can manifest as a vulnerability if your application uses these 42 | // headers for validating the 'trustworthiness' of a request. 43 | func ProxyHeaders(h http.Handler) http.Handler { 44 | fn := func(w http.ResponseWriter, r *http.Request) { 45 | // Set the remote IP with the value passed from the proxy. 46 | if fwd := getIP(r); fwd != "" { 47 | r.RemoteAddr = fwd 48 | } 49 | 50 | // Set the scheme (proto) with the value passed from the proxy. 51 | if scheme := getScheme(r); scheme != "" { 52 | r.URL.Scheme = scheme 53 | } 54 | // Set the host with the value passed by the proxy 55 | if r.Header.Get(xForwardedHost) != "" { 56 | r.Host = r.Header.Get(xForwardedHost) 57 | } 58 | // Call the next handler in the chain. 59 | h.ServeHTTP(w, r) 60 | } 61 | 62 | return http.HandlerFunc(fn) 63 | } 64 | 65 | // getIP retrieves the IP from the X-Forwarded-For, X-Real-IP and RFC7239 66 | // Forwarded headers (in that order). 67 | func getIP(r *http.Request) string { 68 | var addr string 69 | 70 | if fwd := r.Header.Get(xForwardedFor); fwd != "" { 71 | // Only grab the first (client) address. Note that '192.168.0.1, 72 | // 10.1.1.1' is a valid key for X-Forwarded-For where addresses after 73 | // the first may represent forwarding proxies earlier in the chain. 74 | s := strings.Index(fwd, ", ") 75 | if s == -1 { 76 | s = len(fwd) 77 | } 78 | addr = fwd[:s] 79 | } else if fwd := r.Header.Get(xRealIP); fwd != "" { 80 | // X-Real-IP should only contain one IP address (the client making the 81 | // request). 82 | addr = fwd 83 | } else if fwd := r.Header.Get(forwarded); fwd != "" { 84 | // match should contain at least two elements if the protocol was 85 | // specified in the Forwarded header. The first element will always be 86 | // the 'for=' capture, which we ignore. In the case of multiple IP 87 | // addresses (for=8.8.8.8, 8.8.4.4,172.16.1.20 is valid) we only 88 | // extract the first, which should be the client IP. 89 | if match := forRegex.FindStringSubmatch(fwd); len(match) > 1 { 90 | // IPv6 addresses in Forwarded headers are quoted-strings. We strip 91 | // these quotes. 92 | addr = strings.Trim(match[1], `"`) 93 | } 94 | } 95 | 96 | return addr 97 | } 98 | 99 | // getScheme retrieves the scheme from the X-Forwarded-Proto and RFC7239 100 | // Forwarded headers (in that order). 101 | func getScheme(r *http.Request) string { 102 | var scheme string 103 | 104 | // Retrieve the scheme from X-Forwarded-Proto. 105 | if proto := r.Header.Get(xForwardedProto); proto != "" { 106 | scheme = strings.ToLower(proto) 107 | } else if proto = r.Header.Get(xForwardedScheme); proto != "" { 108 | scheme = strings.ToLower(proto) 109 | } else if proto = r.Header.Get(forwarded); proto != "" { 110 | // match should contain at least two elements if the protocol was 111 | // specified in the Forwarded header. The first element will always be 112 | // the 'proto=' capture, which we ignore. In the case of multiple proto 113 | // parameters (invalid) we only extract the first. 114 | if match := protoRegex.FindStringSubmatch(proto); len(match) > 1 { 115 | scheme = strings.ToLower(match[1]) 116 | } 117 | } 118 | 119 | return scheme 120 | } 121 | -------------------------------------------------------------------------------- /compress_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package handlers 6 | 7 | import ( 8 | "bufio" 9 | "io" 10 | "net" 11 | "net/http" 12 | "net/http/httptest" 13 | "strconv" 14 | "testing" 15 | ) 16 | 17 | var contentType = "text/plain; charset=utf-8" 18 | 19 | func compressedRequest(w *httptest.ResponseRecorder, compression string) { 20 | CompressHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 21 | w.Header().Set("Content-Length", strconv.Itoa(9*1024)) 22 | w.Header().Set("Content-Type", contentType) 23 | for i := 0; i < 1024; i++ { 24 | io.WriteString(w, "Gorilla!\n") 25 | } 26 | })).ServeHTTP(w, &http.Request{ 27 | Method: "GET", 28 | Header: http.Header{ 29 | "Accept-Encoding": []string{compression}, 30 | }, 31 | }) 32 | 33 | } 34 | 35 | func TestCompressHandlerNoCompression(t *testing.T) { 36 | w := httptest.NewRecorder() 37 | compressedRequest(w, "") 38 | if enc := w.HeaderMap.Get("Content-Encoding"); enc != "" { 39 | t.Errorf("wrong content encoding, got %q want %q", enc, "") 40 | } 41 | if ct := w.HeaderMap.Get("Content-Type"); ct != contentType { 42 | t.Errorf("wrong content type, got %q want %q", ct, contentType) 43 | } 44 | if w.Body.Len() != 1024*9 { 45 | t.Errorf("wrong len, got %d want %d", w.Body.Len(), 1024*9) 46 | } 47 | if l := w.HeaderMap.Get("Content-Length"); l != "9216" { 48 | t.Errorf("wrong content-length. got %q expected %d", l, 1024*9) 49 | } 50 | } 51 | 52 | func TestCompressHandlerGzip(t *testing.T) { 53 | w := httptest.NewRecorder() 54 | compressedRequest(w, "gzip") 55 | if w.HeaderMap.Get("Content-Encoding") != "gzip" { 56 | t.Errorf("wrong content encoding, got %q want %q", w.HeaderMap.Get("Content-Encoding"), "gzip") 57 | } 58 | if w.HeaderMap.Get("Content-Type") != "text/plain; charset=utf-8" { 59 | t.Errorf("wrong content type, got %s want %s", w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8") 60 | } 61 | if w.Body.Len() != 72 { 62 | t.Errorf("wrong len, got %d want %d", w.Body.Len(), 72) 63 | } 64 | if l := w.HeaderMap.Get("Content-Length"); l != "" { 65 | t.Errorf("wrong content-length. got %q expected %q", l, "") 66 | } 67 | } 68 | 69 | func TestCompressHandlerDeflate(t *testing.T) { 70 | w := httptest.NewRecorder() 71 | compressedRequest(w, "deflate") 72 | if w.HeaderMap.Get("Content-Encoding") != "deflate" { 73 | t.Fatalf("wrong content encoding, got %q want %q", w.HeaderMap.Get("Content-Encoding"), "deflate") 74 | } 75 | if w.HeaderMap.Get("Content-Type") != "text/plain; charset=utf-8" { 76 | t.Fatalf("wrong content type, got %s want %s", w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8") 77 | } 78 | if w.Body.Len() != 54 { 79 | t.Fatalf("wrong len, got %d want %d", w.Body.Len(), 54) 80 | } 81 | } 82 | 83 | func TestCompressHandlerGzipDeflate(t *testing.T) { 84 | w := httptest.NewRecorder() 85 | compressedRequest(w, "gzip, deflate ") 86 | if w.HeaderMap.Get("Content-Encoding") != "gzip" { 87 | t.Fatalf("wrong content encoding, got %q want %q", w.HeaderMap.Get("Content-Encoding"), "gzip") 88 | } 89 | if w.HeaderMap.Get("Content-Type") != "text/plain; charset=utf-8" { 90 | t.Fatalf("wrong content type, got %s want %s", w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8") 91 | } 92 | } 93 | 94 | type fullyFeaturedResponseWriter struct{} 95 | 96 | // Header/Write/WriteHeader implement the http.ResponseWriter interface. 97 | func (fullyFeaturedResponseWriter) Header() http.Header { 98 | return http.Header{} 99 | } 100 | func (fullyFeaturedResponseWriter) Write([]byte) (int, error) { 101 | return 0, nil 102 | } 103 | func (fullyFeaturedResponseWriter) WriteHeader(int) {} 104 | 105 | // Flush implements the http.Flusher interface. 106 | func (fullyFeaturedResponseWriter) Flush() {} 107 | 108 | // Hijack implements the http.Hijacker interface. 109 | func (fullyFeaturedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 110 | return nil, nil, nil 111 | } 112 | 113 | // CloseNotify implements the http.CloseNotifier interface. 114 | func (fullyFeaturedResponseWriter) CloseNotify() <-chan bool { 115 | return nil 116 | } 117 | 118 | func TestCompressHandlerPreserveInterfaces(t *testing.T) { 119 | // Compile time validation fullyFeaturedResponseWriter implements all the 120 | // interfaces we're asserting in the test case below. 121 | var ( 122 | _ http.Flusher = fullyFeaturedResponseWriter{} 123 | _ http.CloseNotifier = fullyFeaturedResponseWriter{} 124 | _ http.Hijacker = fullyFeaturedResponseWriter{} 125 | ) 126 | var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 127 | comp := r.Header.Get("Accept-Encoding") 128 | if _, ok := rw.(*compressResponseWriter); !ok { 129 | t.Fatalf("ResponseWriter wasn't wrapped by compressResponseWriter, got %T type", rw) 130 | } 131 | if _, ok := rw.(http.Flusher); !ok { 132 | t.Errorf("ResponseWriter lost http.Flusher interface for %q", comp) 133 | } 134 | if _, ok := rw.(http.CloseNotifier); !ok { 135 | t.Errorf("ResponseWriter lost http.CloseNotifier interface for %q", comp) 136 | } 137 | if _, ok := rw.(http.Hijacker); !ok { 138 | t.Errorf("ResponseWriter lost http.Hijacker interface for %q", comp) 139 | } 140 | }) 141 | h = CompressHandler(h) 142 | var ( 143 | rw fullyFeaturedResponseWriter 144 | ) 145 | r, err := http.NewRequest("GET", "/", nil) 146 | if err != nil { 147 | t.Fatalf("Failed to create test request: %v", err) 148 | } 149 | r.Header.Set("Accept-Encoding", "gzip") 150 | h.ServeHTTP(rw, r) 151 | 152 | r.Header.Set("Accept-Encoding", "deflate") 153 | h.ServeHTTP(rw, r) 154 | } 155 | -------------------------------------------------------------------------------- /cors.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "net/http" 5 | "strconv" 6 | "strings" 7 | ) 8 | 9 | // CORSOption represents a functional option for configuring the CORS middleware. 10 | type CORSOption func(*cors) error 11 | 12 | type cors struct { 13 | h http.Handler 14 | allowedHeaders []string 15 | allowedMethods []string 16 | allowedOrigins []string 17 | allowedOriginValidator OriginValidator 18 | exposedHeaders []string 19 | maxAge int 20 | ignoreOptions bool 21 | allowCredentials bool 22 | } 23 | 24 | // OriginValidator takes an origin string and returns whether or not that origin is allowed. 25 | type OriginValidator func(string) bool 26 | 27 | var ( 28 | defaultCorsMethods = []string{"GET", "HEAD", "POST"} 29 | defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"} 30 | // (WebKit/Safari v9 sends the Origin header by default in AJAX requests) 31 | ) 32 | 33 | const ( 34 | corsOptionMethod string = "OPTIONS" 35 | corsAllowOriginHeader string = "Access-Control-Allow-Origin" 36 | corsExposeHeadersHeader string = "Access-Control-Expose-Headers" 37 | corsMaxAgeHeader string = "Access-Control-Max-Age" 38 | corsAllowMethodsHeader string = "Access-Control-Allow-Methods" 39 | corsAllowHeadersHeader string = "Access-Control-Allow-Headers" 40 | corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials" 41 | corsRequestMethodHeader string = "Access-Control-Request-Method" 42 | corsRequestHeadersHeader string = "Access-Control-Request-Headers" 43 | corsOriginHeader string = "Origin" 44 | corsVaryHeader string = "Vary" 45 | corsOriginMatchAll string = "*" 46 | ) 47 | 48 | func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { 49 | origin := r.Header.Get(corsOriginHeader) 50 | if !ch.isOriginAllowed(origin) { 51 | ch.h.ServeHTTP(w, r) 52 | return 53 | } 54 | 55 | if r.Method == corsOptionMethod { 56 | if ch.ignoreOptions { 57 | ch.h.ServeHTTP(w, r) 58 | return 59 | } 60 | 61 | if _, ok := r.Header[corsRequestMethodHeader]; !ok { 62 | w.WriteHeader(http.StatusBadRequest) 63 | return 64 | } 65 | 66 | method := r.Header.Get(corsRequestMethodHeader) 67 | if !ch.isMatch(method, ch.allowedMethods) { 68 | w.WriteHeader(http.StatusMethodNotAllowed) 69 | return 70 | } 71 | 72 | requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",") 73 | allowedHeaders := []string{} 74 | for _, v := range requestHeaders { 75 | canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 76 | if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) { 77 | continue 78 | } 79 | 80 | if !ch.isMatch(canonicalHeader, ch.allowedHeaders) { 81 | w.WriteHeader(http.StatusForbidden) 82 | return 83 | } 84 | 85 | allowedHeaders = append(allowedHeaders, canonicalHeader) 86 | } 87 | 88 | if len(allowedHeaders) > 0 { 89 | w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ",")) 90 | } 91 | 92 | if ch.maxAge > 0 { 93 | w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge)) 94 | } 95 | 96 | if !ch.isMatch(method, defaultCorsMethods) { 97 | w.Header().Set(corsAllowMethodsHeader, method) 98 | } 99 | } else { 100 | if len(ch.exposedHeaders) > 0 { 101 | w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ",")) 102 | } 103 | } 104 | 105 | if ch.allowCredentials { 106 | w.Header().Set(corsAllowCredentialsHeader, "true") 107 | } 108 | 109 | if len(ch.allowedOrigins) > 1 { 110 | w.Header().Set(corsVaryHeader, corsOriginHeader) 111 | } 112 | 113 | w.Header().Set(corsAllowOriginHeader, origin) 114 | 115 | if r.Method == corsOptionMethod { 116 | return 117 | } 118 | ch.h.ServeHTTP(w, r) 119 | } 120 | 121 | // CORS provides Cross-Origin Resource Sharing middleware. 122 | // Example: 123 | // 124 | // import ( 125 | // "net/http" 126 | // 127 | // "github.com/gorilla/handlers" 128 | // "github.com/gorilla/mux" 129 | // ) 130 | // 131 | // func main() { 132 | // r := mux.NewRouter() 133 | // r.HandleFunc("/users", UserEndpoint) 134 | // r.HandleFunc("/projects", ProjectEndpoint) 135 | // 136 | // // Apply the CORS middleware to our top-level router, with the defaults. 137 | // http.ListenAndServe(":8000", handlers.CORS()(r)) 138 | // } 139 | // 140 | func CORS(opts ...CORSOption) func(http.Handler) http.Handler { 141 | return func(h http.Handler) http.Handler { 142 | ch := parseCORSOptions(opts...) 143 | ch.h = h 144 | return ch 145 | } 146 | } 147 | 148 | func parseCORSOptions(opts ...CORSOption) *cors { 149 | ch := &cors{ 150 | allowedMethods: defaultCorsMethods, 151 | allowedHeaders: defaultCorsHeaders, 152 | allowedOrigins: []string{corsOriginMatchAll}, 153 | } 154 | 155 | for _, option := range opts { 156 | option(ch) 157 | } 158 | 159 | return ch 160 | } 161 | 162 | // 163 | // Functional options for configuring CORS. 164 | // 165 | 166 | // AllowedHeaders adds the provided headers to the list of allowed headers in a 167 | // CORS request. 168 | // This is an append operation so the headers Accept, Accept-Language, 169 | // and Content-Language are always allowed. 170 | // Content-Type must be explicitly declared if accepting Content-Types other than 171 | // application/x-www-form-urlencoded, multipart/form-data, or text/plain. 172 | func AllowedHeaders(headers []string) CORSOption { 173 | return func(ch *cors) error { 174 | for _, v := range headers { 175 | normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 176 | if normalizedHeader == "" { 177 | continue 178 | } 179 | 180 | if !ch.isMatch(normalizedHeader, ch.allowedHeaders) { 181 | ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader) 182 | } 183 | } 184 | 185 | return nil 186 | } 187 | } 188 | 189 | // AllowedMethods can be used to explicitly allow methods in the 190 | // Access-Control-Allow-Methods header. 191 | // This is a replacement operation so you must also 192 | // pass GET, HEAD, and POST if you wish to support those methods. 193 | func AllowedMethods(methods []string) CORSOption { 194 | return func(ch *cors) error { 195 | ch.allowedMethods = []string{} 196 | for _, v := range methods { 197 | normalizedMethod := strings.ToUpper(strings.TrimSpace(v)) 198 | if normalizedMethod == "" { 199 | continue 200 | } 201 | 202 | if !ch.isMatch(normalizedMethod, ch.allowedMethods) { 203 | ch.allowedMethods = append(ch.allowedMethods, normalizedMethod) 204 | } 205 | } 206 | 207 | return nil 208 | } 209 | } 210 | 211 | // AllowedOrigins sets the allowed origins for CORS requests, as used in the 212 | // 'Allow-Access-Control-Origin' HTTP header. 213 | // Note: Passing in a []string{"*"} will allow any domain. 214 | func AllowedOrigins(origins []string) CORSOption { 215 | return func(ch *cors) error { 216 | for _, v := range origins { 217 | if v == corsOriginMatchAll { 218 | ch.allowedOrigins = []string{corsOriginMatchAll} 219 | return nil 220 | } 221 | } 222 | 223 | ch.allowedOrigins = origins 224 | return nil 225 | } 226 | } 227 | 228 | // AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the 229 | // 'Allow-Access-Control-Origin' HTTP header. 230 | func AllowedOriginValidator(fn OriginValidator) CORSOption { 231 | return func(ch *cors) error { 232 | ch.allowedOriginValidator = fn 233 | return nil 234 | } 235 | } 236 | 237 | // ExposeHeaders can be used to specify headers that are available 238 | // and will not be stripped out by the user-agent. 239 | func ExposedHeaders(headers []string) CORSOption { 240 | return func(ch *cors) error { 241 | ch.exposedHeaders = []string{} 242 | for _, v := range headers { 243 | normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 244 | if normalizedHeader == "" { 245 | continue 246 | } 247 | 248 | if !ch.isMatch(normalizedHeader, ch.exposedHeaders) { 249 | ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader) 250 | } 251 | } 252 | 253 | return nil 254 | } 255 | } 256 | 257 | // MaxAge determines the maximum age (in seconds) between preflight requests. A 258 | // maximum of 10 minutes is allowed. An age above this value will default to 10 259 | // minutes. 260 | func MaxAge(age int) CORSOption { 261 | return func(ch *cors) error { 262 | // Maximum of 10 minutes. 263 | if age > 600 { 264 | age = 600 265 | } 266 | 267 | ch.maxAge = age 268 | return nil 269 | } 270 | } 271 | 272 | // IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead 273 | // passing them through to the next handler. This is useful when your application 274 | // or framework has a pre-existing mechanism for responding to OPTIONS requests. 275 | func IgnoreOptions() CORSOption { 276 | return func(ch *cors) error { 277 | ch.ignoreOptions = true 278 | return nil 279 | } 280 | } 281 | 282 | // AllowCredentials can be used to specify that the user agent may pass 283 | // authentication details along with the request. 284 | func AllowCredentials() CORSOption { 285 | return func(ch *cors) error { 286 | ch.allowCredentials = true 287 | return nil 288 | } 289 | } 290 | 291 | func (ch *cors) isOriginAllowed(origin string) bool { 292 | if origin == "" { 293 | return false 294 | } 295 | 296 | if ch.allowedOriginValidator != nil { 297 | return ch.allowedOriginValidator(origin) 298 | } 299 | 300 | for _, allowedOrigin := range ch.allowedOrigins { 301 | if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll { 302 | return true 303 | } 304 | } 305 | 306 | return false 307 | } 308 | 309 | func (ch *cors) isMatch(needle string, haystack []string) bool { 310 | for _, v := range haystack { 311 | if v == needle { 312 | return true 313 | } 314 | } 315 | 316 | return false 317 | } 318 | -------------------------------------------------------------------------------- /cors_test.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestDefaultCORSHandlerReturnsOk(t *testing.T) { 11 | r := newRequest("GET", "http://www.example.com/") 12 | rr := httptest.NewRecorder() 13 | 14 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 15 | 16 | CORS()(testHandler).ServeHTTP(rr, r) 17 | 18 | if status := rr.Code; status != http.StatusOK { 19 | t.Fatalf("bad status: got %v want %v", status, http.StatusFound) 20 | } 21 | } 22 | 23 | func TestDefaultCORSHandlerReturnsOkWithOrigin(t *testing.T) { 24 | r := newRequest("GET", "http://www.example.com/") 25 | r.Header.Set("Origin", r.URL.String()) 26 | 27 | rr := httptest.NewRecorder() 28 | 29 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 30 | 31 | CORS()(testHandler).ServeHTTP(rr, r) 32 | 33 | if status := rr.Code; status != http.StatusOK { 34 | t.Fatalf("bad status: got %v want %v", status, http.StatusFound) 35 | } 36 | } 37 | 38 | func TestCORSHandlerIgnoreOptionsFallsThrough(t *testing.T) { 39 | r := newRequest("OPTIONS", "http://www.example.com/") 40 | r.Header.Set("Origin", r.URL.String()) 41 | 42 | rr := httptest.NewRecorder() 43 | 44 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 45 | w.WriteHeader(http.StatusTeapot) 46 | }) 47 | 48 | CORS(IgnoreOptions())(testHandler).ServeHTTP(rr, r) 49 | 50 | if status := rr.Code; status != http.StatusTeapot { 51 | t.Fatalf("bad status: got %v want %v", status, http.StatusTeapot) 52 | } 53 | } 54 | 55 | func TestCORSHandlerSetsExposedHeaders(t *testing.T) { 56 | // Test default configuration. 57 | r := newRequest("GET", "http://www.example.com/") 58 | r.Header.Set("Origin", r.URL.String()) 59 | 60 | rr := httptest.NewRecorder() 61 | 62 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 63 | 64 | CORS(ExposedHeaders([]string{"X-CORS-TEST"}))(testHandler).ServeHTTP(rr, r) 65 | 66 | if status := rr.Code; status != http.StatusOK { 67 | t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 68 | } 69 | 70 | header := rr.HeaderMap.Get(corsExposeHeadersHeader) 71 | if header != "X-Cors-Test" { 72 | t.Fatal("bad header: expected X-Cors-Test header, got empty header for method.") 73 | } 74 | } 75 | 76 | func TestCORSHandlerUnsetRequestMethodForPreflightBadRequest(t *testing.T) { 77 | r := newRequest("OPTIONS", "http://www.example.com/") 78 | r.Header.Set("Origin", r.URL.String()) 79 | 80 | rr := httptest.NewRecorder() 81 | 82 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 83 | 84 | CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r) 85 | 86 | if status := rr.Code; status != http.StatusBadRequest { 87 | t.Fatalf("bad status: got %v want %v", status, http.StatusBadRequest) 88 | } 89 | } 90 | 91 | func TestCORSHandlerInvalidRequestMethodForPreflightMethodNotAllowed(t *testing.T) { 92 | r := newRequest("OPTIONS", "http://www.example.com/") 93 | r.Header.Set("Origin", r.URL.String()) 94 | r.Header.Set(corsRequestMethodHeader, "DELETE") 95 | 96 | rr := httptest.NewRecorder() 97 | 98 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 99 | 100 | CORS()(testHandler).ServeHTTP(rr, r) 101 | 102 | if status := rr.Code; status != http.StatusMethodNotAllowed { 103 | t.Fatalf("bad status: got %v want %v", status, http.StatusMethodNotAllowed) 104 | } 105 | } 106 | 107 | func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) { 108 | r := newRequest("OPTIONS", "http://www.example.com/") 109 | r.Header.Set("Origin", r.URL.String()) 110 | r.Header.Set(corsRequestMethodHeader, "GET") 111 | 112 | rr := httptest.NewRecorder() 113 | 114 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 115 | t.Fatal("Options request must not be passed to next handler") 116 | }) 117 | 118 | CORS()(testHandler).ServeHTTP(rr, r) 119 | 120 | if status := rr.Code; status != http.StatusOK { 121 | t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 122 | } 123 | } 124 | 125 | func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) { 126 | r := newRequest("OPTIONS", "http://www.example.com/") 127 | r.Header.Set("Origin", r.URL.String()) 128 | r.Header.Set(corsRequestMethodHeader, "DELETE") 129 | 130 | rr := httptest.NewRecorder() 131 | 132 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 133 | 134 | CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r) 135 | 136 | if status := rr.Code; status != http.StatusOK { 137 | t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 138 | } 139 | 140 | header := rr.HeaderMap.Get(corsAllowMethodsHeader) 141 | if header != "DELETE" { 142 | t.Fatalf("bad header: expected DELETE method header, got empty header.") 143 | } 144 | } 145 | 146 | func TestCORSHandlerAllowMethodsNotSetForSimpleRequestPreflight(t *testing.T) { 147 | for _, method := range defaultCorsMethods { 148 | r := newRequest("OPTIONS", "http://www.example.com/") 149 | r.Header.Set("Origin", r.URL.String()) 150 | r.Header.Set(corsRequestMethodHeader, method) 151 | 152 | rr := httptest.NewRecorder() 153 | 154 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 155 | 156 | CORS()(testHandler).ServeHTTP(rr, r) 157 | 158 | if status := rr.Code; status != http.StatusOK { 159 | t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 160 | } 161 | 162 | header := rr.HeaderMap.Get(corsAllowMethodsHeader) 163 | if header != "" { 164 | t.Fatalf("bad header: expected empty method header, got %s.", header) 165 | } 166 | } 167 | } 168 | 169 | func TestCORSHandlerAllowedHeaderNotSetForSimpleRequestPreflight(t *testing.T) { 170 | for _, simpleHeader := range defaultCorsHeaders { 171 | r := newRequest("OPTIONS", "http://www.example.com/") 172 | r.Header.Set("Origin", r.URL.String()) 173 | r.Header.Set(corsRequestMethodHeader, "GET") 174 | r.Header.Set(corsRequestHeadersHeader, simpleHeader) 175 | 176 | rr := httptest.NewRecorder() 177 | 178 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 179 | 180 | CORS()(testHandler).ServeHTTP(rr, r) 181 | 182 | if status := rr.Code; status != http.StatusOK { 183 | t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 184 | } 185 | 186 | header := rr.HeaderMap.Get(corsAllowHeadersHeader) 187 | if header != "" { 188 | t.Fatalf("bad header: expected empty header, got %s.", header) 189 | } 190 | } 191 | } 192 | 193 | func TestCORSHandlerAllowedHeaderForPreflight(t *testing.T) { 194 | r := newRequest("OPTIONS", "http://www.example.com/") 195 | r.Header.Set("Origin", r.URL.String()) 196 | r.Header.Set(corsRequestMethodHeader, "POST") 197 | r.Header.Set(corsRequestHeadersHeader, "Content-Type") 198 | 199 | rr := httptest.NewRecorder() 200 | 201 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 202 | 203 | CORS(AllowedHeaders([]string{"Content-Type"}))(testHandler).ServeHTTP(rr, r) 204 | 205 | if status := rr.Code; status != http.StatusOK { 206 | t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 207 | } 208 | 209 | header := rr.HeaderMap.Get(corsAllowHeadersHeader) 210 | if header != "Content-Type" { 211 | t.Fatalf("bad header: expected Content-Type header, got empty header.") 212 | } 213 | } 214 | 215 | func TestCORSHandlerInvalidHeaderForPreflightForbidden(t *testing.T) { 216 | r := newRequest("OPTIONS", "http://www.example.com/") 217 | r.Header.Set("Origin", r.URL.String()) 218 | r.Header.Set(corsRequestMethodHeader, "POST") 219 | r.Header.Set(corsRequestHeadersHeader, "Content-Type") 220 | 221 | rr := httptest.NewRecorder() 222 | 223 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 224 | 225 | CORS()(testHandler).ServeHTTP(rr, r) 226 | 227 | if status := rr.Code; status != http.StatusForbidden { 228 | t.Fatalf("bad status: got %v want %v", status, http.StatusForbidden) 229 | } 230 | } 231 | 232 | func TestCORSHandlerMaxAgeForPreflight(t *testing.T) { 233 | r := newRequest("OPTIONS", "http://www.example.com/") 234 | r.Header.Set("Origin", r.URL.String()) 235 | r.Header.Set(corsRequestMethodHeader, "POST") 236 | 237 | rr := httptest.NewRecorder() 238 | 239 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 240 | 241 | CORS(MaxAge(3500))(testHandler).ServeHTTP(rr, r) 242 | 243 | if status := rr.Code; status != http.StatusOK { 244 | t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 245 | } 246 | 247 | header := rr.HeaderMap.Get(corsMaxAgeHeader) 248 | if header != "600" { 249 | t.Fatalf("bad header: expected %s to be %s, got %s.", corsMaxAgeHeader, "600", header) 250 | } 251 | } 252 | 253 | func TestCORSHandlerAllowedCredentials(t *testing.T) { 254 | r := newRequest("GET", "http://www.example.com/") 255 | r.Header.Set("Origin", r.URL.String()) 256 | 257 | rr := httptest.NewRecorder() 258 | 259 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 260 | 261 | CORS(AllowCredentials())(testHandler).ServeHTTP(rr, r) 262 | 263 | if status := rr.Code; status != http.StatusOK { 264 | t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 265 | } 266 | 267 | header := rr.HeaderMap.Get(corsAllowCredentialsHeader) 268 | if header != "true" { 269 | t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowCredentialsHeader, "true", header) 270 | } 271 | } 272 | 273 | func TestCORSHandlerMultipleAllowOriginsSetsVaryHeader(t *testing.T) { 274 | r := newRequest("GET", "http://www.example.com/") 275 | r.Header.Set("Origin", r.URL.String()) 276 | 277 | rr := httptest.NewRecorder() 278 | 279 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 280 | 281 | CORS(AllowedOrigins([]string{r.URL.String(), "http://google.com"}))(testHandler).ServeHTTP(rr, r) 282 | 283 | if status := rr.Code; status != http.StatusOK { 284 | t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 285 | } 286 | 287 | header := rr.HeaderMap.Get(corsVaryHeader) 288 | if header != corsOriginHeader { 289 | t.Fatalf("bad header: expected %s to be %s, got %s.", corsVaryHeader, corsOriginHeader, header) 290 | } 291 | } 292 | 293 | func TestCORSWithMultipleHandlers(t *testing.T) { 294 | var lastHandledBy string 295 | corsMiddleware := CORS() 296 | 297 | testHandler1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 298 | lastHandledBy = "testHandler1" 299 | }) 300 | testHandler2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 301 | lastHandledBy = "testHandler2" 302 | }) 303 | 304 | r1 := newRequest("GET", "http://www.example.com/") 305 | rr1 := httptest.NewRecorder() 306 | handler1 := corsMiddleware(testHandler1) 307 | 308 | corsMiddleware(testHandler2) 309 | 310 | handler1.ServeHTTP(rr1, r1) 311 | if lastHandledBy != "testHandler1" { 312 | t.Fatalf("bad CORS() registration: Handler served should be Handler registered") 313 | } 314 | } 315 | 316 | func TestCORSHandlerWithCustomValidator(t *testing.T) { 317 | r := newRequest("GET", "http://a.example.com") 318 | r.Header.Set("Origin", r.URL.String()) 319 | rr := httptest.NewRecorder() 320 | 321 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 322 | 323 | originValidator := func(origin string) bool { 324 | if strings.HasSuffix(origin, ".example.com") { 325 | return true 326 | } 327 | return false 328 | } 329 | 330 | CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r) 331 | header := rr.HeaderMap.Get(corsAllowOriginHeader) 332 | if header != r.URL.String() { 333 | t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, r.URL.String(), header) 334 | } 335 | 336 | } 337 | -------------------------------------------------------------------------------- /handlers_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package handlers 6 | 7 | import ( 8 | "bytes" 9 | "net" 10 | "net/http" 11 | "net/http/httptest" 12 | "net/url" 13 | "strings" 14 | "testing" 15 | "time" 16 | ) 17 | 18 | const ( 19 | ok = "ok\n" 20 | notAllowed = "Method not allowed\n" 21 | ) 22 | 23 | var okHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 24 | w.Write([]byte(ok)) 25 | }) 26 | 27 | func newRequest(method, url string) *http.Request { 28 | req, err := http.NewRequest(method, url, nil) 29 | if err != nil { 30 | panic(err) 31 | } 32 | return req 33 | } 34 | 35 | func TestMethodHandler(t *testing.T) { 36 | tests := []struct { 37 | req *http.Request 38 | handler http.Handler 39 | code int 40 | allow string // Contents of the Allow header 41 | body string 42 | }{ 43 | // No handlers 44 | {newRequest("GET", "/foo"), MethodHandler{}, http.StatusMethodNotAllowed, "", notAllowed}, 45 | {newRequest("OPTIONS", "/foo"), MethodHandler{}, http.StatusOK, "", ""}, 46 | 47 | // A single handler 48 | {newRequest("GET", "/foo"), MethodHandler{"GET": okHandler}, http.StatusOK, "", ok}, 49 | {newRequest("POST", "/foo"), MethodHandler{"GET": okHandler}, http.StatusMethodNotAllowed, "GET", notAllowed}, 50 | 51 | // Multiple handlers 52 | {newRequest("GET", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusOK, "", ok}, 53 | {newRequest("POST", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusOK, "", ok}, 54 | {newRequest("DELETE", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusMethodNotAllowed, "GET, POST", notAllowed}, 55 | {newRequest("OPTIONS", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusOK, "GET, POST", ""}, 56 | 57 | // Override OPTIONS 58 | {newRequest("OPTIONS", "/foo"), MethodHandler{"OPTIONS": okHandler}, http.StatusOK, "", ok}, 59 | } 60 | 61 | for i, test := range tests { 62 | rec := httptest.NewRecorder() 63 | test.handler.ServeHTTP(rec, test.req) 64 | if rec.Code != test.code { 65 | t.Fatalf("%d: wrong code, got %d want %d", i, rec.Code, test.code) 66 | } 67 | if allow := rec.HeaderMap.Get("Allow"); allow != test.allow { 68 | t.Fatalf("%d: wrong Allow, got %s want %s", i, allow, test.allow) 69 | } 70 | if body := rec.Body.String(); body != test.body { 71 | t.Fatalf("%d: wrong body, got %q want %q", i, body, test.body) 72 | } 73 | } 74 | } 75 | 76 | func TestWriteLog(t *testing.T) { 77 | loc, err := time.LoadLocation("Europe/Warsaw") 78 | if err != nil { 79 | panic(err) 80 | } 81 | ts := time.Date(1983, 05, 26, 3, 30, 45, 0, loc) 82 | 83 | // A typical request with an OK response 84 | req := newRequest("GET", "http://example.com") 85 | req.RemoteAddr = "192.168.100.5" 86 | 87 | buf := new(bytes.Buffer) 88 | writeLog(buf, req, *req.URL, ts, http.StatusOK, 100) 89 | log := buf.String() 90 | 91 | expected := "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 200 100\n" 92 | if log != expected { 93 | t.Fatalf("wrong log, got %q want %q", log, expected) 94 | } 95 | 96 | // CONNECT request over http/2.0 97 | req = &http.Request{ 98 | Method: "CONNECT", 99 | Proto: "HTTP/2.0", 100 | ProtoMajor: 2, 101 | ProtoMinor: 0, 102 | URL: &url.URL{Host: "www.example.com:443"}, 103 | Host: "www.example.com:443", 104 | RemoteAddr: "192.168.100.5", 105 | } 106 | 107 | buf = new(bytes.Buffer) 108 | writeLog(buf, req, *req.URL, ts, http.StatusOK, 100) 109 | log = buf.String() 110 | 111 | expected = "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"CONNECT www.example.com:443 HTTP/2.0\" 200 100\n" 112 | if log != expected { 113 | t.Fatalf("wrong log, got %q want %q", log, expected) 114 | } 115 | 116 | // Request with an unauthorized user 117 | req = newRequest("GET", "http://example.com") 118 | req.RemoteAddr = "192.168.100.5" 119 | req.URL.User = url.User("kamil") 120 | 121 | buf.Reset() 122 | writeLog(buf, req, *req.URL, ts, http.StatusUnauthorized, 500) 123 | log = buf.String() 124 | 125 | expected = "192.168.100.5 - kamil [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 401 500\n" 126 | if log != expected { 127 | t.Fatalf("wrong log, got %q want %q", log, expected) 128 | } 129 | 130 | // Request with url encoded parameters 131 | req = newRequest("GET", "http://example.com/test?abc=hello%20world&a=b%3F") 132 | req.RemoteAddr = "192.168.100.5" 133 | 134 | buf.Reset() 135 | writeLog(buf, req, *req.URL, ts, http.StatusOK, 100) 136 | log = buf.String() 137 | 138 | expected = "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"GET /test?abc=hello%20world&a=b%3F HTTP/1.1\" 200 100\n" 139 | if log != expected { 140 | t.Fatalf("wrong log, got %q want %q", log, expected) 141 | } 142 | } 143 | 144 | func TestWriteCombinedLog(t *testing.T) { 145 | loc, err := time.LoadLocation("Europe/Warsaw") 146 | if err != nil { 147 | panic(err) 148 | } 149 | ts := time.Date(1983, 05, 26, 3, 30, 45, 0, loc) 150 | 151 | // A typical request with an OK response 152 | req := newRequest("GET", "http://example.com") 153 | req.RemoteAddr = "192.168.100.5" 154 | req.Header.Set("Referer", "http://example.com") 155 | req.Header.Set( 156 | "User-Agent", 157 | "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.33 "+ 158 | "(KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33", 159 | ) 160 | 161 | buf := new(bytes.Buffer) 162 | writeCombinedLog(buf, req, *req.URL, ts, http.StatusOK, 100) 163 | log := buf.String() 164 | 165 | expected := "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 200 100 \"http://example.com\" " + 166 | "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + 167 | "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" 168 | if log != expected { 169 | t.Fatalf("wrong log, got %q want %q", log, expected) 170 | } 171 | 172 | // CONNECT request over http/2.0 173 | req1 := &http.Request{ 174 | Method: "CONNECT", 175 | Host: "www.example.com:443", 176 | Proto: "HTTP/2.0", 177 | ProtoMajor: 2, 178 | ProtoMinor: 0, 179 | RemoteAddr: "192.168.100.5", 180 | Header: http.Header{}, 181 | URL: &url.URL{Host: "www.example.com:443"}, 182 | } 183 | req1.Header.Set("Referer", "http://example.com") 184 | req1.Header.Set( 185 | "User-Agent", 186 | "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.33 "+ 187 | "(KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33", 188 | ) 189 | 190 | buf = new(bytes.Buffer) 191 | writeCombinedLog(buf, req1, *req1.URL, ts, http.StatusOK, 100) 192 | log = buf.String() 193 | 194 | expected = "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"CONNECT www.example.com:443 HTTP/2.0\" 200 100 \"http://example.com\" " + 195 | "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + 196 | "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" 197 | if log != expected { 198 | t.Fatalf("wrong log, got %q want %q", log, expected) 199 | } 200 | 201 | // Request with an unauthorized user 202 | req.URL.User = url.User("kamil") 203 | 204 | buf.Reset() 205 | writeCombinedLog(buf, req, *req.URL, ts, http.StatusUnauthorized, 500) 206 | log = buf.String() 207 | 208 | expected = "192.168.100.5 - kamil [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 401 500 \"http://example.com\" " + 209 | "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + 210 | "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" 211 | if log != expected { 212 | t.Fatalf("wrong log, got %q want %q", log, expected) 213 | } 214 | 215 | // Test with remote ipv6 address 216 | req.RemoteAddr = "::1" 217 | 218 | buf.Reset() 219 | writeCombinedLog(buf, req, *req.URL, ts, http.StatusOK, 100) 220 | log = buf.String() 221 | 222 | expected = "::1 - kamil [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 200 100 \"http://example.com\" " + 223 | "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + 224 | "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" 225 | if log != expected { 226 | t.Fatalf("wrong log, got %q want %q", log, expected) 227 | } 228 | 229 | // Test remote ipv6 addr, with port 230 | req.RemoteAddr = net.JoinHostPort("::1", "65000") 231 | 232 | buf.Reset() 233 | writeCombinedLog(buf, req, *req.URL, ts, http.StatusOK, 100) 234 | log = buf.String() 235 | 236 | expected = "::1 - kamil [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 200 100 \"http://example.com\" " + 237 | "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + 238 | "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" 239 | if log != expected { 240 | t.Fatalf("wrong log, got %q want %q", log, expected) 241 | } 242 | } 243 | 244 | func TestLogPathRewrites(t *testing.T) { 245 | var buf bytes.Buffer 246 | 247 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 248 | req.URL.Path = "/" // simulate http.StripPrefix and friends 249 | w.WriteHeader(200) 250 | }) 251 | logger := LoggingHandler(&buf, handler) 252 | 253 | logger.ServeHTTP(httptest.NewRecorder(), newRequest("GET", "/subdir/asdf")) 254 | 255 | if !strings.Contains(buf.String(), "GET /subdir/asdf HTTP") { 256 | t.Fatalf("Got log %#v, wanted substring %#v", buf.String(), "GET /subdir/asdf HTTP") 257 | } 258 | } 259 | 260 | func BenchmarkWriteLog(b *testing.B) { 261 | loc, err := time.LoadLocation("Europe/Warsaw") 262 | if err != nil { 263 | b.Fatalf(err.Error()) 264 | } 265 | ts := time.Date(1983, 05, 26, 3, 30, 45, 0, loc) 266 | 267 | req := newRequest("GET", "http://example.com") 268 | req.RemoteAddr = "192.168.100.5" 269 | 270 | b.ResetTimer() 271 | 272 | buf := &bytes.Buffer{} 273 | for i := 0; i < b.N; i++ { 274 | buf.Reset() 275 | writeLog(buf, req, *req.URL, ts, http.StatusUnauthorized, 500) 276 | } 277 | } 278 | 279 | func TestContentTypeHandler(t *testing.T) { 280 | tests := []struct { 281 | Method string 282 | AllowContentTypes []string 283 | ContentType string 284 | Code int 285 | }{ 286 | {"POST", []string{"application/json"}, "application/json", http.StatusOK}, 287 | {"POST", []string{"application/json", "application/xml"}, "application/json", http.StatusOK}, 288 | {"POST", []string{"application/json"}, "application/json; charset=utf-8", http.StatusOK}, 289 | {"POST", []string{"application/json"}, "application/json+xxx", http.StatusUnsupportedMediaType}, 290 | {"POST", []string{"application/json"}, "text/plain", http.StatusUnsupportedMediaType}, 291 | {"GET", []string{"application/json"}, "", http.StatusOK}, 292 | {"GET", []string{}, "", http.StatusOK}, 293 | } 294 | for _, test := range tests { 295 | r, err := http.NewRequest(test.Method, "/", nil) 296 | if err != nil { 297 | t.Error(err) 298 | continue 299 | } 300 | 301 | h := ContentTypeHandler(okHandler, test.AllowContentTypes...) 302 | r.Header.Set("Content-Type", test.ContentType) 303 | w := httptest.NewRecorder() 304 | h.ServeHTTP(w, r) 305 | if w.Code != test.Code { 306 | t.Errorf("expected %d, got %d", test.Code, w.Code) 307 | } 308 | } 309 | } 310 | 311 | func TestHTTPMethodOverride(t *testing.T) { 312 | var tests = []struct { 313 | Method string 314 | OverrideMethod string 315 | ExpectedMethod string 316 | }{ 317 | {"POST", "PUT", "PUT"}, 318 | {"POST", "PATCH", "PATCH"}, 319 | {"POST", "DELETE", "DELETE"}, 320 | {"PUT", "DELETE", "PUT"}, 321 | {"GET", "GET", "GET"}, 322 | {"HEAD", "HEAD", "HEAD"}, 323 | {"GET", "PUT", "GET"}, 324 | {"HEAD", "DELETE", "HEAD"}, 325 | } 326 | 327 | for _, test := range tests { 328 | h := HTTPMethodOverrideHandler(okHandler) 329 | reqs := make([]*http.Request, 0, 2) 330 | 331 | rHeader, err := http.NewRequest(test.Method, "/", nil) 332 | if err != nil { 333 | t.Error(err) 334 | } 335 | rHeader.Header.Set(HTTPMethodOverrideHeader, test.OverrideMethod) 336 | reqs = append(reqs, rHeader) 337 | 338 | f := url.Values{HTTPMethodOverrideFormKey: []string{test.OverrideMethod}} 339 | rForm, err := http.NewRequest(test.Method, "/", strings.NewReader(f.Encode())) 340 | if err != nil { 341 | t.Error(err) 342 | } 343 | rForm.Header.Set("Content-Type", "application/x-www-form-urlencoded") 344 | reqs = append(reqs, rForm) 345 | 346 | for _, r := range reqs { 347 | w := httptest.NewRecorder() 348 | h.ServeHTTP(w, r) 349 | if r.Method != test.ExpectedMethod { 350 | t.Errorf("Expected %s, got %s", test.ExpectedMethod, r.Method) 351 | } 352 | } 353 | } 354 | } 355 | -------------------------------------------------------------------------------- /handlers.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package handlers 6 | 7 | import ( 8 | "bufio" 9 | "fmt" 10 | "io" 11 | "net" 12 | "net/http" 13 | "net/url" 14 | "sort" 15 | "strconv" 16 | "strings" 17 | "time" 18 | "unicode/utf8" 19 | ) 20 | 21 | // MethodHandler is an http.Handler that dispatches to a handler whose key in the 22 | // MethodHandler's map matches the name of the HTTP request's method, eg: GET 23 | // 24 | // If the request's method is OPTIONS and OPTIONS is not a key in the map then 25 | // the handler responds with a status of 200 and sets the Allow header to a 26 | // comma-separated list of available methods. 27 | // 28 | // If the request's method doesn't match any of its keys the handler responds 29 | // with a status of HTTP 405 "Method Not Allowed" and sets the Allow header to a 30 | // comma-separated list of available methods. 31 | type MethodHandler map[string]http.Handler 32 | 33 | func (h MethodHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { 34 | if handler, ok := h[req.Method]; ok { 35 | handler.ServeHTTP(w, req) 36 | } else { 37 | allow := []string{} 38 | for k := range h { 39 | allow = append(allow, k) 40 | } 41 | sort.Strings(allow) 42 | w.Header().Set("Allow", strings.Join(allow, ", ")) 43 | if req.Method == "OPTIONS" { 44 | w.WriteHeader(http.StatusOK) 45 | } else { 46 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 47 | } 48 | } 49 | } 50 | 51 | // loggingHandler is the http.Handler implementation for LoggingHandlerTo and its 52 | // friends 53 | type loggingHandler struct { 54 | writer io.Writer 55 | handler http.Handler 56 | } 57 | 58 | // combinedLoggingHandler is the http.Handler implementation for LoggingHandlerTo 59 | // and its friends 60 | type combinedLoggingHandler struct { 61 | writer io.Writer 62 | handler http.Handler 63 | } 64 | 65 | func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { 66 | t := time.Now() 67 | logger := makeLogger(w) 68 | url := *req.URL 69 | h.handler.ServeHTTP(logger, req) 70 | writeLog(h.writer, req, url, t, logger.Status(), logger.Size()) 71 | } 72 | 73 | func (h combinedLoggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { 74 | t := time.Now() 75 | logger := makeLogger(w) 76 | url := *req.URL 77 | h.handler.ServeHTTP(logger, req) 78 | writeCombinedLog(h.writer, req, url, t, logger.Status(), logger.Size()) 79 | } 80 | 81 | func makeLogger(w http.ResponseWriter) loggingResponseWriter { 82 | var logger loggingResponseWriter = &responseLogger{w: w} 83 | if _, ok := w.(http.Hijacker); ok { 84 | logger = &hijackLogger{responseLogger{w: w}} 85 | } 86 | h, ok1 := logger.(http.Hijacker) 87 | c, ok2 := w.(http.CloseNotifier) 88 | if ok1 && ok2 { 89 | return hijackCloseNotifier{logger, h, c} 90 | } 91 | if ok2 { 92 | return &closeNotifyWriter{logger, c} 93 | } 94 | return logger 95 | } 96 | 97 | type commonLoggingResponseWriter interface { 98 | http.ResponseWriter 99 | http.Flusher 100 | Status() int 101 | Size() int 102 | } 103 | 104 | // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP 105 | // status code and body size 106 | type responseLogger struct { 107 | w http.ResponseWriter 108 | status int 109 | size int 110 | } 111 | 112 | func (l *responseLogger) Header() http.Header { 113 | return l.w.Header() 114 | } 115 | 116 | func (l *responseLogger) Write(b []byte) (int, error) { 117 | if l.status == 0 { 118 | // The status will be StatusOK if WriteHeader has not been called yet 119 | l.status = http.StatusOK 120 | } 121 | size, err := l.w.Write(b) 122 | l.size += size 123 | return size, err 124 | } 125 | 126 | func (l *responseLogger) WriteHeader(s int) { 127 | l.w.WriteHeader(s) 128 | l.status = s 129 | } 130 | 131 | func (l *responseLogger) Status() int { 132 | return l.status 133 | } 134 | 135 | func (l *responseLogger) Size() int { 136 | return l.size 137 | } 138 | 139 | func (l *responseLogger) Flush() { 140 | f, ok := l.w.(http.Flusher) 141 | if ok { 142 | f.Flush() 143 | } 144 | } 145 | 146 | type hijackLogger struct { 147 | responseLogger 148 | } 149 | 150 | func (l *hijackLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) { 151 | h := l.responseLogger.w.(http.Hijacker) 152 | conn, rw, err := h.Hijack() 153 | if err == nil && l.responseLogger.status == 0 { 154 | // The status will be StatusSwitchingProtocols if there was no error and 155 | // WriteHeader has not been called yet 156 | l.responseLogger.status = http.StatusSwitchingProtocols 157 | } 158 | return conn, rw, err 159 | } 160 | 161 | type closeNotifyWriter struct { 162 | loggingResponseWriter 163 | http.CloseNotifier 164 | } 165 | 166 | type hijackCloseNotifier struct { 167 | loggingResponseWriter 168 | http.Hijacker 169 | http.CloseNotifier 170 | } 171 | 172 | const lowerhex = "0123456789abcdef" 173 | 174 | func appendQuoted(buf []byte, s string) []byte { 175 | var runeTmp [utf8.UTFMax]byte 176 | for width := 0; len(s) > 0; s = s[width:] { 177 | r := rune(s[0]) 178 | width = 1 179 | if r >= utf8.RuneSelf { 180 | r, width = utf8.DecodeRuneInString(s) 181 | } 182 | if width == 1 && r == utf8.RuneError { 183 | buf = append(buf, `\x`...) 184 | buf = append(buf, lowerhex[s[0]>>4]) 185 | buf = append(buf, lowerhex[s[0]&0xF]) 186 | continue 187 | } 188 | if r == rune('"') || r == '\\' { // always backslashed 189 | buf = append(buf, '\\') 190 | buf = append(buf, byte(r)) 191 | continue 192 | } 193 | if strconv.IsPrint(r) { 194 | n := utf8.EncodeRune(runeTmp[:], r) 195 | buf = append(buf, runeTmp[:n]...) 196 | continue 197 | } 198 | switch r { 199 | case '\a': 200 | buf = append(buf, `\a`...) 201 | case '\b': 202 | buf = append(buf, `\b`...) 203 | case '\f': 204 | buf = append(buf, `\f`...) 205 | case '\n': 206 | buf = append(buf, `\n`...) 207 | case '\r': 208 | buf = append(buf, `\r`...) 209 | case '\t': 210 | buf = append(buf, `\t`...) 211 | case '\v': 212 | buf = append(buf, `\v`...) 213 | default: 214 | switch { 215 | case r < ' ': 216 | buf = append(buf, `\x`...) 217 | buf = append(buf, lowerhex[s[0]>>4]) 218 | buf = append(buf, lowerhex[s[0]&0xF]) 219 | case r > utf8.MaxRune: 220 | r = 0xFFFD 221 | fallthrough 222 | case r < 0x10000: 223 | buf = append(buf, `\u`...) 224 | for s := 12; s >= 0; s -= 4 { 225 | buf = append(buf, lowerhex[r>>uint(s)&0xF]) 226 | } 227 | default: 228 | buf = append(buf, `\U`...) 229 | for s := 28; s >= 0; s -= 4 { 230 | buf = append(buf, lowerhex[r>>uint(s)&0xF]) 231 | } 232 | } 233 | } 234 | } 235 | return buf 236 | 237 | } 238 | 239 | // buildCommonLogLine builds a log entry for req in Apache Common Log Format. 240 | // ts is the timestamp with which the entry should be logged. 241 | // status and size are used to provide the response HTTP status and size. 242 | func buildCommonLogLine(req *http.Request, url url.URL, ts time.Time, status int, size int) []byte { 243 | username := "-" 244 | if url.User != nil { 245 | if name := url.User.Username(); name != "" { 246 | username = name 247 | } 248 | } 249 | 250 | host, _, err := net.SplitHostPort(req.RemoteAddr) 251 | 252 | if err != nil { 253 | host = req.RemoteAddr 254 | } 255 | 256 | uri := req.RequestURI 257 | 258 | // Requests using the CONNECT method over HTTP/2.0 must use 259 | // the authority field (aka r.Host) to identify the target. 260 | // Refer: https://httpwg.github.io/specs/rfc7540.html#CONNECT 261 | if req.ProtoMajor == 2 && req.Method == "CONNECT" { 262 | uri = req.Host 263 | } 264 | if uri == "" { 265 | uri = url.RequestURI() 266 | } 267 | 268 | buf := make([]byte, 0, 3*(len(host)+len(username)+len(req.Method)+len(uri)+len(req.Proto)+50)/2) 269 | buf = append(buf, host...) 270 | buf = append(buf, " - "...) 271 | buf = append(buf, username...) 272 | buf = append(buf, " ["...) 273 | buf = append(buf, ts.Format("02/Jan/2006:15:04:05 -0700")...) 274 | buf = append(buf, `] "`...) 275 | buf = append(buf, req.Method...) 276 | buf = append(buf, " "...) 277 | buf = appendQuoted(buf, uri) 278 | buf = append(buf, " "...) 279 | buf = append(buf, req.Proto...) 280 | buf = append(buf, `" `...) 281 | buf = append(buf, strconv.Itoa(status)...) 282 | buf = append(buf, " "...) 283 | buf = append(buf, strconv.Itoa(size)...) 284 | return buf 285 | } 286 | 287 | // writeLog writes a log entry for req to w in Apache Common Log Format. 288 | // ts is the timestamp with which the entry should be logged. 289 | // status and size are used to provide the response HTTP status and size. 290 | func writeLog(w io.Writer, req *http.Request, url url.URL, ts time.Time, status, size int) { 291 | buf := buildCommonLogLine(req, url, ts, status, size) 292 | buf = append(buf, '\n') 293 | w.Write(buf) 294 | } 295 | 296 | // writeCombinedLog writes a log entry for req to w in Apache Combined Log Format. 297 | // ts is the timestamp with which the entry should be logged. 298 | // status and size are used to provide the response HTTP status and size. 299 | func writeCombinedLog(w io.Writer, req *http.Request, url url.URL, ts time.Time, status, size int) { 300 | buf := buildCommonLogLine(req, url, ts, status, size) 301 | buf = append(buf, ` "`...) 302 | buf = appendQuoted(buf, req.Referer()) 303 | buf = append(buf, `" "`...) 304 | buf = appendQuoted(buf, req.UserAgent()) 305 | buf = append(buf, '"', '\n') 306 | w.Write(buf) 307 | } 308 | 309 | // CombinedLoggingHandler return a http.Handler that wraps h and logs requests to out in 310 | // Apache Combined Log Format. 311 | // 312 | // See http://httpd.apache.org/docs/2.2/logs.html#combined for a description of this format. 313 | // 314 | // LoggingHandler always sets the ident field of the log to - 315 | func CombinedLoggingHandler(out io.Writer, h http.Handler) http.Handler { 316 | return combinedLoggingHandler{out, h} 317 | } 318 | 319 | // LoggingHandler return a http.Handler that wraps h and logs requests to out in 320 | // Apache Common Log Format (CLF). 321 | // 322 | // See http://httpd.apache.org/docs/2.2/logs.html#common for a description of this format. 323 | // 324 | // LoggingHandler always sets the ident field of the log to - 325 | // 326 | // Example: 327 | // 328 | // r := mux.NewRouter() 329 | // r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 330 | // w.Write([]byte("This is a catch-all route")) 331 | // }) 332 | // loggedRouter := handlers.LoggingHandler(os.Stdout, r) 333 | // http.ListenAndServe(":1123", loggedRouter) 334 | // 335 | func LoggingHandler(out io.Writer, h http.Handler) http.Handler { 336 | return loggingHandler{out, h} 337 | } 338 | 339 | // isContentType validates the Content-Type header matches the supplied 340 | // contentType. That is, its type and subtype match. 341 | func isContentType(h http.Header, contentType string) bool { 342 | ct := h.Get("Content-Type") 343 | if i := strings.IndexRune(ct, ';'); i != -1 { 344 | ct = ct[0:i] 345 | } 346 | return ct == contentType 347 | } 348 | 349 | // ContentTypeHandler wraps and returns a http.Handler, validating the request 350 | // content type is compatible with the contentTypes list. It writes a HTTP 415 351 | // error if that fails. 352 | // 353 | // Only PUT, POST, and PATCH requests are considered. 354 | func ContentTypeHandler(h http.Handler, contentTypes ...string) http.Handler { 355 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 356 | if !(r.Method == "PUT" || r.Method == "POST" || r.Method == "PATCH") { 357 | h.ServeHTTP(w, r) 358 | return 359 | } 360 | 361 | for _, ct := range contentTypes { 362 | if isContentType(r.Header, ct) { 363 | h.ServeHTTP(w, r) 364 | return 365 | } 366 | } 367 | http.Error(w, fmt.Sprintf("Unsupported content type %q; expected one of %q", r.Header.Get("Content-Type"), contentTypes), http.StatusUnsupportedMediaType) 368 | }) 369 | } 370 | 371 | const ( 372 | // HTTPMethodOverrideHeader is a commonly used 373 | // http header to override a request method. 374 | HTTPMethodOverrideHeader = "X-HTTP-Method-Override" 375 | // HTTPMethodOverrideFormKey is a commonly used 376 | // HTML form key to override a request method. 377 | HTTPMethodOverrideFormKey = "_method" 378 | ) 379 | 380 | // HTTPMethodOverrideHandler wraps and returns a http.Handler which checks for 381 | // the X-HTTP-Method-Override header or the _method form key, and overrides (if 382 | // valid) request.Method with its value. 383 | // 384 | // This is especially useful for HTTP clients that don't support many http verbs. 385 | // It isn't secure to override e.g a GET to a POST, so only POST requests are 386 | // considered. Likewise, the override method can only be a "write" method: PUT, 387 | // PATCH or DELETE. 388 | // 389 | // Form method takes precedence over header method. 390 | func HTTPMethodOverrideHandler(h http.Handler) http.Handler { 391 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 392 | if r.Method == "POST" { 393 | om := r.FormValue(HTTPMethodOverrideFormKey) 394 | if om == "" { 395 | om = r.Header.Get(HTTPMethodOverrideHeader) 396 | } 397 | if om == "PUT" || om == "PATCH" || om == "DELETE" { 398 | r.Method = om 399 | } 400 | } 401 | h.ServeHTTP(w, r) 402 | }) 403 | } 404 | --------------------------------------------------------------------------------