├── .github ├── FUNDING.yml └── workflows │ └── ci.yml ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── _examples ├── README.md ├── chi.svg ├── custom-handler │ └── main.go ├── custom-method │ └── main.go ├── fileserver │ ├── data │ │ └── notes.txt │ └── main.go ├── graceful │ └── main.go ├── hello-world │ └── main.go ├── limits │ └── main.go ├── logging │ └── main.go ├── rest │ ├── go.mod │ ├── go.sum │ ├── main.go │ ├── routes.json │ └── routes.md ├── router-walk │ └── main.go ├── todos-resource │ ├── main.go │ ├── todos.go │ └── users.go └── versions │ ├── data │ ├── article.go │ └── errors.go │ ├── go.mod │ ├── go.sum │ ├── main.go │ └── presenter │ ├── v1 │ └── article.go │ ├── v2 │ └── article.go │ └── v3 │ └── article.go ├── chain.go ├── chi.go ├── context.go ├── context_test.go ├── go.mod ├── middleware ├── basic_auth.go ├── clean_path.go ├── compress.go ├── compress_test.go ├── content_charset.go ├── content_charset_test.go ├── content_encoding.go ├── content_encoding_test.go ├── content_type.go ├── content_type_test.go ├── get_head.go ├── get_head_test.go ├── heartbeat.go ├── logger.go ├── logger_test.go ├── maybe.go ├── middleware.go ├── middleware_test.go ├── nocache.go ├── page_route.go ├── path_rewrite.go ├── profiler.go ├── realip.go ├── realip_test.go ├── recoverer.go ├── recoverer_test.go ├── request_id.go ├── request_id_test.go ├── request_size.go ├── route_headers.go ├── strip.go ├── strip_test.go ├── sunset.go ├── sunset_test.go ├── supress_notfound.go ├── terminal.go ├── throttle.go ├── throttle_test.go ├── timeout.go ├── url_format.go ├── url_format_test.go ├── value.go ├── wrap_writer.go └── wrap_writer_test.go ├── mux.go ├── mux_test.go ├── path_value.go ├── path_value_fallback.go ├── path_value_test.go ├── testdata ├── cert.pem └── key.pem ├── tree.go └── tree_test.go /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [pkieltyka] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: "**" 4 | paths-ignore: 5 | - "docs/**" 6 | pull_request: 7 | branches: "**" 8 | paths-ignore: 9 | - "docs/**" 10 | 11 | name: Test 12 | jobs: 13 | test: 14 | env: 15 | GOPATH: ${{ github.workspace }} 16 | 17 | defaults: 18 | run: 19 | working-directory: ${{ env.GOPATH }}/src/github.com/${{ github.repository }} 20 | 21 | strategy: 22 | matrix: 23 | go-version: [1.20.x, 1.21.x, 1.22.x, 1.23.x, 1.24.x] 24 | os: [ubuntu-latest, windows-latest] 25 | 26 | runs-on: ${{ matrix.os }} 27 | 28 | steps: 29 | - name: Install Go 30 | uses: actions/setup-go@v5 31 | with: 32 | go-version: ${{ matrix.go-version }} 33 | check-latest: true 34 | cache: false 35 | - name: Checkout code 36 | uses: actions/checkout@v4 37 | with: 38 | path: ${{ env.GOPATH }}/src/github.com/${{ github.repository }} 39 | - name: Test 40 | run: | 41 | go get -d -t ./... 42 | make test 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.sw? 3 | .vscode 4 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## v5.0.12 (2024-02-16) 4 | 5 | - History of changes: see https://github.com/go-chi/chi/compare/v5.0.11...v5.0.12 6 | 7 | 8 | ## v5.0.11 (2023-12-19) 9 | 10 | - History of changes: see https://github.com/go-chi/chi/compare/v5.0.10...v5.0.11 11 | 12 | 13 | ## v5.0.10 (2023-07-13) 14 | 15 | - Fixed small edge case in tests of v5.0.9 for older Go versions 16 | - History of changes: see https://github.com/go-chi/chi/compare/v5.0.9...v5.0.10 17 | 18 | 19 | ## v5.0.9 (2023-07-13) 20 | 21 | - History of changes: see https://github.com/go-chi/chi/compare/v5.0.8...v5.0.9 22 | 23 | 24 | ## v5.0.8 (2022-12-07) 25 | 26 | - History of changes: see https://github.com/go-chi/chi/compare/v5.0.7...v5.0.8 27 | 28 | 29 | ## v5.0.7 (2021-11-18) 30 | 31 | - History of changes: see https://github.com/go-chi/chi/compare/v5.0.6...v5.0.7 32 | 33 | 34 | ## v5.0.6 (2021-11-15) 35 | 36 | - History of changes: see https://github.com/go-chi/chi/compare/v5.0.5...v5.0.6 37 | 38 | 39 | ## v5.0.5 (2021-10-27) 40 | 41 | - History of changes: see https://github.com/go-chi/chi/compare/v5.0.4...v5.0.5 42 | 43 | 44 | ## v5.0.4 (2021-08-29) 45 | 46 | - History of changes: see https://github.com/go-chi/chi/compare/v5.0.3...v5.0.4 47 | 48 | 49 | ## v5.0.3 (2021-04-29) 50 | 51 | - History of changes: see https://github.com/go-chi/chi/compare/v5.0.2...v5.0.3 52 | 53 | 54 | ## v5.0.2 (2021-03-25) 55 | 56 | - History of changes: see https://github.com/go-chi/chi/compare/v5.0.1...v5.0.2 57 | 58 | 59 | ## v5.0.1 (2021-03-10) 60 | 61 | - Small improvements 62 | - History of changes: see https://github.com/go-chi/chi/compare/v5.0.0...v5.0.1 63 | 64 | 65 | ## v5.0.0 (2021-02-27) 66 | 67 | - chi v5, `github.com/go-chi/chi/v5` introduces the adoption of Go's SIV to adhere to the current state-of-the-tools in Go. 68 | - chi v1.5.x did not work out as planned, as the Go tooling is too powerful and chi's adoption is too wide. 69 | The most responsible thing to do for everyone's benefit is to just release v5 with SIV, so I present to you all, 70 | chi v5 at `github.com/go-chi/chi/v5`. I hope someday the developer experience and ergonomics I've been seeking 71 | will still come to fruition in some form, see https://github.com/golang/go/issues/44550 72 | - History of changes: see https://github.com/go-chi/chi/compare/v1.5.4...v5.0.0 73 | 74 | 75 | ## v1.5.4 (2021-02-27) 76 | 77 | - Undo prior retraction in v1.5.3 as we prepare for v5.0.0 release 78 | - History of changes: see https://github.com/go-chi/chi/compare/v1.5.3...v1.5.4 79 | 80 | 81 | ## v1.5.3 (2021-02-21) 82 | 83 | - Update go.mod to go 1.16 with new retract directive marking all versions without prior go.mod support 84 | - History of changes: see https://github.com/go-chi/chi/compare/v1.5.2...v1.5.3 85 | 86 | 87 | ## v1.5.2 (2021-02-10) 88 | 89 | - Reverting allocation optimization as a precaution as go test -race fails. 90 | - Minor improvements, see history below 91 | - History of changes: see https://github.com/go-chi/chi/compare/v1.5.1...v1.5.2 92 | 93 | 94 | ## v1.5.1 (2020-12-06) 95 | 96 | - Performance improvement: removing 1 allocation by foregoing context.WithValue, thank you @bouk for 97 | your contribution (https://github.com/go-chi/chi/pull/555). Note: new benchmarks posted in README. 98 | - `middleware.CleanPath`: new middleware that clean's request path of double slashes 99 | - deprecate & remove `chi.ServerBaseContext` in favour of stdlib `http.Server#BaseContext` 100 | - plus other tiny improvements, see full commit history below 101 | - History of changes: see https://github.com/go-chi/chi/compare/v4.1.2...v1.5.1 102 | 103 | 104 | ## v1.5.0 (2020-11-12) - now with go.mod support 105 | 106 | `chi` dates back to 2016 with it's original implementation as one of the first routers to adopt the newly introduced 107 | context.Context api to the stdlib -- set out to design a router that is faster, more modular and simpler than anything 108 | else out there -- while not introducing any custom handler types or dependencies. Today, `chi` still has zero dependencies, 109 | and in many ways is future proofed from changes, given it's minimal nature. Between versions, chi's iterations have been very 110 | incremental, with the architecture and api being the same today as it was originally designed in 2016. For this reason it 111 | makes chi a pretty easy project to maintain, as well thanks to the many amazing community contributions over the years 112 | to who all help make chi better (total of 86 contributors to date -- thanks all!). 113 | 114 | Chi has been a labour of love, art and engineering, with the goals to offer beautiful ergonomics, flexibility, performance 115 | and simplicity when building HTTP services with Go. I've strived to keep the router very minimal in surface area / code size, 116 | and always improving the code wherever possible -- and as of today the `chi` package is just 1082 lines of code (not counting 117 | middlewares, which are all optional). As well, I don't have the exact metrics, but from my analysis and email exchanges from 118 | companies and developers, chi is used by thousands of projects around the world -- thank you all as there is no better form of 119 | joy for me than to have art I had started be helpful and enjoyed by others. And of course I use chi in all of my own projects too :) 120 | 121 | For me, the aesthetics of chi's code and usage are very important. With the introduction of Go's module support 122 | (which I'm a big fan of), chi's past versioning scheme choice to v2, v3 and v4 would mean I'd require the import path 123 | of "github.com/go-chi/chi/v4", leading to the lengthy discussion at https://github.com/go-chi/chi/issues/462. 124 | Haha, to some, you may be scratching your head why I've spent > 1 year stalling to adopt "/vXX" convention in the import 125 | path -- which isn't horrible in general -- but for chi, I'm unable to accept it as I strive for perfection in it's API design, 126 | aesthetics and simplicity. It just doesn't feel good to me given chi's simple nature -- I do not foresee a "v5" or "v6", 127 | and upgrading between versions in the future will also be just incremental. 128 | 129 | I do understand versioning is a part of the API design as well, which is why the solution for a while has been to "do nothing", 130 | as Go supports both old and new import paths with/out go.mod. However, now that Go module support has had time to iron out kinks and 131 | is adopted everywhere, it's time for chi to get with the times. Luckily, I've discovered a path forward that will make me happy, 132 | while also not breaking anyone's app who adopted a prior versioning from tags in v2/v3/v4. I've made an experimental release of 133 | v1.5.0 with go.mod silently, and tested it with new and old projects, to ensure the developer experience is preserved, and it's 134 | largely unnoticed. Fortunately, Go's toolchain will check the tags of a repo and consider the "latest" tag the one with go.mod. 135 | However, you can still request a specific older tag such as v4.1.2, and everything will "just work". But new users can just 136 | `go get github.com/go-chi/chi` or `go get github.com/go-chi/chi@latest` and they will get the latest version which contains 137 | go.mod support, which is v1.5.0+. `chi` will not change very much over the years, just like it hasn't changed much from 4 years ago. 138 | Therefore, we will stay on v1.x from here on, starting from v1.5.0. Any breaking changes will bump a "minor" release and 139 | backwards-compatible improvements/fixes will bump a "tiny" release. 140 | 141 | For existing projects who want to upgrade to the latest go.mod version, run: `go get -u github.com/go-chi/chi@v1.5.0`, 142 | which will get you on the go.mod version line (as Go's mod cache may still remember v4.x). Brand new systems can run 143 | `go get -u github.com/go-chi/chi` or `go get -u github.com/go-chi/chi@latest` to install chi, which will install v1.5.0+ 144 | built with go.mod support. 145 | 146 | My apologies to the developers who will disagree with the decisions above, but, hope you'll try it and see it's a very 147 | minor request which is backwards compatible and won't break your existing installations. 148 | 149 | Cheers all, happy coding! 150 | 151 | 152 | --- 153 | 154 | 155 | ## v4.1.2 (2020-06-02) 156 | 157 | - fix that handles MethodNotAllowed with path variables, thank you @caseyhadden for your contribution 158 | - fix to replace nested wildcards correctly in RoutePattern, thank you @@unmultimedio for your contribution 159 | - History of changes: see https://github.com/go-chi/chi/compare/v4.1.1...v4.1.2 160 | 161 | 162 | ## v4.1.1 (2020-04-16) 163 | 164 | - fix for issue https://github.com/go-chi/chi/issues/411 which allows for overlapping regexp 165 | route to the correct handler through a recursive tree search, thanks to @Jahaja for the PR/fix! 166 | - new middleware.RouteHeaders as a simple router for request headers with wildcard support 167 | - History of changes: see https://github.com/go-chi/chi/compare/v4.1.0...v4.1.1 168 | 169 | 170 | ## v4.1.0 (2020-04-1) 171 | 172 | - middleware.LogEntry: Write method on interface now passes the response header 173 | and an extra interface type useful for custom logger implementations. 174 | - middleware.WrapResponseWriter: minor fix 175 | - middleware.Recoverer: a bit prettier 176 | - History of changes: see https://github.com/go-chi/chi/compare/v4.0.4...v4.1.0 177 | 178 | ## v4.0.4 (2020-03-24) 179 | 180 | - middleware.Recoverer: new pretty stack trace printing (https://github.com/go-chi/chi/pull/496) 181 | - a few minor improvements and fixes 182 | - History of changes: see https://github.com/go-chi/chi/compare/v4.0.3...v4.0.4 183 | 184 | 185 | ## v4.0.3 (2020-01-09) 186 | 187 | - core: fix regexp routing to include default value when param is not matched 188 | - middleware: rewrite of middleware.Compress 189 | - middleware: suppress http.ErrAbortHandler in middleware.Recoverer 190 | - History of changes: see https://github.com/go-chi/chi/compare/v4.0.2...v4.0.3 191 | 192 | 193 | ## v4.0.2 (2019-02-26) 194 | 195 | - Minor fixes 196 | - History of changes: see https://github.com/go-chi/chi/compare/v4.0.1...v4.0.2 197 | 198 | 199 | ## v4.0.1 (2019-01-21) 200 | 201 | - Fixes issue with compress middleware: #382 #385 202 | - History of changes: see https://github.com/go-chi/chi/compare/v4.0.0...v4.0.1 203 | 204 | 205 | ## v4.0.0 (2019-01-10) 206 | 207 | - chi v4 requires Go 1.10.3+ (or Go 1.9.7+) - we have deprecated support for Go 1.7 and 1.8 208 | - router: respond with 404 on router with no routes (#362) 209 | - router: additional check to ensure wildcard is at the end of a url pattern (#333) 210 | - middleware: deprecate use of http.CloseNotifier (#347) 211 | - middleware: fix RedirectSlashes to include query params on redirect (#334) 212 | - History of changes: see https://github.com/go-chi/chi/compare/v3.3.4...v4.0.0 213 | 214 | 215 | ## v3.3.4 (2019-01-07) 216 | 217 | - Minor middleware improvements. No changes to core library/router. Moving v3 into its 218 | - own branch as a version of chi for Go 1.7, 1.8, 1.9, 1.10, 1.11 219 | - History of changes: see https://github.com/go-chi/chi/compare/v3.3.3...v3.3.4 220 | 221 | 222 | ## v3.3.3 (2018-08-27) 223 | 224 | - Minor release 225 | - See https://github.com/go-chi/chi/compare/v3.3.2...v3.3.3 226 | 227 | 228 | ## v3.3.2 (2017-12-22) 229 | 230 | - Support to route trailing slashes on mounted sub-routers (#281) 231 | - middleware: new `ContentCharset` to check matching charsets. Thank you 232 | @csucu for your community contribution! 233 | 234 | 235 | ## v3.3.1 (2017-11-20) 236 | 237 | - middleware: new `AllowContentType` handler for explicit whitelist of accepted request Content-Types 238 | - middleware: new `SetHeader` handler for short-hand middleware to set a response header key/value 239 | - Minor bug fixes 240 | 241 | 242 | ## v3.3.0 (2017-10-10) 243 | 244 | - New chi.RegisterMethod(method) to add support for custom HTTP methods, see _examples/custom-method for usage 245 | - Deprecated LINK and UNLINK methods from the default list, please use `chi.RegisterMethod("LINK")` and `chi.RegisterMethod("UNLINK")` in an `init()` function 246 | 247 | 248 | ## v3.2.1 (2017-08-31) 249 | 250 | - Add new `Match(rctx *Context, method, path string) bool` method to `Routes` interface 251 | and `Mux`. Match searches the mux's routing tree for a handler that matches the method/path 252 | - Add new `RouteMethod` to `*Context` 253 | - Add new `Routes` pointer to `*Context` 254 | - Add new `middleware.GetHead` to route missing HEAD requests to GET handler 255 | - Updated benchmarks (see README) 256 | 257 | 258 | ## v3.1.5 (2017-08-02) 259 | 260 | - Setup golint and go vet for the project 261 | - As per golint, we've redefined `func ServerBaseContext(h http.Handler, baseCtx context.Context) http.Handler` 262 | to `func ServerBaseContext(baseCtx context.Context, h http.Handler) http.Handler` 263 | 264 | 265 | ## v3.1.0 (2017-07-10) 266 | 267 | - Fix a few minor issues after v3 release 268 | - Move `docgen` sub-pkg to https://github.com/go-chi/docgen 269 | - Move `render` sub-pkg to https://github.com/go-chi/render 270 | - Add new `URLFormat` handler to chi/middleware sub-pkg to make working with url mime 271 | suffixes easier, ie. parsing `/articles/1.json` and `/articles/1.xml`. See comments in 272 | https://github.com/go-chi/chi/blob/master/middleware/url_format.go for example usage. 273 | 274 | 275 | ## v3.0.0 (2017-06-21) 276 | 277 | - Major update to chi library with many exciting updates, but also some *breaking changes* 278 | - URL parameter syntax changed from `/:id` to `/{id}` for even more flexible routing, such as 279 | `/articles/{month}-{day}-{year}-{slug}`, `/articles/{id}`, and `/articles/{id}.{ext}` on the 280 | same router 281 | - Support for regexp for routing patterns, in the form of `/{paramKey:regExp}` for example: 282 | `r.Get("/articles/{name:[a-z]+}", h)` and `chi.URLParam(r, "name")` 283 | - Add `Method` and `MethodFunc` to `chi.Router` to allow routing definitions such as 284 | `r.Method("GET", "/", h)` which provides a cleaner interface for custom handlers like 285 | in `_examples/custom-handler` 286 | - Deprecating `mux#FileServer` helper function. Instead, we encourage users to create their 287 | own using file handler with the stdlib, see `_examples/fileserver` for an example 288 | - Add support for LINK/UNLINK http methods via `r.Method()` and `r.MethodFunc()` 289 | - Moved the chi project to its own organization, to allow chi-related community packages to 290 | be easily discovered and supported, at: https://github.com/go-chi 291 | - *NOTE:* please update your import paths to `"github.com/go-chi/chi"` 292 | - *NOTE:* chi v2 is still available at https://github.com/go-chi/chi/tree/v2 293 | 294 | 295 | ## v2.1.0 (2017-03-30) 296 | 297 | - Minor improvements and update to the chi core library 298 | - Introduced a brand new `chi/render` sub-package to complete the story of building 299 | APIs to offer a pattern for managing well-defined request / response payloads. Please 300 | check out the updated `_examples/rest` example for how it works. 301 | - Added `MethodNotAllowed(h http.HandlerFunc)` to chi.Router interface 302 | 303 | 304 | ## v2.0.0 (2017-01-06) 305 | 306 | - After many months of v2 being in an RC state with many companies and users running it in 307 | production, the inclusion of some improvements to the middlewares, we are very pleased to 308 | announce v2.0.0 of chi. 309 | 310 | 311 | ## v2.0.0-rc1 (2016-07-26) 312 | 313 | - Huge update! chi v2 is a large refactor targeting Go 1.7+. As of Go 1.7, the popular 314 | community `"net/context"` package has been included in the standard library as `"context"` and 315 | utilized by `"net/http"` and `http.Request` to managing deadlines, cancelation signals and other 316 | request-scoped values. We're very excited about the new context addition and are proud to 317 | introduce chi v2, a minimal and powerful routing package for building large HTTP services, 318 | with zero external dependencies. Chi focuses on idiomatic design and encourages the use of 319 | stdlib HTTP handlers and middlewares. 320 | - chi v2 deprecates its `chi.Handler` interface and requires `http.Handler` or `http.HandlerFunc` 321 | - chi v2 stores URL routing parameters and patterns in the standard request context: `r.Context()` 322 | - chi v2 lower-level routing context is accessible by `chi.RouteContext(r.Context()) *chi.Context`, 323 | which provides direct access to URL routing parameters, the routing path and the matching 324 | routing patterns. 325 | - Users upgrading from chi v1 to v2, need to: 326 | 1. Update the old chi.Handler signature, `func(ctx context.Context, w http.ResponseWriter, r *http.Request)` to 327 | the standard http.Handler: `func(w http.ResponseWriter, r *http.Request)` 328 | 2. Use `chi.URLParam(r *http.Request, paramKey string) string` 329 | or `URLParamFromCtx(ctx context.Context, paramKey string) string` to access a url parameter value 330 | 331 | 332 | ## v1.0.0 (2016-07-01) 333 | 334 | - Released chi v1 stable https://github.com/go-chi/chi/tree/v1.0.0 for Go 1.6 and older. 335 | 336 | 337 | ## v0.9.0 (2016-03-31) 338 | 339 | - Reuse context objects via sync.Pool for zero-allocation routing [#33](https://github.com/go-chi/chi/pull/33) 340 | - BREAKING NOTE: due to subtle API changes, previously `chi.URLParams(ctx)["id"]` used to access url parameters 341 | has changed to: `chi.URLParam(ctx, "id")` 342 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | ## Prerequisites 4 | 5 | 1. [Install Go][go-install]. 6 | 2. Download the sources and switch the working directory: 7 | 8 | ```bash 9 | go get -u -d github.com/go-chi/chi 10 | cd $GOPATH/src/github.com/go-chi/chi 11 | ``` 12 | 13 | ## Submitting a Pull Request 14 | 15 | A typical workflow is: 16 | 17 | 1. [Fork the repository.][fork] 18 | 2. [Create a topic branch.][branch] 19 | 3. Add tests for your change. 20 | 4. Run `go test`. If your tests pass, return to the step 3. 21 | 5. Implement the change and ensure the steps from the previous step pass. 22 | 6. Run `goimports -w .`, to ensure the new code conforms to Go formatting guideline. 23 | 7. [Add, commit and push your changes.][git-help] 24 | 8. [Submit a pull request.][pull-req] 25 | 26 | [go-install]: https://golang.org/doc/install 27 | [fork]: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/fork-a-repo 28 | [branch]: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-branches 29 | [git-help]: https://docs.github.com/en 30 | [pull-req]: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests 31 | 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015-present Peter Kieltyka (https://github.com/pkieltyka), Google Inc. 2 | 3 | MIT License 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all 2 | all: 3 | @echo "**********************************************************" 4 | @echo "** chi build tool **" 5 | @echo "**********************************************************" 6 | 7 | 8 | .PHONY: test 9 | test: 10 | go clean -testcache && $(MAKE) test-router && $(MAKE) test-middleware 11 | 12 | .PHONY: test-router 13 | test-router: 14 | go test -race -v . 15 | 16 | .PHONY: test-middleware 17 | test-middleware: 18 | go test -race -v ./middleware 19 | 20 | .PHONY: docs 21 | docs: 22 | npx docsify-cli serve ./docs 23 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Reporting Security Issues 2 | 3 | We appreciate your efforts to responsibly disclose your findings, and will make every effort to acknowledge your contributions. 4 | 5 | To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/go-chi/chi/security/advisories/new) tab. 6 | -------------------------------------------------------------------------------- /_examples/README.md: -------------------------------------------------------------------------------- 1 | chi examples 2 | ============ 3 | 4 | * [custom-handler](https://github.com/go-chi/chi/blob/master/_examples/custom-handler/main.go) - Use a custom handler function signature 5 | * [custom-method](https://github.com/go-chi/chi/blob/master/_examples/custom-method/main.go) - Add a custom HTTP method 6 | * [fileserver](https://github.com/go-chi/chi/blob/master/_examples/fileserver/main.go) - Easily serve static files 7 | * [graceful](https://github.com/go-chi/chi/blob/master/_examples/graceful/main.go) - Graceful context signaling and server shutdown 8 | * [hello-world](https://github.com/go-chi/chi/blob/master/_examples/hello-world/main.go) - Hello World! 9 | * [limits](https://github.com/go-chi/chi/blob/master/_examples/limits/main.go) - Timeouts and Throttling 10 | * [logging](https://github.com/go-chi/chi/blob/master/_examples/logging/main.go) - Easy structured logging for any backend 11 | * [rest](https://github.com/go-chi/chi/blob/master/_examples/rest/main.go) - REST APIs made easy, productive and maintainable 12 | * [router-walk](https://github.com/go-chi/chi/blob/master/_examples/router-walk/main.go) - Print to stdout a router's routes 13 | * [todos-resource](https://github.com/go-chi/chi/blob/master/_examples/todos-resource/main.go) - Struct routers/handlers, an example of another code layout style 14 | * [versions](https://github.com/go-chi/chi/blob/master/_examples/versions/main.go) - Demo of `chi/render` subpkg 15 | 16 | 17 | ## Usage 18 | 19 | 1. `go get -v -d -u ./...` - fetch example deps 20 | 2. `cd /` ie. `cd rest/` 21 | 3. `go run *.go` - note, example services run on port 3333 22 | 4. Open another terminal and use curl to send some requests to your example service, 23 | `curl -v http://localhost:3333/` 24 | 5. Read /main.go source to learn how service works and read comments for usage 25 | -------------------------------------------------------------------------------- /_examples/chi.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /_examples/custom-handler/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | 7 | "github.com/go-chi/chi/v5" 8 | ) 9 | 10 | type Handler func(w http.ResponseWriter, r *http.Request) error 11 | 12 | func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 13 | if err := h(w, r); err != nil { 14 | // handle returned error here. 15 | w.WriteHeader(503) 16 | w.Write([]byte("bad")) 17 | } 18 | } 19 | 20 | func main() { 21 | r := chi.NewRouter() 22 | r.Method("GET", "/", Handler(customHandler)) 23 | http.ListenAndServe(":3333", r) 24 | } 25 | 26 | func customHandler(w http.ResponseWriter, r *http.Request) error { 27 | q := r.URL.Query().Get("err") 28 | 29 | if q != "" { 30 | return errors.New(q) 31 | } 32 | 33 | w.Write([]byte("foo")) 34 | return nil 35 | } 36 | -------------------------------------------------------------------------------- /_examples/custom-method/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/go-chi/chi/v5" 7 | "github.com/go-chi/chi/v5/middleware" 8 | ) 9 | 10 | func init() { 11 | chi.RegisterMethod("LINK") 12 | chi.RegisterMethod("UNLINK") 13 | chi.RegisterMethod("WOOHOO") 14 | } 15 | 16 | func main() { 17 | r := chi.NewRouter() 18 | r.Use(middleware.RequestID) 19 | r.Use(middleware.Logger) 20 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 21 | w.Write([]byte("hello world")) 22 | }) 23 | r.MethodFunc("LINK", "/link", func(w http.ResponseWriter, r *http.Request) { 24 | w.Write([]byte("custom link method")) 25 | }) 26 | r.MethodFunc("WOOHOO", "/woo", func(w http.ResponseWriter, r *http.Request) { 27 | w.Write([]byte("custom woohoo method")) 28 | }) 29 | r.HandleFunc("/everything", func(w http.ResponseWriter, r *http.Request) { 30 | w.Write([]byte("capturing all standard http methods, as well as LINK, UNLINK and WOOHOO")) 31 | }) 32 | http.ListenAndServe(":3333", r) 33 | } 34 | -------------------------------------------------------------------------------- /_examples/fileserver/data/notes.txt: -------------------------------------------------------------------------------- 1 | Notessszzz 2 | -------------------------------------------------------------------------------- /_examples/fileserver/main.go: -------------------------------------------------------------------------------- 1 | // This example demonstrates how to serve static files from your filesystem. 2 | // 3 | // Boot the server: 4 | // 5 | // $ go run main.go 6 | // 7 | // Client requests: 8 | // 9 | // $ curl http://localhost:3333/files/ 10 | //
11 | //	notes.txt
12 | //	
13 | // 14 | // $ curl http://localhost:3333/files/notes.txt 15 | // Notessszzz 16 | package main 17 | 18 | import ( 19 | "net/http" 20 | "os" 21 | "path/filepath" 22 | "strings" 23 | 24 | "github.com/go-chi/chi/v5" 25 | "github.com/go-chi/chi/v5/middleware" 26 | ) 27 | 28 | func main() { 29 | r := chi.NewRouter() 30 | r.Use(middleware.Logger) 31 | 32 | // Index handler 33 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 34 | w.Write([]byte("hi")) 35 | }) 36 | 37 | // Create a route along /files that will serve contents from 38 | // the ./data/ folder. 39 | workDir, _ := os.Getwd() 40 | filesDir := http.Dir(filepath.Join(workDir, "data")) 41 | FileServer(r, "/files", filesDir) 42 | 43 | http.ListenAndServe(":3333", r) 44 | } 45 | 46 | // FileServer conveniently sets up a http.FileServer handler to serve 47 | // static files from a http.FileSystem. 48 | func FileServer(r chi.Router, path string, root http.FileSystem) { 49 | if strings.ContainsAny(path, "{}*") { 50 | panic("FileServer does not permit any URL parameters.") 51 | } 52 | 53 | if path != "/" && path[len(path)-1] != '/' { 54 | r.Get(path, http.RedirectHandler(path+"/", 301).ServeHTTP) 55 | path += "/" 56 | } 57 | path += "*" 58 | 59 | r.Get(path, func(w http.ResponseWriter, r *http.Request) { 60 | rctx := chi.RouteContext(r.Context()) 61 | pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*") 62 | fs := http.StripPrefix(pathPrefix, http.FileServer(root)) 63 | fs.ServeHTTP(w, r) 64 | }) 65 | } 66 | -------------------------------------------------------------------------------- /_examples/graceful/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net/http" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | "time" 12 | 13 | "github.com/go-chi/chi/v5" 14 | "github.com/go-chi/chi/v5/middleware" 15 | ) 16 | 17 | func main() { 18 | // The HTTP Server 19 | server := &http.Server{Addr: "0.0.0.0:3333", Handler: service()} 20 | 21 | // Server run context 22 | serverCtx, serverStopCtx := context.WithCancel(context.Background()) 23 | 24 | // Listen for syscall signals for process to interrupt/quit 25 | sig := make(chan os.Signal, 1) 26 | signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) 27 | go func() { 28 | <-sig 29 | 30 | // Shutdown signal with grace period of 30 seconds 31 | shutdownCtx, _ := context.WithTimeout(serverCtx, 30*time.Second) 32 | 33 | go func() { 34 | <-shutdownCtx.Done() 35 | if shutdownCtx.Err() == context.DeadlineExceeded { 36 | log.Fatal("graceful shutdown timed out.. forcing exit.") 37 | } 38 | }() 39 | 40 | // Trigger graceful shutdown 41 | err := server.Shutdown(shutdownCtx) 42 | if err != nil { 43 | log.Fatal(err) 44 | } 45 | serverStopCtx() 46 | }() 47 | 48 | // Run the server 49 | err := server.ListenAndServe() 50 | if err != nil && err != http.ErrServerClosed { 51 | log.Fatal(err) 52 | } 53 | 54 | // Wait for server context to be stopped 55 | <-serverCtx.Done() 56 | } 57 | 58 | func service() http.Handler { 59 | r := chi.NewRouter() 60 | 61 | r.Use(middleware.RequestID) 62 | r.Use(middleware.Logger) 63 | 64 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 65 | w.Write([]byte("sup")) 66 | }) 67 | 68 | r.Get("/slow", func(w http.ResponseWriter, r *http.Request) { 69 | // Simulates some hard work. 70 | // 71 | // We want this handler to complete successfully during a shutdown signal, 72 | // so consider the work here as some background routine to fetch a long running 73 | // search query to find as many results as possible, but, instead we cut it short 74 | // and respond with what we have so far. How a shutdown is handled is entirely 75 | // up to the developer, as some code blocks are preemptible, and others are not. 76 | time.Sleep(5 * time.Second) 77 | 78 | w.Write([]byte(fmt.Sprintf("all done.\n"))) 79 | }) 80 | 81 | return r 82 | } 83 | -------------------------------------------------------------------------------- /_examples/hello-world/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/go-chi/chi/v5" 7 | "github.com/go-chi/chi/v5/middleware" 8 | ) 9 | 10 | func main() { 11 | r := chi.NewRouter() 12 | r.Use(middleware.RequestID) 13 | r.Use(middleware.Logger) 14 | r.Use(middleware.Recoverer) 15 | 16 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 17 | w.Write([]byte("hello world")) 18 | }) 19 | 20 | http.ListenAndServe(":3333", r) 21 | } 22 | -------------------------------------------------------------------------------- /_examples/limits/main.go: -------------------------------------------------------------------------------- 1 | // This example demonstrates the use of Timeout, and Throttle middlewares. 2 | // 3 | // Timeout: cancel a request if processing takes longer than 2.5 seconds, 4 | // server will respond with a http.StatusGatewayTimeout. 5 | // 6 | // Throttle: limit the number of in-flight requests along a particular 7 | // routing path and backlog the others. 8 | package main 9 | 10 | import ( 11 | "context" 12 | "fmt" 13 | "math/rand" 14 | "net/http" 15 | "time" 16 | 17 | "github.com/go-chi/chi/v5" 18 | "github.com/go-chi/chi/v5/middleware" 19 | ) 20 | 21 | func main() { 22 | r := chi.NewRouter() 23 | 24 | r.Use(middleware.RequestID) 25 | r.Use(middleware.Logger) 26 | r.Use(middleware.Recoverer) 27 | 28 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 29 | w.Write([]byte("root.")) 30 | }) 31 | 32 | r.Get("/ping", func(w http.ResponseWriter, r *http.Request) { 33 | w.Write([]byte("pong")) 34 | }) 35 | 36 | r.Get("/panic", func(w http.ResponseWriter, r *http.Request) { 37 | panic("test") 38 | }) 39 | 40 | // Slow handlers/operations. 41 | r.Group(func(r chi.Router) { 42 | // Stop processing after 2.5 seconds. 43 | r.Use(middleware.Timeout(2500 * time.Millisecond)) 44 | 45 | r.Get("/slow", func(w http.ResponseWriter, r *http.Request) { 46 | rand.Seed(time.Now().Unix()) 47 | 48 | // Processing will take 1-5 seconds. 49 | processTime := time.Duration(rand.Intn(4)+1) * time.Second 50 | 51 | select { 52 | case <-r.Context().Done(): 53 | return 54 | 55 | case <-time.After(processTime): 56 | // The above channel simulates some hard work. 57 | } 58 | 59 | w.Write([]byte(fmt.Sprintf("Processed in %v seconds\n", processTime))) 60 | }) 61 | }) 62 | 63 | // Throttle very expensive handlers/operations. 64 | r.Group(func(r chi.Router) { 65 | // Stop processing after 30 seconds. 66 | r.Use(middleware.Timeout(30 * time.Second)) 67 | 68 | // Only one request will be processed at a time. 69 | r.Use(middleware.Throttle(1)) 70 | 71 | r.Get("/throttled", func(w http.ResponseWriter, r *http.Request) { 72 | select { 73 | case <-r.Context().Done(): 74 | switch r.Context().Err() { 75 | case context.DeadlineExceeded: 76 | w.WriteHeader(504) 77 | w.Write([]byte("Processing too slow\n")) 78 | default: 79 | w.Write([]byte("Canceled\n")) 80 | } 81 | return 82 | 83 | case <-time.After(5 * time.Second): 84 | // The above channel simulates some hard work. 85 | } 86 | 87 | w.Write([]byte("Processed\n")) 88 | }) 89 | }) 90 | 91 | http.ListenAndServe(":3333", r) 92 | } 93 | -------------------------------------------------------------------------------- /_examples/logging/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // Please see https://github.com/go-chi/httplog for a complete package 4 | // and example for writing a structured logger on chi built on 5 | // the Go 1.21+ "log/slog" package. 6 | 7 | func main() { 8 | // See https://github.com/go-chi/httplog/blob/master/_example/main.go 9 | } 10 | -------------------------------------------------------------------------------- /_examples/rest/go.mod: -------------------------------------------------------------------------------- 1 | module rest-example 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/go-chi/chi/v5 v5.0.1 7 | github.com/go-chi/docgen v1.2.0 8 | github.com/go-chi/render v1.0.1 9 | ) 10 | -------------------------------------------------------------------------------- /_examples/rest/go.sum: -------------------------------------------------------------------------------- 1 | github.com/go-chi/chi/v5 v5.0.1 h1:ALxjCrTf1aflOlkhMnCUP86MubbWFrzB3gkRPReLpTo= 2 | github.com/go-chi/chi/v5 v5.0.1/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= 3 | github.com/go-chi/docgen v1.2.0 h1:da0Nq2PKU9W9pSOTUfVrKI1vIgTGpauo9cfh4Iwivek= 4 | github.com/go-chi/docgen v1.2.0/go.mod h1:G9W0G551cs2BFMSn/cnGwX+JBHEloAgo17MBhyrnhPI= 5 | github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= 6 | github.com/go-chi/render v1.0.1/go.mod h1:pq4Rr7HbnsdaeHagklXub+p6Wd16Af5l9koip1OvJns= 7 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 8 | gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= 9 | -------------------------------------------------------------------------------- /_examples/rest/routes.json: -------------------------------------------------------------------------------- 1 | { 2 | "router": { 3 | "middlewares": [ 4 | { 5 | "pkg": "github.com/go-chi/chi/v5/middleware", 6 | "func": "RequestID", 7 | "comment": "RequestID is a middleware that injects a request ID into the context of each\nrequest. A request ID is a string of the form \"host.example.com/random-0001\",\nwhere \"random\" is a base62 random string that uniquely identifies this go\nprocess, and where the last number is an atomically incremented request\ncounter.\n", 8 | "file": "github.com/go-chi/chi/middleware/request_id.go", 9 | "line": 63 10 | }, 11 | { 12 | "pkg": "github.com/go-chi/chi/v5/middleware", 13 | "func": "Logger", 14 | "comment": "Logger is a middleware that logs the start and end of each request, along\nwith some useful data about what was requested, what the response status was,\nand how long it took to return. When standard output is a TTY, Logger will\nprint in color, otherwise it will print in black and white. Logger prints a\nrequest ID if one is provided.\n\nAlternatively, look at https://github.com/pressly/lg and the `lg.RequestLogger`\nmiddleware pkg.\n", 15 | "file": "github.com/go-chi/chi/middleware/logger.go", 16 | "line": 26 17 | }, 18 | { 19 | "pkg": "github.com/go-chi/chi/v5/middleware", 20 | "func": "Recoverer", 21 | "comment": "Recoverer is a middleware that recovers from panics, logs the panic (and a\nbacktrace), and returns a HTTP 500 (Internal Server Error) status if\npossible. Recoverer prints a request ID if one is provided.\n\nAlternatively, look at https://github.com/pressly/lg middleware pkgs.\n", 22 | "file": "github.com/go-chi/chi/middleware/recoverer.go", 23 | "line": 18 24 | }, 25 | { 26 | "pkg": "github.com/go-chi/chi/v5/middleware", 27 | "func": "URLFormat", 28 | "comment": "URLFormat is a middleware that parses the url extension from a request path and stores it\non the context as a string under the key `middleware.URLFormatCtxKey`. The middleware will\ntrim the suffix from the routing path and continue routing.\n\nRouters should not include a url parameter for the suffix when using this middleware.\n\nSample usage.. for url paths: `/articles/1`, `/articles/1.json` and `/articles/1.xml`\n\n func routes() http.Handler {\n r := chi.NewRouter()\n r.Use(middleware.URLFormat)\n\n r.Get(\"/articles/{id}\", ListArticles)\n\n return r\n }\n\n func ListArticles(w http.ResponseWriter, r *http.Request) {\n\t urlFormat, _ := r.Context().Value(middleware.URLFormatCtxKey).(string)\n\n\t switch urlFormat {\n\t case \"json\":\n\t \trender.JSON(w, r, articles)\n\t case \"xml:\"\n\t \trender.XML(w, r, articles)\n\t default:\n\t \trender.JSON(w, r, articles)\n\t }\n}\n", 29 | "file": "github.com/go-chi/chi/middleware/url_format.go", 30 | "line": 45 31 | }, 32 | { 33 | "pkg": "github.com/go-chi/render", 34 | "func": "SetContentType.func1", 35 | "comment": "", 36 | "file": "github.com/go-chi/render/content_type.go", 37 | "line": 49, 38 | "anonymous": true 39 | } 40 | ], 41 | "routes": { 42 | "/": { 43 | "handlers": { 44 | "GET": { 45 | "middlewares": [], 46 | "method": "GET", 47 | "pkg": "", 48 | "func": "main.main.func1", 49 | "comment": "", 50 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 51 | "line": 69, 52 | "anonymous": true 53 | } 54 | } 55 | }, 56 | "/admin/*": { 57 | "router": { 58 | "middlewares": [ 59 | { 60 | "pkg": "", 61 | "func": "main.AdminOnly", 62 | "comment": "AdminOnly middleware restricts access to just administrators.\n", 63 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 64 | "line": 238 65 | } 66 | ], 67 | "routes": { 68 | "/": { 69 | "handlers": { 70 | "GET": { 71 | "middlewares": [], 72 | "method": "GET", 73 | "pkg": "", 74 | "func": "main.adminRouter.func1", 75 | "comment": "", 76 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 77 | "line": 225, 78 | "anonymous": true 79 | } 80 | } 81 | }, 82 | "/accounts": { 83 | "handlers": { 84 | "GET": { 85 | "middlewares": [], 86 | "method": "GET", 87 | "pkg": "", 88 | "func": "main.adminRouter.func2", 89 | "comment": "", 90 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 91 | "line": 228, 92 | "anonymous": true 93 | } 94 | } 95 | }, 96 | "/users/{userId}": { 97 | "handlers": { 98 | "GET": { 99 | "middlewares": [], 100 | "method": "GET", 101 | "pkg": "", 102 | "func": "main.adminRouter.func3", 103 | "comment": "", 104 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 105 | "line": 231, 106 | "anonymous": true 107 | } 108 | } 109 | } 110 | } 111 | } 112 | }, 113 | "/articles/*": { 114 | "router": { 115 | "middlewares": [], 116 | "routes": { 117 | "/": { 118 | "handlers": { 119 | "GET": { 120 | "middlewares": [ 121 | { 122 | "pkg": "", 123 | "func": "main.paginate", 124 | "comment": "paginate is a stub, but very possible to implement middleware logic\nto handle the request params for handling a paginated request.\n", 125 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 126 | "line": 251 127 | } 128 | ], 129 | "method": "GET", 130 | "pkg": "", 131 | "func": "main.ListArticles", 132 | "comment": "", 133 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 134 | "line": 117 135 | }, 136 | "POST": { 137 | "middlewares": [], 138 | "method": "POST", 139 | "pkg": "", 140 | "func": "main.CreateArticle", 141 | "comment": "CreateArticle persists the posted Article and returns it\nback to the client as an acknowledgement.\n", 142 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 143 | "line": 158 144 | } 145 | } 146 | }, 147 | "/search": { 148 | "handlers": { 149 | "GET": { 150 | "middlewares": [], 151 | "method": "GET", 152 | "pkg": "", 153 | "func": "main.SearchArticles", 154 | "comment": "SearchArticles searches the Articles data for a matching article.\nIt's just a stub, but you get the idea.\n", 155 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 156 | "line": 152 157 | } 158 | } 159 | }, 160 | "/{articleID}/*": { 161 | "router": { 162 | "middlewares": [ 163 | { 164 | "pkg": "", 165 | "func": "main.ArticleCtx", 166 | "comment": "ArticleCtx middleware is used to load an Article object from\nthe URL parameters passed through as the request. In case\nthe Article could not be found, we stop here and return a 404.\n", 167 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 168 | "line": 127 169 | } 170 | ], 171 | "routes": { 172 | "/": { 173 | "handlers": { 174 | "DELETE": { 175 | "middlewares": [], 176 | "method": "DELETE", 177 | "pkg": "", 178 | "func": "main.DeleteArticle", 179 | "comment": "DeleteArticle removes an existing Article from our persistent store.\n", 180 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 181 | "line": 204 182 | }, 183 | "GET": { 184 | "middlewares": [], 185 | "method": "GET", 186 | "pkg": "", 187 | "func": "main.GetArticle", 188 | "comment": "GetArticle returns the specific Article. You'll notice it just\nfetches the Article right off the context, as its understood that\nif we made it this far, the Article must be on the context. In case\nits not due to a bug, then it will panic, and our Recoverer will save us.\n", 189 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 190 | "line": 176 191 | }, 192 | "PUT": { 193 | "middlewares": [], 194 | "method": "PUT", 195 | "pkg": "", 196 | "func": "main.UpdateArticle", 197 | "comment": "UpdateArticle updates an existing Article in our persistent store.\n", 198 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 199 | "line": 189 200 | } 201 | } 202 | } 203 | } 204 | } 205 | }, 206 | "/{articleSlug:[a-z-]+}": { 207 | "handlers": { 208 | "GET": { 209 | "middlewares": [ 210 | { 211 | "pkg": "", 212 | "func": "main.ArticleCtx", 213 | "comment": "ArticleCtx middleware is used to load an Article object from\nthe URL parameters passed through as the request. In case\nthe Article could not be found, we stop here and return a 404.\n", 214 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 215 | "line": 127 216 | } 217 | ], 218 | "method": "GET", 219 | "pkg": "", 220 | "func": "main.GetArticle", 221 | "comment": "GetArticle returns the specific Article. You'll notice it just\nfetches the Article right off the context, as its understood that\nif we made it this far, the Article must be on the context. In case\nits not due to a bug, then it will panic, and our Recoverer will save us.\n", 222 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 223 | "line": 176 224 | } 225 | } 226 | } 227 | } 228 | } 229 | }, 230 | "/panic": { 231 | "handlers": { 232 | "GET": { 233 | "middlewares": [], 234 | "method": "GET", 235 | "pkg": "", 236 | "func": "main.main.func3", 237 | "comment": "", 238 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 239 | "line": 77, 240 | "anonymous": true 241 | } 242 | } 243 | }, 244 | "/ping": { 245 | "handlers": { 246 | "GET": { 247 | "middlewares": [], 248 | "method": "GET", 249 | "pkg": "", 250 | "func": "main.main.func2", 251 | "comment": "", 252 | "file": "github.com/go-chi/chi/_examples/rest/main.go", 253 | "line": 73, 254 | "anonymous": true 255 | } 256 | } 257 | } 258 | } 259 | } 260 | } 261 | -------------------------------------------------------------------------------- /_examples/rest/routes.md: -------------------------------------------------------------------------------- 1 | # github.com/go-chi/chi 2 | 3 | Welcome to the chi/_examples/rest generated docs. 4 | 5 | ## Routes 6 | 7 |
8 | `/` 9 | 10 | - [RequestID](/middleware/request_id.go#L63) 11 | - [Logger](/middleware/logger.go#L26) 12 | - [Recoverer](/middleware/recoverer.go#L18) 13 | - [URLFormat](/middleware/url_format.go#L45) 14 | - [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49) 15 | - **/** 16 | - _GET_ 17 | - [main.main.func1](/_examples/rest/main.go#L69) 18 | 19 |
20 |
21 | `/admin/*` 22 | 23 | - [RequestID](/middleware/request_id.go#L63) 24 | - [Logger](/middleware/logger.go#L26) 25 | - [Recoverer](/middleware/recoverer.go#L18) 26 | - [URLFormat](/middleware/url_format.go#L45) 27 | - [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49) 28 | - **/admin/*** 29 | - [main.AdminOnly](/_examples/rest/main.go#L238) 30 | - **/** 31 | - _GET_ 32 | - [main.adminRouter.func1](/_examples/rest/main.go#L225) 33 | 34 |
35 |
36 | `/admin/*/accounts` 37 | 38 | - [RequestID](/middleware/request_id.go#L63) 39 | - [Logger](/middleware/logger.go#L26) 40 | - [Recoverer](/middleware/recoverer.go#L18) 41 | - [URLFormat](/middleware/url_format.go#L45) 42 | - [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49) 43 | - **/admin/*** 44 | - [main.AdminOnly](/_examples/rest/main.go#L238) 45 | - **/accounts** 46 | - _GET_ 47 | - [main.adminRouter.func2](/_examples/rest/main.go#L228) 48 | 49 |
50 |
51 | `/admin/*/users/{userId}` 52 | 53 | - [RequestID](/middleware/request_id.go#L63) 54 | - [Logger](/middleware/logger.go#L26) 55 | - [Recoverer](/middleware/recoverer.go#L18) 56 | - [URLFormat](/middleware/url_format.go#L45) 57 | - [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49) 58 | - **/admin/*** 59 | - [main.AdminOnly](/_examples/rest/main.go#L238) 60 | - **/users/{userId}** 61 | - _GET_ 62 | - [main.adminRouter.func3](/_examples/rest/main.go#L231) 63 | 64 |
65 |
66 | `/articles/*` 67 | 68 | - [RequestID](/middleware/request_id.go#L63) 69 | - [Logger](/middleware/logger.go#L26) 70 | - [Recoverer](/middleware/recoverer.go#L18) 71 | - [URLFormat](/middleware/url_format.go#L45) 72 | - [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49) 73 | - **/articles/*** 74 | - **/** 75 | - _GET_ 76 | - [main.paginate](/_examples/rest/main.go#L251) 77 | - [main.ListArticles](/_examples/rest/main.go#L117) 78 | - _POST_ 79 | - [main.CreateArticle](/_examples/rest/main.go#L158) 80 | 81 |
82 |
83 | `/articles/*/search` 84 | 85 | - [RequestID](/middleware/request_id.go#L63) 86 | - [Logger](/middleware/logger.go#L26) 87 | - [Recoverer](/middleware/recoverer.go#L18) 88 | - [URLFormat](/middleware/url_format.go#L45) 89 | - [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49) 90 | - **/articles/*** 91 | - **/search** 92 | - _GET_ 93 | - [main.SearchArticles](/_examples/rest/main.go#L152) 94 | 95 |
96 |
97 | `/articles/*/{articleID}/*` 98 | 99 | - [RequestID](/middleware/request_id.go#L63) 100 | - [Logger](/middleware/logger.go#L26) 101 | - [Recoverer](/middleware/recoverer.go#L18) 102 | - [URLFormat](/middleware/url_format.go#L45) 103 | - [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49) 104 | - **/articles/*** 105 | - **/{articleID}/*** 106 | - [main.ArticleCtx](/_examples/rest/main.go#L127) 107 | - **/** 108 | - _DELETE_ 109 | - [main.DeleteArticle](/_examples/rest/main.go#L204) 110 | - _GET_ 111 | - [main.GetArticle](/_examples/rest/main.go#L176) 112 | - _PUT_ 113 | - [main.UpdateArticle](/_examples/rest/main.go#L189) 114 | 115 |
116 |
117 | `/articles/*/{articleSlug:[a-z-]+}` 118 | 119 | - [RequestID](/middleware/request_id.go#L63) 120 | - [Logger](/middleware/logger.go#L26) 121 | - [Recoverer](/middleware/recoverer.go#L18) 122 | - [URLFormat](/middleware/url_format.go#L45) 123 | - [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49) 124 | - **/articles/*** 125 | - **/{articleSlug:[a-z-]+}** 126 | - _GET_ 127 | - [main.ArticleCtx](/_examples/rest/main.go#L127) 128 | - [main.GetArticle](/_examples/rest/main.go#L176) 129 | 130 |
131 |
132 | `/panic` 133 | 134 | - [RequestID](/middleware/request_id.go#L63) 135 | - [Logger](/middleware/logger.go#L26) 136 | - [Recoverer](/middleware/recoverer.go#L18) 137 | - [URLFormat](/middleware/url_format.go#L45) 138 | - [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49) 139 | - **/panic** 140 | - _GET_ 141 | - [main.main.func3](/_examples/rest/main.go#L77) 142 | 143 |
144 |
145 | `/ping` 146 | 147 | - [RequestID](/middleware/request_id.go#L63) 148 | - [Logger](/middleware/logger.go#L26) 149 | - [Recoverer](/middleware/recoverer.go#L18) 150 | - [URLFormat](/middleware/url_format.go#L45) 151 | - [SetContentType.func1](https://github.com/go-chi/render/content_type.go#L49) 152 | - **/ping** 153 | - _GET_ 154 | - [main.main.func2](/_examples/rest/main.go#L73) 155 | 156 |
157 | 158 | Total # of routes: 10 159 | 160 | -------------------------------------------------------------------------------- /_examples/router-walk/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strings" 7 | 8 | "github.com/go-chi/chi/v5" 9 | ) 10 | 11 | func main() { 12 | r := chi.NewRouter() 13 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 14 | w.Write([]byte("root.")) 15 | }) 16 | 17 | r.Route("/road", func(r chi.Router) { 18 | r.Get("/left", func(w http.ResponseWriter, r *http.Request) { 19 | w.Write([]byte("left road")) 20 | }) 21 | r.Post("/right", func(w http.ResponseWriter, r *http.Request) { 22 | w.Write([]byte("right road")) 23 | }) 24 | }) 25 | 26 | r.Put("/ping", Ping) 27 | 28 | walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error { 29 | route = strings.Replace(route, "/*/", "/", -1) 30 | fmt.Printf("%s %s\n", method, route) 31 | return nil 32 | } 33 | 34 | if err := chi.Walk(r, walkFunc); err != nil { 35 | fmt.Printf("Logging err: %s\n", err.Error()) 36 | } 37 | } 38 | 39 | // Ping returns pong 40 | func Ping(w http.ResponseWriter, r *http.Request) { 41 | w.Write([]byte("pong")) 42 | } 43 | -------------------------------------------------------------------------------- /_examples/todos-resource/main.go: -------------------------------------------------------------------------------- 1 | // This example demonstrates a project structure that defines a subrouter and its 2 | // handlers on a struct, and mounting them as subrouters to a parent router. 3 | // See also _examples/rest for an in-depth example of a REST service, and apply 4 | // those same patterns to this structure. 5 | package main 6 | 7 | import ( 8 | "net/http" 9 | 10 | "github.com/go-chi/chi/v5" 11 | "github.com/go-chi/chi/v5/middleware" 12 | ) 13 | 14 | func main() { 15 | r := chi.NewRouter() 16 | 17 | r.Use(middleware.RequestID) 18 | r.Use(middleware.RealIP) 19 | r.Use(middleware.Logger) 20 | r.Use(middleware.Recoverer) 21 | 22 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 23 | w.Write([]byte(".")) 24 | }) 25 | 26 | r.Mount("/users", usersResource{}.Routes()) 27 | r.Mount("/todos", todosResource{}.Routes()) 28 | 29 | http.ListenAndServe(":3333", r) 30 | } 31 | -------------------------------------------------------------------------------- /_examples/todos-resource/todos.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/go-chi/chi/v5" 7 | ) 8 | 9 | type todosResource struct{} 10 | 11 | // Routes creates a REST router for the todos resource 12 | func (rs todosResource) Routes() chi.Router { 13 | r := chi.NewRouter() 14 | // r.Use() // some middleware.. 15 | 16 | r.Get("/", rs.List) // GET /todos - read a list of todos 17 | r.Post("/", rs.Create) // POST /todos - create a new todo and persist it 18 | r.Put("/", rs.Delete) 19 | 20 | r.Route("/{id}", func(r chi.Router) { 21 | // r.Use(rs.TodoCtx) // lets have a todos map, and lets actually load/manipulate 22 | r.Get("/", rs.Get) // GET /todos/{id} - read a single todo by :id 23 | r.Put("/", rs.Update) // PUT /todos/{id} - update a single todo by :id 24 | r.Delete("/", rs.Delete) // DELETE /todos/{id} - delete a single todo by :id 25 | r.Get("/sync", rs.Sync) 26 | }) 27 | 28 | return r 29 | } 30 | 31 | func (rs todosResource) List(w http.ResponseWriter, r *http.Request) { 32 | w.Write([]byte("todos list of stuff..")) 33 | } 34 | 35 | func (rs todosResource) Create(w http.ResponseWriter, r *http.Request) { 36 | w.Write([]byte("todos create")) 37 | } 38 | 39 | func (rs todosResource) Get(w http.ResponseWriter, r *http.Request) { 40 | w.Write([]byte("todo get")) 41 | } 42 | 43 | func (rs todosResource) Update(w http.ResponseWriter, r *http.Request) { 44 | w.Write([]byte("todo update")) 45 | } 46 | 47 | func (rs todosResource) Delete(w http.ResponseWriter, r *http.Request) { 48 | w.Write([]byte("todo delete")) 49 | } 50 | 51 | func (rs todosResource) Sync(w http.ResponseWriter, r *http.Request) { 52 | w.Write([]byte("todo sync")) 53 | } 54 | -------------------------------------------------------------------------------- /_examples/todos-resource/users.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/go-chi/chi/v5" 7 | ) 8 | 9 | type usersResource struct{} 10 | 11 | // Routes creates a REST router for the todos resource 12 | func (rs usersResource) Routes() chi.Router { 13 | r := chi.NewRouter() 14 | // r.Use() // some middleware.. 15 | 16 | r.Get("/", rs.List) // GET /users - read a list of users 17 | r.Post("/", rs.Create) // POST /users - create a new user and persist it 18 | r.Put("/", rs.Delete) 19 | 20 | r.Route("/{id}", func(r chi.Router) { 21 | // r.Use(rs.TodoCtx) // lets have a users map, and lets actually load/manipulate 22 | r.Get("/", rs.Get) // GET /users/{id} - read a single user by :id 23 | r.Put("/", rs.Update) // PUT /users/{id} - update a single user by :id 24 | r.Delete("/", rs.Delete) // DELETE /users/{id} - delete a single user by :id 25 | }) 26 | 27 | return r 28 | } 29 | 30 | func (rs usersResource) List(w http.ResponseWriter, r *http.Request) { 31 | w.Write([]byte("users list of stuff..")) 32 | } 33 | 34 | func (rs usersResource) Create(w http.ResponseWriter, r *http.Request) { 35 | w.Write([]byte("users create")) 36 | } 37 | 38 | func (rs usersResource) Get(w http.ResponseWriter, r *http.Request) { 39 | w.Write([]byte("user get")) 40 | } 41 | 42 | func (rs usersResource) Update(w http.ResponseWriter, r *http.Request) { 43 | w.Write([]byte("user update")) 44 | } 45 | 46 | func (rs usersResource) Delete(w http.ResponseWriter, r *http.Request) { 47 | w.Write([]byte("user delete")) 48 | } 49 | -------------------------------------------------------------------------------- /_examples/versions/data/article.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | // Article is runtime object, that's not meant to be sent via REST. 4 | type Article struct { 5 | ID int `db:"id" json:"id" xml:"id"` 6 | Title string `db:"title" json:"title" xml:"title"` 7 | Data []string `db:"data,stringarray" json:"data" xml:"data"` 8 | CustomDataForAuthUsers string `db:"custom_data" json:"-" xml:"-"` 9 | } 10 | -------------------------------------------------------------------------------- /_examples/versions/data/errors.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | 7 | "github.com/go-chi/render" 8 | ) 9 | 10 | var ( 11 | ErrUnauthorized = errors.New("Unauthorized") 12 | ErrForbidden = errors.New("Forbidden") 13 | ErrNotFound = errors.New("Resource not found") 14 | ) 15 | 16 | func PresentError(r *http.Request, err error) (*http.Request, interface{}) { 17 | switch err { 18 | case ErrUnauthorized: 19 | render.Status(r, 401) 20 | case ErrForbidden: 21 | render.Status(r, 403) 22 | case ErrNotFound: 23 | render.Status(r, 404) 24 | default: 25 | render.Status(r, 500) 26 | } 27 | return r, map[string]string{"error": err.Error()} 28 | } 29 | -------------------------------------------------------------------------------- /_examples/versions/go.mod: -------------------------------------------------------------------------------- 1 | module versions 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/go-chi/chi/v5 v5.1.0 7 | github.com/go-chi/render v1.0.3 8 | ) 9 | 10 | require github.com/ajg/form v1.5.1 // indirect 11 | -------------------------------------------------------------------------------- /_examples/versions/go.sum: -------------------------------------------------------------------------------- 1 | github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= 2 | github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= 3 | github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= 4 | github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= 5 | github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= 6 | github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= 7 | -------------------------------------------------------------------------------- /_examples/versions/main.go: -------------------------------------------------------------------------------- 1 | // This example demonstrates the use of the render subpackage, with 2 | // a quick concept for how to support multiple api versions. 3 | package main 4 | 5 | import ( 6 | "context" 7 | "errors" 8 | "fmt" 9 | "math/rand" 10 | "net/http" 11 | "time" 12 | 13 | "github.com/go-chi/chi/v5" 14 | "github.com/go-chi/chi/v5/_examples/versions/data" 15 | v1 "github.com/go-chi/chi/v5/_examples/versions/presenter/v1" 16 | v2 "github.com/go-chi/chi/v5/_examples/versions/presenter/v2" 17 | v3 "github.com/go-chi/chi/v5/_examples/versions/presenter/v3" 18 | "github.com/go-chi/chi/v5/middleware" 19 | "github.com/go-chi/render" 20 | ) 21 | 22 | func main() { 23 | r := chi.NewRouter() 24 | 25 | r.Use(middleware.RequestID) 26 | r.Use(middleware.Logger) 27 | r.Use(middleware.Recoverer) 28 | 29 | // API version 3. 30 | r.Route("/v3", func(r chi.Router) { 31 | r.Use(apiVersionCtx("v3")) 32 | r.Mount("/articles", articleRouter()) 33 | }) 34 | 35 | // API version 2. 36 | r.Route("/v2", func(r chi.Router) { 37 | r.Use(apiVersionCtx("v2")) 38 | r.Mount("/articles", articleRouter()) 39 | }) 40 | 41 | // API version 1. 42 | r.Route("/v1", func(r chi.Router) { 43 | r.Use(randomErrorMiddleware) // Simulate random error, ie. version 1 is buggy. 44 | r.Use(apiVersionCtx("v1")) 45 | r.Mount("/articles", articleRouter()) 46 | }) 47 | 48 | http.ListenAndServe(":3333", r) 49 | } 50 | 51 | func apiVersionCtx(version string) func(next http.Handler) http.Handler { 52 | return func(next http.Handler) http.Handler { 53 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 54 | r = r.WithContext(context.WithValue(r.Context(), "api.version", version)) 55 | next.ServeHTTP(w, r) 56 | }) 57 | } 58 | } 59 | 60 | func articleRouter() http.Handler { 61 | r := chi.NewRouter() 62 | r.Get("/", listArticles) 63 | r.Route("/{articleID}", func(r chi.Router) { 64 | r.Get("/", getArticle) 65 | // r.Put("/", updateArticle) 66 | // r.Delete("/", deleteArticle) 67 | }) 68 | return r 69 | } 70 | 71 | func listArticles(w http.ResponseWriter, r *http.Request) { 72 | articles := make(chan render.Renderer, 5) 73 | 74 | // Load data asynchronously into the channel (simulate slow storage): 75 | go func() { 76 | for i := 1; i <= 10; i++ { 77 | article := &data.Article{ 78 | ID: i, 79 | Title: fmt.Sprintf("Article #%v", i), 80 | Data: []string{"one", "two", "three", "four"}, 81 | CustomDataForAuthUsers: "secret data for auth'd users only", 82 | } 83 | 84 | apiVersion := r.Context().Value("api.version").(string) 85 | switch apiVersion { 86 | case "v1": 87 | articles <- v1.NewArticleResponse(article) 88 | case "v2": 89 | articles <- v2.NewArticleResponse(article) 90 | default: 91 | articles <- v3.NewArticleResponse(article) 92 | } 93 | 94 | time.Sleep(100 * time.Millisecond) 95 | } 96 | close(articles) 97 | }() 98 | 99 | // Start streaming data from the channel. 100 | render.Respond(w, r, articles) 101 | } 102 | 103 | func getArticle(w http.ResponseWriter, r *http.Request) { 104 | // Load article. 105 | if chi.URLParam(r, "articleID") != "1" { 106 | render.Respond(w, r, data.ErrNotFound) 107 | return 108 | } 109 | article := &data.Article{ 110 | ID: 1, 111 | Title: "Article #1", 112 | Data: []string{"one", "two", "three", "four"}, 113 | CustomDataForAuthUsers: "secret data for auth'd users only", 114 | } 115 | 116 | // Simulate some context values: 117 | // 1. ?auth=true simulates authenticated session/user. 118 | // 2. ?error=true simulates random error. 119 | if r.URL.Query().Get("auth") != "" { 120 | r = r.WithContext(context.WithValue(r.Context(), "auth", true)) 121 | } 122 | if r.URL.Query().Get("error") != "" { 123 | render.Respond(w, r, errors.New("error")) 124 | return 125 | } 126 | 127 | var payload render.Renderer 128 | 129 | apiVersion := r.Context().Value("api.version").(string) 130 | switch apiVersion { 131 | case "v1": 132 | payload = v1.NewArticleResponse(article) 133 | case "v2": 134 | payload = v2.NewArticleResponse(article) 135 | default: 136 | payload = v3.NewArticleResponse(article) 137 | } 138 | 139 | render.Render(w, r, payload) 140 | } 141 | 142 | func randomErrorMiddleware(next http.Handler) http.Handler { 143 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 144 | rand.Seed(time.Now().Unix()) 145 | 146 | // One in three chance of random error. 147 | if rand.Int31n(3) == 0 { 148 | errors := []error{data.ErrUnauthorized, data.ErrForbidden, data.ErrNotFound} 149 | render.Respond(w, r, errors[rand.Intn(len(errors))]) 150 | return 151 | } 152 | next.ServeHTTP(w, r) 153 | }) 154 | } 155 | -------------------------------------------------------------------------------- /_examples/versions/presenter/v1/article.go: -------------------------------------------------------------------------------- 1 | package v1 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/go-chi/chi/v5/_examples/versions/data" 7 | ) 8 | 9 | // Article presented in API version 1. 10 | type Article struct { 11 | *data.Article 12 | 13 | Data map[string]bool `json:"data" xml:"data"` 14 | } 15 | 16 | func (a *Article) Render(w http.ResponseWriter, r *http.Request) error { 17 | return nil 18 | } 19 | 20 | func NewArticleResponse(article *data.Article) *Article { 21 | return &Article{Article: article} 22 | } 23 | -------------------------------------------------------------------------------- /_examples/versions/presenter/v2/article.go: -------------------------------------------------------------------------------- 1 | package v2 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/go-chi/chi/v5/_examples/versions/data" 8 | ) 9 | 10 | // Article presented in API version 2. 11 | type Article struct { 12 | // *v3.Article `json:",inline" xml:",inline"` 13 | 14 | *data.Article 15 | 16 | // Additional fields. 17 | SelfURL string `json:"self_url" xml:"self_url"` 18 | 19 | // Omitted fields. 20 | URL interface{} `json:"url,omitempty" xml:"url,omitempty"` 21 | } 22 | 23 | func (a *Article) Render(w http.ResponseWriter, r *http.Request) error { 24 | a.SelfURL = fmt.Sprintf("http://localhost:3333/v2?id=%v", a.ID) 25 | return nil 26 | } 27 | 28 | func NewArticleResponse(article *data.Article) *Article { 29 | return &Article{Article: article} 30 | } 31 | -------------------------------------------------------------------------------- /_examples/versions/presenter/v3/article.go: -------------------------------------------------------------------------------- 1 | package v3 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "net/http" 7 | 8 | "github.com/go-chi/chi/v5/_examples/versions/data" 9 | ) 10 | 11 | // Article presented in API version 2. 12 | type Article struct { 13 | *data.Article `json:",inline" xml:",inline"` 14 | 15 | // Additional fields. 16 | URL string `json:"url" xml:"url"` 17 | ViewsCount int64 `json:"views_count" xml:"views_count"` 18 | APIVersion string `json:"api_version" xml:"api_version"` 19 | 20 | // Omitted fields. 21 | // Show custom_data explicitly for auth'd users only. 22 | CustomDataForAuthUsers interface{} `json:"custom_data,omitempty" xml:"custom_data,omitempty"` 23 | } 24 | 25 | func (a *Article) Render(w http.ResponseWriter, r *http.Request) error { 26 | a.ViewsCount = rand.Int63n(100000) 27 | a.URL = fmt.Sprintf("http://localhost:3333/v3/?id=%v", a.ID) 28 | 29 | // Only show to auth'd user. 30 | if _, ok := r.Context().Value("auth").(bool); ok { 31 | a.CustomDataForAuthUsers = a.Article.CustomDataForAuthUsers 32 | } 33 | 34 | return nil 35 | } 36 | 37 | func NewArticleResponse(article *data.Article) *Article { 38 | return &Article{Article: article} 39 | } 40 | -------------------------------------------------------------------------------- /chain.go: -------------------------------------------------------------------------------- 1 | package chi 2 | 3 | import "net/http" 4 | 5 | // Chain returns a Middlewares type from a slice of middleware handlers. 6 | func Chain(middlewares ...func(http.Handler) http.Handler) Middlewares { 7 | return Middlewares(middlewares) 8 | } 9 | 10 | // Handler builds and returns a http.Handler from the chain of middlewares, 11 | // with `h http.Handler` as the final handler. 12 | func (mws Middlewares) Handler(h http.Handler) http.Handler { 13 | return &ChainHandler{h, chain(mws, h), mws} 14 | } 15 | 16 | // HandlerFunc builds and returns a http.Handler from the chain of middlewares, 17 | // with `h http.Handler` as the final handler. 18 | func (mws Middlewares) HandlerFunc(h http.HandlerFunc) http.Handler { 19 | return &ChainHandler{h, chain(mws, h), mws} 20 | } 21 | 22 | // ChainHandler is a http.Handler with support for handler composition and 23 | // execution. 24 | type ChainHandler struct { 25 | Endpoint http.Handler 26 | chain http.Handler 27 | Middlewares Middlewares 28 | } 29 | 30 | func (c *ChainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 31 | c.chain.ServeHTTP(w, r) 32 | } 33 | 34 | // chain builds a http.Handler composed of an inline middleware stack and endpoint 35 | // handler in the order they are passed. 36 | func chain(middlewares []func(http.Handler) http.Handler, endpoint http.Handler) http.Handler { 37 | // Return ahead of time if there aren't any middlewares for the chain 38 | if len(middlewares) == 0 { 39 | return endpoint 40 | } 41 | 42 | // Wrap the end handler with the middleware chain 43 | h := middlewares[len(middlewares)-1](endpoint) 44 | for i := len(middlewares) - 2; i >= 0; i-- { 45 | h = middlewares[i](h) 46 | } 47 | 48 | return h 49 | } 50 | -------------------------------------------------------------------------------- /chi.go: -------------------------------------------------------------------------------- 1 | // Package chi is a small, idiomatic and composable router for building HTTP services. 2 | // 3 | // chi requires Go 1.14 or newer. 4 | // 5 | // Example: 6 | // 7 | // package main 8 | // 9 | // import ( 10 | // "net/http" 11 | // 12 | // "github.com/go-chi/chi/v5" 13 | // "github.com/go-chi/chi/v5/middleware" 14 | // ) 15 | // 16 | // func main() { 17 | // r := chi.NewRouter() 18 | // r.Use(middleware.Logger) 19 | // r.Use(middleware.Recoverer) 20 | // 21 | // r.Get("/", func(w http.ResponseWriter, r *http.Request) { 22 | // w.Write([]byte("root.")) 23 | // }) 24 | // 25 | // http.ListenAndServe(":3333", r) 26 | // } 27 | // 28 | // See github.com/go-chi/chi/_examples/ for more in-depth examples. 29 | // 30 | // URL patterns allow for easy matching of path components in HTTP 31 | // requests. The matching components can then be accessed using 32 | // chi.URLParam(). All patterns must begin with a slash. 33 | // 34 | // A simple named placeholder {name} matches any sequence of characters 35 | // up to the next / or the end of the URL. Trailing slashes on paths must 36 | // be handled explicitly. 37 | // 38 | // A placeholder with a name followed by a colon allows a regular 39 | // expression match, for example {number:\\d+}. The regular expression 40 | // syntax is Go's normal regexp RE2 syntax, except that regular expressions 41 | // including { or } are not supported, and / will never be 42 | // matched. An anonymous regexp pattern is allowed, using an empty string 43 | // before the colon in the placeholder, such as {:\\d+} 44 | // 45 | // The special placeholder of asterisk matches the rest of the requested 46 | // URL. Any trailing characters in the pattern are ignored. This is the only 47 | // placeholder which will match / characters. 48 | // 49 | // Examples: 50 | // 51 | // "/user/{name}" matches "/user/jsmith" but not "/user/jsmith/info" or "/user/jsmith/" 52 | // "/user/{name}/info" matches "/user/jsmith/info" 53 | // "/page/*" matches "/page/intro/latest" 54 | // "/page/{other}/latest" also matches "/page/intro/latest" 55 | // "/date/{yyyy:\\d\\d\\d\\d}/{mm:\\d\\d}/{dd:\\d\\d}" matches "/date/2017/04/01" 56 | package chi 57 | 58 | import "net/http" 59 | 60 | // NewRouter returns a new Mux object that implements the Router interface. 61 | func NewRouter() *Mux { 62 | return NewMux() 63 | } 64 | 65 | // Router consisting of the core routing methods used by chi's Mux, 66 | // using only the standard net/http. 67 | type Router interface { 68 | http.Handler 69 | Routes 70 | 71 | // Use appends one or more middlewares onto the Router stack. 72 | Use(middlewares ...func(http.Handler) http.Handler) 73 | 74 | // With adds inline middlewares for an endpoint handler. 75 | With(middlewares ...func(http.Handler) http.Handler) Router 76 | 77 | // Group adds a new inline-Router along the current routing 78 | // path, with a fresh middleware stack for the inline-Router. 79 | Group(fn func(r Router)) Router 80 | 81 | // Route mounts a sub-Router along a `pattern`` string. 82 | Route(pattern string, fn func(r Router)) Router 83 | 84 | // Mount attaches another http.Handler along ./pattern/* 85 | Mount(pattern string, h http.Handler) 86 | 87 | // Handle and HandleFunc adds routes for `pattern` that matches 88 | // all HTTP methods. 89 | Handle(pattern string, h http.Handler) 90 | HandleFunc(pattern string, h http.HandlerFunc) 91 | 92 | // Method and MethodFunc adds routes for `pattern` that matches 93 | // the `method` HTTP method. 94 | Method(method, pattern string, h http.Handler) 95 | MethodFunc(method, pattern string, h http.HandlerFunc) 96 | 97 | // HTTP-method routing along `pattern` 98 | Connect(pattern string, h http.HandlerFunc) 99 | Delete(pattern string, h http.HandlerFunc) 100 | Get(pattern string, h http.HandlerFunc) 101 | Head(pattern string, h http.HandlerFunc) 102 | Options(pattern string, h http.HandlerFunc) 103 | Patch(pattern string, h http.HandlerFunc) 104 | Post(pattern string, h http.HandlerFunc) 105 | Put(pattern string, h http.HandlerFunc) 106 | Trace(pattern string, h http.HandlerFunc) 107 | 108 | // NotFound defines a handler to respond whenever a route could 109 | // not be found. 110 | NotFound(h http.HandlerFunc) 111 | 112 | // MethodNotAllowed defines a handler to respond whenever a method is 113 | // not allowed. 114 | MethodNotAllowed(h http.HandlerFunc) 115 | } 116 | 117 | // Routes interface adds two methods for router traversal, which is also 118 | // used by the `docgen` subpackage to generation documentation for Routers. 119 | type Routes interface { 120 | // Routes returns the routing tree in an easily traversable structure. 121 | Routes() []Route 122 | 123 | // Middlewares returns the list of middlewares in use by the router. 124 | Middlewares() Middlewares 125 | 126 | // Match searches the routing tree for a handler that matches 127 | // the method/path - similar to routing a http request, but without 128 | // executing the handler thereafter. 129 | Match(rctx *Context, method, path string) bool 130 | 131 | // Find searches the routing tree for the pattern that matches 132 | // the method/path. 133 | Find(rctx *Context, method, path string) string 134 | } 135 | 136 | // Middlewares type is a slice of standard middleware handlers with methods 137 | // to compose middleware chains and http.Handler's. 138 | type Middlewares []func(http.Handler) http.Handler 139 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package chi 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "strings" 7 | ) 8 | 9 | // URLParam returns the url parameter from a http.Request object. 10 | func URLParam(r *http.Request, key string) string { 11 | if rctx := RouteContext(r.Context()); rctx != nil { 12 | return rctx.URLParam(key) 13 | } 14 | return "" 15 | } 16 | 17 | // URLParamFromCtx returns the url parameter from a http.Request Context. 18 | func URLParamFromCtx(ctx context.Context, key string) string { 19 | if rctx := RouteContext(ctx); rctx != nil { 20 | return rctx.URLParam(key) 21 | } 22 | return "" 23 | } 24 | 25 | // RouteContext returns chi's routing Context object from a 26 | // http.Request Context. 27 | func RouteContext(ctx context.Context) *Context { 28 | val, _ := ctx.Value(RouteCtxKey).(*Context) 29 | return val 30 | } 31 | 32 | // NewRouteContext returns a new routing Context object. 33 | func NewRouteContext() *Context { 34 | return &Context{} 35 | } 36 | 37 | var ( 38 | // RouteCtxKey is the context.Context key to store the request context. 39 | RouteCtxKey = &contextKey{"RouteContext"} 40 | ) 41 | 42 | // Context is the default routing context set on the root node of a 43 | // request context to track route patterns, URL parameters and 44 | // an optional routing path. 45 | type Context struct { 46 | Routes Routes 47 | 48 | // parentCtx is the parent of this one, for using Context as a 49 | // context.Context directly. This is an optimization that saves 50 | // 1 allocation. 51 | parentCtx context.Context 52 | 53 | // Routing path/method override used during the route search. 54 | // See Mux#routeHTTP method. 55 | RoutePath string 56 | RouteMethod string 57 | 58 | // URLParams are the stack of routeParams captured during the 59 | // routing lifecycle across a stack of sub-routers. 60 | URLParams RouteParams 61 | 62 | // Route parameters matched for the current sub-router. It is 63 | // intentionally unexported so it can't be tampered. 64 | routeParams RouteParams 65 | 66 | // The endpoint routing pattern that matched the request URI path 67 | // or `RoutePath` of the current sub-router. This value will update 68 | // during the lifecycle of a request passing through a stack of 69 | // sub-routers. 70 | routePattern string 71 | 72 | // Routing pattern stack throughout the lifecycle of the request, 73 | // across all connected routers. It is a record of all matching 74 | // patterns across a stack of sub-routers. 75 | RoutePatterns []string 76 | 77 | methodsAllowed []methodTyp // allowed methods in case of a 405 78 | methodNotAllowed bool 79 | } 80 | 81 | // Reset a routing context to its initial state. 82 | func (x *Context) Reset() { 83 | x.Routes = nil 84 | x.RoutePath = "" 85 | x.RouteMethod = "" 86 | x.RoutePatterns = x.RoutePatterns[:0] 87 | x.URLParams.Keys = x.URLParams.Keys[:0] 88 | x.URLParams.Values = x.URLParams.Values[:0] 89 | 90 | x.routePattern = "" 91 | x.routeParams.Keys = x.routeParams.Keys[:0] 92 | x.routeParams.Values = x.routeParams.Values[:0] 93 | x.methodNotAllowed = false 94 | x.methodsAllowed = x.methodsAllowed[:0] 95 | x.parentCtx = nil 96 | } 97 | 98 | // URLParam returns the corresponding URL parameter value from the request 99 | // routing context. 100 | func (x *Context) URLParam(key string) string { 101 | for k := len(x.URLParams.Keys) - 1; k >= 0; k-- { 102 | if x.URLParams.Keys[k] == key { 103 | return x.URLParams.Values[k] 104 | } 105 | } 106 | return "" 107 | } 108 | 109 | // RoutePattern builds the routing pattern string for the particular 110 | // request, at the particular point during routing. This means, the value 111 | // will change throughout the execution of a request in a router. That is 112 | // why it's advised to only use this value after calling the next handler. 113 | // 114 | // For example, 115 | // 116 | // func Instrument(next http.Handler) http.Handler { 117 | // return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 118 | // next.ServeHTTP(w, r) 119 | // routePattern := chi.RouteContext(r.Context()).RoutePattern() 120 | // measure(w, r, routePattern) 121 | // }) 122 | // } 123 | func (x *Context) RoutePattern() string { 124 | if x == nil { 125 | return "" 126 | } 127 | routePattern := strings.Join(x.RoutePatterns, "") 128 | routePattern = replaceWildcards(routePattern) 129 | if routePattern != "/" { 130 | routePattern = strings.TrimSuffix(routePattern, "//") 131 | routePattern = strings.TrimSuffix(routePattern, "/") 132 | } 133 | return routePattern 134 | } 135 | 136 | // replaceWildcards takes a route pattern and recursively replaces all 137 | // occurrences of "/*/" to "/". 138 | func replaceWildcards(p string) string { 139 | if strings.Contains(p, "/*/") { 140 | return replaceWildcards(strings.Replace(p, "/*/", "/", -1)) 141 | } 142 | return p 143 | } 144 | 145 | // RouteParams is a structure to track URL routing parameters efficiently. 146 | type RouteParams struct { 147 | Keys, Values []string 148 | } 149 | 150 | // Add will append a URL parameter to the end of the route param 151 | func (s *RouteParams) Add(key, value string) { 152 | s.Keys = append(s.Keys, key) 153 | s.Values = append(s.Values, value) 154 | } 155 | 156 | // contextKey is a value for use with context.WithValue. It's used as 157 | // a pointer so it fits in an interface{} without allocation. This technique 158 | // for defining context keys was copied from Go 1.7's new use of context in net/http. 159 | type contextKey struct { 160 | name string 161 | } 162 | 163 | func (k *contextKey) String() string { 164 | return "chi context value " + k.name 165 | } 166 | -------------------------------------------------------------------------------- /context_test.go: -------------------------------------------------------------------------------- 1 | package chi 2 | 3 | import "testing" 4 | 5 | // TestRoutePattern tests correct in-the-middle wildcard removals. 6 | // If user organizes a router like this: 7 | // 8 | // (router.go) 9 | // 10 | // r.Route("/v1", func(r chi.Router) { 11 | // r.Mount("/resources", resourcesController{}.Router()) 12 | // } 13 | // 14 | // (resources_controller.go) 15 | // 16 | // r.Route("/", func(r chi.Router) { 17 | // r.Get("/{resource_id}", getResource()) 18 | // // other routes... 19 | // } 20 | // 21 | // This test checks how the route pattern is calculated 22 | // "/v1/resources/{resource_id}" (right) 23 | // "/v1/resources/*/{resource_id}" (wrong) 24 | func TestRoutePattern(t *testing.T) { 25 | routePatterns := []string{ 26 | "/v1/*", 27 | "/resources/*", 28 | "/{resource_id}", 29 | } 30 | 31 | x := &Context{ 32 | RoutePatterns: routePatterns, 33 | } 34 | 35 | if p := x.RoutePattern(); p != "/v1/resources/{resource_id}" { 36 | t.Fatal("unexpected route pattern: " + p) 37 | } 38 | 39 | x.RoutePatterns = []string{ 40 | "/v1/*", 41 | "/resources/*", 42 | // Additional wildcard, depending on the router structure of the user 43 | "/*", 44 | "/{resource_id}", 45 | } 46 | 47 | // Correctly removes in-the-middle wildcards instead of "/v1/resources/*/{resource_id}" 48 | if p := x.RoutePattern(); p != "/v1/resources/{resource_id}" { 49 | t.Fatal("unexpected route pattern: " + p) 50 | } 51 | 52 | x.RoutePatterns = []string{ 53 | "/v1/*", 54 | "/resources/*", 55 | // Even with many wildcards 56 | "/*", 57 | "/*", 58 | "/*", 59 | "/{resource_id}/*", // Keeping trailing wildcard 60 | } 61 | 62 | // Correctly removes in-the-middle wildcards instead of "/v1/resources/*/*/{resource_id}/*" 63 | if p := x.RoutePattern(); p != "/v1/resources/{resource_id}/*" { 64 | t.Fatal("unexpected route pattern: " + p) 65 | } 66 | 67 | x.RoutePatterns = []string{ 68 | "/v1/*", 69 | "/resources/*", 70 | // And respects asterisks as part of the paths 71 | "/*special_path/*", 72 | "/with_asterisks*/*", 73 | "/{resource_id}", 74 | } 75 | 76 | // Correctly removes in-the-middle wildcards instead of "/v1/resourcesspecial_path/with_asterisks{resource_id}" 77 | if p := x.RoutePattern(); p != "/v1/resources/*special_path/with_asterisks*/{resource_id}" { 78 | t.Fatal("unexpected route pattern: " + p) 79 | } 80 | 81 | // Testing for the root route pattern 82 | x.RoutePatterns = []string{"/"} 83 | // It should just return "/" as the pattern 84 | if p := x.RoutePattern(); p != "/" { 85 | t.Fatal("unexpected route pattern for root: " + p) 86 | } 87 | 88 | // Testing empty route pattern for nil context 89 | var nilContext *Context 90 | if p := nilContext.RoutePattern(); p != "" { 91 | t.Fatalf("unexpected non-empty route pattern for nil context: %q", p) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-chi/chi/v5 2 | 3 | // Chi supports the four most recent major versions of Go. 4 | // See https://github.com/go-chi/chi/issues/963. 5 | go 1.20 6 | -------------------------------------------------------------------------------- /middleware/basic_auth.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "crypto/subtle" 5 | "fmt" 6 | "net/http" 7 | ) 8 | 9 | // BasicAuth implements a simple middleware handler for adding basic http auth to a route. 10 | func BasicAuth(realm string, creds map[string]string) func(next http.Handler) http.Handler { 11 | return func(next http.Handler) http.Handler { 12 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 13 | user, pass, ok := r.BasicAuth() 14 | if !ok { 15 | basicAuthFailed(w, realm) 16 | return 17 | } 18 | 19 | credPass, credUserOk := creds[user] 20 | if !credUserOk || subtle.ConstantTimeCompare([]byte(pass), []byte(credPass)) != 1 { 21 | basicAuthFailed(w, realm) 22 | return 23 | } 24 | 25 | next.ServeHTTP(w, r) 26 | }) 27 | } 28 | } 29 | 30 | func basicAuthFailed(w http.ResponseWriter, realm string) { 31 | w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm)) 32 | w.WriteHeader(http.StatusUnauthorized) 33 | } 34 | -------------------------------------------------------------------------------- /middleware/clean_path.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "path" 6 | 7 | "github.com/go-chi/chi/v5" 8 | ) 9 | 10 | // CleanPath middleware will clean out double slash mistakes from a user's request path. 11 | // For example, if a user requests /users//1 or //users////1 will both be treated as: /users/1 12 | func CleanPath(next http.Handler) http.Handler { 13 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 14 | rctx := chi.RouteContext(r.Context()) 15 | 16 | routePath := rctx.RoutePath 17 | if routePath == "" { 18 | if r.URL.RawPath != "" { 19 | routePath = r.URL.RawPath 20 | } else { 21 | routePath = r.URL.Path 22 | } 23 | rctx.RoutePath = path.Clean(routePath) 24 | } 25 | 26 | next.ServeHTTP(w, r) 27 | }) 28 | } 29 | -------------------------------------------------------------------------------- /middleware/compress.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bufio" 5 | "compress/flate" 6 | "compress/gzip" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net" 11 | "net/http" 12 | "strings" 13 | "sync" 14 | ) 15 | 16 | var defaultCompressibleContentTypes = []string{ 17 | "text/html", 18 | "text/css", 19 | "text/plain", 20 | "text/javascript", 21 | "application/javascript", 22 | "application/x-javascript", 23 | "application/json", 24 | "application/atom+xml", 25 | "application/rss+xml", 26 | "image/svg+xml", 27 | } 28 | 29 | // Compress is a middleware that compresses response 30 | // body of a given content types to a data format based 31 | // on Accept-Encoding request header. It uses a given 32 | // compression level. 33 | // 34 | // NOTE: make sure to set the Content-Type header on your response 35 | // otherwise this middleware will not compress the response body. For ex, in 36 | // your handler you should set w.Header().Set("Content-Type", http.DetectContentType(yourBody)) 37 | // or set it manually. 38 | // 39 | // Passing a compression level of 5 is sensible value 40 | func Compress(level int, types ...string) func(next http.Handler) http.Handler { 41 | compressor := NewCompressor(level, types...) 42 | return compressor.Handler 43 | } 44 | 45 | // Compressor represents a set of encoding configurations. 46 | type Compressor struct { 47 | // The mapping of encoder names to encoder functions. 48 | encoders map[string]EncoderFunc 49 | // The mapping of pooled encoders to pools. 50 | pooledEncoders map[string]*sync.Pool 51 | // The set of content types allowed to be compressed. 52 | allowedTypes map[string]struct{} 53 | allowedWildcards map[string]struct{} 54 | // The list of encoders in order of decreasing precedence. 55 | encodingPrecedence []string 56 | level int // The compression level. 57 | } 58 | 59 | // NewCompressor creates a new Compressor that will handle encoding responses. 60 | // 61 | // The level should be one of the ones defined in the flate package. 62 | // The types are the content types that are allowed to be compressed. 63 | func NewCompressor(level int, types ...string) *Compressor { 64 | // If types are provided, set those as the allowed types. If none are 65 | // provided, use the default list. 66 | allowedTypes := make(map[string]struct{}) 67 | allowedWildcards := make(map[string]struct{}) 68 | if len(types) > 0 { 69 | for _, t := range types { 70 | if strings.Contains(strings.TrimSuffix(t, "/*"), "*") { 71 | panic(fmt.Sprintf("middleware/compress: Unsupported content-type wildcard pattern '%s'. Only '/*' supported", t)) 72 | } 73 | if strings.HasSuffix(t, "/*") { 74 | allowedWildcards[strings.TrimSuffix(t, "/*")] = struct{}{} 75 | } else { 76 | allowedTypes[t] = struct{}{} 77 | } 78 | } 79 | } else { 80 | for _, t := range defaultCompressibleContentTypes { 81 | allowedTypes[t] = struct{}{} 82 | } 83 | } 84 | 85 | c := &Compressor{ 86 | level: level, 87 | encoders: make(map[string]EncoderFunc), 88 | pooledEncoders: make(map[string]*sync.Pool), 89 | allowedTypes: allowedTypes, 90 | allowedWildcards: allowedWildcards, 91 | } 92 | 93 | // Set the default encoders. The precedence order uses the reverse 94 | // ordering that the encoders were added. This means adding new encoders 95 | // will move them to the front of the order. 96 | // 97 | // TODO: 98 | // lzma: Opera. 99 | // sdch: Chrome, Android. Gzip output + dictionary header. 100 | // br: Brotli, see https://github.com/go-chi/chi/pull/326 101 | 102 | // HTTP 1.1 "deflate" (RFC 2616) stands for DEFLATE data (RFC 1951) 103 | // wrapped with zlib (RFC 1950). The zlib wrapper uses Adler-32 104 | // checksum compared to CRC-32 used in "gzip" and thus is faster. 105 | // 106 | // But.. some old browsers (MSIE, Safari 5.1) incorrectly expect 107 | // raw DEFLATE data only, without the mentioned zlib wrapper. 108 | // Because of this major confusion, most modern browsers try it 109 | // both ways, first looking for zlib headers. 110 | // Quote by Mark Adler: http://stackoverflow.com/a/9186091/385548 111 | // 112 | // The list of browsers having problems is quite big, see: 113 | // http://zoompf.com/blog/2012/02/lose-the-wait-http-compression 114 | // https://web.archive.org/web/20120321182910/http://www.vervestudios.co/projects/compression-tests/results 115 | // 116 | // That's why we prefer gzip over deflate. It's just more reliable 117 | // and not significantly slower than deflate. 118 | c.SetEncoder("deflate", encoderDeflate) 119 | 120 | // TODO: Exception for old MSIE browsers that can't handle non-HTML? 121 | // https://zoompf.com/blog/2012/02/lose-the-wait-http-compression 122 | c.SetEncoder("gzip", encoderGzip) 123 | 124 | // NOTE: Not implemented, intentionally: 125 | // case "compress": // LZW. Deprecated. 126 | // case "bzip2": // Too slow on-the-fly. 127 | // case "zopfli": // Too slow on-the-fly. 128 | // case "xz": // Too slow on-the-fly. 129 | return c 130 | } 131 | 132 | // SetEncoder can be used to set the implementation of a compression algorithm. 133 | // 134 | // The encoding should be a standardised identifier. See: 135 | // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding 136 | // 137 | // For example, add the Brotli algorithm: 138 | // 139 | // import brotli_enc "gopkg.in/kothar/brotli-go.v0/enc" 140 | // 141 | // compressor := middleware.NewCompressor(5, "text/html") 142 | // compressor.SetEncoder("br", func(w io.Writer, level int) io.Writer { 143 | // params := brotli_enc.NewBrotliParams() 144 | // params.SetQuality(level) 145 | // return brotli_enc.NewBrotliWriter(params, w) 146 | // }) 147 | func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) { 148 | encoding = strings.ToLower(encoding) 149 | if encoding == "" { 150 | panic("the encoding can not be empty") 151 | } 152 | if fn == nil { 153 | panic("attempted to set a nil encoder function") 154 | } 155 | 156 | // If we are adding a new encoder that is already registered, we have to 157 | // clear that one out first. 158 | delete(c.pooledEncoders, encoding) 159 | delete(c.encoders, encoding) 160 | 161 | // If the encoder supports Resetting (IoReseterWriter), then it can be pooled. 162 | encoder := fn(io.Discard, c.level) 163 | if _, ok := encoder.(ioResetterWriter); ok { 164 | pool := &sync.Pool{ 165 | New: func() interface{} { 166 | return fn(io.Discard, c.level) 167 | }, 168 | } 169 | c.pooledEncoders[encoding] = pool 170 | } 171 | // If the encoder is not in the pooledEncoders, add it to the normal encoders. 172 | if _, ok := c.pooledEncoders[encoding]; !ok { 173 | c.encoders[encoding] = fn 174 | } 175 | 176 | for i, v := range c.encodingPrecedence { 177 | if v == encoding { 178 | c.encodingPrecedence = append(c.encodingPrecedence[:i], c.encodingPrecedence[i+1:]...) 179 | } 180 | } 181 | 182 | c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...) 183 | } 184 | 185 | // Handler returns a new middleware that will compress the response based on the 186 | // current Compressor. 187 | func (c *Compressor) Handler(next http.Handler) http.Handler { 188 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 189 | encoder, encoding, cleanup := c.selectEncoder(r.Header, w) 190 | 191 | cw := &compressResponseWriter{ 192 | ResponseWriter: w, 193 | w: w, 194 | contentTypes: c.allowedTypes, 195 | contentWildcards: c.allowedWildcards, 196 | encoding: encoding, 197 | compressible: false, // determined in post-handler 198 | } 199 | if encoder != nil { 200 | cw.w = encoder 201 | } 202 | // Re-add the encoder to the pool if applicable. 203 | defer cleanup() 204 | defer cw.Close() 205 | 206 | next.ServeHTTP(cw, r) 207 | }) 208 | } 209 | 210 | // selectEncoder returns the encoder, the name of the encoder, and a closer function. 211 | func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, string, func()) { 212 | header := h.Get("Accept-Encoding") 213 | 214 | // Parse the names of all accepted algorithms from the header. 215 | accepted := strings.Split(strings.ToLower(header), ",") 216 | 217 | // Find supported encoder by accepted list by precedence 218 | for _, name := range c.encodingPrecedence { 219 | if matchAcceptEncoding(accepted, name) { 220 | if pool, ok := c.pooledEncoders[name]; ok { 221 | encoder := pool.Get().(ioResetterWriter) 222 | cleanup := func() { 223 | pool.Put(encoder) 224 | } 225 | encoder.Reset(w) 226 | return encoder, name, cleanup 227 | 228 | } 229 | if fn, ok := c.encoders[name]; ok { 230 | return fn(w, c.level), name, func() {} 231 | } 232 | } 233 | 234 | } 235 | 236 | // No encoder found to match the accepted encoding 237 | return nil, "", func() {} 238 | } 239 | 240 | func matchAcceptEncoding(accepted []string, encoding string) bool { 241 | for _, v := range accepted { 242 | if strings.Contains(v, encoding) { 243 | return true 244 | } 245 | } 246 | return false 247 | } 248 | 249 | // An EncoderFunc is a function that wraps the provided io.Writer with a 250 | // streaming compression algorithm and returns it. 251 | // 252 | // In case of failure, the function should return nil. 253 | type EncoderFunc func(w io.Writer, level int) io.Writer 254 | 255 | // Interface for types that allow resetting io.Writers. 256 | type ioResetterWriter interface { 257 | io.Writer 258 | Reset(w io.Writer) 259 | } 260 | 261 | type compressResponseWriter struct { 262 | http.ResponseWriter 263 | 264 | // The streaming encoder writer to be used if there is one. Otherwise, 265 | // this is just the normal writer. 266 | w io.Writer 267 | contentTypes map[string]struct{} 268 | contentWildcards map[string]struct{} 269 | encoding string 270 | wroteHeader bool 271 | compressible bool 272 | } 273 | 274 | func (cw *compressResponseWriter) isCompressible() bool { 275 | // Parse the first part of the Content-Type response header. 276 | contentType := cw.Header().Get("Content-Type") 277 | contentType, _, _ = strings.Cut(contentType, ";") 278 | 279 | // Is the content type compressible? 280 | if _, ok := cw.contentTypes[contentType]; ok { 281 | return true 282 | } 283 | if contentType, _, hadSlash := strings.Cut(contentType, "/"); hadSlash { 284 | _, ok := cw.contentWildcards[contentType] 285 | return ok 286 | } 287 | return false 288 | } 289 | 290 | func (cw *compressResponseWriter) WriteHeader(code int) { 291 | if cw.wroteHeader { 292 | cw.ResponseWriter.WriteHeader(code) // Allow multiple calls to propagate. 293 | return 294 | } 295 | cw.wroteHeader = true 296 | defer cw.ResponseWriter.WriteHeader(code) 297 | 298 | // Already compressed data? 299 | if cw.Header().Get("Content-Encoding") != "" { 300 | return 301 | } 302 | 303 | if !cw.isCompressible() { 304 | cw.compressible = false 305 | return 306 | } 307 | 308 | if cw.encoding != "" { 309 | cw.compressible = true 310 | cw.Header().Set("Content-Encoding", cw.encoding) 311 | cw.Header().Add("Vary", "Accept-Encoding") 312 | 313 | // The content-length after compression is unknown 314 | cw.Header().Del("Content-Length") 315 | } 316 | } 317 | 318 | func (cw *compressResponseWriter) Write(p []byte) (int, error) { 319 | if !cw.wroteHeader { 320 | cw.WriteHeader(http.StatusOK) 321 | } 322 | 323 | return cw.writer().Write(p) 324 | } 325 | 326 | func (cw *compressResponseWriter) writer() io.Writer { 327 | if cw.compressible { 328 | return cw.w 329 | } 330 | return cw.ResponseWriter 331 | } 332 | 333 | type compressFlusher interface { 334 | Flush() error 335 | } 336 | 337 | func (cw *compressResponseWriter) Flush() { 338 | if f, ok := cw.writer().(http.Flusher); ok { 339 | f.Flush() 340 | } 341 | // If the underlying writer has a compression flush signature, 342 | // call this Flush() method instead 343 | if f, ok := cw.writer().(compressFlusher); ok { 344 | f.Flush() 345 | 346 | // Also flush the underlying response writer 347 | if f, ok := cw.ResponseWriter.(http.Flusher); ok { 348 | f.Flush() 349 | } 350 | } 351 | } 352 | 353 | func (cw *compressResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 354 | if hj, ok := cw.writer().(http.Hijacker); ok { 355 | return hj.Hijack() 356 | } 357 | return nil, nil, errors.New("chi/middleware: http.Hijacker is unavailable on the writer") 358 | } 359 | 360 | func (cw *compressResponseWriter) Push(target string, opts *http.PushOptions) error { 361 | if ps, ok := cw.writer().(http.Pusher); ok { 362 | return ps.Push(target, opts) 363 | } 364 | return errors.New("chi/middleware: http.Pusher is unavailable on the writer") 365 | } 366 | 367 | func (cw *compressResponseWriter) Close() error { 368 | if c, ok := cw.writer().(io.WriteCloser); ok { 369 | return c.Close() 370 | } 371 | return errors.New("chi/middleware: io.WriteCloser is unavailable on the writer") 372 | } 373 | 374 | func (cw *compressResponseWriter) Unwrap() http.ResponseWriter { 375 | return cw.ResponseWriter 376 | } 377 | 378 | func encoderGzip(w io.Writer, level int) io.Writer { 379 | gw, err := gzip.NewWriterLevel(w, level) 380 | if err != nil { 381 | return nil 382 | } 383 | return gw 384 | } 385 | 386 | func encoderDeflate(w io.Writer, level int) io.Writer { 387 | dw, err := flate.NewWriter(w, level) 388 | if err != nil { 389 | return nil 390 | } 391 | return dw 392 | } 393 | -------------------------------------------------------------------------------- /middleware/compress_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "compress/flate" 5 | "compress/gzip" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "net/http" 10 | "net/http/httptest" 11 | "strings" 12 | "testing" 13 | 14 | "github.com/go-chi/chi/v5" 15 | ) 16 | 17 | func TestCompressor(t *testing.T) { 18 | r := chi.NewRouter() 19 | 20 | compressor := NewCompressor(5, "text/html", "text/css") 21 | if len(compressor.encoders) != 0 || len(compressor.pooledEncoders) != 2 { 22 | t.Errorf("gzip and deflate should be pooled") 23 | } 24 | 25 | compressor.SetEncoder("nop", func(w io.Writer, _ int) io.Writer { 26 | return w 27 | }) 28 | 29 | if len(compressor.encoders) != 1 { 30 | t.Errorf("nop encoder should be stored in the encoders map") 31 | } 32 | 33 | r.Use(compressor.Handler) 34 | 35 | r.Get("/gethtml", func(w http.ResponseWriter, r *http.Request) { 36 | w.Header().Set("Content-Type", "text/html") 37 | w.Write([]byte("textstring")) 38 | }) 39 | 40 | r.Get("/getcss", func(w http.ResponseWriter, r *http.Request) { 41 | w.Header().Set("Content-Type", "text/html") 42 | w.Write([]byte("textstring")) 43 | }) 44 | 45 | r.Get("/getplain", func(w http.ResponseWriter, r *http.Request) { 46 | w.Header().Set("Content-Type", "text/html") 47 | w.Write([]byte("textstring")) 48 | }) 49 | 50 | ts := httptest.NewServer(r) 51 | defer ts.Close() 52 | 53 | tests := []struct { 54 | name string 55 | path string 56 | expectedEncoding string 57 | acceptedEncodings []string 58 | }{ 59 | { 60 | name: "no expected encodings due to no accepted encodings", 61 | path: "/gethtml", 62 | acceptedEncodings: nil, 63 | expectedEncoding: "", 64 | }, 65 | { 66 | name: "no expected encodings due to content type", 67 | path: "/getplain", 68 | acceptedEncodings: nil, 69 | expectedEncoding: "", 70 | }, 71 | { 72 | name: "gzip is only encoding", 73 | path: "/gethtml", 74 | acceptedEncodings: []string{"gzip"}, 75 | expectedEncoding: "gzip", 76 | }, 77 | { 78 | name: "gzip is preferred over deflate", 79 | path: "/getcss", 80 | acceptedEncodings: []string{"gzip", "deflate"}, 81 | expectedEncoding: "gzip", 82 | }, 83 | { 84 | name: "deflate is used", 85 | path: "/getcss", 86 | acceptedEncodings: []string{"deflate"}, 87 | expectedEncoding: "deflate", 88 | }, 89 | { 90 | 91 | name: "nop is preferred", 92 | path: "/getcss", 93 | acceptedEncodings: []string{"nop, gzip, deflate"}, 94 | expectedEncoding: "nop", 95 | }, 96 | } 97 | 98 | for _, tc := range tests { 99 | tc := tc 100 | t.Run(tc.name, func(t *testing.T) { 101 | resp, respString := testRequestWithAcceptedEncodings(t, ts, "GET", tc.path, tc.acceptedEncodings...) 102 | if respString != "textstring" { 103 | t.Errorf("response text doesn't match; expected:%q, got:%q", "textstring", respString) 104 | } 105 | if got := resp.Header.Get("Content-Encoding"); got != tc.expectedEncoding { 106 | t.Errorf("expected encoding %q but got %q", tc.expectedEncoding, got) 107 | } 108 | 109 | }) 110 | 111 | } 112 | } 113 | 114 | func TestCompressorWildcards(t *testing.T) { 115 | tests := []struct { 116 | name string 117 | recover string 118 | types []string 119 | typesCount int 120 | wcCount int 121 | }{ 122 | { 123 | name: "defaults", 124 | typesCount: 10, 125 | }, 126 | { 127 | name: "no wildcard", 128 | types: []string{"text/plain", "text/html"}, 129 | typesCount: 2, 130 | }, 131 | { 132 | name: "invalid wildcard #1", 133 | types: []string{"audio/*wav"}, 134 | recover: "middleware/compress: Unsupported content-type wildcard pattern 'audio/*wav'. Only '/*' supported", 135 | }, 136 | { 137 | name: "invalid wildcard #2", 138 | types: []string{"application*/*"}, 139 | recover: "middleware/compress: Unsupported content-type wildcard pattern 'application*/*'. Only '/*' supported", 140 | }, 141 | { 142 | name: "valid wildcard", 143 | types: []string{"text/*"}, 144 | wcCount: 1, 145 | }, 146 | { 147 | name: "mixed", 148 | types: []string{"audio/wav", "text/*"}, 149 | typesCount: 1, 150 | wcCount: 1, 151 | }, 152 | } 153 | for _, tt := range tests { 154 | t.Run(tt.name, func(t *testing.T) { 155 | defer func() { 156 | if tt.recover == "" { 157 | tt.recover = "" 158 | } 159 | if r := recover(); tt.recover != fmt.Sprintf("%v", r) { 160 | t.Errorf("Unexpected value recovered: %v", r) 161 | } 162 | }() 163 | compressor := NewCompressor(5, tt.types...) 164 | if len(compressor.allowedTypes) != tt.typesCount { 165 | t.Errorf("expected %d allowedTypes, got %d", tt.typesCount, len(compressor.allowedTypes)) 166 | } 167 | if len(compressor.allowedWildcards) != tt.wcCount { 168 | t.Errorf("expected %d allowedWildcards, got %d", tt.wcCount, len(compressor.allowedWildcards)) 169 | } 170 | }) 171 | } 172 | } 173 | 174 | func testRequestWithAcceptedEncodings(t *testing.T, ts *httptest.Server, method, path string, encodings ...string) (*http.Response, string) { 175 | req, err := http.NewRequest(method, ts.URL+path, nil) 176 | if err != nil { 177 | t.Fatal(err) 178 | return nil, "" 179 | } 180 | if len(encodings) > 0 { 181 | encodingsString := strings.Join(encodings, ",") 182 | req.Header.Set("Accept-Encoding", encodingsString) 183 | } 184 | 185 | resp, err := http.DefaultClient.Do(req) 186 | if err != nil { 187 | t.Fatal(err) 188 | return nil, "" 189 | } 190 | 191 | respBody := decodeResponseBody(t, resp) 192 | defer resp.Body.Close() 193 | 194 | return resp, respBody 195 | } 196 | 197 | func decodeResponseBody(t *testing.T, resp *http.Response) string { 198 | var reader io.ReadCloser 199 | switch resp.Header.Get("Content-Encoding") { 200 | case "gzip": 201 | var err error 202 | reader, err = gzip.NewReader(resp.Body) 203 | if err != nil { 204 | t.Fatal(err) 205 | } 206 | case "deflate": 207 | reader = flate.NewReader(resp.Body) 208 | default: 209 | reader = resp.Body 210 | } 211 | respBody, err := ioutil.ReadAll(reader) 212 | if err != nil { 213 | t.Fatal(err) 214 | return "" 215 | } 216 | reader.Close() 217 | 218 | return string(respBody) 219 | } 220 | -------------------------------------------------------------------------------- /middleware/content_charset.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | ) 7 | 8 | // ContentCharset generates a handler that writes a 415 Unsupported Media Type response if none of the charsets match. 9 | // An empty charset will allow requests with no Content-Type header or no specified charset. 10 | func ContentCharset(charsets ...string) func(next http.Handler) http.Handler { 11 | for i, c := range charsets { 12 | charsets[i] = strings.ToLower(c) 13 | } 14 | 15 | return func(next http.Handler) http.Handler { 16 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 17 | if !contentEncoding(r.Header.Get("Content-Type"), charsets...) { 18 | w.WriteHeader(http.StatusUnsupportedMediaType) 19 | return 20 | } 21 | 22 | next.ServeHTTP(w, r) 23 | }) 24 | } 25 | } 26 | 27 | // Check the content encoding against a list of acceptable values. 28 | func contentEncoding(ce string, charsets ...string) bool { 29 | _, ce = split(strings.ToLower(ce), ";") 30 | _, ce = split(ce, "charset=") 31 | ce, _ = split(ce, ";") 32 | for _, c := range charsets { 33 | if ce == c { 34 | return true 35 | } 36 | } 37 | 38 | return false 39 | } 40 | 41 | // Split a string in two parts, cleaning any whitespace. 42 | func split(str, sep string) (string, string) { 43 | var a, b string 44 | var parts = strings.SplitN(str, sep, 2) 45 | a = strings.TrimSpace(parts[0]) 46 | if len(parts) == 2 { 47 | b = strings.TrimSpace(parts[1]) 48 | } 49 | 50 | return a, b 51 | } 52 | -------------------------------------------------------------------------------- /middleware/content_charset_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/go-chi/chi/v5" 9 | ) 10 | 11 | func TestContentCharset(t *testing.T) { 12 | t.Parallel() 13 | 14 | var tests = []struct { 15 | name string 16 | inputValue string 17 | inputContentCharset []string 18 | want int 19 | }{ 20 | { 21 | "should accept requests with a matching charset", 22 | "application/json; charset=UTF-8", 23 | []string{"UTF-8"}, 24 | http.StatusOK, 25 | }, 26 | { 27 | "should be case-insensitive", 28 | "application/json; charset=utf-8", 29 | []string{"UTF-8"}, 30 | http.StatusOK, 31 | }, 32 | { 33 | "should accept requests with a matching charset with extra values", 34 | "application/json; foo=bar; charset=UTF-8; spam=eggs", 35 | []string{"UTF-8"}, 36 | http.StatusOK, 37 | }, 38 | { 39 | "should accept requests with a matching charset when multiple charsets are supported", 40 | "text/xml; charset=UTF-8", 41 | []string{"UTF-8", "Latin-1"}, 42 | http.StatusOK, 43 | }, 44 | { 45 | "should accept requests with no charset if empty charset headers are allowed", 46 | "text/xml", 47 | []string{"UTF-8", ""}, 48 | http.StatusOK, 49 | }, 50 | { 51 | "should not accept requests with no charset if empty charset headers are not allowed", 52 | "text/xml", 53 | []string{"UTF-8"}, 54 | http.StatusUnsupportedMediaType, 55 | }, 56 | { 57 | "should not accept requests with a mismatching charset", 58 | "text/plain; charset=Latin-1", 59 | []string{"UTF-8"}, 60 | http.StatusUnsupportedMediaType, 61 | }, 62 | { 63 | "should not accept requests with a mismatching charset even if empty charsets are allowed", 64 | "text/plain; charset=Latin-1", 65 | []string{"UTF-8", ""}, 66 | http.StatusUnsupportedMediaType, 67 | }, 68 | } 69 | 70 | for _, tt := range tests { 71 | var tt = tt 72 | t.Run(tt.name, func(t *testing.T) { 73 | t.Parallel() 74 | 75 | var recorder = httptest.NewRecorder() 76 | 77 | var r = chi.NewRouter() 78 | r.Use(ContentCharset(tt.inputContentCharset...)) 79 | r.Get("/", func(w http.ResponseWriter, r *http.Request) {}) 80 | 81 | var req, _ = http.NewRequest("GET", "/", nil) 82 | req.Header.Set("Content-Type", tt.inputValue) 83 | 84 | r.ServeHTTP(recorder, req) 85 | var res = recorder.Result() 86 | 87 | if res.StatusCode != tt.want { 88 | t.Errorf("response is incorrect, got %d, want %d", recorder.Code, tt.want) 89 | } 90 | }) 91 | } 92 | } 93 | 94 | func TestSplit(t *testing.T) { 95 | t.Parallel() 96 | 97 | var s1, s2 = split(" type1;type2 ", ";") 98 | 99 | if s1 != "type1" || s2 != "type2" { 100 | t.Errorf("Want type1, type2 got %s, %s", s1, s2) 101 | } 102 | 103 | s1, s2 = split("type1 ", ";") 104 | 105 | if s1 != "type1" { 106 | t.Errorf("Want \"type1\" got \"%s\"", s1) 107 | } 108 | if s2 != "" { 109 | t.Errorf("Want empty string got \"%s\"", s2) 110 | } 111 | } 112 | 113 | func TestContentEncoding(t *testing.T) { 114 | t.Parallel() 115 | 116 | if !contentEncoding("application/json; foo=bar; charset=utf-8; spam=eggs", []string{"utf-8"}...) { 117 | t.Error("Want true, got false") 118 | } 119 | 120 | if contentEncoding("text/plain; charset=latin-1", []string{"utf-8"}...) { 121 | t.Error("Want false, got true") 122 | } 123 | 124 | if !contentEncoding("text/xml; charset=UTF-8", []string{"latin-1", "utf-8"}...) { 125 | t.Error("Want true, got false") 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /middleware/content_encoding.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | ) 7 | 8 | // AllowContentEncoding enforces a whitelist of request Content-Encoding otherwise responds 9 | // with a 415 Unsupported Media Type status. 10 | func AllowContentEncoding(contentEncoding ...string) func(next http.Handler) http.Handler { 11 | allowedEncodings := make(map[string]struct{}, len(contentEncoding)) 12 | for _, encoding := range contentEncoding { 13 | allowedEncodings[strings.TrimSpace(strings.ToLower(encoding))] = struct{}{} 14 | } 15 | return func(next http.Handler) http.Handler { 16 | fn := func(w http.ResponseWriter, r *http.Request) { 17 | requestEncodings := r.Header["Content-Encoding"] 18 | // skip check for empty content body or no Content-Encoding 19 | if r.ContentLength == 0 { 20 | next.ServeHTTP(w, r) 21 | return 22 | } 23 | // All encodings in the request must be allowed 24 | for _, encoding := range requestEncodings { 25 | if _, ok := allowedEncodings[strings.TrimSpace(strings.ToLower(encoding))]; !ok { 26 | w.WriteHeader(http.StatusUnsupportedMediaType) 27 | return 28 | } 29 | } 30 | next.ServeHTTP(w, r) 31 | } 32 | return http.HandlerFunc(fn) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /middleware/content_encoding_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/go-chi/chi/v5" 10 | ) 11 | 12 | func TestContentEncodingMiddleware(t *testing.T) { 13 | t.Parallel() 14 | 15 | // support for: 16 | // Content-Encoding: gzip 17 | // Content-Encoding: deflate 18 | // Content-Encoding: gzip, deflate 19 | // Content-Encoding: deflate, gzip 20 | middleware := AllowContentEncoding("deflate", "gzip") 21 | 22 | tests := []struct { 23 | name string 24 | encodings []string 25 | expectedStatus int 26 | }{ 27 | { 28 | name: "Support no encoding", 29 | encodings: []string{}, 30 | expectedStatus: 200, 31 | }, 32 | { 33 | name: "Support gzip encoding", 34 | encodings: []string{"gzip"}, 35 | expectedStatus: 200, 36 | }, 37 | { 38 | name: "No support for br encoding", 39 | encodings: []string{"br"}, 40 | expectedStatus: 415, 41 | }, 42 | { 43 | name: "Support for gzip and deflate encoding", 44 | encodings: []string{"gzip", "deflate"}, 45 | expectedStatus: 200, 46 | }, 47 | { 48 | name: "Support for deflate and gzip encoding", 49 | encodings: []string{"deflate", "gzip"}, 50 | expectedStatus: 200, 51 | }, 52 | { 53 | name: "No support for deflate and br encoding", 54 | encodings: []string{"deflate", "br"}, 55 | expectedStatus: 415, 56 | }, 57 | } 58 | 59 | for _, tt := range tests { 60 | var tt = tt 61 | t.Run(tt.name, func(t *testing.T) { 62 | t.Parallel() 63 | 64 | body := []byte("This is my content. There are many like this but this one is mine") 65 | r := httptest.NewRequest("POST", "/", bytes.NewReader(body)) 66 | for _, encoding := range tt.encodings { 67 | r.Header.Set("Content-Encoding", encoding) 68 | } 69 | 70 | w := httptest.NewRecorder() 71 | router := chi.NewRouter() 72 | router.Use(middleware) 73 | router.Post("/", func(w http.ResponseWriter, r *http.Request) {}) 74 | 75 | router.ServeHTTP(w, r) 76 | res := w.Result() 77 | if res.StatusCode != tt.expectedStatus { 78 | t.Errorf("response is incorrect, got %d, want %d", w.Code, tt.expectedStatus) 79 | } 80 | }) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /middleware/content_type.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | ) 7 | 8 | // SetHeader is a convenience handler to set a response header key/value 9 | func SetHeader(key, value string) func(http.Handler) http.Handler { 10 | return func(next http.Handler) http.Handler { 11 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 12 | w.Header().Set(key, value) 13 | next.ServeHTTP(w, r) 14 | }) 15 | } 16 | } 17 | 18 | // AllowContentType enforces a whitelist of request Content-Types otherwise responds 19 | // with a 415 Unsupported Media Type status. 20 | func AllowContentType(contentTypes ...string) func(http.Handler) http.Handler { 21 | allowedContentTypes := make(map[string]struct{}, len(contentTypes)) 22 | for _, ctype := range contentTypes { 23 | allowedContentTypes[strings.TrimSpace(strings.ToLower(ctype))] = struct{}{} 24 | } 25 | 26 | return func(next http.Handler) http.Handler { 27 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 28 | if r.ContentLength == 0 { 29 | // Skip check for empty content body 30 | next.ServeHTTP(w, r) 31 | return 32 | } 33 | 34 | s := strings.ToLower(strings.TrimSpace(strings.Split(r.Header.Get("Content-Type"), ";")[0])) 35 | 36 | if _, ok := allowedContentTypes[s]; ok { 37 | next.ServeHTTP(w, r) 38 | return 39 | } 40 | 41 | w.WriteHeader(http.StatusUnsupportedMediaType) 42 | }) 43 | } 44 | } 45 | 46 | -------------------------------------------------------------------------------- /middleware/content_type_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/go-chi/chi/v5" 10 | ) 11 | 12 | func TestContentType(t *testing.T) { 13 | t.Parallel() 14 | 15 | var tests = []struct { 16 | name string 17 | inputValue string 18 | allowedContentTypes []string 19 | want int 20 | }{ 21 | { 22 | "should accept requests with a matching content type", 23 | "application/json; charset=UTF-8", 24 | []string{"application/json"}, 25 | http.StatusOK, 26 | }, 27 | { 28 | "should accept requests with a matching content type no charset", 29 | "application/json", 30 | []string{"application/json"}, 31 | http.StatusOK, 32 | }, 33 | { 34 | "should accept requests with a matching content-type with extra values", 35 | "application/json; foo=bar; charset=UTF-8; spam=eggs", 36 | []string{"application/json"}, 37 | http.StatusOK, 38 | }, 39 | { 40 | "should accept requests with a matching content type when multiple content types are supported", 41 | "text/xml; charset=UTF-8", 42 | []string{"application/json", "text/xml"}, 43 | http.StatusOK, 44 | }, 45 | { 46 | "should not accept requests with a mismatching content type", 47 | "text/plain; charset=latin-1", 48 | []string{"application/json"}, 49 | http.StatusUnsupportedMediaType, 50 | }, 51 | { 52 | "should not accept requests with a mismatching content type even if multiple content types are allowed", 53 | "text/plain; charset=Latin-1", 54 | []string{"application/json", "text/xml"}, 55 | http.StatusUnsupportedMediaType, 56 | }, 57 | } 58 | 59 | for _, tt := range tests { 60 | var tt = tt 61 | t.Run(tt.name, func(t *testing.T) { 62 | t.Parallel() 63 | 64 | recorder := httptest.NewRecorder() 65 | 66 | r := chi.NewRouter() 67 | r.Use(AllowContentType(tt.allowedContentTypes...)) 68 | r.Post("/", func(w http.ResponseWriter, r *http.Request) {}) 69 | 70 | body := []byte("This is my content. There are many like this but this one is mine") 71 | req := httptest.NewRequest("POST", "/", bytes.NewReader(body)) 72 | req.Header.Set("Content-Type", tt.inputValue) 73 | 74 | r.ServeHTTP(recorder, req) 75 | res := recorder.Result() 76 | 77 | if res.StatusCode != tt.want { 78 | t.Errorf("response is incorrect, got %d, want %d", recorder.Code, tt.want) 79 | } 80 | }) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /middleware/get_head.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/go-chi/chi/v5" 7 | ) 8 | 9 | // GetHead automatically route undefined HEAD requests to GET handlers. 10 | func GetHead(next http.Handler) http.Handler { 11 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 12 | if r.Method == "HEAD" { 13 | rctx := chi.RouteContext(r.Context()) 14 | routePath := rctx.RoutePath 15 | if routePath == "" { 16 | if r.URL.RawPath != "" { 17 | routePath = r.URL.RawPath 18 | } else { 19 | routePath = r.URL.Path 20 | } 21 | } 22 | 23 | // Temporary routing context to look-ahead before routing the request 24 | tctx := chi.NewRouteContext() 25 | 26 | // Attempt to find a HEAD handler for the routing path, if not found, traverse 27 | // the router as through its a GET route, but proceed with the request 28 | // with the HEAD method. 29 | if !rctx.Routes.Match(tctx, "HEAD", routePath) { 30 | rctx.RouteMethod = "GET" 31 | rctx.RoutePath = routePath 32 | next.ServeHTTP(w, r) 33 | return 34 | } 35 | } 36 | 37 | next.ServeHTTP(w, r) 38 | }) 39 | } 40 | -------------------------------------------------------------------------------- /middleware/get_head_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/go-chi/chi/v5" 9 | ) 10 | 11 | func TestGetHead(t *testing.T) { 12 | r := chi.NewRouter() 13 | r.Use(GetHead) 14 | r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { 15 | w.Header().Set("X-Test", "yes") 16 | w.Write([]byte("bye")) 17 | }) 18 | r.Route("/articles", func(r chi.Router) { 19 | r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { 20 | id := chi.URLParam(r, "id") 21 | w.Header().Set("X-Article", id) 22 | w.Write([]byte("article:" + id)) 23 | }) 24 | }) 25 | r.Route("/users", func(r chi.Router) { 26 | r.Head("/{id}", func(w http.ResponseWriter, r *http.Request) { 27 | w.Header().Set("X-User", "-") 28 | w.Write([]byte("user")) 29 | }) 30 | r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { 31 | id := chi.URLParam(r, "id") 32 | w.Header().Set("X-User", id) 33 | w.Write([]byte("user:" + id)) 34 | }) 35 | }) 36 | 37 | ts := httptest.NewServer(r) 38 | defer ts.Close() 39 | 40 | if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" { 41 | t.Fatal(body) 42 | } 43 | if req, body := testRequest(t, ts, "HEAD", "/hi", nil); body != "" || req.Header.Get("X-Test") != "yes" { 44 | t.Fatal(body) 45 | } 46 | if _, body := testRequest(t, ts, "GET", "/", nil); body != "404 page not found\n" { 47 | t.Fatal(body) 48 | } 49 | if req, body := testRequest(t, ts, "HEAD", "/", nil); body != "" || req.StatusCode != 404 { 50 | t.Fatal(body) 51 | } 52 | 53 | if _, body := testRequest(t, ts, "GET", "/articles/5", nil); body != "article:5" { 54 | t.Fatal(body) 55 | } 56 | if req, body := testRequest(t, ts, "HEAD", "/articles/5", nil); body != "" || req.Header.Get("X-Article") != "5" { 57 | t.Fatalf("expecting X-Article header '5' but got '%s'", req.Header.Get("X-Article")) 58 | } 59 | 60 | if _, body := testRequest(t, ts, "GET", "/users/1", nil); body != "user:1" { 61 | t.Fatal(body) 62 | } 63 | if req, body := testRequest(t, ts, "HEAD", "/users/1", nil); body != "" || req.Header.Get("X-User") != "-" { 64 | t.Fatalf("expecting X-User header '-' but got '%s'", req.Header.Get("X-User")) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /middleware/heartbeat.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | ) 7 | 8 | // Heartbeat endpoint middleware useful to setting up a path like 9 | // `/ping` that load balancers or uptime testing external services 10 | // can make a request before hitting any routes. It's also convenient 11 | // to place this above ACL middlewares as well. 12 | func Heartbeat(endpoint string) func(http.Handler) http.Handler { 13 | f := func(h http.Handler) http.Handler { 14 | fn := func(w http.ResponseWriter, r *http.Request) { 15 | if (r.Method == "GET" || r.Method == "HEAD") && strings.EqualFold(r.URL.Path, endpoint) { 16 | w.Header().Set("Content-Type", "text/plain") 17 | w.WriteHeader(http.StatusOK) 18 | w.Write([]byte(".")) 19 | return 20 | } 21 | h.ServeHTTP(w, r) 22 | } 23 | return http.HandlerFunc(fn) 24 | } 25 | return f 26 | } 27 | -------------------------------------------------------------------------------- /middleware/logger.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "log" 7 | "net/http" 8 | "os" 9 | "runtime" 10 | "time" 11 | ) 12 | 13 | var ( 14 | // LogEntryCtxKey is the context.Context key to store the request log entry. 15 | LogEntryCtxKey = &contextKey{"LogEntry"} 16 | 17 | // DefaultLogger is called by the Logger middleware handler to log each request. 18 | // Its made a package-level variable so that it can be reconfigured for custom 19 | // logging configurations. 20 | DefaultLogger func(next http.Handler) http.Handler 21 | ) 22 | 23 | // Logger is a middleware that logs the start and end of each request, along 24 | // with some useful data about what was requested, what the response status was, 25 | // and how long it took to return. When standard output is a TTY, Logger will 26 | // print in color, otherwise it will print in black and white. Logger prints a 27 | // request ID if one is provided. 28 | // 29 | // Alternatively, look at https://github.com/goware/httplog for a more in-depth 30 | // http logger with structured logging support. 31 | // 32 | // IMPORTANT NOTE: Logger should go before any other middleware that may change 33 | // the response, such as middleware.Recoverer. Example: 34 | // 35 | // r := chi.NewRouter() 36 | // r.Use(middleware.Logger) // <--<< Logger should come before Recoverer 37 | // r.Use(middleware.Recoverer) 38 | // r.Get("/", handler) 39 | func Logger(next http.Handler) http.Handler { 40 | return DefaultLogger(next) 41 | } 42 | 43 | // RequestLogger returns a logger handler using a custom LogFormatter. 44 | func RequestLogger(f LogFormatter) func(next http.Handler) http.Handler { 45 | return func(next http.Handler) http.Handler { 46 | fn := func(w http.ResponseWriter, r *http.Request) { 47 | entry := f.NewLogEntry(r) 48 | ww := NewWrapResponseWriter(w, r.ProtoMajor) 49 | 50 | t1 := time.Now() 51 | defer func() { 52 | entry.Write(ww.Status(), ww.BytesWritten(), ww.Header(), time.Since(t1), nil) 53 | }() 54 | 55 | next.ServeHTTP(ww, WithLogEntry(r, entry)) 56 | } 57 | return http.HandlerFunc(fn) 58 | } 59 | } 60 | 61 | // LogFormatter initiates the beginning of a new LogEntry per request. 62 | // See DefaultLogFormatter for an example implementation. 63 | type LogFormatter interface { 64 | NewLogEntry(r *http.Request) LogEntry 65 | } 66 | 67 | // LogEntry records the final log when a request completes. 68 | // See defaultLogEntry for an example implementation. 69 | type LogEntry interface { 70 | Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) 71 | Panic(v interface{}, stack []byte) 72 | } 73 | 74 | // GetLogEntry returns the in-context LogEntry for a request. 75 | func GetLogEntry(r *http.Request) LogEntry { 76 | entry, _ := r.Context().Value(LogEntryCtxKey).(LogEntry) 77 | return entry 78 | } 79 | 80 | // WithLogEntry sets the in-context LogEntry for a request. 81 | func WithLogEntry(r *http.Request, entry LogEntry) *http.Request { 82 | r = r.WithContext(context.WithValue(r.Context(), LogEntryCtxKey, entry)) 83 | return r 84 | } 85 | 86 | // LoggerInterface accepts printing to stdlib logger or compatible logger. 87 | type LoggerInterface interface { 88 | Print(v ...interface{}) 89 | } 90 | 91 | // DefaultLogFormatter is a simple logger that implements a LogFormatter. 92 | type DefaultLogFormatter struct { 93 | Logger LoggerInterface 94 | NoColor bool 95 | } 96 | 97 | // NewLogEntry creates a new LogEntry for the request. 98 | func (l *DefaultLogFormatter) NewLogEntry(r *http.Request) LogEntry { 99 | useColor := !l.NoColor 100 | entry := &defaultLogEntry{ 101 | DefaultLogFormatter: l, 102 | request: r, 103 | buf: &bytes.Buffer{}, 104 | useColor: useColor, 105 | } 106 | 107 | reqID := GetReqID(r.Context()) 108 | if reqID != "" { 109 | cW(entry.buf, useColor, nYellow, "[%s] ", reqID) 110 | } 111 | cW(entry.buf, useColor, nCyan, "\"") 112 | cW(entry.buf, useColor, bMagenta, "%s ", r.Method) 113 | 114 | scheme := "http" 115 | if r.TLS != nil { 116 | scheme = "https" 117 | } 118 | cW(entry.buf, useColor, nCyan, "%s://%s%s %s\" ", scheme, r.Host, r.RequestURI, r.Proto) 119 | 120 | entry.buf.WriteString("from ") 121 | entry.buf.WriteString(r.RemoteAddr) 122 | entry.buf.WriteString(" - ") 123 | 124 | return entry 125 | } 126 | 127 | type defaultLogEntry struct { 128 | *DefaultLogFormatter 129 | request *http.Request 130 | buf *bytes.Buffer 131 | useColor bool 132 | } 133 | 134 | func (l *defaultLogEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) { 135 | switch { 136 | case status < 200: 137 | cW(l.buf, l.useColor, bBlue, "%03d", status) 138 | case status < 300: 139 | cW(l.buf, l.useColor, bGreen, "%03d", status) 140 | case status < 400: 141 | cW(l.buf, l.useColor, bCyan, "%03d", status) 142 | case status < 500: 143 | cW(l.buf, l.useColor, bYellow, "%03d", status) 144 | default: 145 | cW(l.buf, l.useColor, bRed, "%03d", status) 146 | } 147 | 148 | cW(l.buf, l.useColor, bBlue, " %dB", bytes) 149 | 150 | l.buf.WriteString(" in ") 151 | if elapsed < 500*time.Millisecond { 152 | cW(l.buf, l.useColor, nGreen, "%s", elapsed) 153 | } else if elapsed < 5*time.Second { 154 | cW(l.buf, l.useColor, nYellow, "%s", elapsed) 155 | } else { 156 | cW(l.buf, l.useColor, nRed, "%s", elapsed) 157 | } 158 | 159 | l.Logger.Print(l.buf.String()) 160 | } 161 | 162 | func (l *defaultLogEntry) Panic(v interface{}, stack []byte) { 163 | PrintPrettyStack(v) 164 | } 165 | 166 | func init() { 167 | color := true 168 | if runtime.GOOS == "windows" { 169 | color = false 170 | } 171 | DefaultLogger = RequestLogger(&DefaultLogFormatter{Logger: log.New(os.Stdout, "", log.LstdFlags), NoColor: !color}) 172 | } 173 | -------------------------------------------------------------------------------- /middleware/logger_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "net" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | type testLoggerWriter struct { 14 | *httptest.ResponseRecorder 15 | } 16 | 17 | func (cw testLoggerWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 18 | return nil, nil, nil 19 | } 20 | 21 | func TestRequestLogger(t *testing.T) { 22 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 23 | _, ok := w.(http.Hijacker) 24 | if !ok { 25 | t.Errorf("http.Hijacker is unavailable on the writer. add the interface methods.") 26 | } 27 | }) 28 | 29 | r := httptest.NewRequest("GET", "/", nil) 30 | w := testLoggerWriter{ 31 | ResponseRecorder: httptest.NewRecorder(), 32 | } 33 | 34 | handler := DefaultLogger(testHandler) 35 | handler.ServeHTTP(w, r) 36 | } 37 | 38 | func TestRequestLoggerReadFrom(t *testing.T) { 39 | data := []byte("file data") 40 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 41 | http.ServeContent(w, r, "file", time.Time{}, bytes.NewReader(data)) 42 | }) 43 | 44 | r := httptest.NewRequest("GET", "/", nil) 45 | w := httptest.NewRecorder() 46 | 47 | handler := DefaultLogger(testHandler) 48 | handler.ServeHTTP(w, r) 49 | 50 | assertEqual(t, data, w.Body.Bytes()) 51 | } 52 | -------------------------------------------------------------------------------- /middleware/maybe.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import "net/http" 4 | 5 | // Maybe middleware will allow you to change the flow of the middleware stack execution depending on return 6 | // value of maybeFn(request). This is useful for example if you'd like to skip a middleware handler if 7 | // a request does not satisfy the maybeFn logic. 8 | func Maybe(mw func(http.Handler) http.Handler, maybeFn func(r *http.Request) bool) func(http.Handler) http.Handler { 9 | return func(next http.Handler) http.Handler { 10 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 11 | if maybeFn(r) { 12 | mw(next).ServeHTTP(w, r) 13 | } else { 14 | next.ServeHTTP(w, r) 15 | } 16 | }) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /middleware/middleware.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import "net/http" 4 | 5 | // New will create a new middleware handler from a http.Handler. 6 | func New(h http.Handler) func(next http.Handler) http.Handler { 7 | return func(next http.Handler) http.Handler { 8 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 9 | h.ServeHTTP(w, r) 10 | }) 11 | } 12 | } 13 | 14 | // contextKey is a value for use with context.WithValue. It's used as 15 | // a pointer so it fits in an interface{} without allocation. This technique 16 | // for defining context keys was copied from Go 1.7's new use of context in net/http. 17 | type contextKey struct { 18 | name string 19 | } 20 | 21 | func (k *contextKey) String() string { 22 | return "chi/middleware context value " + k.name 23 | } 24 | -------------------------------------------------------------------------------- /middleware/middleware_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "crypto/tls" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "path" 9 | "reflect" 10 | "runtime" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | var testdataDir string 16 | 17 | func init() { 18 | _, filename, _, _ := runtime.Caller(0) 19 | testdataDir = path.Join(path.Dir(filename), "/../testdata") 20 | } 21 | 22 | func TestWrapWriterHTTP2(t *testing.T) { 23 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 24 | if r.Proto != "HTTP/2.0" { 25 | t.Fatalf("request proto should be HTTP/2.0 but was %s", r.Proto) 26 | } 27 | _, fl := w.(http.Flusher) 28 | if !fl { 29 | t.Fatal("request should have been a http.Flusher") 30 | } 31 | _, hj := w.(http.Hijacker) 32 | if hj { 33 | t.Fatal("request should not have been a http.Hijacker") 34 | } 35 | _, rf := w.(io.ReaderFrom) 36 | if rf { 37 | t.Fatal("request should not have been an io.ReaderFrom") 38 | } 39 | _, ps := w.(http.Pusher) 40 | if !ps { 41 | t.Fatal("request should have been a http.Pusher") 42 | } 43 | 44 | w.Write([]byte("OK")) 45 | }) 46 | 47 | wmw := func(next http.Handler) http.Handler { 48 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 49 | next.ServeHTTP(NewWrapResponseWriter(w, r.ProtoMajor), r) 50 | }) 51 | } 52 | 53 | server := http.Server{ 54 | Addr: ":7072", 55 | Handler: wmw(handler), 56 | } 57 | // By serving over TLS, we get HTTP2 requests 58 | go server.ListenAndServeTLS(testdataDir+"/cert.pem", testdataDir+"/key.pem") 59 | defer server.Close() 60 | // We need the server to start before making the request 61 | time.Sleep(100 * time.Millisecond) 62 | 63 | client := &http.Client{ 64 | Transport: &http.Transport{ 65 | TLSClientConfig: &tls.Config{ 66 | // The certificates we are using are self signed 67 | InsecureSkipVerify: true, 68 | }, 69 | ForceAttemptHTTP2: true, 70 | }, 71 | } 72 | 73 | resp, err := client.Get("https://localhost:7072") 74 | if err != nil { 75 | t.Fatalf("could not get server: %v", err) 76 | } 77 | if resp.StatusCode != 200 { 78 | t.Fatalf("non 200 response: %v", resp.StatusCode) 79 | } 80 | } 81 | 82 | func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) { 83 | req, err := http.NewRequest(method, ts.URL+path, body) 84 | if err != nil { 85 | t.Fatal(err) 86 | return nil, "" 87 | } 88 | 89 | resp, err := http.DefaultClient.Do(req) 90 | if err != nil { 91 | t.Fatal(err) 92 | return nil, "" 93 | } 94 | 95 | respBody, err := io.ReadAll(resp.Body) 96 | if err != nil { 97 | t.Fatal(err) 98 | return nil, "" 99 | } 100 | defer resp.Body.Close() 101 | 102 | return resp, string(respBody) 103 | } 104 | 105 | func testRequestNoRedirect(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) { 106 | req, err := http.NewRequest(method, ts.URL+path, body) 107 | if err != nil { 108 | t.Fatal(err) 109 | return nil, "" 110 | } 111 | 112 | // http client that doesn't redirect 113 | httpClient := &http.Client{ 114 | CheckRedirect: func(req *http.Request, via []*http.Request) error { 115 | return http.ErrUseLastResponse 116 | }, 117 | } 118 | 119 | resp, err := httpClient.Do(req) 120 | if err != nil { 121 | t.Fatal(err) 122 | return nil, "" 123 | } 124 | 125 | respBody, err := io.ReadAll(resp.Body) 126 | if err != nil { 127 | t.Fatal(err) 128 | return nil, "" 129 | } 130 | defer resp.Body.Close() 131 | 132 | return resp, string(respBody) 133 | } 134 | 135 | func assertNoError(t *testing.T, err error) { 136 | t.Helper() 137 | if err != nil { 138 | t.Fatalf("expecting no error") 139 | } 140 | } 141 | 142 | func assertError(t *testing.T, err error) { 143 | t.Helper() 144 | if err == nil { 145 | t.Fatalf("expecting error") 146 | } 147 | } 148 | 149 | func assertEqual(t *testing.T, a, b interface{}) { 150 | t.Helper() 151 | if !reflect.DeepEqual(a, b) { 152 | t.Fatalf("expecting values to be equal but got: '%v' and '%v'", a, b) 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /middleware/nocache.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | // Ported from Goji's middleware, source: 4 | // https://github.com/zenazn/goji/tree/master/web/middleware 5 | 6 | import ( 7 | "net/http" 8 | "time" 9 | ) 10 | 11 | // Unix epoch time 12 | var epoch = time.Unix(0, 0).UTC().Format(http.TimeFormat) 13 | 14 | // Taken from https://github.com/mytrile/nocache 15 | var noCacheHeaders = map[string]string{ 16 | "Expires": epoch, 17 | "Cache-Control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0", 18 | "Pragma": "no-cache", 19 | "X-Accel-Expires": "0", 20 | } 21 | 22 | var etagHeaders = []string{ 23 | "ETag", 24 | "If-Modified-Since", 25 | "If-Match", 26 | "If-None-Match", 27 | "If-Range", 28 | "If-Unmodified-Since", 29 | } 30 | 31 | // NoCache is a simple piece of middleware that sets a number of HTTP headers to prevent 32 | // a router (or subrouter) from being cached by an upstream proxy and/or client. 33 | // 34 | // As per http://wiki.nginx.org/HttpProxyModule - NoCache sets: 35 | // 36 | // Expires: Thu, 01 Jan 1970 00:00:00 UTC 37 | // Cache-Control: no-cache, private, max-age=0 38 | // X-Accel-Expires: 0 39 | // Pragma: no-cache (for HTTP/1.0 proxies/clients) 40 | func NoCache(h http.Handler) http.Handler { 41 | fn := func(w http.ResponseWriter, r *http.Request) { 42 | 43 | // Delete any ETag headers that may have been set 44 | for _, v := range etagHeaders { 45 | if r.Header.Get(v) != "" { 46 | r.Header.Del(v) 47 | } 48 | } 49 | 50 | // Set our NoCache headers 51 | for k, v := range noCacheHeaders { 52 | w.Header().Set(k, v) 53 | } 54 | 55 | h.ServeHTTP(w, r) 56 | } 57 | 58 | return http.HandlerFunc(fn) 59 | } 60 | -------------------------------------------------------------------------------- /middleware/page_route.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | ) 7 | 8 | // PageRoute is a simple middleware which allows you to route a static GET request 9 | // at the middleware stack level. 10 | func PageRoute(path string, handler http.Handler) func(http.Handler) http.Handler { 11 | return func(next http.Handler) http.Handler { 12 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 13 | if r.Method == "GET" && strings.EqualFold(r.URL.Path, path) { 14 | handler.ServeHTTP(w, r) 15 | return 16 | } 17 | next.ServeHTTP(w, r) 18 | }) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /middleware/path_rewrite.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | ) 7 | 8 | // PathRewrite is a simple middleware which allows you to rewrite the request URL path. 9 | func PathRewrite(old, new string) func(http.Handler) http.Handler { 10 | return func(next http.Handler) http.Handler { 11 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 12 | r.URL.Path = strings.Replace(r.URL.Path, old, new, 1) 13 | next.ServeHTTP(w, r) 14 | }) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /middleware/profiler.go: -------------------------------------------------------------------------------- 1 | //go:build !tinygo 2 | // +build !tinygo 3 | 4 | package middleware 5 | 6 | import ( 7 | "expvar" 8 | "net/http" 9 | "net/http/pprof" 10 | 11 | "github.com/go-chi/chi/v5" 12 | ) 13 | 14 | // Profiler is a convenient subrouter used for mounting net/http/pprof. ie. 15 | // 16 | // func MyService() http.Handler { 17 | // r := chi.NewRouter() 18 | // // ..middlewares 19 | // r.Mount("/debug", middleware.Profiler()) 20 | // // ..routes 21 | // return r 22 | // } 23 | func Profiler() http.Handler { 24 | r := chi.NewRouter() 25 | r.Use(NoCache) 26 | 27 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 28 | http.Redirect(w, r, r.RequestURI+"/pprof/", http.StatusMovedPermanently) 29 | }) 30 | r.HandleFunc("/pprof", func(w http.ResponseWriter, r *http.Request) { 31 | http.Redirect(w, r, r.RequestURI+"/", http.StatusMovedPermanently) 32 | }) 33 | 34 | r.HandleFunc("/pprof/*", pprof.Index) 35 | r.HandleFunc("/pprof/cmdline", pprof.Cmdline) 36 | r.HandleFunc("/pprof/profile", pprof.Profile) 37 | r.HandleFunc("/pprof/symbol", pprof.Symbol) 38 | r.HandleFunc("/pprof/trace", pprof.Trace) 39 | r.Handle("/vars", expvar.Handler()) 40 | 41 | r.Handle("/pprof/goroutine", pprof.Handler("goroutine")) 42 | r.Handle("/pprof/threadcreate", pprof.Handler("threadcreate")) 43 | r.Handle("/pprof/mutex", pprof.Handler("mutex")) 44 | r.Handle("/pprof/heap", pprof.Handler("heap")) 45 | r.Handle("/pprof/block", pprof.Handler("block")) 46 | r.Handle("/pprof/allocs", pprof.Handler("allocs")) 47 | 48 | return r 49 | } 50 | -------------------------------------------------------------------------------- /middleware/realip.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | // Ported from Goji's middleware, source: 4 | // https://github.com/zenazn/goji/tree/master/web/middleware 5 | 6 | import ( 7 | "net" 8 | "net/http" 9 | "strings" 10 | ) 11 | 12 | var trueClientIP = http.CanonicalHeaderKey("True-Client-IP") 13 | var xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") 14 | var xRealIP = http.CanonicalHeaderKey("X-Real-IP") 15 | 16 | // RealIP is a middleware that sets a http.Request's RemoteAddr to the results 17 | // of parsing either the True-Client-IP, X-Real-IP or the X-Forwarded-For headers 18 | // (in that order). 19 | // 20 | // This middleware should be inserted fairly early in the middleware stack to 21 | // ensure that subsequent layers (e.g., request loggers) which examine the 22 | // RemoteAddr will see the intended value. 23 | // 24 | // You should only use this middleware if you can trust the headers passed to 25 | // you (in particular, the three headers this middleware uses), for example 26 | // because you have placed a reverse proxy like HAProxy or nginx in front of 27 | // chi. If your reverse proxies are configured to pass along arbitrary header 28 | // values from the client, or if you use this middleware without a reverse 29 | // proxy, malicious clients will be able to make you very sad (or, depending on 30 | // how you're using RemoteAddr, vulnerable to an attack of some sort). 31 | func RealIP(h http.Handler) http.Handler { 32 | fn := func(w http.ResponseWriter, r *http.Request) { 33 | if rip := realIP(r); rip != "" { 34 | r.RemoteAddr = rip 35 | } 36 | h.ServeHTTP(w, r) 37 | } 38 | 39 | return http.HandlerFunc(fn) 40 | } 41 | 42 | func realIP(r *http.Request) string { 43 | var ip string 44 | 45 | if tcip := r.Header.Get(trueClientIP); tcip != "" { 46 | ip = tcip 47 | } else if xrip := r.Header.Get(xRealIP); xrip != "" { 48 | ip = xrip 49 | } else if xff := r.Header.Get(xForwardedFor); xff != "" { 50 | ip, _, _ = strings.Cut(xff, ",") 51 | } 52 | if ip == "" || net.ParseIP(ip) == nil { 53 | return "" 54 | } 55 | return ip 56 | } 57 | -------------------------------------------------------------------------------- /middleware/realip_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/go-chi/chi/v5" 9 | ) 10 | 11 | func TestXRealIP(t *testing.T) { 12 | req, _ := http.NewRequest("GET", "/", nil) 13 | req.Header.Add("X-Real-IP", "100.100.100.100") 14 | w := httptest.NewRecorder() 15 | 16 | r := chi.NewRouter() 17 | r.Use(RealIP) 18 | 19 | realIP := "" 20 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 21 | realIP = r.RemoteAddr 22 | w.Write([]byte("Hello World")) 23 | }) 24 | r.ServeHTTP(w, req) 25 | 26 | if w.Code != 200 { 27 | t.Fatal("Response Code should be 200") 28 | } 29 | 30 | if realIP != "100.100.100.100" { 31 | t.Fatal("Test get real IP error.") 32 | } 33 | } 34 | 35 | func TestXForwardForIP(t *testing.T) { 36 | xForwardedForIPs := []string{ 37 | "100.100.100.100", 38 | "100.100.100.100, 200.200.200.200", 39 | "100.100.100.100,200.200.200.200", 40 | } 41 | 42 | r := chi.NewRouter() 43 | r.Use(RealIP) 44 | 45 | for _, v := range xForwardedForIPs { 46 | req, _ := http.NewRequest("GET", "/", nil) 47 | req.Header.Add("X-Forwarded-For", v) 48 | 49 | w := httptest.NewRecorder() 50 | 51 | realIP := "" 52 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 53 | realIP = r.RemoteAddr 54 | w.Write([]byte("Hello World")) 55 | }) 56 | r.ServeHTTP(w, req) 57 | 58 | if w.Code != 200 { 59 | t.Fatal("Response Code should be 200") 60 | } 61 | 62 | if realIP != "100.100.100.100" { 63 | t.Fatal("Test get real IP error.") 64 | } 65 | } 66 | } 67 | 68 | func TestXForwardForXRealIPPrecedence(t *testing.T) { 69 | req, _ := http.NewRequest("GET", "/", nil) 70 | req.Header.Add("X-Forwarded-For", "0.0.0.0") 71 | req.Header.Add("X-Real-IP", "100.100.100.100") 72 | w := httptest.NewRecorder() 73 | 74 | r := chi.NewRouter() 75 | r.Use(RealIP) 76 | 77 | realIP := "" 78 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 79 | realIP = r.RemoteAddr 80 | w.Write([]byte("Hello World")) 81 | }) 82 | r.ServeHTTP(w, req) 83 | 84 | if w.Code != 200 { 85 | t.Fatal("Response Code should be 200") 86 | } 87 | 88 | if realIP != "100.100.100.100" { 89 | t.Fatal("Test get real IP precedence error.") 90 | } 91 | } 92 | 93 | func TestInvalidIP(t *testing.T) { 94 | req, _ := http.NewRequest("GET", "/", nil) 95 | req.Header.Add("X-Real-IP", "100.100.100.1000") 96 | w := httptest.NewRecorder() 97 | 98 | r := chi.NewRouter() 99 | r.Use(RealIP) 100 | 101 | realIP := "" 102 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 103 | realIP = r.RemoteAddr 104 | w.Write([]byte("Hello World")) 105 | }) 106 | r.ServeHTTP(w, req) 107 | 108 | if w.Code != 200 { 109 | t.Fatal("Response Code should be 200") 110 | } 111 | 112 | if realIP != "" { 113 | t.Fatal("Invalid IP used.") 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /middleware/recoverer.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | // The original work was derived from Goji's middleware, source: 4 | // https://github.com/zenazn/goji/tree/master/web/middleware 5 | 6 | import ( 7 | "bytes" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "net/http" 12 | "os" 13 | "runtime/debug" 14 | "strings" 15 | ) 16 | 17 | // Recoverer is a middleware that recovers from panics, logs the panic (and a 18 | // backtrace), and returns a HTTP 500 (Internal Server Error) status if 19 | // possible. Recoverer prints a request ID if one is provided. 20 | // 21 | // Alternatively, look at https://github.com/go-chi/httplog middleware pkgs. 22 | func Recoverer(next http.Handler) http.Handler { 23 | fn := func(w http.ResponseWriter, r *http.Request) { 24 | defer func() { 25 | if rvr := recover(); rvr != nil { 26 | if rvr == http.ErrAbortHandler { 27 | // we don't recover http.ErrAbortHandler so the response 28 | // to the client is aborted, this should not be logged 29 | panic(rvr) 30 | } 31 | 32 | logEntry := GetLogEntry(r) 33 | if logEntry != nil { 34 | logEntry.Panic(rvr, debug.Stack()) 35 | } else { 36 | PrintPrettyStack(rvr) 37 | } 38 | 39 | if r.Header.Get("Connection") != "Upgrade" { 40 | w.WriteHeader(http.StatusInternalServerError) 41 | } 42 | } 43 | }() 44 | 45 | next.ServeHTTP(w, r) 46 | } 47 | 48 | return http.HandlerFunc(fn) 49 | } 50 | 51 | // for ability to test the PrintPrettyStack function 52 | var recovererErrorWriter io.Writer = os.Stderr 53 | 54 | func PrintPrettyStack(rvr interface{}) { 55 | debugStack := debug.Stack() 56 | s := prettyStack{} 57 | out, err := s.parse(debugStack, rvr) 58 | if err == nil { 59 | recovererErrorWriter.Write(out) 60 | } else { 61 | // print stdlib output as a fallback 62 | os.Stderr.Write(debugStack) 63 | } 64 | } 65 | 66 | type prettyStack struct { 67 | } 68 | 69 | func (s prettyStack) parse(debugStack []byte, rvr interface{}) ([]byte, error) { 70 | var err error 71 | useColor := true 72 | buf := &bytes.Buffer{} 73 | 74 | cW(buf, false, bRed, "\n") 75 | cW(buf, useColor, bCyan, " panic: ") 76 | cW(buf, useColor, bBlue, "%v", rvr) 77 | cW(buf, false, bWhite, "\n \n") 78 | 79 | // process debug stack info 80 | stack := strings.Split(string(debugStack), "\n") 81 | lines := []string{} 82 | 83 | // locate panic line, as we may have nested panics 84 | for i := len(stack) - 1; i > 0; i-- { 85 | lines = append(lines, stack[i]) 86 | if strings.HasPrefix(stack[i], "panic(") { 87 | lines = lines[0 : len(lines)-2] // remove boilerplate 88 | break 89 | } 90 | } 91 | 92 | // reverse 93 | for i := len(lines)/2 - 1; i >= 0; i-- { 94 | opp := len(lines) - 1 - i 95 | lines[i], lines[opp] = lines[opp], lines[i] 96 | } 97 | 98 | // decorate 99 | for i, line := range lines { 100 | lines[i], err = s.decorateLine(line, useColor, i) 101 | if err != nil { 102 | return nil, err 103 | } 104 | } 105 | 106 | for _, l := range lines { 107 | fmt.Fprintf(buf, "%s", l) 108 | } 109 | return buf.Bytes(), nil 110 | } 111 | 112 | func (s prettyStack) decorateLine(line string, useColor bool, num int) (string, error) { 113 | line = strings.TrimSpace(line) 114 | if strings.HasPrefix(line, "\t") || strings.Contains(line, ".go:") { 115 | return s.decorateSourceLine(line, useColor, num) 116 | } 117 | if strings.HasSuffix(line, ")") { 118 | return s.decorateFuncCallLine(line, useColor, num) 119 | } 120 | if strings.HasPrefix(line, "\t") { 121 | return strings.Replace(line, "\t", " ", 1), nil 122 | } 123 | return fmt.Sprintf(" %s\n", line), nil 124 | } 125 | 126 | func (s prettyStack) decorateFuncCallLine(line string, useColor bool, num int) (string, error) { 127 | idx := strings.LastIndex(line, "(") 128 | if idx < 0 { 129 | return "", errors.New("not a func call line") 130 | } 131 | 132 | buf := &bytes.Buffer{} 133 | pkg := line[0:idx] 134 | // addr := line[idx:] 135 | method := "" 136 | 137 | if idx := strings.LastIndex(pkg, string(os.PathSeparator)); idx < 0 { 138 | if idx := strings.Index(pkg, "."); idx > 0 { 139 | method = pkg[idx:] 140 | pkg = pkg[0:idx] 141 | } 142 | } else { 143 | method = pkg[idx+1:] 144 | pkg = pkg[0 : idx+1] 145 | if idx := strings.Index(method, "."); idx > 0 { 146 | pkg += method[0:idx] 147 | method = method[idx:] 148 | } 149 | } 150 | pkgColor := nYellow 151 | methodColor := bGreen 152 | 153 | if num == 0 { 154 | cW(buf, useColor, bRed, " -> ") 155 | pkgColor = bMagenta 156 | methodColor = bRed 157 | } else { 158 | cW(buf, useColor, bWhite, " ") 159 | } 160 | cW(buf, useColor, pkgColor, "%s", pkg) 161 | cW(buf, useColor, methodColor, "%s\n", method) 162 | // cW(buf, useColor, nBlack, "%s", addr) 163 | return buf.String(), nil 164 | } 165 | 166 | func (s prettyStack) decorateSourceLine(line string, useColor bool, num int) (string, error) { 167 | idx := strings.LastIndex(line, ".go:") 168 | if idx < 0 { 169 | return "", errors.New("not a source line") 170 | } 171 | 172 | buf := &bytes.Buffer{} 173 | path := line[0 : idx+3] 174 | lineno := line[idx+3:] 175 | 176 | idx = strings.LastIndex(path, string(os.PathSeparator)) 177 | dir := path[0 : idx+1] 178 | file := path[idx+1:] 179 | 180 | idx = strings.Index(lineno, " ") 181 | if idx > 0 { 182 | lineno = lineno[0:idx] 183 | } 184 | fileColor := bCyan 185 | lineColor := bGreen 186 | 187 | if num == 1 { 188 | cW(buf, useColor, bRed, " -> ") 189 | fileColor = bRed 190 | lineColor = bMagenta 191 | } else { 192 | cW(buf, false, bWhite, " ") 193 | } 194 | cW(buf, useColor, bWhite, "%s", dir) 195 | cW(buf, useColor, fileColor, "%s", file) 196 | cW(buf, useColor, lineColor, "%s", lineno) 197 | if num == 1 { 198 | cW(buf, false, bWhite, "\n") 199 | } 200 | cW(buf, false, bWhite, "\n") 201 | 202 | return buf.String(), nil 203 | } 204 | -------------------------------------------------------------------------------- /middleware/recoverer_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/httptest" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/go-chi/chi/v5" 11 | ) 12 | 13 | func panickingHandler(http.ResponseWriter, *http.Request) { panic("foo") } 14 | 15 | func TestRecoverer(t *testing.T) { 16 | r := chi.NewRouter() 17 | 18 | oldRecovererErrorWriter := recovererErrorWriter 19 | defer func() { recovererErrorWriter = oldRecovererErrorWriter }() 20 | buf := &bytes.Buffer{} 21 | recovererErrorWriter = buf 22 | 23 | r.Use(Recoverer) 24 | r.Get("/", panickingHandler) 25 | 26 | ts := httptest.NewServer(r) 27 | defer ts.Close() 28 | 29 | res, _ := testRequest(t, ts, "GET", "/", nil) 30 | assertEqual(t, res.StatusCode, http.StatusInternalServerError) 31 | 32 | lines := strings.Split(buf.String(), "\n") 33 | for _, line := range lines { 34 | if strings.HasPrefix(strings.TrimSpace(line), "->") { 35 | if !strings.Contains(line, "panickingHandler") { 36 | t.Fatalf("First func call line should refer to panickingHandler, but actual line:\n%v\n", line) 37 | } 38 | return 39 | } 40 | } 41 | t.Fatal("First func call line should start with ->.") 42 | } 43 | 44 | func TestRecovererAbortHandler(t *testing.T) { 45 | defer func() { 46 | rcv := recover() 47 | if rcv != http.ErrAbortHandler { 48 | t.Fatalf("http.ErrAbortHandler should not be recovered") 49 | } 50 | }() 51 | 52 | w := httptest.NewRecorder() 53 | 54 | r := chi.NewRouter() 55 | r.Use(Recoverer) 56 | 57 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 58 | panic(http.ErrAbortHandler) 59 | }) 60 | 61 | req, err := http.NewRequest("GET", "/", nil) 62 | if err != nil { 63 | t.Fatal(err) 64 | } 65 | 66 | r.ServeHTTP(w, req) 67 | } 68 | -------------------------------------------------------------------------------- /middleware/request_id.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | // Ported from Goji's middleware, source: 4 | // https://github.com/zenazn/goji/tree/master/web/middleware 5 | 6 | import ( 7 | "context" 8 | "crypto/rand" 9 | "encoding/base64" 10 | "fmt" 11 | "net/http" 12 | "os" 13 | "strings" 14 | "sync/atomic" 15 | ) 16 | 17 | // Key to use when setting the request ID. 18 | type ctxKeyRequestID int 19 | 20 | // RequestIDKey is the key that holds the unique request ID in a request context. 21 | const RequestIDKey ctxKeyRequestID = 0 22 | 23 | // RequestIDHeader is the name of the HTTP Header which contains the request id. 24 | // Exported so that it can be changed by developers 25 | var RequestIDHeader = "X-Request-Id" 26 | 27 | var prefix string 28 | var reqid uint64 29 | 30 | // A quick note on the statistics here: we're trying to calculate the chance that 31 | // two randomly generated base62 prefixes will collide. We use the formula from 32 | // http://en.wikipedia.org/wiki/Birthday_problem 33 | // 34 | // P[m, n] \approx 1 - e^{-m^2/2n} 35 | // 36 | // We ballpark an upper bound for $m$ by imagining (for whatever reason) a server 37 | // that restarts every second over 10 years, for $m = 86400 * 365 * 10 = 315360000$ 38 | // 39 | // For a $k$ character base-62 identifier, we have $n(k) = 62^k$ 40 | // 41 | // Plugging this in, we find $P[m, n(10)] \approx 5.75%$, which is good enough for 42 | // our purposes, and is surely more than anyone would ever need in practice -- a 43 | // process that is rebooted a handful of times a day for a hundred years has less 44 | // than a millionth of a percent chance of generating two colliding IDs. 45 | 46 | func init() { 47 | hostname, err := os.Hostname() 48 | if hostname == "" || err != nil { 49 | hostname = "localhost" 50 | } 51 | var buf [12]byte 52 | var b64 string 53 | for len(b64) < 10 { 54 | rand.Read(buf[:]) 55 | b64 = base64.StdEncoding.EncodeToString(buf[:]) 56 | b64 = strings.NewReplacer("+", "", "/", "").Replace(b64) 57 | } 58 | 59 | prefix = fmt.Sprintf("%s/%s", hostname, b64[0:10]) 60 | } 61 | 62 | // RequestID is a middleware that injects a request ID into the context of each 63 | // request. A request ID is a string of the form "host.example.com/random-0001", 64 | // where "random" is a base62 random string that uniquely identifies this go 65 | // process, and where the last number is an atomically incremented request 66 | // counter. 67 | func RequestID(next http.Handler) http.Handler { 68 | fn := func(w http.ResponseWriter, r *http.Request) { 69 | ctx := r.Context() 70 | requestID := r.Header.Get(RequestIDHeader) 71 | if requestID == "" { 72 | myid := atomic.AddUint64(&reqid, 1) 73 | requestID = fmt.Sprintf("%s-%06d", prefix, myid) 74 | } 75 | ctx = context.WithValue(ctx, RequestIDKey, requestID) 76 | next.ServeHTTP(w, r.WithContext(ctx)) 77 | } 78 | return http.HandlerFunc(fn) 79 | } 80 | 81 | // GetReqID returns a request ID from the given context if one is present. 82 | // Returns the empty string if a request ID cannot be found. 83 | func GetReqID(ctx context.Context) string { 84 | if ctx == nil { 85 | return "" 86 | } 87 | if reqID, ok := ctx.Value(RequestIDKey).(string); ok { 88 | return reqID 89 | } 90 | return "" 91 | } 92 | 93 | // NextRequestID generates the next request ID in the sequence. 94 | func NextRequestID() uint64 { 95 | return atomic.AddUint64(&reqid, 1) 96 | } 97 | -------------------------------------------------------------------------------- /middleware/request_id_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/go-chi/chi/v5" 10 | ) 11 | 12 | func maintainDefaultRequestID() func() { 13 | original := RequestIDHeader 14 | 15 | return func() { 16 | RequestIDHeader = original 17 | } 18 | } 19 | 20 | func TestRequestID(t *testing.T) { 21 | tests := map[string]struct { 22 | requestIDHeader string 23 | request func() *http.Request 24 | expectedResponse string 25 | }{ 26 | "Retrieves Request Id from default header": { 27 | "X-Request-Id", 28 | func() *http.Request { 29 | req, _ := http.NewRequest("GET", "/", nil) 30 | req.Header.Add("X-Request-Id", "req-123456") 31 | 32 | return req 33 | }, 34 | "RequestID: req-123456", 35 | }, 36 | "Retrieves Request Id from custom header": { 37 | "X-Trace-Id", 38 | func() *http.Request { 39 | req, _ := http.NewRequest("GET", "/", nil) 40 | req.Header.Add("X-Trace-Id", "trace:abc123") 41 | 42 | return req 43 | }, 44 | "RequestID: trace:abc123", 45 | }, 46 | } 47 | 48 | defer maintainDefaultRequestID()() 49 | 50 | for _, test := range tests { 51 | w := httptest.NewRecorder() 52 | 53 | r := chi.NewRouter() 54 | 55 | RequestIDHeader = test.requestIDHeader 56 | 57 | r.Use(RequestID) 58 | 59 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 60 | requestID := GetReqID(r.Context()) 61 | response := fmt.Sprintf("RequestID: %s", requestID) 62 | 63 | w.Write([]byte(response)) 64 | }) 65 | r.ServeHTTP(w, test.request()) 66 | 67 | if w.Body.String() != test.expectedResponse { 68 | t.Fatalf("RequestID was not the expected value") 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /middleware/request_size.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | // RequestSize is a middleware that will limit request sizes to a specified 8 | // number of bytes. It uses MaxBytesReader to do so. 9 | func RequestSize(bytes int64) func(http.Handler) http.Handler { 10 | f := func(h http.Handler) http.Handler { 11 | fn := func(w http.ResponseWriter, r *http.Request) { 12 | r.Body = http.MaxBytesReader(w, r.Body, bytes) 13 | h.ServeHTTP(w, r) 14 | } 15 | return http.HandlerFunc(fn) 16 | } 17 | return f 18 | } 19 | -------------------------------------------------------------------------------- /middleware/route_headers.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | ) 7 | 8 | // RouteHeaders is a neat little header-based router that allows you to direct 9 | // the flow of a request through a middleware stack based on a request header. 10 | // 11 | // For example, lets say you'd like to setup multiple routers depending on the 12 | // request Host header, you could then do something as so: 13 | // 14 | // r := chi.NewRouter() 15 | // rSubdomain := chi.NewRouter() 16 | // r.Use(middleware.RouteHeaders(). 17 | // Route("Host", "example.com", middleware.New(r)). 18 | // Route("Host", "*.example.com", middleware.New(rSubdomain)). 19 | // Handler) 20 | // r.Get("/", h) 21 | // rSubdomain.Get("/", h2) 22 | // 23 | // Another example, imagine you want to setup multiple CORS handlers, where for 24 | // your origin servers you allow authorized requests, but for third-party public 25 | // requests, authorization is disabled. 26 | // 27 | // r := chi.NewRouter() 28 | // r.Use(middleware.RouteHeaders(). 29 | // Route("Origin", "https://app.skyweaver.net", cors.Handler(cors.Options{ 30 | // AllowedOrigins: []string{"https://api.skyweaver.net"}, 31 | // AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, 32 | // AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"}, 33 | // AllowCredentials: true, // <----------<<< allow credentials 34 | // })). 35 | // Route("Origin", "*", cors.Handler(cors.Options{ 36 | // AllowedOrigins: []string{"*"}, 37 | // AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, 38 | // AllowedHeaders: []string{"Accept", "Content-Type"}, 39 | // AllowCredentials: false, // <----------<<< do not allow credentials 40 | // })). 41 | // Handler) 42 | func RouteHeaders() HeaderRouter { 43 | return HeaderRouter{} 44 | } 45 | 46 | type HeaderRouter map[string][]HeaderRoute 47 | 48 | func (hr HeaderRouter) Route(header, match string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter { 49 | header = strings.ToLower(header) 50 | k := hr[header] 51 | if k == nil { 52 | hr[header] = []HeaderRoute{} 53 | } 54 | hr[header] = append(hr[header], HeaderRoute{MatchOne: NewPattern(match), Middleware: middlewareHandler}) 55 | return hr 56 | } 57 | 58 | func (hr HeaderRouter) RouteAny(header string, match []string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter { 59 | header = strings.ToLower(header) 60 | k := hr[header] 61 | if k == nil { 62 | hr[header] = []HeaderRoute{} 63 | } 64 | patterns := []Pattern{} 65 | for _, m := range match { 66 | patterns = append(patterns, NewPattern(m)) 67 | } 68 | hr[header] = append(hr[header], HeaderRoute{MatchAny: patterns, Middleware: middlewareHandler}) 69 | return hr 70 | } 71 | 72 | func (hr HeaderRouter) RouteDefault(handler func(next http.Handler) http.Handler) HeaderRouter { 73 | hr["*"] = []HeaderRoute{{Middleware: handler}} 74 | return hr 75 | } 76 | 77 | func (hr HeaderRouter) Handler(next http.Handler) http.Handler { 78 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 79 | if len(hr) == 0 { 80 | // skip if no routes set 81 | next.ServeHTTP(w, r) 82 | } 83 | 84 | // find first matching header route, and continue 85 | for header, matchers := range hr { 86 | headerValue := r.Header.Get(header) 87 | if headerValue == "" { 88 | continue 89 | } 90 | headerValue = strings.ToLower(headerValue) 91 | for _, matcher := range matchers { 92 | if matcher.IsMatch(headerValue) { 93 | matcher.Middleware(next).ServeHTTP(w, r) 94 | return 95 | } 96 | } 97 | } 98 | 99 | // if no match, check for "*" default route 100 | matcher, ok := hr["*"] 101 | if !ok || matcher[0].Middleware == nil { 102 | next.ServeHTTP(w, r) 103 | return 104 | } 105 | matcher[0].Middleware(next).ServeHTTP(w, r) 106 | }) 107 | } 108 | 109 | type HeaderRoute struct { 110 | Middleware func(next http.Handler) http.Handler 111 | MatchOne Pattern 112 | MatchAny []Pattern 113 | } 114 | 115 | func (r HeaderRoute) IsMatch(value string) bool { 116 | if len(r.MatchAny) > 0 { 117 | for _, m := range r.MatchAny { 118 | if m.Match(value) { 119 | return true 120 | } 121 | } 122 | } else if r.MatchOne.Match(value) { 123 | return true 124 | } 125 | return false 126 | } 127 | 128 | type Pattern struct { 129 | prefix string 130 | suffix string 131 | wildcard bool 132 | } 133 | 134 | func NewPattern(value string) Pattern { 135 | p := Pattern{} 136 | p.prefix, p.suffix, p.wildcard = strings.Cut(value, "*") 137 | return p 138 | } 139 | 140 | func (p Pattern) Match(v string) bool { 141 | if !p.wildcard { 142 | return p.prefix == v 143 | } 144 | return len(v) >= len(p.prefix+p.suffix) && strings.HasPrefix(v, p.prefix) && strings.HasSuffix(v, p.suffix) 145 | } 146 | -------------------------------------------------------------------------------- /middleware/strip.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/go-chi/chi/v5" 8 | ) 9 | 10 | // StripSlashes is a middleware that will match request paths with a trailing 11 | // slash, strip it from the path and continue routing through the mux, if a route 12 | // matches, then it will serve the handler. 13 | func StripSlashes(next http.Handler) http.Handler { 14 | fn := func(w http.ResponseWriter, r *http.Request) { 15 | var path string 16 | rctx := chi.RouteContext(r.Context()) 17 | if rctx != nil && rctx.RoutePath != "" { 18 | path = rctx.RoutePath 19 | } else { 20 | path = r.URL.Path 21 | } 22 | if len(path) > 1 && path[len(path)-1] == '/' { 23 | newPath := path[:len(path)-1] 24 | if rctx == nil { 25 | r.URL.Path = newPath 26 | } else { 27 | rctx.RoutePath = newPath 28 | } 29 | } 30 | next.ServeHTTP(w, r) 31 | } 32 | return http.HandlerFunc(fn) 33 | } 34 | 35 | // RedirectSlashes is a middleware that will match request paths with a trailing 36 | // slash and redirect to the same path, less the trailing slash. 37 | // 38 | // NOTE: RedirectSlashes middleware is *incompatible* with http.FileServer, 39 | // see https://github.com/go-chi/chi/issues/343 40 | func RedirectSlashes(next http.Handler) http.Handler { 41 | fn := func(w http.ResponseWriter, r *http.Request) { 42 | var path string 43 | rctx := chi.RouteContext(r.Context()) 44 | if rctx != nil && rctx.RoutePath != "" { 45 | path = rctx.RoutePath 46 | } else { 47 | path = r.URL.Path 48 | } 49 | if len(path) > 1 && path[len(path)-1] == '/' { 50 | if r.URL.RawQuery != "" { 51 | path = fmt.Sprintf("%s?%s", path[:len(path)-1], r.URL.RawQuery) 52 | } else { 53 | path = path[:len(path)-1] 54 | } 55 | redirectURL := fmt.Sprintf("//%s%s", r.Host, path) 56 | http.Redirect(w, r, redirectURL, 301) 57 | return 58 | } 59 | next.ServeHTTP(w, r) 60 | } 61 | return http.HandlerFunc(fn) 62 | } 63 | 64 | // StripPrefix is a middleware that will strip the provided prefix from the 65 | // request path before handing the request over to the next handler. 66 | func StripPrefix(prefix string) func(http.Handler) http.Handler { 67 | return func(next http.Handler) http.Handler { 68 | return http.StripPrefix(prefix, next) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /middleware/strip_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "net/url" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/go-chi/chi/v5" 11 | ) 12 | 13 | func TestStripSlashes(t *testing.T) { 14 | r := chi.NewRouter() 15 | 16 | // This middleware must be mounted at the top level of the router, not at the end-handler 17 | // because then it'll be too late and will end up in a 404 18 | r.Use(StripSlashes) 19 | 20 | r.NotFound(func(w http.ResponseWriter, r *http.Request) { 21 | w.WriteHeader(404) 22 | w.Write([]byte("nothing here")) 23 | }) 24 | 25 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 26 | w.Write([]byte("root")) 27 | }) 28 | 29 | r.Route("/accounts/{accountID}", func(r chi.Router) { 30 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 31 | accountID := chi.URLParam(r, "accountID") 32 | w.Write([]byte(accountID)) 33 | }) 34 | }) 35 | 36 | ts := httptest.NewServer(r) 37 | defer ts.Close() 38 | 39 | if _, resp := testRequest(t, ts, "GET", "/", nil); resp != "root" { 40 | t.Fatal(resp) 41 | } 42 | if _, resp := testRequest(t, ts, "GET", "//", nil); resp != "root" { 43 | t.Fatal(resp) 44 | } 45 | if _, resp := testRequest(t, ts, "GET", "/accounts/admin", nil); resp != "admin" { 46 | t.Fatal(resp) 47 | } 48 | if _, resp := testRequest(t, ts, "GET", "/accounts/admin/", nil); resp != "admin" { 49 | t.Fatal(resp) 50 | } 51 | if _, resp := testRequest(t, ts, "GET", "/nothing-here", nil); resp != "nothing here" { 52 | t.Fatal(resp) 53 | } 54 | } 55 | 56 | func TestStripSlashesInRoute(t *testing.T) { 57 | r := chi.NewRouter() 58 | 59 | r.NotFound(func(w http.ResponseWriter, r *http.Request) { 60 | w.WriteHeader(404) 61 | w.Write([]byte("nothing here")) 62 | }) 63 | 64 | r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { 65 | w.Write([]byte("hi")) 66 | }) 67 | 68 | r.Route("/accounts/{accountID}", func(r chi.Router) { 69 | r.Use(StripSlashes) 70 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 71 | w.Write([]byte("accounts index")) 72 | }) 73 | r.Get("/query", func(w http.ResponseWriter, r *http.Request) { 74 | accountID := chi.URLParam(r, "accountID") 75 | w.Write([]byte(accountID)) 76 | }) 77 | }) 78 | 79 | ts := httptest.NewServer(r) 80 | defer ts.Close() 81 | 82 | if _, resp := testRequest(t, ts, "GET", "/hi", nil); resp != "hi" { 83 | t.Fatal(resp) 84 | } 85 | if _, resp := testRequest(t, ts, "GET", "/hi/", nil); resp != "nothing here" { 86 | t.Fatal(resp) 87 | } 88 | if _, resp := testRequest(t, ts, "GET", "/accounts/admin", nil); resp != "accounts index" { 89 | t.Fatal(resp) 90 | } 91 | if _, resp := testRequest(t, ts, "GET", "/accounts/admin/", nil); resp != "accounts index" { 92 | t.Fatal(resp) 93 | } 94 | if _, resp := testRequest(t, ts, "GET", "/accounts/admin/query", nil); resp != "admin" { 95 | t.Fatal(resp) 96 | } 97 | if _, resp := testRequest(t, ts, "GET", "/accounts/admin/query/", nil); resp != "admin" { 98 | t.Fatal(resp) 99 | } 100 | } 101 | 102 | func TestRedirectSlashes(t *testing.T) { 103 | r := chi.NewRouter() 104 | 105 | // This middleware must be mounted at the top level of the router, not at the end-handler 106 | // because then it'll be too late and will end up in a 404 107 | r.Use(RedirectSlashes) 108 | 109 | r.NotFound(func(w http.ResponseWriter, r *http.Request) { 110 | w.WriteHeader(404) 111 | w.Write([]byte("nothing here")) 112 | }) 113 | 114 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 115 | w.Write([]byte("root")) 116 | }) 117 | 118 | r.Route("/accounts/{accountID}", func(r chi.Router) { 119 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 120 | accountID := chi.URLParam(r, "accountID") 121 | w.Write([]byte(accountID)) 122 | }) 123 | }) 124 | 125 | ts := httptest.NewServer(r) 126 | defer ts.Close() 127 | 128 | if resp, body := testRequest(t, ts, "GET", "/", nil); body != "root" || resp.StatusCode != 200 { 129 | t.Fatal(body, resp.StatusCode) 130 | } 131 | 132 | // NOTE: the testRequest client will follow the redirection.. 133 | if resp, body := testRequest(t, ts, "GET", "//", nil); body != "root" || resp.StatusCode != 200 { 134 | t.Fatal(body, resp.StatusCode) 135 | } 136 | 137 | if resp, body := testRequest(t, ts, "GET", "/accounts/admin", nil); body != "admin" || resp.StatusCode != 200 { 138 | t.Fatal(body, resp.StatusCode) 139 | } 140 | 141 | // NOTE: the testRequest client will follow the redirection.. 142 | if resp, body := testRequest(t, ts, "GET", "/accounts/admin/", nil); body != "admin" || resp.StatusCode != 200 { 143 | t.Fatal(body, resp.StatusCode) 144 | } 145 | 146 | if resp, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" || resp.StatusCode != 404 { 147 | t.Fatal(body, resp.StatusCode) 148 | } 149 | 150 | // Ensure redirect Location url is correct 151 | { 152 | resp, body := testRequestNoRedirect(t, ts, "GET", "/accounts/someuser/", nil) 153 | if resp.StatusCode != 301 { 154 | t.Fatal(body, resp.StatusCode) 155 | } 156 | location := resp.Header.Get("Location") 157 | if !strings.HasPrefix(location, "//") || !strings.HasSuffix(location, "/accounts/someuser") { 158 | t.Fatalf("invalid redirection, should be /accounts/someuser") 159 | } 160 | } 161 | 162 | // Ensure query params are kept in tact upon redirecting a slash 163 | { 164 | resp, body := testRequestNoRedirect(t, ts, "GET", "/accounts/someuser/?a=1&b=2", nil) 165 | if resp.StatusCode != 301 { 166 | t.Fatal(body, resp.StatusCode) 167 | } 168 | location := resp.Header.Get("Location") 169 | if !strings.HasPrefix(location, "//") || !strings.HasSuffix(location, "/accounts/someuser?a=1&b=2") { 170 | t.Fatalf("invalid redirection, should be /accounts/someuser?a=1&b=2") 171 | } 172 | } 173 | 174 | // Ensure that we don't redirect to 'evil.com', but rather to 'server.url/evil.com/' 175 | { 176 | paths := []string{"//evil.com/", "///evil.com/"} 177 | 178 | for _, p := range paths { 179 | resp, body := testRequest(t, ts, "GET", p, nil) 180 | if u, err := url.Parse(ts.URL); err != nil && resp.Request.URL.Host != u.Host { 181 | t.Fatalf("host should remain the same. got: %q, want: %q", resp.Request.URL.Host, ts.URL) 182 | } 183 | if body != "nothing here" || resp.StatusCode != 404 { 184 | t.Fatal(body, resp.StatusCode) 185 | } 186 | } 187 | } 188 | 189 | // Ensure that we don't redirect to 'evil.com', but rather to 'server.url/evil.com/' 190 | { 191 | resp, body := testRequest(t, ts, "GET", "//evil.com/", nil) 192 | if u, err := url.Parse(ts.URL); err != nil && resp.Request.URL.Host != u.Host { 193 | t.Fatalf("host should remain the same. got: %q, want: %q", resp.Request.URL.Host, ts.URL) 194 | } 195 | if body != "nothing here" || resp.StatusCode != 404 { 196 | t.Fatal(body, resp.StatusCode) 197 | } 198 | } 199 | } 200 | 201 | // This tests a http.Handler that is not chi.Router 202 | // In these cases, the routeContext is nil 203 | func TestStripSlashesWithNilContext(t *testing.T) { 204 | r := http.NewServeMux() 205 | 206 | r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 207 | w.Write([]byte("root")) 208 | }) 209 | 210 | r.HandleFunc("/accounts", func(w http.ResponseWriter, r *http.Request) { 211 | w.Write([]byte("accounts")) 212 | }) 213 | 214 | r.HandleFunc("/accounts/admin", func(w http.ResponseWriter, r *http.Request) { 215 | w.Write([]byte("admin")) 216 | }) 217 | 218 | ts := httptest.NewServer(StripSlashes(r)) 219 | defer ts.Close() 220 | 221 | if _, resp := testRequest(t, ts, "GET", "/", nil); resp != "root" { 222 | t.Fatal(resp) 223 | } 224 | if _, resp := testRequest(t, ts, "GET", "//", nil); resp != "root" { 225 | t.Fatal(resp) 226 | } 227 | if _, resp := testRequest(t, ts, "GET", "/accounts", nil); resp != "accounts" { 228 | t.Fatal(resp) 229 | } 230 | if _, resp := testRequest(t, ts, "GET", "/accounts/", nil); resp != "accounts" { 231 | t.Fatal(resp) 232 | } 233 | if _, resp := testRequest(t, ts, "GET", "/accounts/admin", nil); resp != "admin" { 234 | t.Fatal(resp) 235 | } 236 | if _, resp := testRequest(t, ts, "GET", "/accounts/admin/", nil); resp != "admin" { 237 | t.Fatal(resp) 238 | } 239 | } 240 | 241 | func TestStripPrefix(t *testing.T) { 242 | r := chi.NewRouter() 243 | 244 | r.Use(StripPrefix("/api")) 245 | 246 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 247 | w.Write([]byte("api root")) 248 | }) 249 | 250 | r.Get("/accounts", func(w http.ResponseWriter, r *http.Request) { 251 | w.Write([]byte("api accounts")) 252 | }) 253 | 254 | r.Get("/accounts/{accountID}", func(w http.ResponseWriter, r *http.Request) { 255 | accountID := chi.URLParam(r, "accountID") 256 | w.Write([]byte(accountID)) 257 | }) 258 | 259 | ts := httptest.NewServer(r) 260 | defer ts.Close() 261 | 262 | if _, resp := testRequest(t, ts, "GET", "/api/", nil); resp != "api root" { 263 | t.Fatalf("got: %q, want: %q", resp, "api root") 264 | } 265 | if _, resp := testRequest(t, ts, "GET", "/api/accounts", nil); resp != "api accounts" { 266 | t.Fatalf("got: %q, want: %q", resp, "api accounts") 267 | } 268 | if _, resp := testRequest(t, ts, "GET", "/api/accounts/admin", nil); resp != "admin" { 269 | t.Fatalf("got: %q, want: %q", resp, "admin") 270 | } 271 | if _, resp := testRequest(t, ts, "GET", "/api-nope/", nil); resp != "404 page not found\n" { 272 | t.Fatalf("got: %q, want: %q", resp, "404 page not found\n") 273 | } 274 | } 275 | -------------------------------------------------------------------------------- /middleware/sunset.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | ) 7 | 8 | // Sunset set Deprecation/Sunset header to response 9 | // This can be used to enable Sunset in a route or a route group 10 | // For more: https://www.rfc-editor.org/rfc/rfc8594.html 11 | func Sunset(sunsetAt time.Time, links ...string) func(http.Handler) http.Handler { 12 | return func(next http.Handler) http.Handler { 13 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 14 | if !sunsetAt.IsZero() { 15 | w.Header().Set("Sunset", sunsetAt.Format(http.TimeFormat)) 16 | w.Header().Set("Deprecation", sunsetAt.Format(http.TimeFormat)) 17 | 18 | for _, link := range links { 19 | w.Header().Add("Link", link) 20 | } 21 | } 22 | next.ServeHTTP(w, r) 23 | }) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /middleware/sunset_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | "time" 8 | 9 | "github.com/go-chi/chi/v5" 10 | ) 11 | 12 | func TestSunset(t *testing.T) { 13 | 14 | t.Run("Sunset without link", func(t *testing.T) { 15 | req, _ := http.NewRequest("GET", "/", nil) 16 | w := httptest.NewRecorder() 17 | 18 | r := chi.NewRouter() 19 | 20 | sunsetAt := time.Date(2025, 12, 24, 10, 20, 0, 0, time.UTC) 21 | r.Use(Sunset(sunsetAt)) 22 | 23 | var sunset, deprecation string 24 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 25 | clonedHeader := w.Header().Clone() 26 | sunset = clonedHeader.Get("Sunset") 27 | deprecation = clonedHeader.Get("Deprecation") 28 | w.Write([]byte("I'll be unavailable soon")) 29 | }) 30 | r.ServeHTTP(w, req) 31 | 32 | if w.Code != 200 { 33 | t.Fatal("Response Code should be 200") 34 | } 35 | 36 | if sunset != "Wed, 24 Dec 2025 10:20:00 GMT" { 37 | t.Fatal("Test get sunset error.", sunset) 38 | } 39 | 40 | if deprecation != "Wed, 24 Dec 2025 10:20:00 GMT" { 41 | t.Fatal("Test get deprecation error.") 42 | } 43 | }) 44 | 45 | t.Run("Sunset with link", func(t *testing.T) { 46 | req, _ := http.NewRequest("GET", "/", nil) 47 | w := httptest.NewRecorder() 48 | 49 | r := chi.NewRouter() 50 | 51 | sunsetAt := time.Date(2025, 12, 24, 10, 20, 0, 0, time.UTC) 52 | deprecationLink := "https://example.com/v1/deprecation-details" 53 | r.Use(Sunset(sunsetAt, deprecationLink)) 54 | 55 | var sunset, deprecation, link string 56 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 57 | clonedHeader := w.Header().Clone() 58 | sunset = clonedHeader.Get("Sunset") 59 | deprecation = clonedHeader.Get("Deprecation") 60 | link = clonedHeader.Get("Link") 61 | 62 | w.Write([]byte("I'll be unavailable soon")) 63 | }) 64 | 65 | r.ServeHTTP(w, req) 66 | 67 | if w.Code != 200 { 68 | t.Fatal("Response Code should be 200") 69 | } 70 | 71 | if sunset != "Wed, 24 Dec 2025 10:20:00 GMT" { 72 | t.Fatal("Test get sunset error.", sunset) 73 | } 74 | 75 | if deprecation != "Wed, 24 Dec 2025 10:20:00 GMT" { 76 | t.Fatal("Test get deprecation error.") 77 | } 78 | 79 | if link != deprecationLink { 80 | t.Fatal("Test get deprecation link error.") 81 | } 82 | }) 83 | 84 | } 85 | 86 | /** 87 | EXAMPLE USAGES 88 | func main() { 89 | r := chi.NewRouter() 90 | 91 | sunsetAt := time.Date(2025, 12, 24, 10, 20, 0, 0, time.UTC) 92 | r.Use(middleware.Sunset(sunsetAt)) 93 | 94 | // can provide additional link for updated resource 95 | // r.Use(middleware.Sunset(sunsetAt, "https://example.com/v1/deprecation-details")) 96 | 97 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 98 | w.Write([]byte("This endpoint will be removed soon")) 99 | }) 100 | 101 | log.Println("Listening on port: 3000") 102 | http.ListenAndServe(":3000", r) 103 | } 104 | **/ 105 | -------------------------------------------------------------------------------- /middleware/supress_notfound.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/go-chi/chi/v5" 7 | ) 8 | 9 | // SupressNotFound will quickly respond with a 404 if the route is not found 10 | // and will not continue to the next middleware handler. 11 | // 12 | // This is handy to put at the top of your middleware stack to avoid unnecessary 13 | // processing of requests that are not going to match any routes anyway. For 14 | // example its super annoying to see a bunch of 404's in your logs from bots. 15 | func SupressNotFound(router *chi.Mux) func(next http.Handler) http.Handler { 16 | return func(next http.Handler) http.Handler { 17 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 18 | rctx := chi.RouteContext(r.Context()) 19 | match := rctx.Routes.Match(rctx, r.Method, r.URL.Path) 20 | if !match { 21 | router.NotFoundHandler().ServeHTTP(w, r) 22 | return 23 | } 24 | next.ServeHTTP(w, r) 25 | }) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /middleware/terminal.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | // Ported from Goji's middleware, source: 4 | // https://github.com/zenazn/goji/tree/master/web/middleware 5 | 6 | import ( 7 | "fmt" 8 | "io" 9 | "os" 10 | ) 11 | 12 | var ( 13 | // Normal colors 14 | nBlack = []byte{'\033', '[', '3', '0', 'm'} 15 | nRed = []byte{'\033', '[', '3', '1', 'm'} 16 | nGreen = []byte{'\033', '[', '3', '2', 'm'} 17 | nYellow = []byte{'\033', '[', '3', '3', 'm'} 18 | nBlue = []byte{'\033', '[', '3', '4', 'm'} 19 | nMagenta = []byte{'\033', '[', '3', '5', 'm'} 20 | nCyan = []byte{'\033', '[', '3', '6', 'm'} 21 | nWhite = []byte{'\033', '[', '3', '7', 'm'} 22 | // Bright colors 23 | bBlack = []byte{'\033', '[', '3', '0', ';', '1', 'm'} 24 | bRed = []byte{'\033', '[', '3', '1', ';', '1', 'm'} 25 | bGreen = []byte{'\033', '[', '3', '2', ';', '1', 'm'} 26 | bYellow = []byte{'\033', '[', '3', '3', ';', '1', 'm'} 27 | bBlue = []byte{'\033', '[', '3', '4', ';', '1', 'm'} 28 | bMagenta = []byte{'\033', '[', '3', '5', ';', '1', 'm'} 29 | bCyan = []byte{'\033', '[', '3', '6', ';', '1', 'm'} 30 | bWhite = []byte{'\033', '[', '3', '7', ';', '1', 'm'} 31 | 32 | reset = []byte{'\033', '[', '0', 'm'} 33 | ) 34 | 35 | var IsTTY bool 36 | 37 | func init() { 38 | // This is sort of cheating: if stdout is a character device, we assume 39 | // that means it's a TTY. Unfortunately, there are many non-TTY 40 | // character devices, but fortunately stdout is rarely set to any of 41 | // them. 42 | // 43 | // We could solve this properly by pulling in a dependency on 44 | // code.google.com/p/go.crypto/ssh/terminal, for instance, but as a 45 | // heuristic for whether to print in color or in black-and-white, I'd 46 | // really rather not. 47 | fi, err := os.Stdout.Stat() 48 | if err == nil { 49 | m := os.ModeDevice | os.ModeCharDevice 50 | IsTTY = fi.Mode()&m == m 51 | } 52 | } 53 | 54 | // colorWrite 55 | func cW(w io.Writer, useColor bool, color []byte, s string, args ...interface{}) { 56 | if IsTTY && useColor { 57 | w.Write(color) 58 | } 59 | fmt.Fprintf(w, s, args...) 60 | if IsTTY && useColor { 61 | w.Write(reset) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /middleware/throttle.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "strconv" 6 | "time" 7 | ) 8 | 9 | const ( 10 | errCapacityExceeded = "Server capacity exceeded." 11 | errTimedOut = "Timed out while waiting for a pending request to complete." 12 | errContextCanceled = "Context was canceled." 13 | ) 14 | 15 | var ( 16 | defaultBacklogTimeout = time.Second * 60 17 | ) 18 | 19 | // ThrottleOpts represents a set of throttling options. 20 | type ThrottleOpts struct { 21 | RetryAfterFn func(ctxDone bool) time.Duration 22 | Limit int 23 | BacklogLimit int 24 | BacklogTimeout time.Duration 25 | StatusCode int 26 | } 27 | 28 | // Throttle is a middleware that limits number of currently processed requests 29 | // at a time across all users. Note: Throttle is not a rate-limiter per user, 30 | // instead it just puts a ceiling on the number of current in-flight requests 31 | // being processed from the point from where the Throttle middleware is mounted. 32 | func Throttle(limit int) func(http.Handler) http.Handler { 33 | return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogTimeout: defaultBacklogTimeout}) 34 | } 35 | 36 | // ThrottleBacklog is a middleware that limits number of currently processed 37 | // requests at a time and provides a backlog for holding a finite number of 38 | // pending requests. 39 | func ThrottleBacklog(limit, backlogLimit int, backlogTimeout time.Duration) func(http.Handler) http.Handler { 40 | return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogLimit: backlogLimit, BacklogTimeout: backlogTimeout}) 41 | } 42 | 43 | // ThrottleWithOpts is a middleware that limits number of currently processed requests using passed ThrottleOpts. 44 | func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler { 45 | if opts.Limit < 1 { 46 | panic("chi/middleware: Throttle expects limit > 0") 47 | } 48 | 49 | if opts.BacklogLimit < 0 { 50 | panic("chi/middleware: Throttle expects backlogLimit to be positive") 51 | } 52 | 53 | statusCode := opts.StatusCode 54 | if statusCode == 0 { 55 | statusCode = http.StatusTooManyRequests 56 | } 57 | 58 | t := throttler{ 59 | tokens: make(chan token, opts.Limit), 60 | backlogTokens: make(chan token, opts.Limit+opts.BacklogLimit), 61 | backlogTimeout: opts.BacklogTimeout, 62 | statusCode: statusCode, 63 | retryAfterFn: opts.RetryAfterFn, 64 | } 65 | 66 | // Filling tokens. 67 | for i := 0; i < opts.Limit+opts.BacklogLimit; i++ { 68 | if i < opts.Limit { 69 | t.tokens <- token{} 70 | } 71 | t.backlogTokens <- token{} 72 | } 73 | 74 | return func(next http.Handler) http.Handler { 75 | fn := func(w http.ResponseWriter, r *http.Request) { 76 | ctx := r.Context() 77 | 78 | select { 79 | 80 | case <-ctx.Done(): 81 | t.setRetryAfterHeaderIfNeeded(w, true) 82 | http.Error(w, errContextCanceled, t.statusCode) 83 | return 84 | 85 | case btok := <-t.backlogTokens: 86 | timer := time.NewTimer(t.backlogTimeout) 87 | 88 | defer func() { 89 | t.backlogTokens <- btok 90 | }() 91 | 92 | select { 93 | case <-timer.C: 94 | t.setRetryAfterHeaderIfNeeded(w, false) 95 | http.Error(w, errTimedOut, t.statusCode) 96 | return 97 | case <-ctx.Done(): 98 | timer.Stop() 99 | t.setRetryAfterHeaderIfNeeded(w, true) 100 | http.Error(w, errContextCanceled, t.statusCode) 101 | return 102 | case tok := <-t.tokens: 103 | defer func() { 104 | timer.Stop() 105 | t.tokens <- tok 106 | }() 107 | next.ServeHTTP(w, r) 108 | } 109 | return 110 | 111 | default: 112 | t.setRetryAfterHeaderIfNeeded(w, false) 113 | http.Error(w, errCapacityExceeded, t.statusCode) 114 | return 115 | } 116 | } 117 | 118 | return http.HandlerFunc(fn) 119 | } 120 | } 121 | 122 | // token represents a request that is being processed. 123 | type token struct{} 124 | 125 | // throttler limits number of currently processed requests at a time. 126 | type throttler struct { 127 | tokens chan token 128 | backlogTokens chan token 129 | retryAfterFn func(ctxDone bool) time.Duration 130 | backlogTimeout time.Duration 131 | statusCode int 132 | } 133 | 134 | // setRetryAfterHeaderIfNeeded sets Retry-After HTTP header if corresponding retryAfterFn option of throttler is initialized. 135 | func (t throttler) setRetryAfterHeaderIfNeeded(w http.ResponseWriter, ctxDone bool) { 136 | if t.retryAfterFn == nil { 137 | return 138 | } 139 | w.Header().Set("Retry-After", strconv.Itoa(int(t.retryAfterFn(ctxDone).Seconds()))) 140 | } 141 | -------------------------------------------------------------------------------- /middleware/throttle_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "io/ioutil" 5 | "net/http" 6 | "net/http/httptest" 7 | "strings" 8 | "sync" 9 | "testing" 10 | "time" 11 | 12 | "github.com/go-chi/chi/v5" 13 | ) 14 | 15 | var testContent = []byte("Hello world!") 16 | 17 | func TestThrottleBacklog(t *testing.T) { 18 | r := chi.NewRouter() 19 | 20 | r.Use(ThrottleBacklog(10, 50, time.Second*10)) 21 | 22 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 23 | w.WriteHeader(http.StatusOK) 24 | time.Sleep(time.Second * 1) // Expensive operation. 25 | w.Write(testContent) 26 | }) 27 | 28 | server := httptest.NewServer(r) 29 | defer server.Close() 30 | 31 | client := http.Client{ 32 | Timeout: time.Second * 5, // Maximum waiting time. 33 | } 34 | 35 | var wg sync.WaitGroup 36 | 37 | // The throttler processes 10 consecutive requests, each one of those 38 | // requests lasts 1s. The maximum number of requests this can possible serve 39 | // before the clients time out (5s) is 40. 40 | for i := 0; i < 40; i++ { 41 | wg.Add(1) 42 | go func(i int) { 43 | defer wg.Done() 44 | 45 | res, err := client.Get(server.URL) 46 | assertNoError(t, err) 47 | 48 | assertEqual(t, http.StatusOK, res.StatusCode) 49 | buf, err := ioutil.ReadAll(res.Body) 50 | assertNoError(t, err) 51 | assertEqual(t, testContent, buf) 52 | }(i) 53 | } 54 | 55 | wg.Wait() 56 | } 57 | 58 | func TestThrottleClientTimeout(t *testing.T) { 59 | r := chi.NewRouter() 60 | 61 | r.Use(ThrottleBacklog(10, 50, time.Second*10)) 62 | 63 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 64 | w.WriteHeader(http.StatusOK) 65 | time.Sleep(time.Second * 5) // Expensive operation. 66 | w.Write(testContent) 67 | }) 68 | 69 | server := httptest.NewServer(r) 70 | defer server.Close() 71 | 72 | client := http.Client{ 73 | Timeout: time.Second * 3, // Maximum waiting time. 74 | } 75 | 76 | var wg sync.WaitGroup 77 | 78 | for i := 0; i < 10; i++ { 79 | wg.Add(1) 80 | go func(i int) { 81 | defer wg.Done() 82 | _, err := client.Get(server.URL) 83 | assertError(t, err) 84 | }(i) 85 | } 86 | 87 | wg.Wait() 88 | } 89 | 90 | func TestThrottleTriggerGatewayTimeout(t *testing.T) { 91 | r := chi.NewRouter() 92 | 93 | r.Use(ThrottleBacklog(50, 100, time.Second*5)) 94 | 95 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 96 | w.WriteHeader(http.StatusOK) 97 | time.Sleep(time.Second * 10) // Expensive operation. 98 | w.Write(testContent) 99 | }) 100 | 101 | server := httptest.NewServer(r) 102 | defer server.Close() 103 | 104 | client := http.Client{ 105 | Timeout: time.Second * 60, // Maximum waiting time. 106 | } 107 | 108 | var wg sync.WaitGroup 109 | 110 | // These requests will be processed normally until they finish. 111 | for i := 0; i < 50; i++ { 112 | wg.Add(1) 113 | go func(i int) { 114 | defer wg.Done() 115 | 116 | res, err := client.Get(server.URL) 117 | assertNoError(t, err) 118 | assertEqual(t, http.StatusOK, res.StatusCode) 119 | }(i) 120 | } 121 | 122 | time.Sleep(time.Second * 1) 123 | 124 | // These requests will wait for the first batch to complete but it will take 125 | // too much time, so they will eventually receive a timeout error. 126 | for i := 0; i < 50; i++ { 127 | wg.Add(1) 128 | go func(i int) { 129 | defer wg.Done() 130 | 131 | res, err := client.Get(server.URL) 132 | assertNoError(t, err) 133 | 134 | buf, err := ioutil.ReadAll(res.Body) 135 | assertNoError(t, err) 136 | assertEqual(t, http.StatusTooManyRequests, res.StatusCode) 137 | assertEqual(t, errTimedOut, strings.TrimSpace(string(buf))) 138 | }(i) 139 | } 140 | 141 | wg.Wait() 142 | } 143 | 144 | func TestThrottleMaximum(t *testing.T) { 145 | r := chi.NewRouter() 146 | 147 | r.Use(ThrottleBacklog(10, 10, time.Second*5)) 148 | 149 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 150 | w.WriteHeader(http.StatusOK) 151 | time.Sleep(time.Second * 3) // Expensive operation. 152 | w.Write(testContent) 153 | }) 154 | 155 | server := httptest.NewServer(r) 156 | defer server.Close() 157 | 158 | client := http.Client{ 159 | Timeout: time.Second * 60, // Maximum waiting time. 160 | } 161 | 162 | var wg sync.WaitGroup 163 | 164 | for i := 0; i < 20; i++ { 165 | wg.Add(1) 166 | go func(i int) { 167 | defer wg.Done() 168 | 169 | res, err := client.Get(server.URL) 170 | assertNoError(t, err) 171 | assertEqual(t, http.StatusOK, res.StatusCode) 172 | 173 | buf, err := ioutil.ReadAll(res.Body) 174 | assertNoError(t, err) 175 | assertEqual(t, testContent, buf) 176 | }(i) 177 | } 178 | 179 | // Wait less time than what the server takes to reply. 180 | time.Sleep(time.Second * 2) 181 | 182 | // At this point the server is still processing, all the following request 183 | // will be beyond the server capacity. 184 | for i := 0; i < 20; i++ { 185 | wg.Add(1) 186 | go func(i int) { 187 | defer wg.Done() 188 | 189 | res, err := client.Get(server.URL) 190 | assertNoError(t, err) 191 | 192 | buf, err := ioutil.ReadAll(res.Body) 193 | assertNoError(t, err) 194 | assertEqual(t, http.StatusTooManyRequests, res.StatusCode) 195 | assertEqual(t, errCapacityExceeded, strings.TrimSpace(string(buf))) 196 | }(i) 197 | } 198 | 199 | wg.Wait() 200 | } 201 | 202 | // NOTE: test is disabled as it requires some refactoring. It is prone to intermittent failure. 203 | /*func TestThrottleRetryAfter(t *testing.T) { 204 | r := chi.NewRouter() 205 | 206 | retryAfterFn := func(ctxDone bool) time.Duration { return time.Hour * 1 } 207 | r.Use(ThrottleWithOpts(ThrottleOpts{Limit: 10, RetryAfterFn: retryAfterFn})) 208 | 209 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 210 | w.WriteHeader(http.StatusOK) 211 | time.Sleep(time.Second * 4) // Expensive operation. 212 | w.Write(testContent) 213 | }) 214 | 215 | server := httptest.NewServer(r) 216 | defer server.Close() 217 | 218 | client := http.Client{ 219 | Timeout: time.Second * 60, // Maximum waiting time. 220 | } 221 | 222 | var wg sync.WaitGroup 223 | 224 | for i := 0; i < 10; i++ { 225 | wg.Add(1) 226 | go func(i int) { 227 | defer wg.Done() 228 | 229 | res, err := client.Get(server.URL) 230 | assertNoError(t, err) 231 | assertEqual(t, http.StatusOK, res.StatusCode) 232 | }(i) 233 | } 234 | 235 | time.Sleep(time.Second * 1) 236 | 237 | for i := 0; i < 10; i++ { 238 | wg.Add(1) 239 | go func(i int) { 240 | defer wg.Done() 241 | 242 | res, err := client.Get(server.URL) 243 | assertNoError(t, err) 244 | assertEqual(t, http.StatusTooManyRequests, res.StatusCode) 245 | assertEqual(t, res.Header.Get("Retry-After"), "3600") 246 | }(i) 247 | } 248 | 249 | wg.Wait() 250 | }*/ 251 | 252 | func TestThrottleCustomStatusCode(t *testing.T) { 253 | const timeout = time.Second * 3 254 | 255 | wait := make(chan struct{}) 256 | 257 | r := chi.NewRouter() 258 | r.Use(ThrottleWithOpts(ThrottleOpts{Limit: 1, StatusCode: http.StatusServiceUnavailable})) 259 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 260 | select { 261 | case <-wait: 262 | case <-time.After(timeout): 263 | } 264 | w.WriteHeader(http.StatusOK) 265 | }) 266 | server := httptest.NewServer(r) 267 | defer server.Close() 268 | 269 | const totalRequestCount = 5 270 | 271 | codes := make(chan int, totalRequestCount) 272 | errs := make(chan error, totalRequestCount) 273 | client := &http.Client{Timeout: timeout} 274 | for i := 0; i < totalRequestCount; i++ { 275 | go func() { 276 | resp, err := client.Get(server.URL) 277 | if err != nil { 278 | errs <- err 279 | return 280 | } 281 | codes <- resp.StatusCode 282 | }() 283 | } 284 | 285 | waitResponse := func(wantCode int) { 286 | select { 287 | case err := <-errs: 288 | t.Fatal(err) 289 | case code := <-codes: 290 | assertEqual(t, wantCode, code) 291 | case <-time.After(timeout): 292 | t.Fatalf("waiting %d code, timeout exceeded", wantCode) 293 | } 294 | } 295 | 296 | for i := 0; i < totalRequestCount-1; i++ { 297 | waitResponse(http.StatusServiceUnavailable) 298 | } 299 | close(wait) // Allow the last request to proceed. 300 | waitResponse(http.StatusOK) 301 | } 302 | -------------------------------------------------------------------------------- /middleware/timeout.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | // Timeout is a middleware that cancels ctx after a given timeout and return 10 | // a 504 Gateway Timeout error to the client. 11 | // 12 | // It's required that you select the ctx.Done() channel to check for the signal 13 | // if the context has reached its deadline and return, otherwise the timeout 14 | // signal will be just ignored. 15 | // 16 | // ie. a route/handler may look like: 17 | // 18 | // r.Get("/long", func(w http.ResponseWriter, r *http.Request) { 19 | // ctx := r.Context() 20 | // processTime := time.Duration(rand.Intn(4)+1) * time.Second 21 | // 22 | // select { 23 | // case <-ctx.Done(): 24 | // return 25 | // 26 | // case <-time.After(processTime): 27 | // // The above channel simulates some hard work. 28 | // } 29 | // 30 | // w.Write([]byte("done")) 31 | // }) 32 | func Timeout(timeout time.Duration) func(next http.Handler) http.Handler { 33 | return func(next http.Handler) http.Handler { 34 | fn := func(w http.ResponseWriter, r *http.Request) { 35 | ctx, cancel := context.WithTimeout(r.Context(), timeout) 36 | defer func() { 37 | cancel() 38 | if ctx.Err() == context.DeadlineExceeded { 39 | w.WriteHeader(http.StatusGatewayTimeout) 40 | } 41 | }() 42 | 43 | r = r.WithContext(ctx) 44 | next.ServeHTTP(w, r) 45 | } 46 | return http.HandlerFunc(fn) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /middleware/url_format.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "strings" 7 | 8 | "github.com/go-chi/chi/v5" 9 | ) 10 | 11 | var ( 12 | // URLFormatCtxKey is the context.Context key to store the URL format data 13 | // for a request. 14 | URLFormatCtxKey = &contextKey{"URLFormat"} 15 | ) 16 | 17 | // URLFormat is a middleware that parses the url extension from a request path and stores it 18 | // on the context as a string under the key `middleware.URLFormatCtxKey`. The middleware will 19 | // trim the suffix from the routing path and continue routing. 20 | // 21 | // Routers should not include a url parameter for the suffix when using this middleware. 22 | // 23 | // Sample usage for url paths `/articles/1`, `/articles/1.json` and `/articles/1.xml`: 24 | // 25 | // func routes() http.Handler { 26 | // r := chi.NewRouter() 27 | // r.Use(middleware.URLFormat) 28 | // 29 | // r.Get("/articles/{id}", ListArticles) 30 | // 31 | // return r 32 | // } 33 | // 34 | // func ListArticles(w http.ResponseWriter, r *http.Request) { 35 | // urlFormat, _ := r.Context().Value(middleware.URLFormatCtxKey).(string) 36 | // 37 | // switch urlFormat { 38 | // case "json": 39 | // render.JSON(w, r, articles) 40 | // case "xml:" 41 | // render.XML(w, r, articles) 42 | // default: 43 | // render.JSON(w, r, articles) 44 | // } 45 | // } 46 | func URLFormat(next http.Handler) http.Handler { 47 | fn := func(w http.ResponseWriter, r *http.Request) { 48 | ctx := r.Context() 49 | 50 | var format string 51 | path := r.URL.Path 52 | 53 | rctx := chi.RouteContext(r.Context()) 54 | if rctx != nil && rctx.RoutePath != "" { 55 | path = rctx.RoutePath 56 | } 57 | 58 | if strings.Index(path, ".") > 0 { 59 | base := strings.LastIndex(path, "/") 60 | idx := strings.LastIndex(path[base:], ".") 61 | 62 | if idx > 0 { 63 | idx += base 64 | format = path[idx+1:] 65 | 66 | rctx.RoutePath = path[:idx] 67 | } 68 | } 69 | 70 | r = r.WithContext(context.WithValue(ctx, URLFormatCtxKey, format)) 71 | 72 | next.ServeHTTP(w, r) 73 | } 74 | return http.HandlerFunc(fn) 75 | } 76 | -------------------------------------------------------------------------------- /middleware/url_format_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/go-chi/chi/v5" 9 | ) 10 | 11 | func TestURLFormat(t *testing.T) { 12 | r := chi.NewRouter() 13 | 14 | r.Use(URLFormat) 15 | 16 | r.NotFound(func(w http.ResponseWriter, r *http.Request) { 17 | w.WriteHeader(404) 18 | w.Write([]byte("nothing here")) 19 | }) 20 | 21 | r.Route("/samples/articles/samples.{articleID}", func(r chi.Router) { 22 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 23 | articleID := chi.URLParam(r, "articleID") 24 | w.Write([]byte(articleID)) 25 | }) 26 | }) 27 | 28 | r.Route("/articles/{articleID}", func(r chi.Router) { 29 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { 30 | articleID := chi.URLParam(r, "articleID") 31 | w.Write([]byte(articleID)) 32 | }) 33 | }) 34 | 35 | ts := httptest.NewServer(r) 36 | defer ts.Close() 37 | 38 | if _, resp := testRequest(t, ts, "GET", "/articles/1.json", nil); resp != "1" { 39 | t.Fatal(resp) 40 | } 41 | if _, resp := testRequest(t, ts, "GET", "/articles/1.xml", nil); resp != "1" { 42 | t.Fatal(resp) 43 | } 44 | if _, resp := testRequest(t, ts, "GET", "/samples/articles/samples.1.json", nil); resp != "1" { 45 | t.Fatal(resp) 46 | } 47 | if _, resp := testRequest(t, ts, "GET", "/samples/articles/samples.1.xml", nil); resp != "1" { 48 | t.Fatal(resp) 49 | } 50 | } 51 | 52 | func TestURLFormatInSubRouter(t *testing.T) { 53 | r := chi.NewRouter() 54 | 55 | r.Route("/articles/{articleID}", func(r chi.Router) { 56 | r.Use(URLFormat) 57 | r.Get("/subroute", func(w http.ResponseWriter, r *http.Request) { 58 | articleID := chi.URLParam(r, "articleID") 59 | w.Write([]byte(articleID)) 60 | }) 61 | }) 62 | 63 | ts := httptest.NewServer(r) 64 | defer ts.Close() 65 | 66 | if _, resp := testRequest(t, ts, "GET", "/articles/1/subroute.json", nil); resp != "1" { 67 | t.Fatal(resp) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /middleware/value.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | ) 7 | 8 | // WithValue is a middleware that sets a given key/value in a context chain. 9 | func WithValue(key, val interface{}) func(next http.Handler) http.Handler { 10 | return func(next http.Handler) http.Handler { 11 | fn := func(w http.ResponseWriter, r *http.Request) { 12 | r = r.WithContext(context.WithValue(r.Context(), key, val)) 13 | next.ServeHTTP(w, r) 14 | } 15 | return http.HandlerFunc(fn) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /middleware/wrap_writer.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | // The original work was derived from Goji's middleware, source: 4 | // https://github.com/zenazn/goji/tree/master/web/middleware 5 | 6 | import ( 7 | "bufio" 8 | "io" 9 | "net" 10 | "net/http" 11 | ) 12 | 13 | // NewWrapResponseWriter wraps an http.ResponseWriter, returning a proxy that allows you to 14 | // hook into various parts of the response process. 15 | func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) WrapResponseWriter { 16 | _, fl := w.(http.Flusher) 17 | 18 | bw := basicWriter{ResponseWriter: w} 19 | 20 | if protoMajor == 2 { 21 | _, ps := w.(http.Pusher) 22 | if fl && ps { 23 | return &http2FancyWriter{bw} 24 | } 25 | } else { 26 | _, hj := w.(http.Hijacker) 27 | _, rf := w.(io.ReaderFrom) 28 | if fl && hj && rf { 29 | return &httpFancyWriter{bw} 30 | } 31 | if fl && hj { 32 | return &flushHijackWriter{bw} 33 | } 34 | if hj { 35 | return &hijackWriter{bw} 36 | } 37 | } 38 | 39 | if fl { 40 | return &flushWriter{bw} 41 | } 42 | 43 | return &bw 44 | } 45 | 46 | // WrapResponseWriter is a proxy around an http.ResponseWriter that allows you to hook 47 | // into various parts of the response process. 48 | type WrapResponseWriter interface { 49 | http.ResponseWriter 50 | // Status returns the HTTP status of the request, or 0 if one has not 51 | // yet been sent. 52 | Status() int 53 | // BytesWritten returns the total number of bytes sent to the client. 54 | BytesWritten() int 55 | // Tee causes the response body to be written to the given io.Writer in 56 | // addition to proxying the writes through. Only one io.Writer can be 57 | // tee'd to at once: setting a second one will overwrite the first. 58 | // Writes will be sent to the proxy before being written to this 59 | // io.Writer. It is illegal for the tee'd writer to be modified 60 | // concurrently with writes. 61 | Tee(io.Writer) 62 | // Unwrap returns the original proxied target. 63 | Unwrap() http.ResponseWriter 64 | // Discard causes all writes to the original ResponseWriter be discarded, 65 | // instead writing only to the tee'd writer if it's set. 66 | // The caller is responsible for calling WriteHeader and Write on the 67 | // original ResponseWriter once the processing is done. 68 | Discard() 69 | } 70 | 71 | // basicWriter wraps a http.ResponseWriter that implements the minimal 72 | // http.ResponseWriter interface. 73 | type basicWriter struct { 74 | http.ResponseWriter 75 | tee io.Writer 76 | code int 77 | bytes int 78 | wroteHeader bool 79 | discard bool 80 | } 81 | 82 | func (b *basicWriter) WriteHeader(code int) { 83 | if code >= 100 && code <= 199 && code != http.StatusSwitchingProtocols { 84 | if !b.discard { 85 | b.ResponseWriter.WriteHeader(code) 86 | } 87 | } else if !b.wroteHeader { 88 | b.code = code 89 | b.wroteHeader = true 90 | if !b.discard { 91 | b.ResponseWriter.WriteHeader(code) 92 | } 93 | } 94 | } 95 | 96 | func (b *basicWriter) Write(buf []byte) (n int, err error) { 97 | b.maybeWriteHeader() 98 | if !b.discard { 99 | n, err = b.ResponseWriter.Write(buf) 100 | if b.tee != nil { 101 | _, err2 := b.tee.Write(buf[:n]) 102 | // Prefer errors generated by the proxied writer. 103 | if err == nil { 104 | err = err2 105 | } 106 | } 107 | } else if b.tee != nil { 108 | n, err = b.tee.Write(buf) 109 | } else { 110 | n, err = io.Discard.Write(buf) 111 | } 112 | b.bytes += n 113 | return n, err 114 | } 115 | 116 | func (b *basicWriter) maybeWriteHeader() { 117 | if !b.wroteHeader { 118 | b.WriteHeader(http.StatusOK) 119 | } 120 | } 121 | 122 | func (b *basicWriter) Status() int { 123 | return b.code 124 | } 125 | 126 | func (b *basicWriter) BytesWritten() int { 127 | return b.bytes 128 | } 129 | 130 | func (b *basicWriter) Tee(w io.Writer) { 131 | b.tee = w 132 | } 133 | 134 | func (b *basicWriter) Unwrap() http.ResponseWriter { 135 | return b.ResponseWriter 136 | } 137 | 138 | func (b *basicWriter) Discard() { 139 | b.discard = true 140 | } 141 | 142 | // flushWriter ... 143 | type flushWriter struct { 144 | basicWriter 145 | } 146 | 147 | func (f *flushWriter) Flush() { 148 | f.wroteHeader = true 149 | fl := f.basicWriter.ResponseWriter.(http.Flusher) 150 | fl.Flush() 151 | } 152 | 153 | var _ http.Flusher = &flushWriter{} 154 | 155 | // hijackWriter ... 156 | type hijackWriter struct { 157 | basicWriter 158 | } 159 | 160 | func (f *hijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 161 | hj := f.basicWriter.ResponseWriter.(http.Hijacker) 162 | return hj.Hijack() 163 | } 164 | 165 | var _ http.Hijacker = &hijackWriter{} 166 | 167 | // flushHijackWriter ... 168 | type flushHijackWriter struct { 169 | basicWriter 170 | } 171 | 172 | func (f *flushHijackWriter) Flush() { 173 | f.wroteHeader = true 174 | fl := f.basicWriter.ResponseWriter.(http.Flusher) 175 | fl.Flush() 176 | } 177 | 178 | func (f *flushHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 179 | hj := f.basicWriter.ResponseWriter.(http.Hijacker) 180 | return hj.Hijack() 181 | } 182 | 183 | var _ http.Flusher = &flushHijackWriter{} 184 | var _ http.Hijacker = &flushHijackWriter{} 185 | 186 | // httpFancyWriter is a HTTP writer that additionally satisfies 187 | // http.Flusher, http.Hijacker, and io.ReaderFrom. It exists for the common case 188 | // of wrapping the http.ResponseWriter that package http gives you, in order to 189 | // make the proxied object support the full method set of the proxied object. 190 | type httpFancyWriter struct { 191 | basicWriter 192 | } 193 | 194 | func (f *httpFancyWriter) Flush() { 195 | f.wroteHeader = true 196 | fl := f.basicWriter.ResponseWriter.(http.Flusher) 197 | fl.Flush() 198 | } 199 | 200 | func (f *httpFancyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 201 | hj := f.basicWriter.ResponseWriter.(http.Hijacker) 202 | return hj.Hijack() 203 | } 204 | 205 | func (f *http2FancyWriter) Push(target string, opts *http.PushOptions) error { 206 | return f.basicWriter.ResponseWriter.(http.Pusher).Push(target, opts) 207 | } 208 | 209 | func (f *httpFancyWriter) ReadFrom(r io.Reader) (int64, error) { 210 | if f.basicWriter.tee != nil { 211 | n, err := io.Copy(&f.basicWriter, r) 212 | f.basicWriter.bytes += int(n) 213 | return n, err 214 | } 215 | rf := f.basicWriter.ResponseWriter.(io.ReaderFrom) 216 | f.basicWriter.maybeWriteHeader() 217 | n, err := rf.ReadFrom(r) 218 | f.basicWriter.bytes += int(n) 219 | return n, err 220 | } 221 | 222 | var _ http.Flusher = &httpFancyWriter{} 223 | var _ http.Hijacker = &httpFancyWriter{} 224 | var _ http.Pusher = &http2FancyWriter{} 225 | var _ io.ReaderFrom = &httpFancyWriter{} 226 | 227 | // http2FancyWriter is a HTTP2 writer that additionally satisfies 228 | // http.Flusher, and io.ReaderFrom. It exists for the common case 229 | // of wrapping the http.ResponseWriter that package http gives you, in order to 230 | // make the proxied object support the full method set of the proxied object. 231 | type http2FancyWriter struct { 232 | basicWriter 233 | } 234 | 235 | func (f *http2FancyWriter) Flush() { 236 | f.wroteHeader = true 237 | fl := f.basicWriter.ResponseWriter.(http.Flusher) 238 | fl.Flush() 239 | } 240 | 241 | var _ http.Flusher = &http2FancyWriter{} 242 | -------------------------------------------------------------------------------- /middleware/wrap_writer_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | ) 9 | 10 | func TestHttpFancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) { 11 | f := &httpFancyWriter{basicWriter: basicWriter{ResponseWriter: httptest.NewRecorder()}} 12 | f.Flush() 13 | 14 | if !f.wroteHeader { 15 | t.Fatal("want Flush to have set wroteHeader=true") 16 | } 17 | } 18 | 19 | func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) { 20 | f := &http2FancyWriter{basicWriter{ResponseWriter: httptest.NewRecorder()}} 21 | f.Flush() 22 | 23 | if !f.wroteHeader { 24 | t.Fatal("want Flush to have set wroteHeader=true") 25 | } 26 | } 27 | 28 | func TestBasicWritesTeesWritesWithoutDiscard(t *testing.T) { 29 | // explicitly create the struct instead of NewRecorder to control the value of Code 30 | original := &httptest.ResponseRecorder{ 31 | HeaderMap: make(http.Header), 32 | Body: new(bytes.Buffer), 33 | } 34 | wrap := &basicWriter{ResponseWriter: original} 35 | 36 | var buf bytes.Buffer 37 | wrap.Tee(&buf) 38 | 39 | _, err := wrap.Write([]byte("hello world")) 40 | assertNoError(t, err) 41 | 42 | assertEqual(t, 200, original.Code) 43 | assertEqual(t, []byte("hello world"), original.Body.Bytes()) 44 | assertEqual(t, []byte("hello world"), buf.Bytes()) 45 | assertEqual(t, 11, wrap.BytesWritten()) 46 | } 47 | 48 | func TestBasicWriterDiscardsWritesToOriginalResponseWriter(t *testing.T) { 49 | t.Run("With Tee", func(t *testing.T) { 50 | // explicitly create the struct instead of NewRecorder to control the value of Code 51 | original := &httptest.ResponseRecorder{ 52 | HeaderMap: make(http.Header), 53 | Body: new(bytes.Buffer), 54 | } 55 | wrap := &basicWriter{ResponseWriter: original} 56 | 57 | var buf bytes.Buffer 58 | wrap.Tee(&buf) 59 | wrap.Discard() 60 | 61 | _, err := wrap.Write([]byte("hello world")) 62 | assertNoError(t, err) 63 | 64 | assertEqual(t, 0, original.Code) // wrapper shouldn't call WriteHeader implicitly 65 | assertEqual(t, 0, original.Body.Len()) 66 | assertEqual(t, []byte("hello world"), buf.Bytes()) 67 | assertEqual(t, 11, wrap.BytesWritten()) 68 | }) 69 | 70 | t.Run("Without Tee", func(t *testing.T) { 71 | // explicitly create the struct instead of NewRecorder to control the value of Code 72 | original := &httptest.ResponseRecorder{ 73 | HeaderMap: make(http.Header), 74 | Body: new(bytes.Buffer), 75 | } 76 | wrap := &basicWriter{ResponseWriter: original} 77 | wrap.Discard() 78 | 79 | _, err := wrap.Write([]byte("hello world")) 80 | assertNoError(t, err) 81 | 82 | assertEqual(t, 0, original.Code) // wrapper shouldn't call WriteHeader implicitly 83 | assertEqual(t, 0, original.Body.Len()) 84 | assertEqual(t, 11, wrap.BytesWritten()) 85 | }) 86 | } 87 | -------------------------------------------------------------------------------- /path_value.go: -------------------------------------------------------------------------------- 1 | //go:build go1.22 && !tinygo 2 | // +build go1.22,!tinygo 3 | 4 | 5 | package chi 6 | 7 | import "net/http" 8 | 9 | // supportsPathValue is true if the Go version is 1.22 and above. 10 | // 11 | // If this is true, `net/http.Request` has methods `SetPathValue` and `PathValue`. 12 | const supportsPathValue = true 13 | 14 | // setPathValue sets the path values in the Request value 15 | // based on the provided request context. 16 | func setPathValue(rctx *Context, r *http.Request) { 17 | for i, key := range rctx.URLParams.Keys { 18 | value := rctx.URLParams.Values[i] 19 | r.SetPathValue(key, value) 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /path_value_fallback.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.22 || tinygo 2 | // +build !go1.22 tinygo 3 | 4 | package chi 5 | 6 | import "net/http" 7 | 8 | // supportsPathValue is true if the Go version is 1.22 and above. 9 | // 10 | // If this is true, `net/http.Request` has methods `SetPathValue` and `PathValue`. 11 | const supportsPathValue = false 12 | 13 | // setPathValue sets the path values in the Request value 14 | // based on the provided request context. 15 | // 16 | // setPathValue is only supported in Go 1.22 and above so 17 | // this is just a blank function so that it compiles. 18 | func setPathValue(rctx *Context, r *http.Request) { 19 | } 20 | -------------------------------------------------------------------------------- /path_value_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.22 && !tinygo 2 | // +build go1.22,!tinygo 3 | 4 | package chi 5 | 6 | import ( 7 | "net/http" 8 | "net/http/httptest" 9 | "strings" 10 | "testing" 11 | ) 12 | 13 | func TestPathValue(t *testing.T) { 14 | testCases := []struct { 15 | name string 16 | pattern string 17 | method string 18 | requestPath string 19 | expectedBody string 20 | pathKeys []string 21 | }{ 22 | { 23 | name: "Basic path value", 24 | pattern: "/hubs/{hubID}", 25 | method: "GET", 26 | pathKeys: []string{"hubID"}, 27 | requestPath: "/hubs/392", 28 | expectedBody: "392", 29 | }, 30 | { 31 | name: "Two path values", 32 | pattern: "/users/{userID}/conversations/{conversationID}", 33 | method: "POST", 34 | pathKeys: []string{"userID", "conversationID"}, 35 | requestPath: "/users/Gojo/conversations/2948", 36 | expectedBody: "Gojo 2948", 37 | }, 38 | { 39 | name: "Wildcard path", 40 | pattern: "/users/{userID}/friends/*", 41 | method: "POST", 42 | pathKeys: []string{"userID", "*"}, 43 | requestPath: "/users/Gojo/friends/all-of-them/and/more", 44 | expectedBody: "Gojo all-of-them/and/more", 45 | }, 46 | } 47 | 48 | for _, tc := range testCases { 49 | t.Run(tc.name, func(t *testing.T) { 50 | r := NewRouter() 51 | 52 | r.Handle(tc.method+" "+tc.pattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 53 | pathValues := []string{} 54 | for _, pathKey := range tc.pathKeys { 55 | pathValue := r.PathValue(pathKey) 56 | if pathValue == "" { 57 | pathValue = "NOT_FOUND:" + pathKey 58 | } 59 | 60 | pathValues = append(pathValues, pathValue) 61 | } 62 | 63 | body := strings.Join(pathValues, " ") 64 | 65 | w.Write([]byte(body)) 66 | })) 67 | 68 | ts := httptest.NewServer(r) 69 | defer ts.Close() 70 | 71 | _, body := testRequest(t, ts, tc.method, tc.requestPath, nil) 72 | if body != tc.expectedBody { 73 | t.Fatalf("expecting %q, got %q", tc.expectedBody, body) 74 | } 75 | }) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /testdata/cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIC/zCCAeegAwIBAgIRANioW0Re7DtpT4qZpJU1iK8wDQYJKoZIhvcNAQELBQAw 3 | EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNjEyMzExNDU0MzBaFw0xNzEyMzExNDU0 4 | MzBaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw 5 | ggEKAoIBAQDpFfOsaXDYlL+ektfsqGYrSAsoTbe7zqjpow9nqUU4PmLRu2YMaaW8 6 | fAoneUnJxsJw7ql38+VMpphZUOmOWvsO7uV/lfnTIQfTwllHDdgAR5A11d84Zy/y 7 | TiNIFJduuaPtEhQs1dxPhU7TG8sEfFRhBoUDPv473akeGPNkVU756RVBYM6rUc3b 8 | YygD0PXGsQ2obrImbYUyyHH5YClCvGl1No57n3ugLqSSfwbgR3/Gw7kkGKy0PMOu 9 | TuHuJnTEmofJPkqEyFRVMlIAtfqFqJUfDHTOuQGWIUPnjDg+fqTI9EPJ+pElBqDQ 10 | IqW93BY5XePMdrTQc1h6xkduDfuLeA7TAgMBAAGjUDBOMA4GA1UdDwEB/wQEAwIF 11 | oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBkGA1UdEQQSMBCC 12 | DmxvY2FsaG9zdDo3MDcyMA0GCSqGSIb3DQEBCwUAA4IBAQDnsWmZdf7209A/XHUe 13 | xoONCbU8jaYFVoA+CN9J+3CASzrzTQ4fh9RJdm2FZuv4sWnb5c5hDN7H/M/nLcb0 14 | +uu7ACBGhd7yACYCQm/z3Pm3CY2BRIo0vCCRioGx+6J3CPGWFm0vHwNBge0iBOKC 15 | Wn+/YOlTDth/M3auHYlr7hdFmf57U4V/5iTr4wiKxwM9yMPcVRQF/1XpPd7A0VqM 16 | nFSEfDpFjrA7MvT3DrRqQGqF/ZXxDbro2nyki3YG8FwgKlFNVN9w55zNiriQ+WNA 17 | uz86lKg1FTc+m/R/0CD//7+7mme28N813EPVdV83TgxWNrfvAIRazkHE7YxETry0 18 | BJDg 19 | -----END CERTIFICATE----- -------------------------------------------------------------------------------- /testdata/key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpAIBAAKCAQEA6RXzrGlw2JS/npLX7KhmK0gLKE23u86o6aMPZ6lFOD5i0btm 3 | DGmlvHwKJ3lJycbCcO6pd/PlTKaYWVDpjlr7Du7lf5X50yEH08JZRw3YAEeQNdXf 4 | OGcv8k4jSBSXbrmj7RIULNXcT4VO0xvLBHxUYQaFAz7+O92pHhjzZFVO+ekVQWDO 5 | q1HN22MoA9D1xrENqG6yJm2FMshx+WApQrxpdTaOe597oC6kkn8G4Ed/xsO5JBis 6 | tDzDrk7h7iZ0xJqHyT5KhMhUVTJSALX6haiVHwx0zrkBliFD54w4Pn6kyPRDyfqR 7 | JQag0CKlvdwWOV3jzHa00HNYesZHbg37i3gO0wIDAQABAoIBAFvqYDE5U1rVLctm 8 | tOeKcN/YhS3bl/zjvhCEUOrcAYPwdh+m+tMiRk1RzN9MISEE1GCcfQ/kiiPz/lga 9 | ZD/S+PYmlzH8/ouXlvKWzYYLm4ZgsinIsUIYzvuKfLdMB3uOkWpHmtUjcMGbHD57 10 | 009tiAjK/WEOUkthWfOYe0KxsXczBn3PTAWZuiIkuA3RVWa7pCCFHUENkViP58wl 11 | Ky1hYKnunKPApRwuiC6qIT5ZOCSukdCCbkmRnj/x+P8+nsosu+1d85MNZb8uLRi0 12 | RzMmuOfOK2poDsrNHQX7itKlu7rzMJQc3+RauqIZovNe/BmSq+tYBLboXvUp18g/ 13 | +VqKeEECgYEA/LaD1tJepzD/1lhgunFcnDjxsDJqLUpfR5eDMX1qhGJphuPBLOXS 14 | ushmVVjbVIn25Wxeoe4RYrZ6Tuu0FEJJgV44Lt42OOFgK2gyrCJpYmlxpRaw+7jc 15 | Dbp1Sh3/9VqMZjR/mQIzTnfOtS2n4Fk1Q53hdJn5Pn+uPMmMO4hF87sCgYEA7B4V 16 | BACsd6eqVxKkEMc72VLeYb0Ri0bl0FwbvIKXImppwA0tbMDmeA+6yhcRm23dhd5v 17 | cfNhJepRIzkM2CkhnazlsAbDoJPqb7/sbNzodtW1P0op7YIFYbrkcX4yOu9O1DNI 18 | Ij4PR8H1WcpPjhvr3q+iNO5agQX7bMQ1BnnJg8kCgYBA1tdm090DSrgpl81hqNpZ 19 | HucsDRNfAXkG1mIL3aDpzJJE0MTsrx7tW6Od/ElyHF/jp3V0WK/PQwCIpUMz+3n+ 20 | nl0N8We6GmFhYb+2mLGvVVyaPgM04s5bG18ioCXfHtdtFcUzTfQ6CtVXeRpcnqbi 21 | 7Ww+TY88sOfUouW/FIzWJwKBgQCsLauJhaw+fOc8I328NmywJzu+7g5TD9oZvHEF 22 | X/0xvYNr5rAPNANb3ayKHZRbURxOuEtwPtfCvEF6e+mf3y6COkgrumMBP5ue7cdM 23 | AzMJJQHMKxqz9TJTd+OJ10ptq4BCQTsCrVqbKxbs6RhmOnofoteX3Y/lsiULxXAd 24 | TsXh8QKBgQDQHosH8VoL7vIK+SqY5uoHAhMytSVNx4IaZZg4ho8oyjw12QXcidgV 25 | QJZQMdPEv8cAK78WcQdSthop+O/tu2cKLHyAmWmO3oU7gIQECui0aMXSqraO6Vde 26 | C5tqYlyLa7bHZS3AqrjRv9BRfwPKVkmBoYdA652rN/tE/K4UWsghnA== 27 | -----END RSA PRIVATE KEY----- --------------------------------------------------------------------------------