├── .github └── workflows │ └── ci.yml ├── LICENSE ├── README.md ├── _example └── main.go ├── cors.go ├── cors_test.go ├── go.mod ├── utils.go └── utils_test.go /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | name: Test 3 | jobs: 4 | test: 5 | env: 6 | GOPATH: ${{ github.workspace }} 7 | 8 | defaults: 9 | run: 10 | working-directory: ${{ env.GOPATH }}/src/github.com/${{ github.repository }} 11 | 12 | strategy: 13 | matrix: 14 | go-version: [1.16.x, 1.19.x, 1.22.x] 15 | os: [ubuntu-latest, macos-latest, windows-latest] 16 | 17 | runs-on: ${{ matrix.os }} 18 | 19 | steps: 20 | - name: Install Go 21 | uses: actions/setup-go@v5 22 | with: 23 | go-version: ${{ matrix.go-version }} 24 | - name: Checkout code 25 | uses: actions/checkout@v4 26 | with: 27 | path: ${{ env.GOPATH }}/src/github.com/${{ github.repository }} 28 | - name: Test 29 | run: | 30 | go get -d -t ./... 31 | go test -v ./... 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014 Olivier Poitrey 2 | Copyright (c) 2016-Present https://github.com/go-chi authors 3 | 4 | MIT License 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy of 7 | this software and associated documentation files (the "Software"), to deal in 8 | the Software without restriction, including without limitation the rights to 9 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 10 | the Software, and to permit persons to whom the Software is furnished to do so, 11 | subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 18 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 19 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 20 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 21 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CORS net/http middleware 2 | 3 | [go-chi/cors](https://github.com/go-chi/cors) is a fork of [github.com/rs/cors](https://github.com/rs/cors) that 4 | provides a `net/http` compatible middleware for performing preflight CORS checks on the server side. These headers 5 | are required for using the browser native [Fetch API](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). 6 | 7 | This middleware is designed to be used as a top-level middleware on the [chi](https://github.com/go-chi/chi) router. 8 | Applying with within a `r.Group()` or using `With()` will not work without routes matching `OPTIONS` added. 9 | 10 | ## Install 11 | 12 | `go get github.com/go-chi/cors` 13 | 14 | ## Usage 15 | 16 | ```go 17 | func main() { 18 | r := chi.NewRouter() 19 | 20 | // Basic CORS 21 | // for more ideas, see: https://developer.github.com/v3/#cross-origin-resource-sharing 22 | r.Use(cors.Handler(cors.Options{ 23 | // AllowedOrigins: []string{"https://foo.com"}, // Use this to allow specific origin hosts 24 | AllowedOrigins: []string{"https://*", "http://*"}, 25 | // AllowOriginFunc: func(r *http.Request, origin string) bool { return true }, 26 | AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, 27 | AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, 28 | ExposedHeaders: []string{"Link"}, 29 | AllowCredentials: false, 30 | MaxAge: 300, // Maximum value not ignored by any of major browsers 31 | })) 32 | 33 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 34 | w.Write([]byte("welcome")) 35 | }) 36 | 37 | http.ListenAndServe(":3000", r) 38 | } 39 | ``` 40 | 41 | ## Credits 42 | 43 | All credit for the original work of this middleware goes out to [github.com/rs](https://github.com/rs). 44 | -------------------------------------------------------------------------------- /_example/main.go: -------------------------------------------------------------------------------- 1 | // cors example 2 | // 3 | // ie. 4 | // 5 | // Unsuccessful Preflight request: 6 | // =============================== 7 | // $ curl -i http://localhost:3000/ -H "Origin: http://no.com" -H "Access-Control-Request-Method: GET" -X OPTIONS 8 | // HTTP/1.1 200 OK 9 | // Vary: Origin 10 | // Vary: Access-Control-Request-Method 11 | // Vary: Access-Control-Request-Headers 12 | // Date: Fri, 28 Jul 2017 17:55:47 GMT 13 | // Content-Length: 0 14 | // Content-Type: text/plain; charset=utf-8 15 | // 16 | // 17 | // Successful Preflight request: 18 | // ============================= 19 | // $ curl -i http://localhost:3000/ -H "Origin: http://example.com" -H "Access-Control-Request-Method: GET" -X OPTIONS 20 | // HTTP/1.1 200 OK 21 | // Access-Control-Allow-Credentials: true 22 | // Access-Control-Allow-Methods: GET 23 | // Access-Control-Allow-Origin: http://example.com 24 | // Access-Control-Max-Age: 300 25 | // Vary: Origin 26 | // Vary: Access-Control-Request-Method 27 | // Vary: Access-Control-Request-Headers 28 | // Date: Fri, 28 Jul 2017 17:56:44 GMT 29 | // Content-Length: 0 30 | // Content-Type: text/plain; charset=utf-8 31 | // 32 | // 33 | // Content request (after a successful preflight): 34 | // =============================================== 35 | // $ curl -i http://localhost:3000/ -H "Origin: http://example.com" 36 | // HTTP/1.1 200 OK 37 | // Access-Control-Allow-Credentials: true 38 | // Access-Control-Allow-Origin: http://example.com 39 | // Access-Control-Expose-Headers: Link 40 | // Vary: Origin 41 | // Date: Fri, 28 Jul 2017 17:57:52 GMT 42 | // Content-Length: 7 43 | // Content-Type: text/plain; charset=utf-8 44 | // 45 | // welcome% 46 | // 47 | package main 48 | 49 | import ( 50 | "net/http" 51 | 52 | "github.com/go-chi/chi" 53 | "github.com/go-chi/chi/middleware" 54 | "github.com/go-chi/cors" 55 | ) 56 | 57 | func main() { 58 | r := chi.NewRouter() 59 | r.Use(middleware.Logger) 60 | 61 | // Basic CORS 62 | // for more ideas, see: https://developer.github.com/v3/#cross-origin-resource-sharing 63 | r.Use(cors.Handler(cors.Options{ 64 | AllowOriginFunc: AllowOriginFunc, 65 | AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, 66 | AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, 67 | ExposedHeaders: []string{"Link"}, 68 | AllowCredentials: true, 69 | MaxAge: 300, // Maximum value not ignored by any of major browsers 70 | })) 71 | 72 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 73 | w.Write([]byte("welcome")) 74 | }) 75 | 76 | http.ListenAndServe(":3000", r) 77 | } 78 | 79 | func AllowOriginFunc(r *http.Request, origin string) bool { 80 | if origin == "http://example.com" { 81 | return true 82 | } 83 | return false 84 | } 85 | -------------------------------------------------------------------------------- /cors.go: -------------------------------------------------------------------------------- 1 | // cors package is net/http handler to handle CORS related requests 2 | // as defined by http://www.w3.org/TR/cors/ 3 | // 4 | // You can configure it by passing an option struct to cors.New: 5 | // 6 | // c := cors.New(cors.Options{ 7 | // AllowedOrigins: []string{"foo.com"}, 8 | // AllowedMethods: []string{"GET", "POST", "DELETE"}, 9 | // AllowCredentials: true, 10 | // }) 11 | // 12 | // Then insert the handler in the chain: 13 | // 14 | // handler = c.Handler(handler) 15 | // 16 | // See Options documentation for more options. 17 | // 18 | // The resulting handler is a standard net/http handler. 19 | package cors 20 | 21 | import ( 22 | "log" 23 | "net/http" 24 | "os" 25 | "strconv" 26 | "strings" 27 | ) 28 | 29 | // Options is a configuration container to setup the CORS middleware. 30 | type Options struct { 31 | // AllowedOrigins is a list of origins a cross-domain request can be executed from. 32 | // If the special "*" value is present in the list, all origins will be allowed. 33 | // An origin may contain a wildcard (*) to replace 0 or more characters 34 | // (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty. 35 | // Only one wildcard can be used per origin. 36 | // Default value is ["*"] 37 | AllowedOrigins []string 38 | 39 | // AllowOriginFunc is a custom function to validate the origin. It takes the origin 40 | // as argument and returns true if allowed or false otherwise. If this option is 41 | // set, the content of AllowedOrigins is ignored. 42 | AllowOriginFunc func(r *http.Request, origin string) bool 43 | 44 | // AllowedMethods is a list of methods the client is allowed to use with 45 | // cross-domain requests. Default value is simple methods (HEAD, GET and POST). 46 | AllowedMethods []string 47 | 48 | // AllowedHeaders is list of non simple headers the client is allowed to use with 49 | // cross-domain requests. 50 | // If the special "*" value is present in the list, all headers will be allowed. 51 | // Default value is [] but "Origin" is always appended to the list. 52 | AllowedHeaders []string 53 | 54 | // ExposedHeaders indicates which headers are safe to expose to the API of a CORS 55 | // API specification 56 | ExposedHeaders []string 57 | 58 | // AllowCredentials indicates whether the request can include user credentials like 59 | // cookies, HTTP authentication or client side SSL certificates. 60 | AllowCredentials bool 61 | 62 | // MaxAge indicates how long (in seconds) the results of a preflight request 63 | // can be cached 64 | MaxAge int 65 | 66 | // OptionsPassthrough instructs preflight to let other potential next handlers to 67 | // process the OPTIONS method. Turn this on if your application handles OPTIONS. 68 | OptionsPassthrough bool 69 | 70 | // Debugging flag adds additional output to debug server side CORS issues 71 | Debug bool 72 | } 73 | 74 | // Logger generic interface for logger 75 | type Logger interface { 76 | Printf(string, ...interface{}) 77 | } 78 | 79 | // Cors http handler 80 | type Cors struct { 81 | // Debug logger 82 | Log Logger 83 | 84 | // Normalized list of plain allowed origins 85 | allowedOrigins []string 86 | 87 | // List of allowed origins containing wildcards 88 | allowedWOrigins []wildcard 89 | 90 | // Optional origin validator function 91 | allowOriginFunc func(r *http.Request, origin string) bool 92 | 93 | // Normalized list of allowed headers 94 | allowedHeaders []string 95 | 96 | // Normalized list of allowed methods 97 | allowedMethods []string 98 | 99 | // Normalized list of exposed headers 100 | exposedHeaders []string 101 | maxAge int 102 | 103 | // Set to true when allowed origins contains a "*" 104 | allowedOriginsAll bool 105 | 106 | // Set to true when allowed headers contains a "*" 107 | allowedHeadersAll bool 108 | 109 | allowCredentials bool 110 | optionPassthrough bool 111 | } 112 | 113 | // New creates a new Cors handler with the provided options. 114 | func New(options Options) *Cors { 115 | c := &Cors{ 116 | exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey), 117 | allowOriginFunc: options.AllowOriginFunc, 118 | allowCredentials: options.AllowCredentials, 119 | maxAge: options.MaxAge, 120 | optionPassthrough: options.OptionsPassthrough, 121 | } 122 | if options.Debug && c.Log == nil { 123 | c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags) 124 | } 125 | 126 | // Normalize options 127 | // Note: for origins and methods matching, the spec requires a case-sensitive matching. 128 | // As it may error prone, we chose to ignore the spec here. 129 | 130 | // Allowed Origins 131 | if len(options.AllowedOrigins) == 0 { 132 | if options.AllowOriginFunc == nil { 133 | // Default is all origins 134 | c.allowedOriginsAll = true 135 | } 136 | } else { 137 | c.allowedOrigins = []string{} 138 | c.allowedWOrigins = []wildcard{} 139 | for _, origin := range options.AllowedOrigins { 140 | // Normalize 141 | origin = strings.ToLower(origin) 142 | if origin == "*" { 143 | // If "*" is present in the list, turn the whole list into a match all 144 | c.allowedOriginsAll = true 145 | c.allowedOrigins = nil 146 | c.allowedWOrigins = nil 147 | break 148 | } else if i := strings.IndexByte(origin, '*'); i >= 0 { 149 | // Split the origin in two: start and end string without the * 150 | w := wildcard{origin[0:i], origin[i+1:]} 151 | c.allowedWOrigins = append(c.allowedWOrigins, w) 152 | } else { 153 | c.allowedOrigins = append(c.allowedOrigins, origin) 154 | } 155 | } 156 | } 157 | 158 | // Allowed Headers 159 | if len(options.AllowedHeaders) == 0 { 160 | // Use sensible defaults 161 | c.allowedHeaders = []string{"Origin", "Accept", "Content-Type"} 162 | } else { 163 | // Origin is always appended as some browsers will always request for this header at preflight 164 | c.allowedHeaders = convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey) 165 | for _, h := range options.AllowedHeaders { 166 | if h == "*" { 167 | c.allowedHeadersAll = true 168 | c.allowedHeaders = nil 169 | break 170 | } 171 | } 172 | } 173 | 174 | // Allowed Methods 175 | if len(options.AllowedMethods) == 0 { 176 | // Default is spec's "simple" methods 177 | c.allowedMethods = []string{http.MethodGet, http.MethodPost, http.MethodHead} 178 | } else { 179 | c.allowedMethods = convert(options.AllowedMethods, strings.ToUpper) 180 | } 181 | 182 | return c 183 | } 184 | 185 | // Handler creates a new Cors handler with passed options. 186 | func Handler(options Options) func(next http.Handler) http.Handler { 187 | c := New(options) 188 | return c.Handler 189 | } 190 | 191 | // AllowAll create a new Cors handler with permissive configuration allowing all 192 | // origins with all standard methods with any header and credentials. 193 | func AllowAll() *Cors { 194 | return New(Options{ 195 | AllowedOrigins: []string{"*"}, 196 | AllowedMethods: []string{ 197 | http.MethodHead, 198 | http.MethodGet, 199 | http.MethodPost, 200 | http.MethodPut, 201 | http.MethodPatch, 202 | http.MethodDelete, 203 | }, 204 | AllowedHeaders: []string{"*"}, 205 | AllowCredentials: false, 206 | }) 207 | } 208 | 209 | // Handler apply the CORS specification on the request, and add relevant CORS headers 210 | // as necessary. 211 | func (c *Cors) Handler(next http.Handler) http.Handler { 212 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 213 | if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { 214 | c.logf("Handler: Preflight request") 215 | c.handlePreflight(w, r) 216 | // Preflight requests are standalone and should stop the chain as some other 217 | // middleware may not handle OPTIONS requests correctly. One typical example 218 | // is authentication middleware ; OPTIONS requests won't carry authentication 219 | // headers (see #1) 220 | if c.optionPassthrough { 221 | next.ServeHTTP(w, r) 222 | } else { 223 | w.WriteHeader(http.StatusOK) 224 | } 225 | } else { 226 | c.logf("Handler: Actual request") 227 | c.handleActualRequest(w, r) 228 | next.ServeHTTP(w, r) 229 | } 230 | }) 231 | } 232 | 233 | // handlePreflight handles pre-flight CORS requests 234 | func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { 235 | headers := w.Header() 236 | origin := r.Header.Get("Origin") 237 | 238 | if r.Method != http.MethodOptions { 239 | c.logf("Preflight aborted: %s!=OPTIONS", r.Method) 240 | return 241 | } 242 | // Always set Vary headers 243 | // see https://github.com/rs/cors/issues/10, 244 | // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001 245 | headers.Add("Vary", "Origin") 246 | headers.Add("Vary", "Access-Control-Request-Method") 247 | headers.Add("Vary", "Access-Control-Request-Headers") 248 | 249 | if origin == "" { 250 | c.logf("Preflight aborted: empty origin") 251 | return 252 | } 253 | if !c.isOriginAllowed(r, origin) { 254 | c.logf("Preflight aborted: origin '%s' not allowed", origin) 255 | return 256 | } 257 | 258 | reqMethod := r.Header.Get("Access-Control-Request-Method") 259 | if !c.isMethodAllowed(reqMethod) { 260 | c.logf("Preflight aborted: method '%s' not allowed", reqMethod) 261 | return 262 | } 263 | reqHeaders := parseHeaderList(r.Header.Get("Access-Control-Request-Headers")) 264 | if !c.areHeadersAllowed(reqHeaders) { 265 | c.logf("Preflight aborted: headers '%v' not allowed", reqHeaders) 266 | return 267 | } 268 | if c.allowedOriginsAll { 269 | headers.Set("Access-Control-Allow-Origin", "*") 270 | } else { 271 | headers.Set("Access-Control-Allow-Origin", origin) 272 | } 273 | // Spec says: Since the list of methods can be unbounded, simply returning the method indicated 274 | // by Access-Control-Request-Method (if supported) can be enough 275 | headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod)) 276 | if len(reqHeaders) > 0 { 277 | 278 | // Spec says: Since the list of headers can be unbounded, simply returning supported headers 279 | // from Access-Control-Request-Headers can be enough 280 | headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", ")) 281 | } 282 | if c.allowCredentials { 283 | headers.Set("Access-Control-Allow-Credentials", "true") 284 | } 285 | if c.maxAge > 0 { 286 | headers.Set("Access-Control-Max-Age", strconv.Itoa(c.maxAge)) 287 | } 288 | c.logf("Preflight response headers: %v", headers) 289 | } 290 | 291 | // handleActualRequest handles simple cross-origin requests, actual request or redirects 292 | func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) { 293 | headers := w.Header() 294 | origin := r.Header.Get("Origin") 295 | 296 | // Always set Vary, see https://github.com/rs/cors/issues/10 297 | headers.Add("Vary", "Origin") 298 | if origin == "" { 299 | c.logf("Actual request no headers added: missing origin") 300 | return 301 | } 302 | if !c.isOriginAllowed(r, origin) { 303 | c.logf("Actual request no headers added: origin '%s' not allowed", origin) 304 | return 305 | } 306 | 307 | // Note that spec does define a way to specifically disallow a simple method like GET or 308 | // POST. Access-Control-Allow-Methods is only used for pre-flight requests and the 309 | // spec doesn't instruct to check the allowed methods for simple cross-origin requests. 310 | // We think it's a nice feature to be able to have control on those methods though. 311 | if !c.isMethodAllowed(r.Method) { 312 | c.logf("Actual request no headers added: method '%s' not allowed", r.Method) 313 | 314 | return 315 | } 316 | if c.allowedOriginsAll { 317 | headers.Set("Access-Control-Allow-Origin", "*") 318 | } else { 319 | headers.Set("Access-Control-Allow-Origin", origin) 320 | } 321 | if len(c.exposedHeaders) > 0 { 322 | headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", ")) 323 | } 324 | if c.allowCredentials { 325 | headers.Set("Access-Control-Allow-Credentials", "true") 326 | } 327 | c.logf("Actual response added headers: %v", headers) 328 | } 329 | 330 | // convenience method. checks if a logger is set. 331 | func (c *Cors) logf(format string, a ...interface{}) { 332 | if c.Log != nil { 333 | c.Log.Printf(format, a...) 334 | } 335 | } 336 | 337 | // isOriginAllowed checks if a given origin is allowed to perform cross-domain requests 338 | // on the endpoint 339 | func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool { 340 | if c.allowOriginFunc != nil { 341 | return c.allowOriginFunc(r, origin) 342 | } 343 | if c.allowedOriginsAll { 344 | return true 345 | } 346 | origin = strings.ToLower(origin) 347 | for _, o := range c.allowedOrigins { 348 | if o == origin { 349 | return true 350 | } 351 | } 352 | for _, w := range c.allowedWOrigins { 353 | if w.match(origin) { 354 | return true 355 | } 356 | } 357 | return false 358 | } 359 | 360 | // isMethodAllowed checks if a given method can be used as part of a cross-domain request 361 | // on the endpoint 362 | func (c *Cors) isMethodAllowed(method string) bool { 363 | if len(c.allowedMethods) == 0 { 364 | // If no method allowed, always return false, even for preflight request 365 | return false 366 | } 367 | method = strings.ToUpper(method) 368 | if method == http.MethodOptions { 369 | // Always allow preflight requests 370 | return true 371 | } 372 | for _, m := range c.allowedMethods { 373 | if m == method { 374 | return true 375 | } 376 | } 377 | return false 378 | } 379 | 380 | // areHeadersAllowed checks if a given list of headers are allowed to used within 381 | // a cross-domain request. 382 | func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool { 383 | if c.allowedHeadersAll || len(requestedHeaders) == 0 { 384 | return true 385 | } 386 | for _, header := range requestedHeaders { 387 | header = http.CanonicalHeaderKey(header) 388 | found := false 389 | for _, h := range c.allowedHeaders { 390 | if h == header { 391 | found = true 392 | break 393 | } 394 | } 395 | if !found { 396 | return false 397 | } 398 | } 399 | return true 400 | } 401 | -------------------------------------------------------------------------------- /cors_test.go: -------------------------------------------------------------------------------- 1 | package cors 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "regexp" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 12 | w.Write([]byte("bar")) 13 | }) 14 | 15 | var allHeaders = []string{ 16 | "Vary", 17 | "Access-Control-Allow-Origin", 18 | "Access-Control-Allow-Methods", 19 | "Access-Control-Allow-Headers", 20 | "Access-Control-Allow-Credentials", 21 | "Access-Control-Max-Age", 22 | "Access-Control-Expose-Headers", 23 | } 24 | 25 | func assertHeaders(t *testing.T, resHeaders http.Header, expHeaders map[string]string) { 26 | for _, name := range allHeaders { 27 | got := strings.Join(resHeaders[name], ", ") 28 | want := expHeaders[name] 29 | if got != want { 30 | t.Errorf("Response header %q = %q, want %q", name, got, want) 31 | } 32 | } 33 | } 34 | 35 | func assertResponse(t *testing.T, res *httptest.ResponseRecorder, responseCode int) { 36 | if responseCode != res.Code { 37 | t.Errorf("assertResponse: expected response code to be %d but got %d. ", responseCode, res.Code) 38 | } 39 | } 40 | 41 | func TestSpec(t *testing.T) { 42 | cases := []struct { 43 | name string 44 | options Options 45 | method string 46 | reqHeaders map[string]string 47 | resHeaders map[string]string 48 | }{ 49 | { 50 | "NoConfig", 51 | Options{ 52 | // Intentionally left blank. 53 | }, 54 | "GET", 55 | map[string]string{}, 56 | map[string]string{ 57 | "Vary": "Origin", 58 | }, 59 | }, 60 | { 61 | "MatchAllOrigin", 62 | Options{ 63 | AllowedOrigins: []string{"*"}, 64 | }, 65 | "GET", 66 | map[string]string{ 67 | "Origin": "http://foobar.com", 68 | }, 69 | map[string]string{ 70 | "Vary": "Origin", 71 | "Access-Control-Allow-Origin": "*", 72 | }, 73 | }, 74 | { 75 | "MatchAllOriginWithCredentials", 76 | Options{ 77 | AllowedOrigins: []string{"*"}, 78 | AllowCredentials: true, 79 | }, 80 | "GET", 81 | map[string]string{ 82 | "Origin": "http://foobar.com", 83 | }, 84 | map[string]string{ 85 | "Vary": "Origin", 86 | "Access-Control-Allow-Origin": "*", 87 | "Access-Control-Allow-Credentials": "true", 88 | }, 89 | }, 90 | { 91 | "AllowedOrigin", 92 | Options{ 93 | AllowedOrigins: []string{"http://foobar.com"}, 94 | }, 95 | "GET", 96 | map[string]string{ 97 | "Origin": "http://foobar.com", 98 | }, 99 | map[string]string{ 100 | "Vary": "Origin", 101 | "Access-Control-Allow-Origin": "http://foobar.com", 102 | }, 103 | }, 104 | { 105 | "WildcardOrigin", 106 | Options{ 107 | AllowedOrigins: []string{"http://*.bar.com"}, 108 | }, 109 | "GET", 110 | map[string]string{ 111 | "Origin": "http://foo.bar.com", 112 | }, 113 | map[string]string{ 114 | "Vary": "Origin", 115 | "Access-Control-Allow-Origin": "http://foo.bar.com", 116 | }, 117 | }, 118 | { 119 | "DisallowedOrigin", 120 | Options{ 121 | AllowedOrigins: []string{"http://foobar.com"}, 122 | }, 123 | "GET", 124 | map[string]string{ 125 | "Origin": "http://barbaz.com", 126 | }, 127 | map[string]string{ 128 | "Vary": "Origin", 129 | }, 130 | }, 131 | { 132 | "DisallowedWildcardOrigin", 133 | Options{ 134 | AllowedOrigins: []string{"http://*.bar.com"}, 135 | }, 136 | "GET", 137 | map[string]string{ 138 | "Origin": "http://foo.baz.com", 139 | }, 140 | map[string]string{ 141 | "Vary": "Origin", 142 | }, 143 | }, 144 | { 145 | "AllowedOriginFuncMatch", 146 | Options{ 147 | AllowOriginFunc: func(r *http.Request, o string) bool { 148 | return regexp.MustCompile("^http://foo").MatchString(o) && r.Header.Get("Authorization") == "secret" 149 | }, 150 | }, 151 | "GET", 152 | map[string]string{ 153 | "Origin": "http://foobar.com", 154 | "Authorization": "secret", 155 | }, 156 | map[string]string{ 157 | "Vary": "Origin", 158 | "Access-Control-Allow-Origin": "http://foobar.com", 159 | }, 160 | }, 161 | { 162 | "AllowOriginFuncNotMatch", 163 | Options{ 164 | AllowOriginFunc: func(r *http.Request, o string) bool { 165 | return regexp.MustCompile("^http://foo").MatchString(o) && r.Header.Get("Authorization") == "secret" 166 | }, 167 | }, 168 | "GET", 169 | map[string]string{ 170 | "Origin": "http://foobar.com", 171 | "Authorization": "not-secret", 172 | }, 173 | map[string]string{ 174 | "Vary": "Origin", 175 | }, 176 | }, 177 | { 178 | "MaxAge", 179 | Options{ 180 | AllowedOrigins: []string{"http://example.com/"}, 181 | AllowedMethods: []string{"GET"}, 182 | MaxAge: 10, 183 | }, 184 | "OPTIONS", 185 | map[string]string{ 186 | "Origin": "http://example.com/", 187 | "Access-Control-Request-Method": "GET", 188 | }, 189 | map[string]string{ 190 | "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 191 | "Access-Control-Allow-Origin": "http://example.com/", 192 | "Access-Control-Allow-Methods": "GET", 193 | "Access-Control-Max-Age": "10", 194 | }, 195 | }, 196 | { 197 | "AllowedMethod", 198 | Options{ 199 | AllowedOrigins: []string{"http://foobar.com"}, 200 | AllowedMethods: []string{"PUT", "DELETE"}, 201 | }, 202 | "OPTIONS", 203 | map[string]string{ 204 | "Origin": "http://foobar.com", 205 | "Access-Control-Request-Method": "PUT", 206 | }, 207 | map[string]string{ 208 | "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 209 | "Access-Control-Allow-Origin": "http://foobar.com", 210 | "Access-Control-Allow-Methods": "PUT", 211 | }, 212 | }, 213 | { 214 | "DisallowedMethod", 215 | Options{ 216 | AllowedOrigins: []string{"http://foobar.com"}, 217 | AllowedMethods: []string{"PUT", "DELETE"}, 218 | }, 219 | "OPTIONS", 220 | map[string]string{ 221 | "Origin": "http://foobar.com", 222 | "Access-Control-Request-Method": "PATCH", 223 | }, 224 | map[string]string{ 225 | "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 226 | }, 227 | }, 228 | { 229 | "AllowedHeaders", 230 | Options{ 231 | AllowedOrigins: []string{"http://foobar.com"}, 232 | AllowedHeaders: []string{"X-Header-1", "x-header-2"}, 233 | }, 234 | "OPTIONS", 235 | map[string]string{ 236 | "Origin": "http://foobar.com", 237 | "Access-Control-Request-Method": "GET", 238 | "Access-Control-Request-Headers": "X-Header-2, X-HEADER-1", 239 | }, 240 | map[string]string{ 241 | "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 242 | "Access-Control-Allow-Origin": "http://foobar.com", 243 | "Access-Control-Allow-Methods": "GET", 244 | "Access-Control-Allow-Headers": "X-Header-2, X-Header-1", 245 | }, 246 | }, 247 | { 248 | "DefaultAllowedHeaders", 249 | Options{ 250 | AllowedOrigins: []string{"http://foobar.com"}, 251 | AllowedHeaders: []string{}, 252 | }, 253 | "OPTIONS", 254 | map[string]string{ 255 | "Origin": "http://foobar.com", 256 | "Access-Control-Request-Method": "GET", 257 | "Access-Control-Request-Headers": "Content-Type", 258 | }, 259 | map[string]string{ 260 | "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 261 | "Access-Control-Allow-Origin": "http://foobar.com", 262 | "Access-Control-Allow-Methods": "GET", 263 | "Access-Control-Allow-Headers": "Content-Type", 264 | }, 265 | }, 266 | { 267 | "AllowedWildcardHeader", 268 | Options{ 269 | AllowedOrigins: []string{"http://foobar.com"}, 270 | AllowedHeaders: []string{"*"}, 271 | }, 272 | "OPTIONS", 273 | map[string]string{ 274 | "Origin": "http://foobar.com", 275 | "Access-Control-Request-Method": "GET", 276 | "Access-Control-Request-Headers": "X-Header-2, X-HEADER-1", 277 | }, 278 | map[string]string{ 279 | "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 280 | "Access-Control-Allow-Origin": "http://foobar.com", 281 | "Access-Control-Allow-Methods": "GET", 282 | "Access-Control-Allow-Headers": "X-Header-2, X-Header-1", 283 | }, 284 | }, 285 | { 286 | "DisallowedHeader", 287 | Options{ 288 | AllowedOrigins: []string{"http://foobar.com"}, 289 | AllowedHeaders: []string{"X-Header-1", "x-header-2"}, 290 | }, 291 | "OPTIONS", 292 | map[string]string{ 293 | "Origin": "http://foobar.com", 294 | "Access-Control-Request-Method": "GET", 295 | "Access-Control-Request-Headers": "X-Header-3, X-Header-1", 296 | }, 297 | map[string]string{ 298 | "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 299 | }, 300 | }, 301 | { 302 | "OriginHeader", 303 | Options{ 304 | AllowedOrigins: []string{"http://foobar.com"}, 305 | }, 306 | "OPTIONS", 307 | map[string]string{ 308 | "Origin": "http://foobar.com", 309 | "Access-Control-Request-Method": "GET", 310 | "Access-Control-Request-Headers": "origin", 311 | }, 312 | map[string]string{ 313 | "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 314 | "Access-Control-Allow-Origin": "http://foobar.com", 315 | "Access-Control-Allow-Methods": "GET", 316 | "Access-Control-Allow-Headers": "Origin", 317 | }, 318 | }, 319 | { 320 | "ExposedHeader", 321 | Options{ 322 | AllowedOrigins: []string{"http://foobar.com"}, 323 | ExposedHeaders: []string{"X-Header-1", "x-header-2"}, 324 | }, 325 | "GET", 326 | map[string]string{ 327 | "Origin": "http://foobar.com", 328 | }, 329 | map[string]string{ 330 | "Vary": "Origin", 331 | "Access-Control-Allow-Origin": "http://foobar.com", 332 | "Access-Control-Expose-Headers": "X-Header-1, X-Header-2", 333 | }, 334 | }, 335 | { 336 | "AllowedCredentials", 337 | Options{ 338 | AllowedOrigins: []string{"http://foobar.com"}, 339 | AllowCredentials: true, 340 | }, 341 | "OPTIONS", 342 | map[string]string{ 343 | "Origin": "http://foobar.com", 344 | "Access-Control-Request-Method": "GET", 345 | }, 346 | map[string]string{ 347 | "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 348 | "Access-Control-Allow-Origin": "http://foobar.com", 349 | "Access-Control-Allow-Methods": "GET", 350 | "Access-Control-Allow-Credentials": "true", 351 | }, 352 | }, 353 | { 354 | "OptionPassthrough", 355 | Options{ 356 | OptionsPassthrough: true, 357 | }, 358 | "OPTIONS", 359 | map[string]string{ 360 | "Origin": "http://foobar.com", 361 | "Access-Control-Request-Method": "GET", 362 | }, 363 | map[string]string{ 364 | "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 365 | "Access-Control-Allow-Origin": "*", 366 | "Access-Control-Allow-Methods": "GET", 367 | }, 368 | }, 369 | { 370 | "NonPreflightOptions", 371 | Options{ 372 | AllowedOrigins: []string{"http://foobar.com"}, 373 | }, 374 | "OPTIONS", 375 | map[string]string{ 376 | "Origin": "http://foobar.com", 377 | }, 378 | map[string]string{ 379 | "Vary": "Origin", 380 | "Access-Control-Allow-Origin": "http://foobar.com", 381 | }, 382 | }, 383 | } 384 | for i := range cases { 385 | tc := cases[i] 386 | t.Run(tc.name, func(t *testing.T) { 387 | s := New(tc.options) 388 | 389 | req, _ := http.NewRequest(tc.method, "http://example.com/foo", nil) 390 | for name, value := range tc.reqHeaders { 391 | req.Header.Add(name, value) 392 | } 393 | 394 | t.Run("Handler", func(t *testing.T) { 395 | res := httptest.NewRecorder() 396 | s.Handler(testHandler).ServeHTTP(res, req) 397 | assertHeaders(t, res.Header(), tc.resHeaders) 398 | }) 399 | }) 400 | } 401 | } 402 | 403 | func TestDebug(t *testing.T) { 404 | s := New(Options{ 405 | Debug: true, 406 | }) 407 | 408 | if s.Log == nil { 409 | t.Error("Logger not created when debug=true") 410 | } 411 | } 412 | 413 | func TestDefault(t *testing.T) { 414 | s := New(Options{}) 415 | if s.Log != nil { 416 | t.Error("c.log should be nil when Default") 417 | } 418 | if !s.allowedOriginsAll { 419 | t.Error("c.allowedOriginsAll should be true when Default") 420 | } 421 | if s.allowedHeaders == nil { 422 | t.Error("c.allowedHeaders must not be nil when Default") 423 | } 424 | if s.allowedMethods == nil { 425 | t.Error("c.allowedMethods must not be nil when Default") 426 | } 427 | } 428 | 429 | func TestHandlePreflightInvalidOriginAbortion(t *testing.T) { 430 | s := New(Options{ 431 | AllowedOrigins: []string{"http://foo.com"}, 432 | }) 433 | res := httptest.NewRecorder() 434 | req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil) 435 | req.Header.Add("Origin", "http://example.com/") 436 | 437 | s.handlePreflight(res, req) 438 | 439 | assertHeaders(t, res.Header(), map[string]string{ 440 | "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 441 | }) 442 | } 443 | 444 | func TestHandlePreflightNoOptionsAbortion(t *testing.T) { 445 | s := New(Options{ 446 | // Intentionally left blank. 447 | }) 448 | res := httptest.NewRecorder() 449 | req, _ := http.NewRequest("GET", "http://example.com/foo", nil) 450 | 451 | s.handlePreflight(res, req) 452 | 453 | assertHeaders(t, res.Header(), map[string]string{}) 454 | } 455 | 456 | func TestHandleActualRequestInvalidOriginAbortion(t *testing.T) { 457 | s := New(Options{ 458 | AllowedOrigins: []string{"http://foo.com"}, 459 | }) 460 | res := httptest.NewRecorder() 461 | req, _ := http.NewRequest("GET", "http://example.com/foo", nil) 462 | req.Header.Add("Origin", "http://example.com/") 463 | 464 | s.handleActualRequest(res, req) 465 | 466 | assertHeaders(t, res.Header(), map[string]string{ 467 | "Vary": "Origin", 468 | }) 469 | } 470 | 471 | func TestHandleActualRequestInvalidMethodAbortion(t *testing.T) { 472 | s := New(Options{ 473 | AllowedMethods: []string{"POST"}, 474 | AllowCredentials: true, 475 | }) 476 | res := httptest.NewRecorder() 477 | req, _ := http.NewRequest("GET", "http://example.com/foo", nil) 478 | req.Header.Add("Origin", "http://example.com/") 479 | 480 | s.handleActualRequest(res, req) 481 | 482 | assertHeaders(t, res.Header(), map[string]string{ 483 | "Vary": "Origin", 484 | }) 485 | } 486 | 487 | func TestIsMethodAllowedReturnsFalseWithNoMethods(t *testing.T) { 488 | s := New(Options{ 489 | // Intentionally left blank. 490 | }) 491 | s.allowedMethods = []string{} 492 | if s.isMethodAllowed("") { 493 | t.Error("IsMethodAllowed should return false when c.allowedMethods is nil.") 494 | } 495 | } 496 | 497 | func TestIsMethodAllowedReturnsTrueWithOptions(t *testing.T) { 498 | s := New(Options{ 499 | // Intentionally left blank. 500 | }) 501 | if !s.isMethodAllowed("OPTIONS") { 502 | t.Error("IsMethodAllowed should return true when c.allowedMethods is nil.") 503 | } 504 | } 505 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-chi/cors 2 | 3 | go 1.14 4 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package cors 2 | 3 | import "strings" 4 | 5 | const toLower = 'a' - 'A' 6 | 7 | type converter func(string) string 8 | 9 | type wildcard struct { 10 | prefix string 11 | suffix string 12 | } 13 | 14 | func (w wildcard) match(s string) bool { 15 | return len(s) >= len(w.prefix+w.suffix) && strings.HasPrefix(s, w.prefix) && strings.HasSuffix(s, w.suffix) 16 | } 17 | 18 | // convert converts a list of string using the passed converter function 19 | func convert(s []string, c converter) []string { 20 | out := []string{} 21 | for _, i := range s { 22 | out = append(out, c(i)) 23 | } 24 | return out 25 | } 26 | 27 | // parseHeaderList tokenize + normalize a string containing a list of headers 28 | func parseHeaderList(headerList string) []string { 29 | l := len(headerList) 30 | h := make([]byte, 0, l) 31 | upper := true 32 | // Estimate the number headers in order to allocate the right splice size 33 | t := 0 34 | for i := 0; i < l; i++ { 35 | if headerList[i] == ',' { 36 | t++ 37 | } 38 | } 39 | headers := make([]string, 0, t) 40 | for i := 0; i < l; i++ { 41 | b := headerList[i] 42 | if b >= 'a' && b <= 'z' { 43 | if upper { 44 | h = append(h, b-toLower) 45 | } else { 46 | h = append(h, b) 47 | } 48 | } else if b >= 'A' && b <= 'Z' { 49 | if !upper { 50 | h = append(h, b+toLower) 51 | } else { 52 | h = append(h, b) 53 | } 54 | } else if b == '-' || b == '_' || b == '.' || (b >= '0' && b <= '9') { 55 | h = append(h, b) 56 | } 57 | 58 | if b == ' ' || b == ',' || i == l-1 { 59 | if len(h) > 0 { 60 | // Flush the found header 61 | headers = append(headers, string(h)) 62 | h = h[:0] 63 | upper = true 64 | } 65 | } else { 66 | upper = b == '-' 67 | } 68 | } 69 | return headers 70 | } 71 | -------------------------------------------------------------------------------- /utils_test.go: -------------------------------------------------------------------------------- 1 | package cors 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestWildcard(t *testing.T) { 9 | w := wildcard{"foo", "bar"} 10 | if !w.match("foobar") { 11 | t.Error("foo*bar should match foobar") 12 | } 13 | if !w.match("foobazbar") { 14 | t.Error("foo*bar should match foobazbar") 15 | } 16 | if w.match("foobaz") { 17 | t.Error("foo*bar should not match foobaz") 18 | } 19 | 20 | w = wildcard{"foo", "oof"} 21 | if w.match("foof") { 22 | t.Error("foo*oof should not match foof") 23 | } 24 | } 25 | 26 | func TestConvert(t *testing.T) { 27 | s := convert([]string{"A", "b", "C"}, strings.ToLower) 28 | e := []string{"a", "b", "c"} 29 | if s[0] != e[0] || s[1] != e[1] || s[2] != e[2] { 30 | t.Errorf("%v != %v", s, e) 31 | } 32 | } 33 | 34 | func TestParseHeaderList(t *testing.T) { 35 | h := parseHeaderList("header, second-header, THIRD-HEADER, Numb3r3d-H34d3r, Header_with_underscore Header.with.full.stop") 36 | e := []string{"Header", "Second-Header", "Third-Header", "Numb3r3d-H34d3r", "Header_with_underscore", "Header.with.full.stop"} 37 | if h[0] != e[0] || h[1] != e[1] || h[2] != e[2] || h[3] != e[3] || h[4] != e[4] || h[5] != e[5] { 38 | t.Errorf("%v != %v", h, e) 39 | } 40 | } 41 | 42 | func TestParseHeaderListEmpty(t *testing.T) { 43 | if len(parseHeaderList("")) != 0 { 44 | t.Error("should be empty slice") 45 | } 46 | if len(parseHeaderList(" , ")) != 0 { 47 | t.Error("should be empty slice") 48 | } 49 | } 50 | 51 | func BenchmarkParseHeaderList(b *testing.B) { 52 | b.ReportAllocs() 53 | for i := 0; i < b.N; i++ { 54 | parseHeaderList("header, second-header, THIRD-HEADER") 55 | } 56 | } 57 | 58 | func BenchmarkParseHeaderListSingle(b *testing.B) { 59 | b.ReportAllocs() 60 | for i := 0; i < b.N; i++ { 61 | parseHeaderList("header") 62 | } 63 | } 64 | 65 | func BenchmarkParseHeaderListNormalized(b *testing.B) { 66 | b.ReportAllocs() 67 | for i := 0; i < b.N; i++ { 68 | parseHeaderList("Header1, Header2, Third-Header") 69 | } 70 | } 71 | 72 | func BenchmarkWildcard(b *testing.B) { 73 | w := wildcard{"foo", "bar"} 74 | b.Run("match", func(b *testing.B) { 75 | b.ReportAllocs() 76 | for i := 0; i < b.N; i++ { 77 | w.match("foobazbar") 78 | } 79 | }) 80 | b.Run("too short", func(b *testing.B) { 81 | b.ReportAllocs() 82 | for i := 0; i < b.N; i++ { 83 | w.match("fobar") 84 | } 85 | }) 86 | } 87 | --------------------------------------------------------------------------------