├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── access ├── logger.go └── logger_test.go ├── auth ├── handlers.go └── handlers_test.go ├── content ├── language.go ├── language_test.go ├── negotiator.go ├── negotiator_test.go ├── type.go └── type_test.go ├── context.go ├── context_test.go ├── cors ├── handler.go └── handler_test.go ├── error.go ├── error_test.go ├── example_test.go ├── fault ├── error.go ├── error_test.go ├── panic.go ├── panic_test.go ├── recovery.go └── recovery_test.go ├── file ├── server.go ├── server_test.go └── testdata │ ├── css │ ├── index.html │ └── main.css │ └── index.html ├── go.mod ├── go.sum ├── graceful.go ├── group.go ├── group_test.go ├── reader.go ├── reader_test.go ├── route.go ├── route_test.go ├── router.go ├── router_test.go ├── slash ├── remover.go └── remover_test.go ├── store.go ├── store_test.go ├── writer.go └── writer_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | coverage.out -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: bionic 2 | 3 | language: go 4 | 5 | go: 6 | - 1.13.x 7 | 8 | install: 9 | - go get golang.org/x/tools/cmd/cover 10 | - go get github.com/mattn/goveralls 11 | - go get golang.org/x/lint/golint 12 | 13 | script: 14 | - go test -v -covermode=count -coverprofile=coverage.out ./... 15 | - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright (c) 2016, Qiang Xue 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software 5 | and associated documentation files (the "Software"), to deal in the Software without restriction, 6 | including without limitation the rights to use, copy, modify, merge, publish, distribute, 7 | sublicense, and/or sell copies of the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all copies or 11 | substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING 14 | BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 15 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 16 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ozzo-routing 2 | 3 | [![GoDoc](https://godoc.org/github.com/go-ozzo/ozzo-routing?status.png)](http://godoc.org/github.com/go-ozzo/ozzo-routing) 4 | [![Build Status](https://travis-ci.org/go-ozzo/ozzo-routing.svg?branch=master)](https://travis-ci.org/go-ozzo/ozzo-routing) 5 | [![Coverage Status](https://coveralls.io/repos/github/go-ozzo/ozzo-routing/badge.svg?branch=master)](https://coveralls.io/github/go-ozzo/ozzo-routing?branch=master) 6 | [![Go Report](https://goreportcard.com/badge/github.com/go-ozzo/ozzo-routing)](https://goreportcard.com/report/github.com/go-ozzo/ozzo-routing) 7 | 8 | **You may consider using [go-rest-api](https://github.com/qiangxue/go-rest-api) to jumpstart your new RESTful applications with ozzo-routing.** 9 | 10 | ## Description 11 | 12 | ozzo-routing is a Go package that provides high performance and powerful HTTP routing capabilities for Web applications. 13 | It has the following features: 14 | 15 | * middleware pipeline architecture, similar to that of the [Express framework](http://expressjs.com). 16 | * extremely fast request routing with zero dynamic memory allocation (the performance is comparable to that of [httprouter](https://github.com/julienschmidt/httprouter) and 17 | [gin](https://github.com/gin-gonic/gin), see the [performance comparison below](#benchmarks)) 18 | * modular code organization through route grouping 19 | * flexible URL path matching, supporting URL parameters and regular expressions 20 | * URL creation according to the predefined routes 21 | * compatible with `http.Handler` and `http.HandlerFunc` 22 | * ready-to-use handlers sufficient for building RESTful APIs 23 | * graceful shutdown 24 | 25 | If you are using [fasthttp](https://github.com/valyala/fasthttp), you may use a similar routing package [fasthttp-routing](https://github.com/qiangxue/fasthttp-routing) which is adapted from ozzo-routing. 26 | 27 | ## Requirements 28 | 29 | Go 1.13 or above. 30 | 31 | ## Installation 32 | 33 | In your Go project using `go mod`, run the following command to install the package: 34 | 35 | ``` 36 | go get github.com/go-ozzo/ozzo-routing/v2 37 | ``` 38 | 39 | ## Getting Started 40 | 41 | For a complete RESTful application boilerplate based on ozzo-routing, please refer to the [golang-restful-starter-kit](https://github.com/qiangxue/golang-restful-starter-kit). Below we describe how to create a simple REST API using ozzo-routing. 42 | 43 | Create a `server.go` file with the following content: 44 | 45 | ```go 46 | package main 47 | 48 | import ( 49 | "log" 50 | "net/http" 51 | "github.com/go-ozzo/ozzo-routing/v2" 52 | "github.com/go-ozzo/ozzo-routing/v2/access" 53 | "github.com/go-ozzo/ozzo-routing/v2/slash" 54 | "github.com/go-ozzo/ozzo-routing/v2/content" 55 | "github.com/go-ozzo/ozzo-routing/v2/fault" 56 | "github.com/go-ozzo/ozzo-routing/v2/file" 57 | ) 58 | 59 | func main() { 60 | router := routing.New() 61 | 62 | router.Use( 63 | // all these handlers are shared by every route 64 | access.Logger(log.Printf), 65 | slash.Remover(http.StatusMovedPermanently), 66 | fault.Recovery(log.Printf), 67 | ) 68 | 69 | // serve RESTful APIs 70 | api := router.Group("/api") 71 | api.Use( 72 | // these handlers are shared by the routes in the api group only 73 | content.TypeNegotiator(content.JSON, content.XML), 74 | ) 75 | api.Get("/users", func(c *routing.Context) error { 76 | return c.Write("user list") 77 | }) 78 | api.Post("/users", func(c *routing.Context) error { 79 | return c.Write("create a new user") 80 | }) 81 | api.Put(`/users/`, func(c *routing.Context) error { 82 | return c.Write("update user " + c.Param("id")) 83 | }) 84 | 85 | // serve index file 86 | router.Get("/", file.Content("ui/index.html")) 87 | // serve files under the "ui" subdirectory 88 | router.Get("/*", file.Server(file.PathMap{ 89 | "/": "/ui/", 90 | })) 91 | 92 | http.Handle("/", router) 93 | http.ListenAndServe(":8080", nil) 94 | } 95 | ``` 96 | 97 | Create an HTML file `ui/index.html` with any content. 98 | 99 | Now run the following command to start the Web server: 100 | 101 | ``` 102 | go run server.go 103 | ``` 104 | 105 | You should be able to access URLs such as `http://localhost:8080`, `http://localhost:8080/api/users`. 106 | 107 | 108 | ### Routes 109 | 110 | ozzo-routing works by building a routing table in a router and then dispatching HTTP requests to the matching handlers 111 | found in the routing table. An intuitive illustration of a routing table is as follows: 112 | 113 | 114 | Routes | Handlers 115 | --------------------|----------------- 116 | `GET /users` | m1, m2, h1, ... 117 | `POST /users` | m1, m2, h2, ... 118 | `PUT /users/` | m1, m2, h3, ... 119 | `DELETE /users/`| m1, m2, h4, ... 120 | 121 | 122 | For an incoming request `GET /users`, the first route would match and the handlers m1, m2, and h1 would be executed. 123 | If the request is `PUT /users/123`, the third route would match and the corresponding handlers would be executed. 124 | Note that the token `` can match any number of non-slash characters and the matching part can be accessed as 125 | a path parameter value in the handlers. 126 | 127 | **If an incoming request matches multiple routes in the table, the route added first to the table will take precedence. 128 | All other matching routes will be ignored.** 129 | 130 | The actual implementation of the routing table uses a variant of the radix tree data structure, which makes the routing 131 | process as fast as working with a hash table, thanks to the inspiration from [httprouter](https://github.com/julienschmidt/httprouter). 132 | 133 | To add a new route and its handlers to the routing table, call the `To` method like the following: 134 | 135 | ```go 136 | router := routing.New() 137 | router.To("GET", "/users", m1, m2, h1) 138 | router.To("POST", "/users", m1, m2, h2) 139 | ``` 140 | 141 | You can also use shortcut methods, such as `Get`, `Post`, `Put`, etc., which are named after the HTTP method names: 142 | 143 | ```go 144 | router.Get("/users", m1, m2, h1) 145 | router.Post("/users", m1, m2, h2) 146 | ``` 147 | 148 | If you have multiple routes with the same URL path but different HTTP methods, like the above example, you can 149 | chain them together as follows, 150 | 151 | ```go 152 | router.Get("/users", m1, m2, h1).Post(m1, m2, h2) 153 | ``` 154 | 155 | If you want to use the same set of handlers to handle the same URL path but different HTTP methods, you can take 156 | the following shortcut: 157 | 158 | ```go 159 | router.To("GET,POST", "/users", m1, m2, h) 160 | ``` 161 | 162 | A route may contain parameter tokens which are in the format of ``, where `name` stands for the parameter 163 | name, and `pattern` is a regular expression which the parameter value should match. A token `` is equivalent 164 | to ``, i.e., it matches any number of non-slash characters. At the end of a route, an asterisk character 165 | can be used to match any number of arbitrary characters. Below are some examples: 166 | 167 | * `/users/`: matches `/users/admin` 168 | * `/users/accnt-`: matches `/users/accnt-123`, but not `/users/accnt-admin` 169 | * `/users//*`: matches `/users/admin/profile/address` 170 | 171 | When a URL path matches a route, the matching parameters on the URL path can be accessed via `Context.Param()`: 172 | 173 | ```go 174 | router := routing.New() 175 | 176 | router.Get("/users/", func (c *routing.Context) error { 177 | fmt.Fprintf(c.Response, "Name: %v", c.Param("username")) 178 | return nil 179 | }) 180 | ``` 181 | 182 | 183 | ### Route Groups 184 | 185 | Route group is a way of grouping together the routes which have the same route prefix. The routes in a group also 186 | share the same handlers that are registered with the group via its `Use` method. For example, 187 | 188 | ```go 189 | router := routing.New() 190 | api := router.Group("/api") 191 | api.Use(m1, m2) 192 | api.Get("/users", h1).Post(h2) 193 | api.Put("/users/", h3).Delete(h4) 194 | ``` 195 | 196 | The above `/api` route group establishes the following routing table: 197 | 198 | 199 | Routes | Handlers 200 | ------------------------|------------- 201 | `GET /api/users` | m1, m2, h1, ... 202 | `POST /api/users` | m1, m2, h2, ... 203 | `PUT /api/users/` | m1, m2, h3, ... 204 | `DELETE /api/users/`| m1, m2, h4, ... 205 | 206 | 207 | As you can see, all these routes have the same route prefix `/api` and the handlers `m1` and `m2`. In other similar 208 | routing frameworks, the handlers registered with a route group are also called *middlewares*. 209 | 210 | Route groups can be nested. That is, a route group can create a child group by calling the `Group()` method. The router 211 | serves as the top level route group. A child group inherits the handlers registered with its parent group. For example, 212 | 213 | ```go 214 | router := routing.New() 215 | router.Use(m1) 216 | 217 | api := router.Group("/api") 218 | api.Use(m2) 219 | 220 | users := api.Group("/users") 221 | users.Use(m3) 222 | users.Put("/", h1) 223 | ``` 224 | 225 | Because the router serves as the parent of the `api` group which is the parent of the `users` group, 226 | the `PUT /api/users/` route is associated with the handlers `m1`, `m2`, `m3`, and `h1`. 227 | 228 | 229 | ### Router 230 | 231 | Router manages the routing table and dispatches incoming requests to appropriate handlers. A router instance is created 232 | by calling the `routing.New()` method. 233 | 234 | Because `Router` implements the `http.Handler` interface, it can be readily used to serve subtrees on existing Go servers. 235 | For example, 236 | 237 | ```go 238 | router := routing.New() 239 | http.Handle("/", router) 240 | http.ListenAndServe(":8080", nil) 241 | ``` 242 | 243 | 244 | ### Handlers 245 | 246 | A handler is a function with the signature `func(*routing.Context) error`. A handler is executed by the router if 247 | the incoming request URL path matches the route that the handler is associated with. Through the `routing.Context` 248 | parameter, you can access the request information in handlers. 249 | 250 | A route may be associated with multiple handlers. These handlers will be executed in the order that they are registered 251 | to the route. The execution sequence can be terminated in the middle using one of the following two methods: 252 | 253 | * A handler returns an error: the router will skip the rest of the handlers and handle the returned error. 254 | * A handler calls `Context.Abort()`: the router will simply skip the rest of the handlers. There is no error to be handled. 255 | 256 | A handler can call `Context.Next()` to explicitly execute the rest of the unexecuted handlers and take actions after 257 | they finish execution. For example, a response compression handler may start the output buffer, call `Context.Next()`, 258 | and then compress and send the output to response. 259 | 260 | 261 | ### Context 262 | 263 | For each incoming request, a `routing.Context` object is populated with the request information and passed through 264 | the handlers that need to handle the request. Handlers can get the request information via `Context.Request` and 265 | send a response back via `Context.Response`. The `Context.Param()` method allows handlers to access the URL path 266 | parameters that match the current route. 267 | 268 | Using `Context.Get()` and `Context.Set()`, handlers can share data between each other. For example, an authentication 269 | handler can store the authenticated user identity by calling `Context.Set()`, and other handlers can retrieve back 270 | the identity information by calling `Context.Get()`. 271 | 272 | 273 | ### Reading Request Data 274 | 275 | Context provides a few shortcut methods to read query parameters. The `Context.Query()` method returns 276 | the named URL query parameter value; the `Context.PostForm()` method returns the named parameter value in the POST or 277 | PUT body parameters; and the `Context.Form()` method returns the value from either POST/PUT or URL query parameters. 278 | 279 | The `Context.Read()` method supports reading data from the request body and populating it into an object. 280 | The method will check the `Content-Type` HTTP header and parse the body data as the corresponding format. 281 | For example, if `Content-Type` is `application/json`, the request body will be parsed as JSON data. 282 | The public fields in the object being populated will receive the parsed data if the data contains the same named fields. 283 | For example, 284 | 285 | ```go 286 | func foo(c *routing.Context) error { 287 | data := &struct{ 288 | A string 289 | B bool 290 | }{} 291 | 292 | // assume the body data is: {"A":"abc", "B":true} 293 | // data will be populated as: {A: "abc", B: true} 294 | if err := c.Read(&data); err != nil { 295 | return err 296 | } 297 | } 298 | ``` 299 | 300 | By default, `Context` supports reading data that are in JSON, XML, form, and multipart-form data. 301 | You may modify `routing.DataReaders` to add support for other data formats. 302 | 303 | Note that when the data is read as form data, you may use struct tag named `form` to customize 304 | the name of the corresponding field in the form data. The form data reader also supports populating 305 | data into embedded objects which are either named or anonymous. 306 | 307 | ### Writing Response Data 308 | 309 | The `Context.Write()` method can be used to write data of arbitrary type to the response. 310 | By default, if the data being written is neither a string nor a byte array, the method will 311 | will call `fmt.Fprint()` to write the data into the response. 312 | 313 | You can call `Context.SetWriter()` to replace the default data writer with a customized one. 314 | For example, the `content.TypeNegotiator` will negotiate the content response type and set the data 315 | writer with an appropriate one. 316 | 317 | ### Error Handling 318 | 319 | A handler may return an error indicating some erroneous condition. Sometimes, a handler or the code it calls may cause 320 | a panic. Both should be handled properly to ensure best user experience. It is recommended that you use 321 | the `fault.Recover` handler or a similar error handler to handle these errors. 322 | 323 | If an error is not handled by any handler, the router will handle it by calling its `handleError()` method which 324 | simply sets an appropriate HTTP status code and writes the error message to the response. 325 | 326 | When an incoming request has no matching route, the router will call the handlers registered via the `Router.NotFound()` 327 | method. All the handlers registered via `Router.Use()` will also be called in advance. By default, the following two 328 | handlers are registered with `Router.NotFound()`: 329 | 330 | * `routing.MethodNotAllowedHandler`: a handler that sends an `Allow` HTTP header indicating the allowed HTTP methods for a requested URL 331 | * `routing.NotFoundHandler`: a handler triggering 404 HTTP error 332 | 333 | ## Serving Static Files 334 | 335 | Static files can be served with the help of `file.Server` and `file.Content` handlers. The former serves files 336 | under the specified directories, while the latter serves the content of a single file. For example, 337 | 338 | ```go 339 | import ( 340 | "github.com/go-ozzo/ozzo-routing/v2" 341 | "github.com/go-ozzo/ozzo-routing/v2/file" 342 | ) 343 | 344 | router := routing.NewRouter() 345 | 346 | // serve index file 347 | router.Get("/", file.Content("ui/index.html")) 348 | // serve files under the "ui" subdirectory 349 | router.Get("/*", file.Server(file.PathMap{ 350 | "/": "/ui/", 351 | })) 352 | ``` 353 | 354 | ## Handlers 355 | 356 | ozzo-routing comes with a few commonly used handlers in its subpackages: 357 | 358 | Handler name | Description 359 | --------------------------------|-------------------------------------------- 360 | [access.Logger](https://godoc.org/github.com/go-ozzo/ozzo-routing/access) | records an entry for every incoming request 361 | [auth.Basic](https://godoc.org/github.com/go-ozzo/ozzo-routing/auth) | provides authentication via HTTP Basic 362 | [auth.Bearer](https://godoc.org/github.com/go-ozzo/ozzo-routing/auth) | provides authentication via HTTP Bearer 363 | [auth.Query](https://godoc.org/github.com/go-ozzo/ozzo-routing/auth) | provides authentication via token-based query parameter 364 | [auth.JWT](https://godoc.org/github.com/go-ozzo/ozzo-routing/auth) | provides JWT-based authentication 365 | [content.TypeNegotiator](https://godoc.org/github.com/go-ozzo/ozzo-routing/content) | supports content negotiation by response types 366 | [content.LanguageNegotiator](https://godoc.org/github.com/go-ozzo/ozzo-routing/content) | supports content negotiation by accepted languages 367 | [cors.Handler](https://godoc.org/github.com/go-ozzo/ozzo-routing/cors) | implements the CORS (Cross Origin Resource Sharing) specification from the W3C 368 | [fault.Recovery](https://godoc.org/github.com/go-ozzo/ozzo-routing/fault) | recovers from panics and handles errors returned by handlers 369 | [fault.PanicHandler](https://godoc.org/github.com/go-ozzo/ozzo-routing/fault) | recovers from panics happened in the handlers 370 | [fault.ErrorHandler](https://godoc.org/github.com/go-ozzo/ozzo-routing/fault) | handles errors returned by handlers by writing them in an appropriate format to the response 371 | [file.Server](https://godoc.org/github.com/go-ozzo/ozzo-routing/file) | serves the files under the specified folder as response content 372 | [file.Content](https://godoc.org/github.com/go-ozzo/ozzo-routing/file) | serves the content of the specified file as the response 373 | [slash.Remover](https://godoc.org/github.com/go-ozzo/ozzo-routing/slash) | removes the trailing slashes from the request URL and redirects to the proper URL 374 | 375 | The following code shows how these handlers may be used: 376 | 377 | ```go 378 | import ( 379 | "log" 380 | "net/http" 381 | "github.com/go-ozzo/ozzo-routing/v2" 382 | "github.com/go-ozzo/ozzo-routing/v2/access" 383 | "github.com/go-ozzo/ozzo-routing/v2/slash" 384 | "github.com/go-ozzo/ozzo-routing/v2/fault" 385 | ) 386 | 387 | router := routing.New() 388 | 389 | router.Use( 390 | access.Logger(log.Printf), 391 | slash.Remover(http.StatusMovedPermanently), 392 | fault.Recovery(log.Printf), 393 | ) 394 | 395 | ... 396 | ``` 397 | 398 | ### Third-party Handlers 399 | 400 | 401 | The following third-party handlers are specifically designed for ozzo-routing: 402 | 403 | Handler name | Description 404 | --------------------------------|-------------------------------------------- 405 | [jwt.JWT](https://github.com/vvv-v13/ozzo-jwt) | supports JWT Authorization 406 | 407 | 408 | ozzo-routing also provides adapters to support using third-party `http.HandlerFunc` or `http.Handler` handlers. 409 | For example, 410 | 411 | ```go 412 | router := routing.New() 413 | 414 | // using http.HandlerFunc 415 | router.Use(routing.HTTPHandlerFunc(http.NotFound)) 416 | 417 | // using http.Handler 418 | router.Use(routing.HTTPHandler(http.NotFoundHandler)) 419 | ``` 420 | 421 | ## 3rd-Party Extensions and Code Examples 422 | 423 | * [Simple Standard Service Endpoints (SE4)](https://github.com/jdamick/ozzo-se4) 424 | * [ozzo examples](https://github.com/marshyski/go-ozzo-examples) 425 | 426 | ## Benchmarks 427 | 428 | *Last updated on Jan 6, 2017* 429 | 430 | Ozzo-routing is very fast, thanks to the radix tree data structure and the usage of `sync.Pool` (the idea was 431 | originally from HttpRouter and Gin). The following table (by running [go-http-routing-benchmark](https://github.com/qiangxue/go-http-routing-benchmark)) 432 | shows how ozzo-routing compares with Gin, HttpRouter, and Martini in performance. 433 | 434 | ``` 435 | BenchmarkOzzo_GithubAll 50000 37989 ns/op 0 B/op 0 allocs/op 436 | BenchmarkEcho_GithubAll 20000 91003 ns/op 6496 B/op 203 allocs/op 437 | BenchmarkGin_GithubAll 50000 26717 ns/op 0 B/op 0 allocs/op 438 | BenchmarkHttpRouter_GithubAll 50000 36052 ns/op 13792 B/op 167 allocs/op 439 | BenchmarkMartini_GithubAll 300 4162283 ns/op 228216 B/op 2483 allocs/op 440 | 441 | BenchmarkOzzo_GPlusAll 1000000 1732 ns/op 0 B/op 0 allocs/op 442 | BenchmarkEcho_GPlusAll 300000 4523 ns/op 416 B/op 13 allocs/op 443 | BenchmarkGin_GPlusAll 1000000 1171 ns/op 0 B/op 0 allocs/op 444 | BenchmarkHttpRouter_GPlusAll 1000000 1533 ns/op 640 B/op 11 allocs/op 445 | BenchmarkMartini_GPlusAll 20000 75634 ns/op 14448 B/op 165 allocs/op 446 | 447 | BenchmarkOzzo_ParseAll 500000 3318 ns/op 0 B/op 0 allocs/op 448 | BenchmarkEcho_ParseAll 200000 7336 ns/op 832 B/op 26 allocs/op 449 | BenchmarkGin_ParseAll 1000000 2075 ns/op 0 B/op 0 allocs/op 450 | BenchmarkHttpRouter_ParseAll 1000000 2034 ns/op 640 B/op 16 allocs/op 451 | BenchmarkMartini_ParseAll 10000 122002 ns/op 25600 B/op 276 allocs/op 452 | ``` 453 | 454 | ## Credits 455 | 456 | ozzo-routing has referenced many popular routing frameworks, including [Express](http://expressjs.com/), 457 | [Martini](https://github.com/go-martini/martini), [httprouter](https://github.com/julienschmidt/httprouter), and 458 | [gin](https://github.com/gin-gonic/gin). 459 | -------------------------------------------------------------------------------- /access/logger.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package access provides an access logging handler for the ozzo routing package. 6 | package access 7 | 8 | import ( 9 | "fmt" 10 | "net/http" 11 | "strings" 12 | "time" 13 | 14 | routing "github.com/go-ozzo/ozzo-routing/v2" 15 | ) 16 | 17 | // LogFunc logs a message using the given format and optional arguments. 18 | // The usage of format and arguments is similar to that for fmt.Printf(). 19 | // LogFunc should be thread safe. 20 | type LogFunc func(format string, a ...interface{}) 21 | 22 | // LogWriterFunc takes in the request and responseWriter objects as well 23 | // as a float64 containing the elapsed time since the request first passed 24 | // through this middleware and does whatever log writing it wants with that 25 | // information. 26 | // LogWriterFunc should be thread safe. 27 | type LogWriterFunc func(req *http.Request, res *LogResponseWriter, elapsed float64) 28 | 29 | // CustomLogger returns a handler that calls the LogWriterFunc passed to it for every request. 30 | // The LogWriterFunc is provided with the http.Request and LogResponseWriter objects for the 31 | // request, as well as the elapsed time since the request first came through the middleware. 32 | // LogWriterFunc can then do whatever logging it needs to do. 33 | // 34 | // import ( 35 | // "log" 36 | // "github.com/go-ozzo/ozzo-routing/v2" 37 | // "github.com/go-ozzo/ozzo-routing/v2/access" 38 | // "net/http" 39 | // ) 40 | // 41 | // func myCustomLogger(req http.Context, res access.LogResponseWriter, elapsed int64) { 42 | // // Do something with the request, response, and elapsed time data here 43 | // } 44 | // r := routing.New() 45 | // r.Use(access.CustomLogger(myCustomLogger)) 46 | func CustomLogger(loggerFunc LogWriterFunc) routing.Handler { 47 | return func(c *routing.Context) error { 48 | startTime := time.Now() 49 | 50 | req := c.Request 51 | rw := &LogResponseWriter{c.Response, http.StatusOK, 0} 52 | c.Response = rw 53 | 54 | err := c.Next() 55 | 56 | elapsed := float64(time.Now().Sub(startTime).Nanoseconds()) / 1e6 57 | loggerFunc(req, rw, elapsed) 58 | 59 | return err 60 | } 61 | 62 | } 63 | 64 | // Logger returns a handler that logs a message for every request. 65 | // The access log messages contain information including client IPs, time used to serve each request, request line, 66 | // response status and size. 67 | // 68 | // import ( 69 | // "log" 70 | // "github.com/go-ozzo/ozzo-routing/v2" 71 | // "github.com/go-ozzo/ozzo-routing/v2/access" 72 | // ) 73 | // 74 | // r := routing.New() 75 | // r.Use(access.Logger(log.Printf)) 76 | func Logger(log LogFunc) routing.Handler { 77 | var logger = func(req *http.Request, rw *LogResponseWriter, elapsed float64) { 78 | clientIP := GetClientIP(req) 79 | requestLine := fmt.Sprintf("%s %s %s", req.Method, req.URL.String(), req.Proto) 80 | log(`[%s] [%.3fms] %s %d %d`, clientIP, elapsed, requestLine, rw.Status, rw.BytesWritten) 81 | 82 | } 83 | return CustomLogger(logger) 84 | } 85 | 86 | // LogResponseWriter wraps http.ResponseWriter in order to capture HTTP status and response length information. 87 | type LogResponseWriter struct { 88 | http.ResponseWriter 89 | Status int 90 | BytesWritten int64 91 | } 92 | 93 | func (r *LogResponseWriter) Write(p []byte) (int, error) { 94 | written, err := r.ResponseWriter.Write(p) 95 | r.BytesWritten += int64(written) 96 | return written, err 97 | } 98 | 99 | // WriteHeader records the response status and then writes HTTP headers. 100 | func (r *LogResponseWriter) WriteHeader(status int) { 101 | r.Status = status 102 | r.ResponseWriter.WriteHeader(status) 103 | } 104 | 105 | // GetClientIP returns the client IP address from the given HTTP request. 106 | func GetClientIP(req *http.Request) string { 107 | ip := req.Header.Get("X-Real-IP") 108 | if ip == "" { 109 | ip = req.Header.Get("X-Forwarded-For") 110 | if ip == "" { 111 | ip = req.RemoteAddr 112 | } 113 | } 114 | if colon := strings.LastIndex(ip, ":"); colon != -1 { 115 | ip = ip[:colon] 116 | } 117 | return ip 118 | } 119 | -------------------------------------------------------------------------------- /access/logger_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package access 6 | 7 | import ( 8 | "bytes" 9 | "errors" 10 | "fmt" 11 | "net/http" 12 | "net/http/httptest" 13 | "testing" 14 | 15 | routing "github.com/go-ozzo/ozzo-routing/v2" 16 | "github.com/stretchr/testify/assert" 17 | ) 18 | 19 | func TestCustomLogger(t *testing.T) { 20 | var buf bytes.Buffer 21 | var customFunc = func(req *http.Request, rw *LogResponseWriter, elapsed float64) { 22 | var logWriter = getLogger(&buf) 23 | clientIP := GetClientIP(req) 24 | requestLine := fmt.Sprintf("%s %s %s", req.Method, req.URL.String(), req.Proto) 25 | logWriter(`[%s] [%.3fms] %s %d %d`, clientIP, elapsed, requestLine, rw.Status, rw.BytesWritten) 26 | } 27 | h := CustomLogger(customFunc) 28 | 29 | res := httptest.NewRecorder() 30 | req, _ := http.NewRequest("GET", "http://127.0.0.1/users", nil) 31 | c := routing.NewContext(res, req, h, handler1) 32 | assert.NotNil(t, c.Next()) 33 | assert.Contains(t, buf.String(), "GET http://127.0.0.1/users") 34 | } 35 | 36 | func TestLogger(t *testing.T) { 37 | var buf bytes.Buffer 38 | h := Logger(getLogger(&buf)) 39 | 40 | res := httptest.NewRecorder() 41 | req, _ := http.NewRequest("GET", "http://127.0.0.1/users", nil) 42 | c := routing.NewContext(res, req, h, handler1) 43 | assert.NotNil(t, c.Next()) 44 | assert.Contains(t, buf.String(), "GET http://127.0.0.1/users") 45 | } 46 | 47 | func TestLogResponseWriter(t *testing.T) { 48 | res := httptest.NewRecorder() 49 | w := &LogResponseWriter{res, 0, 0} 50 | w.WriteHeader(http.StatusBadRequest) 51 | assert.Equal(t, http.StatusBadRequest, res.Code) 52 | assert.Equal(t, http.StatusBadRequest, w.Status) 53 | n, _ := w.Write([]byte("test")) 54 | assert.Equal(t, 4, n) 55 | assert.Equal(t, int64(4), w.BytesWritten) 56 | assert.Equal(t, "test", res.Body.String()) 57 | } 58 | 59 | func TestGetClientIP(t *testing.T) { 60 | req, _ := http.NewRequest("GET", "/users/", nil) 61 | req.Header.Set("X-Real-IP", "192.168.100.1") 62 | req.Header.Set("X-Forwarded-For", "192.168.100.2") 63 | req.RemoteAddr = "192.168.100.3" 64 | 65 | assert.Equal(t, "192.168.100.1", GetClientIP(req)) 66 | req.Header.Del("X-Real-IP") 67 | assert.Equal(t, "192.168.100.2", GetClientIP(req)) 68 | req.Header.Del("X-Forwarded-For") 69 | assert.Equal(t, "192.168.100.3", GetClientIP(req)) 70 | 71 | req.RemoteAddr = "192.168.100.3:8080" 72 | assert.Equal(t, "192.168.100.3", GetClientIP(req)) 73 | } 74 | 75 | func getLogger(buf *bytes.Buffer) LogFunc { 76 | return func(format string, a ...interface{}) { 77 | fmt.Fprintf(buf, format, a...) 78 | } 79 | } 80 | 81 | func handler1(c *routing.Context) error { 82 | return errors.New("abc") 83 | } 84 | -------------------------------------------------------------------------------- /auth/handlers.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package auth provides a set of user authentication handlers for the ozzo routing package. 6 | package auth 7 | 8 | import ( 9 | "encoding/base64" 10 | "net/http" 11 | "strings" 12 | 13 | "github.com/golang-jwt/jwt" 14 | "github.com/go-ozzo/ozzo-routing/v2" 15 | ) 16 | 17 | // User is the key used to store and retrieve the user identity information in routing.Context 18 | const User = "User" 19 | 20 | // Identity represents an authenticated user. If a user is successfully authenticated by 21 | // an auth handler (Basic, Bearer, or Query), an Identity object will be made available for injection. 22 | type Identity interface{} 23 | 24 | // DefaultRealm is the default realm name for HTTP authentication. It is used by HTTP authentication based on 25 | // Basic and Bearer. 26 | var DefaultRealm = "API" 27 | 28 | // BasicAuthFunc is the function that does the actual user authentication according to the given username and password. 29 | type BasicAuthFunc func(c *routing.Context, username, password string) (Identity, error) 30 | 31 | // Basic returns a routing.Handler that performs HTTP basic authentication. 32 | // It can be used like the following: 33 | // 34 | // import ( 35 | // "errors" 36 | // "fmt" 37 | // "net/http" 38 | // "github.com/go-ozzo/ozzo-routing/v2" 39 | // "github.com/go-ozzo/ozzo-routing/v2/auth" 40 | // ) 41 | // func main() { 42 | // r := routing.New() 43 | // r.Use(auth.Basic(func(c *routing.Context, username, password string) (auth.Identity, error) { 44 | // if username == "demo" && password == "foo" { 45 | // return auth.Identity(username), nil 46 | // } 47 | // return nil, errors.New("invalid credential") 48 | // })) 49 | // r.Get("/demo", func(c *routing.Context) error { 50 | // fmt.Fprintf(res, "Hello, %v", c.Get(auth.User)) 51 | // return nil 52 | // }) 53 | // } 54 | // 55 | // By default, the auth realm is named as "API". You may customize it by specifying the realm parameter. 56 | // 57 | // When authentication fails, a "WWW-Authenticate" header will be sent, and an http.StatusUnauthorized 58 | // error will be returned. 59 | func Basic(fn BasicAuthFunc, realm ...string) routing.Handler { 60 | name := DefaultRealm 61 | if len(realm) > 0 { 62 | name = realm[0] 63 | } 64 | return func(c *routing.Context) error { 65 | username, password := parseBasicAuth(c.Request.Header.Get("Authorization")) 66 | identity, e := fn(c, username, password) 67 | if e == nil { 68 | c.Set(User, identity) 69 | return nil 70 | } 71 | c.Response.Header().Set("WWW-Authenticate", `Basic realm="`+name+`"`) 72 | return routing.NewHTTPError(http.StatusUnauthorized, e.Error()) 73 | } 74 | } 75 | 76 | func parseBasicAuth(auth string) (username, password string) { 77 | if strings.HasPrefix(auth, "Basic ") { 78 | if bytes, err := base64.StdEncoding.DecodeString(auth[6:]); err == nil { 79 | str := string(bytes) 80 | if i := strings.IndexByte(str, ':'); i >= 0 { 81 | return str[:i], str[i+1:] 82 | } 83 | } 84 | } 85 | return 86 | } 87 | 88 | // TokenAuthFunc is the function for authenticating a user based on a secret token. 89 | type TokenAuthFunc func(c *routing.Context, token string) (Identity, error) 90 | 91 | // Bearer returns a routing.Handler that performs HTTP authentication based on bearer token. 92 | // It can be used like the following: 93 | // 94 | // import ( 95 | // "errors" 96 | // "fmt" 97 | // "net/http" 98 | // "github.com/go-ozzo/ozzo-routing/v2" 99 | // "github.com/go-ozzo/ozzo-routing/v2/auth" 100 | // ) 101 | // func main() { 102 | // r := routing.New() 103 | // r.Use(auth.Bearer(func(c *routing.Context, token string) (auth.Identity, error) { 104 | // if token == "secret" { 105 | // return auth.Identity("demo"), nil 106 | // } 107 | // return nil, errors.New("invalid credential") 108 | // })) 109 | // r.Get("/demo", func(c *routing.Context) error { 110 | // fmt.Fprintf(res, "Hello, %v", c.Get(auth.User)) 111 | // return nil 112 | // }) 113 | // } 114 | // 115 | // By default, the auth realm is named as "API". You may customize it by specifying the realm parameter. 116 | // 117 | // When authentication fails, a "WWW-Authenticate" header will be sent, and an http.StatusUnauthorized 118 | // error will be returned. 119 | func Bearer(fn TokenAuthFunc, realm ...string) routing.Handler { 120 | name := DefaultRealm 121 | if len(realm) > 0 { 122 | name = realm[0] 123 | } 124 | return func(c *routing.Context) error { 125 | token := parseBearerAuth(c.Request.Header.Get("Authorization")) 126 | identity, e := fn(c, token) 127 | if e == nil { 128 | c.Set(User, identity) 129 | return nil 130 | } 131 | c.Response.Header().Set("WWW-Authenticate", `Bearer realm="`+name+`"`) 132 | return routing.NewHTTPError(http.StatusUnauthorized, e.Error()) 133 | } 134 | } 135 | 136 | func parseBearerAuth(auth string) string { 137 | if strings.HasPrefix(auth, "Bearer ") { 138 | if bearer, err := base64.StdEncoding.DecodeString(auth[7:]); err == nil { 139 | return string(bearer) 140 | } 141 | } 142 | return "" 143 | } 144 | 145 | // TokenName is the query parameter name for auth token. 146 | var TokenName = "access-token" 147 | 148 | // Query returns a routing.Handler that performs authentication based on a token passed via a query parameter. 149 | // It can be used like the following: 150 | // 151 | // import ( 152 | // "errors" 153 | // "fmt" 154 | // "net/http" 155 | // "github.com/go-ozzo/ozzo-routing/v2" 156 | // "github.com/go-ozzo/ozzo-routing/v2/auth" 157 | // ) 158 | // func main() { 159 | // r := routing.New() 160 | // r.Use(auth.Query(func(token string) (auth.Identity, error) { 161 | // if token == "secret" { 162 | // return auth.Identity("demo"), nil 163 | // } 164 | // return nil, errors.New("invalid credential") 165 | // })) 166 | // r.Get("/demo", func(c *routing.Context) error { 167 | // fmt.Fprintf(res, "Hello, %v", c.Get(auth.User)) 168 | // return nil 169 | // }) 170 | // } 171 | // 172 | // When authentication fails, an http.StatusUnauthorized error will be returned. 173 | func Query(fn TokenAuthFunc, tokenName ...string) routing.Handler { 174 | name := TokenName 175 | if len(tokenName) > 0 { 176 | name = tokenName[0] 177 | } 178 | return func(c *routing.Context) error { 179 | token := c.Request.URL.Query().Get(name) 180 | identity, err := fn(c, token) 181 | if err != nil { 182 | return routing.NewHTTPError(http.StatusUnauthorized, err.Error()) 183 | } 184 | c.Set(User, identity) 185 | return nil 186 | } 187 | } 188 | 189 | // JWTTokenHandler represents a handler function that handles the parsed JWT token. 190 | type JWTTokenHandler func(*routing.Context, *jwt.Token) error 191 | 192 | // VerificationKeyHandler represents a handler function that gets a dynamic VerificationKey 193 | type VerificationKeyHandler func(*routing.Context) string 194 | 195 | // JWTOptions represents the options that can be used with the JWT handler. 196 | type JWTOptions struct { 197 | // auth realm. Defaults to "API". 198 | Realm string 199 | // the allowed signing method. This is required and should be the actual method that you use to create JWT token. It defaults to "HS256". 200 | SigningMethod string 201 | // a function that handles the parsed JWT token. Defaults to DefaultJWTTokenHandler, which stores the token in the routing context with the key "JWT". 202 | TokenHandler JWTTokenHandler 203 | // a function to get a dynamic VerificationKey 204 | GetVerificationKey VerificationKeyHandler 205 | } 206 | 207 | // DefaultJWTTokenHandler stores the parsed JWT token in the routing context with the key named "JWT". 208 | func DefaultJWTTokenHandler(c *routing.Context, token *jwt.Token) error { 209 | c.Set("JWT", token) 210 | return nil 211 | } 212 | 213 | // JWT returns a JWT (JSON Web Token) handler which attempts to parse the Bearer header into a JWT token and validate it. 214 | // If both are successful, it will call a JWTTokenHandler to further handle the token. By default, the token 215 | // will be stored in the routing context with the key named "JWT". Other handlers can retrieve this token to obtain 216 | // the user identity information. 217 | // If the parsing or validation fails, a "WWW-Authenticate" header will be sent, and an http.StatusUnauthorized 218 | // error will be returned. 219 | // 220 | // JWT can be used like the following: 221 | // 222 | // import ( 223 | // "errors" 224 | // "fmt" 225 | // "net/http" 226 | // "github.com/dgrijalva/jwt-go" 227 | // "github.com/go-ozzo/ozzo-routing/v2" 228 | // "github.com/go-ozzo/ozzo-routing/v2/auth" 229 | // ) 230 | // func main() { 231 | // signingKey := "secret-key" 232 | // r := routing.New() 233 | // 234 | // r.Get("/login", func(c *routing.Context) error { 235 | // id, err := authenticate(c) 236 | // if err != nil { 237 | // return err 238 | // } 239 | // token, err := auth.NewJWT(jwt.MapClaims{ 240 | // "id": id 241 | // }, signingKey) 242 | // if err != nil { 243 | // return err 244 | // } 245 | // return c.Write(token) 246 | // }) 247 | // 248 | // r.Use(auth.JWT(signingKey)) 249 | // r.Get("/restricted", func(c *routing.Context) error { 250 | // claims := c.Get("JWT").(*jwt.Token).Claims.(jwt.MapClaims) 251 | // return c.Write(fmt.Sprint("Welcome, %v!", claims["id"])) 252 | // }) 253 | // } 254 | func JWT(verificationKey string, options ...JWTOptions) routing.Handler { 255 | var opt JWTOptions 256 | if len(options) > 0 { 257 | opt = options[0] 258 | } 259 | if opt.Realm == "" { 260 | opt.Realm = DefaultRealm 261 | } 262 | if opt.SigningMethod == "" { 263 | opt.SigningMethod = "HS256" 264 | } 265 | if opt.TokenHandler == nil { 266 | opt.TokenHandler = DefaultJWTTokenHandler 267 | } 268 | parser := &jwt.Parser{ 269 | ValidMethods: []string{opt.SigningMethod}, 270 | } 271 | return func(c *routing.Context) error { 272 | header := c.Request.Header.Get("Authorization") 273 | message := "" 274 | if opt.GetVerificationKey != nil { 275 | verificationKey = opt.GetVerificationKey(c) 276 | } 277 | if strings.HasPrefix(header, "Bearer ") { 278 | token, err := parser.Parse(header[7:], func(t *jwt.Token) (interface{}, error) { return []byte(verificationKey), nil }) 279 | if err == nil && token.Valid { 280 | err = opt.TokenHandler(c, token) 281 | } 282 | if err == nil { 283 | return nil 284 | } 285 | message = err.Error() 286 | } 287 | 288 | c.Response.Header().Set("WWW-Authenticate", `Bearer realm="`+opt.Realm+`"`) 289 | if message != "" { 290 | return routing.NewHTTPError(http.StatusUnauthorized, message) 291 | } 292 | return routing.NewHTTPError(http.StatusUnauthorized) 293 | } 294 | } 295 | 296 | // NewJWT creates a new JWT token and returns it as a signed string that may be sent to the client side. 297 | // The signingMethod parameter is optional. It defaults to the HS256 algorithm. 298 | func NewJWT(claims jwt.MapClaims, signingKey string, signingMethod ...jwt.SigningMethod) (string, error) { 299 | var sm jwt.SigningMethod = jwt.SigningMethodHS256 300 | if len(signingMethod) > 0 { 301 | sm = signingMethod[0] 302 | } 303 | return jwt.NewWithClaims(sm, claims).SignedString([]byte(signingKey)) 304 | } 305 | -------------------------------------------------------------------------------- /auth/handlers_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package auth 6 | 7 | import ( 8 | "errors" 9 | "net/http" 10 | "net/http/httptest" 11 | "testing" 12 | 13 | "github.com/golang-jwt/jwt" 14 | "github.com/go-ozzo/ozzo-routing/v2" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | func TestParseBasicAuth(t *testing.T) { 19 | tests := []struct { 20 | id string 21 | header string 22 | user, pass string 23 | }{ 24 | {"t1", "", "", ""}, 25 | {"t2", "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", "Aladdin", "open sesame"}, 26 | {"t3", "Basic xyz", "", ""}, 27 | } 28 | for _, test := range tests { 29 | user, pass := parseBasicAuth(test.header) 30 | assert.Equal(t, test.user, user, test.id) 31 | assert.Equal(t, test.pass, pass, test.id) 32 | } 33 | } 34 | 35 | func basicAuth(c *routing.Context, username, password string) (Identity, error) { 36 | if username == "Aladdin" && password == "open sesame" { 37 | return "yes", nil 38 | } 39 | return nil, errors.New("no") 40 | } 41 | 42 | func TestBasic(t *testing.T) { 43 | h := Basic(basicAuth, "App") 44 | res := httptest.NewRecorder() 45 | req, _ := http.NewRequest("GET", "/users/", nil) 46 | c := routing.NewContext(res, req) 47 | err := h(c) 48 | if assert.NotNil(t, err) { 49 | assert.Equal(t, "no", err.Error()) 50 | } 51 | assert.Equal(t, `Basic realm="App"`, res.Header().Get("WWW-Authenticate")) 52 | assert.Nil(t, c.Get(User)) 53 | 54 | req, _ = http.NewRequest("GET", "/users/", nil) 55 | req.Header.Set("Authorization", "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") 56 | res = httptest.NewRecorder() 57 | c = routing.NewContext(res, req) 58 | err = h(c) 59 | assert.Nil(t, err) 60 | assert.Equal(t, "", res.Header().Get("WWW-Authenticate")) 61 | assert.Equal(t, "yes", c.Get(User)) 62 | } 63 | 64 | func TestParseBearerToken(t *testing.T) { 65 | tests := []struct { 66 | id string 67 | header string 68 | token string 69 | }{ 70 | {"t1", "", ""}, 71 | {"t2", "Bearer QWxhZGRpbjpvcGVuIHNlc2FtZQ==", "Aladdin:open sesame"}, 72 | {"t3", "Bearer xyz", ""}, 73 | } 74 | for _, test := range tests { 75 | token := parseBearerAuth(test.header) 76 | assert.Equal(t, test.token, token, test.id) 77 | } 78 | } 79 | 80 | func bearerAuth(c *routing.Context, token string) (Identity, error) { 81 | if token == "Aladdin:open sesame" { 82 | return "yes", nil 83 | } 84 | return nil, errors.New("no") 85 | } 86 | 87 | func TestBearer(t *testing.T) { 88 | h := Bearer(bearerAuth, "App") 89 | res := httptest.NewRecorder() 90 | req, _ := http.NewRequest("GET", "/users/", nil) 91 | c := routing.NewContext(res, req) 92 | err := h(c) 93 | if assert.NotNil(t, err) { 94 | assert.Equal(t, "no", err.Error()) 95 | } 96 | assert.Equal(t, `Bearer realm="App"`, res.Header().Get("WWW-Authenticate")) 97 | assert.Nil(t, c.Get(User)) 98 | 99 | req, _ = http.NewRequest("GET", "/users/", nil) 100 | req.Header.Set("Authorization", "Bearer QWxhZGRpbjpvcGVuIHNlc2FtZQ==") 101 | res = httptest.NewRecorder() 102 | c = routing.NewContext(res, req) 103 | err = h(c) 104 | assert.Nil(t, err) 105 | assert.Equal(t, "", res.Header().Get("WWW-Authenticate")) 106 | assert.Equal(t, "yes", c.Get(User)) 107 | 108 | req, _ = http.NewRequest("GET", "/users/", nil) 109 | req.Header.Set("Authorization", "Bearer QW") 110 | res = httptest.NewRecorder() 111 | c = routing.NewContext(res, req) 112 | err = h(c) 113 | if assert.NotNil(t, err) { 114 | assert.Equal(t, "no", err.Error()) 115 | } 116 | assert.Equal(t, `Bearer realm="App"`, res.Header().Get("WWW-Authenticate")) 117 | assert.Nil(t, c.Get(User)) 118 | } 119 | 120 | func TestQuery(t *testing.T) { 121 | h := Query(bearerAuth, "token") 122 | res := httptest.NewRecorder() 123 | req, _ := http.NewRequest("GET", "/users", nil) 124 | c := routing.NewContext(res, req) 125 | err := h(c) 126 | if assert.NotNil(t, err) { 127 | assert.Equal(t, "no", err.Error()) 128 | } 129 | assert.Nil(t, c.Get(User)) 130 | 131 | req, _ = http.NewRequest("GET", "/users?token=Aladdin:open sesame", nil) 132 | res = httptest.NewRecorder() 133 | c = routing.NewContext(res, req) 134 | err = h(c) 135 | assert.Nil(t, err) 136 | assert.Equal(t, "", res.Header().Get("WWW-Authenticate")) 137 | assert.Equal(t, "yes", c.Get(User)) 138 | } 139 | 140 | func TestJWT(t *testing.T) { 141 | secret := "secret-key" 142 | { 143 | // valid token 144 | tokenString, err := NewJWT(jwt.MapClaims{ 145 | "id": "100", 146 | }, secret) 147 | assert.Nil(t, err) 148 | 149 | h := JWT(secret) 150 | res := httptest.NewRecorder() 151 | req, _ := http.NewRequest("GET", "/users/", nil) 152 | req.Header.Set("Authorization", "Bearer "+tokenString) 153 | c := routing.NewContext(res, req) 154 | err = h(c) 155 | assert.Nil(t, err) 156 | token := c.Get("JWT") 157 | if assert.NotNil(t, token) { 158 | assert.Equal(t, "100", token.(*jwt.Token).Claims.(jwt.MapClaims)["id"]) 159 | } 160 | } 161 | 162 | { 163 | // invalid signing method 164 | token := jwt.New(jwt.SigningMethodHS256) 165 | claims := token.Claims.(jwt.MapClaims) 166 | claims["name"] = "Qiang" 167 | claims["admin"] = true 168 | bearer, _ := token.SignedString([]byte("secret")) 169 | 170 | h := JWT("secret", JWTOptions{ 171 | SigningMethod: "HS512", 172 | }) 173 | res := httptest.NewRecorder() 174 | req, _ := http.NewRequest("GET", "/users/", nil) 175 | req.Header.Set("Authorization", "Bearer "+bearer) 176 | c := routing.NewContext(res, req) 177 | err := h(c) 178 | assert.NotNil(t, err) 179 | } 180 | 181 | { 182 | // invalid token 183 | h := JWT("secret") 184 | res := httptest.NewRecorder() 185 | req, _ := http.NewRequest("GET", "/users/", nil) 186 | req.Header.Set("Authorization", "Bearer QWxhZGRpbjpvcGVuIHNlc2FtZQ==") 187 | c := routing.NewContext(res, req) 188 | err := h(c) 189 | assert.NotNil(t, err) 190 | assert.Equal(t, `Bearer realm="API"`, res.Header().Get("WWW-Authenticate")) 191 | assert.Nil(t, c.Get("JWT")) 192 | } 193 | 194 | { 195 | // invalid token with options 196 | h := JWT("secret", JWTOptions{ 197 | Realm: "App", 198 | }) 199 | res := httptest.NewRecorder() 200 | req, _ := http.NewRequest("GET", "/users/", nil) 201 | req.Header.Set("Authorization", "Bearer QWxhZGRpbjpvcGVuIHNlc2FtZQ==") 202 | c := routing.NewContext(res, req) 203 | err := h(c) 204 | assert.NotNil(t, err) 205 | assert.Equal(t, `Bearer realm="App"`, res.Header().Get("WWW-Authenticate")) 206 | assert.Nil(t, c.Get("JWT")) 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /content/language.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package content 6 | 7 | import ( 8 | "net/http" 9 | 10 | "github.com/go-ozzo/ozzo-routing/v2" 11 | "github.com/golang/gddo/httputil/header" 12 | ) 13 | 14 | // Language is the key used to store and retrieve the chosen language in routing.Context 15 | const Language = "Language" 16 | 17 | // LanguageNegotiator returns a content language negotiation handler. 18 | // 19 | // The method takes a list of languages (locale IDs) that are supported by the application. 20 | // The negotiator will determine the best language to use by checking the Accept-Language request header. 21 | // If no match is found, the first language will be used. 22 | // 23 | // In a handler, you can access the chosen language through routing.Context like the following: 24 | // 25 | // func(c *routing.Context) error { 26 | // language := c.Get(content.Language).(string) 27 | // } 28 | // 29 | // If you do not specify languages, the negotiator will set the language to be "en-US". 30 | func LanguageNegotiator(languages ...string) routing.Handler { 31 | if len(languages) == 0 { 32 | languages = []string{"en-US"} 33 | } 34 | defaultLanguage := languages[0] 35 | 36 | return func(c *routing.Context) error { 37 | language := negotiateLanguage(c.Request, languages, defaultLanguage) 38 | c.Set(Language, language) 39 | return nil 40 | } 41 | } 42 | 43 | // negotiateLanguage negotiates the acceptable language according to the Accept-Language HTTP header. 44 | func negotiateLanguage(r *http.Request, offers []string, defaultOffer string) string { 45 | bestOffer := defaultOffer 46 | bestQ := -1.0 47 | specs := header.ParseAccept(r.Header, "Accept-Language") 48 | for _, offer := range offers { 49 | for _, spec := range specs { 50 | if spec.Q > bestQ && (spec.Value == "*" || spec.Value == offer) { 51 | bestQ = spec.Q 52 | bestOffer = offer 53 | } 54 | } 55 | } 56 | if bestQ == 0 { 57 | bestOffer = defaultOffer 58 | } 59 | return bestOffer 60 | } 61 | -------------------------------------------------------------------------------- /content/language_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package content 6 | 7 | import ( 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | 12 | "github.com/go-ozzo/ozzo-routing/v2" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestLanguageNegotiator(t *testing.T) { 17 | req, _ := http.NewRequest("GET", "/users/", nil) 18 | req.Header.Set("Accept-Language", "ru-RU;q=0.6,ru;q=0.5,zh-CN;q=1.0,zh;q=0.9") 19 | 20 | // test no arguments 21 | res := httptest.NewRecorder() 22 | c := routing.NewContext(res, req) 23 | h := LanguageNegotiator() 24 | assert.Nil(t, h(c)) 25 | assert.Equal(t, "en-US", c.Get(Language)) 26 | 27 | h = LanguageNegotiator("ru-RU", "ru", "zh", "zh-CN") 28 | assert.Nil(t, h(c)) 29 | assert.Equal(t, "zh-CN", c.Get(Language)) 30 | 31 | h = LanguageNegotiator("en", "en-US") 32 | assert.Nil(t, h(c)) 33 | assert.Equal(t, "en", c.Get(Language)) 34 | 35 | req.Header.Set("Accept-Language", "ru-RU;q=0") 36 | h = LanguageNegotiator("en", "ru-RU") 37 | assert.Nil(t, h(c)) 38 | assert.Equal(t, "en", c.Get(Language)) 39 | } 40 | -------------------------------------------------------------------------------- /content/negotiator.go: -------------------------------------------------------------------------------- 1 | // Use of this source code is governed by a MIT-style 2 | // license that can be found in the LICENSE file. 3 | 4 | package content 5 | 6 | import ( 7 | "net/http" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | // AcceptRange represents an accept range as defined in https://tools.ietf.org/html/rfc7231#section-5.3.2 13 | // 14 | // Accept = #( media-range [ accept-params ] ) 15 | // media-range = ( "*/*" 16 | // / ( type "/" "*" ) 17 | // / ( type "/" subtype ) 18 | // ) *( OWS ";" OWS parameter ) 19 | // accept-params = weight *( accept-ext ) 20 | // accept-ext = OWS ";" OWS token [ "=" ( token / quoted-string ) ] 21 | type AcceptRange struct { 22 | Type string 23 | Subtype string 24 | Weight float64 25 | Parameters map[string]string 26 | raw string // the raw string for this accept 27 | } 28 | 29 | // RawString returns the raw string in the request specifying the accept range. 30 | func (a AcceptRange) RawString() string { 31 | return a.raw 32 | } 33 | 34 | // AcceptMediaTypes builds a list of AcceptRange from the given HTTP request. 35 | func AcceptMediaTypes(r *http.Request) []AcceptRange { 36 | result := []AcceptRange{} 37 | 38 | for _, v := range r.Header["Accept"] { 39 | result = append(result, ParseAcceptRanges(v)...) 40 | } 41 | 42 | return result 43 | } 44 | 45 | // ParseAcceptRanges parses an Accept header into a list of AcceptRange 46 | func ParseAcceptRanges(accepts string) []AcceptRange { 47 | result := []AcceptRange{} 48 | remaining := accepts 49 | for { 50 | var accept string 51 | accept, remaining = extractFieldAndSkipToken(remaining, ',') 52 | result = append(result, ParseAcceptRange(accept)) 53 | if len(remaining) == 0 { 54 | break 55 | } 56 | } 57 | return result 58 | } 59 | 60 | // ParseAcceptRange parses a single accept string into an AcceptRange. 61 | func ParseAcceptRange(accept string) AcceptRange { 62 | typeAndSub, rawparams := extractFieldAndSkipToken(accept, ';') 63 | 64 | tp, subtp := extractFieldAndSkipToken(typeAndSub, '/') 65 | params := extractParams(rawparams) 66 | 67 | w := extractWeight(params) 68 | return AcceptRange{Type: tp, Subtype: subtp, Parameters: params, Weight: w, raw: accept} 69 | } 70 | 71 | func extractWeight(params map[string]string) float64 { 72 | if w, ok := params["q"]; ok { 73 | res, err := strconv.ParseFloat(w, 64) 74 | if err == nil { 75 | return res 76 | } 77 | } 78 | return 1 // default is 1 79 | } 80 | 81 | func extractParams(raw string) map[string]string { 82 | params := map[string]string{} 83 | rest := raw 84 | for { 85 | var p string 86 | p, rest = extractFieldAndSkipToken(rest, ';') 87 | if len(p) > 0 { 88 | k, v := extractFieldAndSkipToken(p, '=') 89 | params[k] = v 90 | } 91 | if len(rest) == 0 { 92 | break 93 | } 94 | } 95 | 96 | return params 97 | } 98 | 99 | func extractFieldAndSkipToken(s string, sep rune) (string, string) { 100 | f, r := extractField(s, sep) 101 | if len(r) > 0 { 102 | r = r[1:] 103 | } 104 | return f, r 105 | } 106 | 107 | func extractField(s string, sep rune) (field, rest string) { 108 | field = s 109 | for i, v := range s { 110 | if v == sep { 111 | field = strings.TrimSpace(s[:i]) 112 | rest = strings.TrimSpace(s[i:]) 113 | break 114 | } 115 | } 116 | return 117 | } 118 | 119 | func compareParams(params1 map[string]string, params2 map[string]string) (count int) { 120 | for k1, v1 := range params1 { 121 | if v2, ok := params2[k1]; ok && v1 == v2 { 122 | count++ 123 | } 124 | } 125 | return count 126 | } 127 | 128 | // NegotiateContentType negotiates the content types based on the given request and allowed types. 129 | func NegotiateContentType(r *http.Request, offers []string, defaultOffer string) string { 130 | accepts := AcceptMediaTypes(r) 131 | offerRanges := []AcceptRange{} 132 | for _, off := range offers { 133 | offerRanges = append(offerRanges, ParseAcceptRange(off)) 134 | } 135 | 136 | return negotiateContentType(accepts, offerRanges, ParseAcceptRange(defaultOffer)) 137 | } 138 | 139 | func negotiateContentType(accepts []AcceptRange, offers []AcceptRange, defaultOffer AcceptRange) string { 140 | best := defaultOffer.RawString() 141 | bestWeight := defaultOffer.Weight 142 | bestParams := 0 143 | 144 | for _, offer := range offers { 145 | for _, accept := range accepts { 146 | // add a booster on the weights to prefer more exact matches to wildcards 147 | // such that: */* = 0, x/* = 1, x/x = 2 148 | booster := float64(0) 149 | if accept.Type != "*" { 150 | booster++ 151 | if accept.Subtype != "*" { 152 | booster++ 153 | } 154 | } 155 | 156 | if bestWeight > (accept.Weight + booster) { 157 | continue // we already have something better.. 158 | } else if accept.Type == "*" && accept.Subtype == "*" && ((accept.Weight + booster) > bestWeight) { 159 | best = offer.RawString() 160 | bestWeight = accept.Weight + booster 161 | } else if accept.Subtype == "*" && offer.Type == accept.Type && ((accept.Weight + booster) > bestWeight) { 162 | best = offer.RawString() 163 | bestWeight = accept.Weight + booster 164 | } else if accept.Type == offer.Type && accept.Subtype == offer.Subtype { 165 | paramCount := compareParams(accept.Parameters, offer.Parameters) 166 | if paramCount >= bestParams { // if it's equal this one must be better, since the weight was better.. 167 | best = offer.RawString() 168 | bestWeight = accept.Weight + booster 169 | bestParams = paramCount 170 | } 171 | } 172 | } 173 | } 174 | 175 | return best 176 | } 177 | -------------------------------------------------------------------------------- /content/negotiator_test.go: -------------------------------------------------------------------------------- 1 | // Use of this source code is governed by a MIT-style 2 | // license that can be found in the LICENSE file. 3 | 4 | package content 5 | 6 | import ( 7 | "net/http" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestContentNegotiation(t *testing.T) { 14 | header := http.Header{} 15 | header.Set("Accept", "application/json;q=1;v=1") 16 | req := &http.Request{Header: header} 17 | 18 | offers := []string{"application/json", "application/xml", "application/json;v=1", "application/json;v=2"} 19 | format := NegotiateContentType(req, offers, "text/html") 20 | assert.Equal(t, "application/json;v=1", format) 21 | } 22 | 23 | func TestContentNegotiation2(t *testing.T) { 24 | header := http.Header{} 25 | header.Set("Accept", "application/json;q=0.6;v=1,application/json;v=2") 26 | req := &http.Request{Header: header} 27 | 28 | offers := []string{"application/json", "application/xml", "application/json;v=1", "application/json;v=2"} 29 | format := NegotiateContentType(req, offers, "text/html") 30 | assert.Equal(t, "application/json;v=2", format) 31 | } 32 | 33 | func TestContentNegotiation3(t *testing.T) { 34 | header := http.Header{} 35 | header.Set("Accept", "*/*,application/xml") 36 | req := &http.Request{Header: header} 37 | 38 | offers := []string{"application/json", "application/xml", "application/json;v=1", "application/json;v=2"} 39 | format := NegotiateContentType(req, offers, "text/html") 40 | assert.Equal(t, "application/xml", format) 41 | } 42 | 43 | func TestContentNegotiation4(t *testing.T) { 44 | header := http.Header{} 45 | header.Set("Accept", "*/*") 46 | req := &http.Request{Header: header} 47 | 48 | offers := []string{"application/json", "application/xml"} 49 | format := NegotiateContentType(req, offers, "application/json") 50 | assert.Equal(t, "application/json", format) 51 | } 52 | 53 | func TestContentNegotiation5(t *testing.T) { 54 | header := http.Header{} 55 | header.Set("Accept", "*/*") 56 | req := &http.Request{Header: header} 57 | 58 | offers := []string{"application/json", "application/xml", "application/json;v=1", "application/json;v=2"} 59 | format := NegotiateContentType(req, offers, "text/html") 60 | assert.Equal(t, "text/html", format) 61 | } 62 | func TestAccept(t *testing.T) { 63 | header := http.Header{} 64 | header.Set("Accept", "application/json; q=1 ; v=1,") 65 | req := &http.Request{Header: header} 66 | mtypes := AcceptMediaTypes(req) 67 | 68 | assert.Equal(t, float64(1), mtypes[0].Weight) 69 | assert.Equal(t, "application", mtypes[0].Type) 70 | assert.Equal(t, "json", mtypes[0].Subtype) 71 | assert.Equal(t, map[string]string{"v": "1", "q": "1"}, mtypes[0].Parameters) 72 | } 73 | 74 | func TestAcceptMultiple(t *testing.T) { 75 | header := http.Header{} 76 | header.Set("Accept", "application/json;q=1;v=1, application/json;v=2, text/html") 77 | req := &http.Request{Header: header} 78 | 79 | mtypes := AcceptMediaTypes(req) 80 | 81 | assert.Equal(t, float64(1), mtypes[0].Weight) 82 | assert.Equal(t, "application", mtypes[0].Type) 83 | assert.Equal(t, "json", mtypes[0].Subtype) 84 | assert.Equal(t, map[string]string{"v": "1", "q": "1"}, mtypes[0].Parameters) 85 | 86 | assert.Equal(t, float64(1), mtypes[1].Weight) 87 | assert.Equal(t, "application", mtypes[1].Type) 88 | assert.Equal(t, "json", mtypes[1].Subtype) 89 | assert.Equal(t, map[string]string{"v": "2"}, mtypes[1].Parameters) 90 | 91 | assert.Equal(t, float64(1), mtypes[2].Weight) 92 | assert.Equal(t, "text", mtypes[2].Type) 93 | assert.Equal(t, "html", mtypes[2].Subtype) 94 | assert.Equal(t, map[string]string{}, mtypes[2].Parameters) 95 | } 96 | 97 | func TestAcceptElaborate(t *testing.T) { 98 | a := `text/plain; q=0.5, text/html, 99 | text/x-dvi; q=0.8, text/x-c` 100 | 101 | header := http.Header{} 102 | header.Set("Accept", a) 103 | req := &http.Request{Header: header} 104 | mtypes := AcceptMediaTypes(req) 105 | 106 | assert.Equal(t, float64(0.5), mtypes[0].Weight) 107 | assert.Equal(t, "text", mtypes[0].Type) 108 | assert.Equal(t, "plain", mtypes[0].Subtype) 109 | 110 | assert.Equal(t, float64(1), mtypes[1].Weight) 111 | assert.Equal(t, "text", mtypes[1].Type) 112 | assert.Equal(t, "html", mtypes[1].Subtype) 113 | 114 | assert.Equal(t, float64(0.8), mtypes[2].Weight) 115 | assert.Equal(t, "text", mtypes[2].Type) 116 | assert.Equal(t, "x-dvi", mtypes[2].Subtype) 117 | 118 | assert.Equal(t, float64(1), mtypes[3].Weight) 119 | assert.Equal(t, "text", mtypes[3].Type) 120 | assert.Equal(t, "x-c", mtypes[3].Subtype) 121 | } 122 | -------------------------------------------------------------------------------- /content/type.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package content provides content negotiation handlers for the ozzo routing package. 6 | package content 7 | 8 | import ( 9 | "encoding/json" 10 | "encoding/xml" 11 | "net/http" 12 | 13 | routing "github.com/go-ozzo/ozzo-routing/v2" 14 | ) 15 | 16 | // MIME types 17 | const ( 18 | JSON = routing.MIME_JSON 19 | XML = routing.MIME_XML 20 | XML2 = routing.MIME_XML2 21 | HTML = routing.MIME_HTML 22 | ) 23 | 24 | // DataWriters lists all supported content types and the corresponding data writers. 25 | // By default, JSON, XML, and HTML are supported. You may modify this variable before calling TypeNegotiator 26 | // to customize supported data writers. 27 | var DataWriters = map[string]routing.DataWriter{ 28 | JSON: &JSONDataWriter{}, 29 | XML: &XMLDataWriter{}, 30 | XML2: &XMLDataWriter{}, 31 | HTML: &HTMLDataWriter{}, 32 | } 33 | 34 | // TypeNegotiator returns a content type negotiation handler. 35 | // 36 | // The method takes a list of response MIME types that are supported by the application. 37 | // The negotiator will determine the best response MIME type to use by checking the "Accept" HTTP header. 38 | // If no match is found, the first MIME type will be used. 39 | // 40 | // The negotiator will set the "Content-Type" response header as the chosen MIME type. It will call routing.Context.SetDataWriter() 41 | // to set the appropriate data writer that can write data in the negotiated format. 42 | // 43 | // If you do not specify any supported MIME types, the negotiator will use "text/html" as the response MIME type. 44 | func TypeNegotiator(formats ...string) routing.Handler { 45 | if len(formats) == 0 { 46 | formats = []string{HTML} 47 | } 48 | for _, format := range formats { 49 | if _, ok := DataWriters[format]; !ok { 50 | panic(format + " is not supported") 51 | } 52 | } 53 | 54 | return func(c *routing.Context) error { 55 | format := NegotiateContentType(c.Request, formats, formats[0]) 56 | c.SetDataWriter(DataWriters[format]) 57 | return nil 58 | } 59 | } 60 | 61 | // JSONDataWriter sets the "Content-Type" response header as "application/json" and writes the given data in JSON format to the response. 62 | type JSONDataWriter struct{} 63 | 64 | // SetHeader sets the Content-Type response header. 65 | func (w *JSONDataWriter) SetHeader(res http.ResponseWriter) { 66 | res.Header().Set("Content-Type", "application/json") 67 | } 68 | 69 | func (w *JSONDataWriter) Write(res http.ResponseWriter, data interface{}) (err error) { 70 | enc := json.NewEncoder(res) 71 | enc.SetEscapeHTML(false) 72 | return enc.Encode(data) 73 | } 74 | 75 | // XMLDataWriter sets the "Content-Type" response header as "application/xml; charset=UTF-8" and writes the given data in XML format to the response. 76 | type XMLDataWriter struct{} 77 | 78 | // SetHeader sets the Content-Type response header. 79 | func (w *XMLDataWriter) SetHeader(res http.ResponseWriter) { 80 | res.Header().Set("Content-Type", "application/xml; charset=UTF-8") 81 | } 82 | 83 | func (w *XMLDataWriter) Write(res http.ResponseWriter, data interface{}) (err error) { 84 | var bytes []byte 85 | if bytes, err = xml.Marshal(data); err != nil { 86 | return 87 | } 88 | _, err = res.Write(bytes) 89 | return 90 | } 91 | 92 | // HTMLDataWriter sets the "Content-Type" response header as "text/html; charset=UTF-8" and calls routing.DefaultDataWriter to write the given data to the response. 93 | type HTMLDataWriter struct{} 94 | 95 | // SetHeader sets the Content-Type response header. 96 | func (w *HTMLDataWriter) SetHeader(res http.ResponseWriter) { 97 | res.Header().Set("Content-Type", "text/html; charset=UTF-8") 98 | } 99 | 100 | func (w *HTMLDataWriter) Write(res http.ResponseWriter, data interface{}) error { 101 | return routing.DefaultDataWriter.Write(res, data) 102 | } 103 | -------------------------------------------------------------------------------- /content/type_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package content 6 | 7 | import ( 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | 12 | "github.com/go-ozzo/ozzo-routing/v2" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestJSONFormatter(t *testing.T) { 17 | res := httptest.NewRecorder() 18 | w := &JSONDataWriter{} 19 | w.SetHeader(res) 20 | err := w.Write(res, "xyz") 21 | assert.Nil(t, err) 22 | assert.Equal(t, "application/json", res.Header().Get("Content-Type")) 23 | assert.Equal(t, "\"xyz\"\n", res.Body.String()) 24 | } 25 | 26 | func TestXMLFormatter(t *testing.T) { 27 | res := httptest.NewRecorder() 28 | w := &XMLDataWriter{} 29 | w.SetHeader(res) 30 | err := w.Write(res, "xyz") 31 | assert.Nil(t, err) 32 | assert.Equal(t, "application/xml; charset=UTF-8", res.Header().Get("Content-Type")) 33 | assert.Equal(t, "xyz", res.Body.String()) 34 | } 35 | 36 | func TestHTMLFormatter(t *testing.T) { 37 | res := httptest.NewRecorder() 38 | w := &HTMLDataWriter{} 39 | w.SetHeader(res) 40 | err := w.Write(res, "xyz") 41 | assert.Nil(t, err) 42 | assert.Equal(t, "text/html; charset=UTF-8", res.Header().Get("Content-Type")) 43 | assert.Equal(t, "xyz", res.Body.String()) 44 | } 45 | 46 | func TestTypeNegotiator(t *testing.T) { 47 | req, _ := http.NewRequest("GET", "/users/", nil) 48 | req.Header.Set("Accept", "application/xml") 49 | 50 | // test no arguments 51 | res := httptest.NewRecorder() 52 | c := routing.NewContext(res, req) 53 | h := TypeNegotiator() 54 | assert.Nil(t, h(c)) 55 | c.Write("xyz") 56 | assert.Equal(t, "text/html; charset=UTF-8", res.Header().Get("Content-Type")) 57 | assert.Equal(t, "xyz", res.Body.String()) 58 | 59 | // test format chosen based on Accept 60 | res = httptest.NewRecorder() 61 | c = routing.NewContext(res, req) 62 | h = TypeNegotiator(JSON, XML) 63 | assert.Nil(t, h(c)) 64 | assert.Nil(t, c.Write("xyz")) 65 | assert.Equal(t, "application/xml; charset=UTF-8", res.Header().Get("Content-Type")) 66 | assert.Equal(t, "xyz", res.Body.String()) 67 | 68 | // test default format used when no match 69 | req.Header.Set("Accept", "application/pdf") 70 | res = httptest.NewRecorder() 71 | c = routing.NewContext(res, req) 72 | assert.Nil(t, h(c)) 73 | assert.Nil(t, c.Write("xyz")) 74 | assert.Equal(t, "application/json", res.Header().Get("Content-Type")) 75 | assert.Equal(t, "\"xyz\"\n", res.Body.String()) 76 | 77 | assert.Panics(t, func() { 78 | TypeNegotiator("unknown") 79 | }) 80 | } 81 | 82 | var ( 83 | v1JSON = "application/json;v=1" 84 | v2JSON = "application/json;v=2" 85 | ) 86 | 87 | type JSONDataWriter1 struct { 88 | JSONDataWriter 89 | } 90 | 91 | func (w *JSONDataWriter1) SetHeader(res http.ResponseWriter) { 92 | res.Header().Set("Content-Type", v1JSON) 93 | } 94 | 95 | type JSONDataWriter2 struct { 96 | JSONDataWriter 97 | } 98 | 99 | func (w *JSONDataWriter2) SetHeader(res http.ResponseWriter) { 100 | res.Header().Set("Content-Type", v2JSON) 101 | } 102 | 103 | func TestTypeNegotiatorWithVersion(t *testing.T) { 104 | 105 | req, _ := http.NewRequest("GET", "/users/", nil) 106 | req.Header.Set("Accept", "application/xml,"+v1JSON) 107 | 108 | // test no arguments 109 | res := httptest.NewRecorder() 110 | c := routing.NewContext(res, req) 111 | h := TypeNegotiator() 112 | assert.Nil(t, h(c)) 113 | c.Write("xyz") 114 | assert.Equal(t, "text/html; charset=UTF-8", res.Header().Get("Content-Type")) 115 | assert.Equal(t, "xyz", res.Body.String()) 116 | 117 | DataWriters[v1JSON] = &JSONDataWriter1{} 118 | DataWriters[v2JSON] = &JSONDataWriter2{} 119 | 120 | // test format chosen based on Accept 121 | res = httptest.NewRecorder() 122 | c = routing.NewContext(res, req) 123 | h = TypeNegotiator(v2JSON, v1JSON, XML) 124 | assert.Nil(t, h(c)) 125 | assert.Nil(t, c.Write("xyz")) 126 | assert.Equal(t, "application/json;v=1", res.Header().Get("Content-Type")) 127 | assert.Equal(t, `"xyz"`+"\n", res.Body.String()) 128 | 129 | // test default format used when no match 130 | req.Header.Set("Accept", "application/pdf") 131 | res = httptest.NewRecorder() 132 | c = routing.NewContext(res, req) 133 | assert.Nil(t, h(c)) 134 | assert.Nil(t, c.Write("xyz")) 135 | assert.Equal(t, v2JSON, res.Header().Get("Content-Type")) 136 | assert.Equal(t, "\"xyz\"\n", res.Body.String()) 137 | 138 | assert.Panics(t, func() { 139 | TypeNegotiator("unknown") 140 | }) 141 | } 142 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package routing 6 | 7 | import ( 8 | "net/http" 9 | ) 10 | 11 | // Context represents the contextual data and environment while processing an incoming HTTP request. 12 | type Context struct { 13 | Request *http.Request // the current request 14 | Response http.ResponseWriter // the response writer 15 | router *Router 16 | pnames []string // list of route parameter names 17 | pvalues []string // list of parameter values corresponding to pnames 18 | data map[string]interface{} // data items managed by Get and Set 19 | index int // the index of the currently executing handler in handlers 20 | handlers []Handler // the handlers associated with the current route 21 | writer DataWriter 22 | } 23 | 24 | // NewContext creates a new Context object with the given response, request, and the handlers. 25 | // This method is primarily provided for writing unit tests for handlers. 26 | func NewContext(res http.ResponseWriter, req *http.Request, handlers ...Handler) *Context { 27 | c := &Context{handlers: handlers} 28 | c.init(res, req) 29 | return c 30 | } 31 | 32 | // Router returns the Router that is handling the incoming HTTP request. 33 | func (c *Context) Router() *Router { 34 | return c.router 35 | } 36 | 37 | // Param returns the named parameter value that is found in the URL path matching the current route. 38 | // If the named parameter cannot be found, an empty string will be returned. 39 | func (c *Context) Param(name string) string { 40 | for i, n := range c.pnames { 41 | if n == name { 42 | return c.pvalues[i] 43 | } 44 | } 45 | return "" 46 | } 47 | 48 | // SetParam sets the named parameter value. 49 | // This method is primarily provided for writing unit tests. 50 | func (c *Context) SetParam(name, value string) { 51 | for i, n := range c.pnames { 52 | if n == name { 53 | c.pvalues[i] = value 54 | return 55 | } 56 | } 57 | c.pnames = append(c.pnames, name) 58 | c.pvalues = append(c.pvalues, value) 59 | } 60 | 61 | // Get returns the named data item previously registered with the context by calling Set. 62 | // If the named data item cannot be found, nil will be returned. 63 | func (c *Context) Get(name string) interface{} { 64 | return c.data[name] 65 | } 66 | 67 | // Set stores the named data item in the context so that it can be retrieved later. 68 | func (c *Context) Set(name string, value interface{}) { 69 | if c.data == nil { 70 | c.data = make(map[string]interface{}) 71 | } 72 | c.data[name] = value 73 | } 74 | 75 | // Query returns the first value for the named component of the URL query parameters. 76 | // If key is not present, it returns the specified default value or an empty string. 77 | func (c *Context) Query(name string, defaultValue ...string) string { 78 | if vs, _ := c.Request.URL.Query()[name]; len(vs) > 0 { 79 | return vs[0] 80 | } 81 | if len(defaultValue) > 0 { 82 | return defaultValue[0] 83 | } 84 | return "" 85 | } 86 | 87 | // Form returns the first value for the named component of the query. 88 | // Form reads the value from POST and PUT body parameters as well as URL query parameters. 89 | // The form takes precedence over the latter. 90 | // If key is not present, it returns the specified default value or an empty string. 91 | func (c *Context) Form(key string, defaultValue ...string) string { 92 | r := c.Request 93 | r.ParseMultipartForm(32 << 20) 94 | if vs := r.Form[key]; len(vs) > 0 { 95 | return vs[0] 96 | } 97 | 98 | if len(defaultValue) > 0 { 99 | return defaultValue[0] 100 | } 101 | return "" 102 | } 103 | 104 | // PostForm returns the first value for the named component from POST and PUT body parameters. 105 | // If key is not present, it returns the specified default value or an empty string. 106 | func (c *Context) PostForm(key string, defaultValue ...string) string { 107 | r := c.Request 108 | r.ParseMultipartForm(32 << 20) 109 | if vs := r.PostForm[key]; len(vs) > 0 { 110 | return vs[0] 111 | } 112 | 113 | if len(defaultValue) > 0 { 114 | return defaultValue[0] 115 | } 116 | return "" 117 | } 118 | 119 | // Next calls the rest of the handlers associated with the current route. 120 | // If any of these handlers returns an error, Next will return the error and skip the following handlers. 121 | // Next is normally used when a handler needs to do some postprocessing after the rest of the handlers 122 | // are executed. 123 | func (c *Context) Next() error { 124 | c.index++ 125 | for n := len(c.handlers); c.index < n; c.index++ { 126 | if err := c.handlers[c.index](c); err != nil { 127 | return err 128 | } 129 | } 130 | return nil 131 | } 132 | 133 | // Abort skips the rest of the handlers associated with the current route. 134 | // Abort is normally used when a handler handles the request normally and wants to skip the rest of the handlers. 135 | // If a handler wants to indicate an error condition, it should simply return the error without calling Abort. 136 | func (c *Context) Abort() { 137 | c.index = len(c.handlers) 138 | } 139 | 140 | // URL creates a URL using the named route and the parameter values. 141 | // The parameters should be given in the sequence of name1, value1, name2, value2, and so on. 142 | // If a parameter in the route is not provided a value, the parameter token will remain in the resulting URL. 143 | // Parameter values will be properly URL encoded. 144 | // The method returns an empty string if the URL creation fails. 145 | func (c *Context) URL(route string, pairs ...interface{}) string { 146 | if r := c.router.namedRoutes[route]; r != nil { 147 | return r.URL(pairs...) 148 | } 149 | return "" 150 | } 151 | 152 | // Read populates the given struct variable with the data from the current request. 153 | // If the request is NOT a GET request, it will check the "Content-Type" header 154 | // and find a matching reader from DataReaders to read the request data. 155 | // If there is no match or if the request is a GET request, it will use DefaultFormDataReader 156 | // to read the request data. 157 | func (c *Context) Read(data interface{}) error { 158 | if c.Request.Method != "GET" { 159 | t := getContentType(c.Request) 160 | if reader, ok := DataReaders[t]; ok { 161 | return reader.Read(c.Request, data) 162 | } 163 | } 164 | 165 | return DefaultFormDataReader.Read(c.Request, data) 166 | } 167 | 168 | // Write writes the given data of arbitrary type to the response. 169 | // The method calls the data writer set via SetDataWriter() to do the actual writing. 170 | // By default, the DefaultDataWriter will be used. 171 | func (c *Context) Write(data interface{}) error { 172 | return c.writer.Write(c.Response, data) 173 | } 174 | 175 | // WriteWithStatus sends the HTTP status code and writes the given data of arbitrary type to the response. 176 | // See Write() for details on how data is written to response. 177 | func (c *Context) WriteWithStatus(data interface{}, statusCode int) error { 178 | c.Response.WriteHeader(statusCode) 179 | return c.Write(data) 180 | } 181 | 182 | // SetDataWriter sets the data writer that will be used by Write(). 183 | func (c *Context) SetDataWriter(writer DataWriter) { 184 | c.writer = writer 185 | writer.SetHeader(c.Response) 186 | } 187 | 188 | // init sets the request and response of the context and resets all other properties. 189 | func (c *Context) init(response http.ResponseWriter, request *http.Request) { 190 | c.Response = response 191 | c.Request = request 192 | c.data = nil 193 | c.index = -1 194 | c.writer = DefaultDataWriter 195 | } 196 | 197 | func getContentType(req *http.Request) string { 198 | t := req.Header.Get("Content-Type") 199 | for i, c := range t { 200 | if c == ' ' || c == ';' { 201 | return t[:i] 202 | } 203 | } 204 | return t 205 | } 206 | -------------------------------------------------------------------------------- /context_test.go: -------------------------------------------------------------------------------- 1 | package routing 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestContextParam(t *testing.T) { 15 | c := NewContext(nil, nil) 16 | values := []string{"a", "b", "c", "d"} 17 | 18 | c.pvalues = values 19 | c.pnames = nil 20 | assert.Equal(t, "", c.Param("")) 21 | assert.Equal(t, "", c.Param("Name")) 22 | 23 | c.pnames = []string{"Name", "Age"} 24 | assert.Equal(t, "", c.Param("")) 25 | assert.Equal(t, "a", c.Param("Name")) 26 | assert.Equal(t, "b", c.Param("Age")) 27 | assert.Equal(t, "", c.Param("Xyz")) 28 | } 29 | 30 | func TestContextSetParam(t *testing.T) { 31 | c := NewContext(nil, nil) 32 | c.pnames = []string{"Name", "Age"} 33 | c.pvalues = []string{"abc", "123"} 34 | assert.Equal(t, "abc", c.Param("Name")) 35 | c.SetParam("Name", "xyz") 36 | assert.Equal(t, "xyz", c.Param("Name")) 37 | assert.Equal(t, "", c.Param("unknown")) 38 | c.SetParam("unknown", "xyz") 39 | assert.Equal(t, "xyz", c.Param("unknown")) 40 | } 41 | 42 | func TestContextInit(t *testing.T) { 43 | c := NewContext(nil, nil) 44 | assert.Nil(t, c.Response) 45 | assert.Nil(t, c.Request) 46 | assert.Equal(t, 0, len(c.handlers)) 47 | req, _ := http.NewRequest("GET", "/users/", nil) 48 | c.init(httptest.NewRecorder(), req) 49 | assert.NotNil(t, c.Response) 50 | assert.NotNil(t, c.Request) 51 | assert.Equal(t, -1, c.index) 52 | assert.Nil(t, c.data) 53 | } 54 | 55 | func TestContextURL(t *testing.T) { 56 | router := New() 57 | router.Get("/users///*").Name("users") 58 | c := &Context{router: router} 59 | assert.Equal(t, "/users/123/address/", c.URL("users", "id", 123, "action", "address")) 60 | assert.Equal(t, "", c.URL("abc", "id", 123, "action", "address")) 61 | } 62 | 63 | func TestContextGetSet(t *testing.T) { 64 | c := NewContext(nil, nil) 65 | c.init(nil, nil) 66 | assert.Nil(t, c.Get("abc")) 67 | c.Set("abc", "123") 68 | c.Set("xyz", 123) 69 | assert.Equal(t, "123", c.Get("abc").(string)) 70 | assert.Equal(t, 123, c.Get("xyz").(int)) 71 | } 72 | 73 | func TestContextQueryForm(t *testing.T) { 74 | req, _ := http.NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not", 75 | strings.NewReader("z=post&both=y&prio=2&empty=")) 76 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") 77 | c := NewContext(nil, req) 78 | assert.Equal(t, "foo", c.Query("q")) 79 | assert.Equal(t, "", c.Query("z")) 80 | assert.Equal(t, "123", c.Query("z", "123")) 81 | assert.Equal(t, "not", c.Query("empty", "123")) 82 | assert.Equal(t, "post", c.PostForm("z")) 83 | assert.Equal(t, "", c.PostForm("x")) 84 | assert.Equal(t, "123", c.PostForm("q", "123")) 85 | assert.Equal(t, "", c.PostForm("empty", "123")) 86 | assert.Equal(t, "y", c.Form("both")) 87 | assert.Equal(t, "", c.Form("x")) 88 | assert.Equal(t, "123", c.Form("x", "123")) 89 | } 90 | 91 | func TestContextNextAbort(t *testing.T) { 92 | c, res := testNewContext( 93 | testNormalHandler("a"), 94 | testNormalHandler("b"), 95 | testNormalHandler("c"), 96 | ) 97 | assert.Nil(t, c.Next()) 98 | assert.Equal(t, "", res.Body.String()) 99 | 100 | c, res = testNewContext( 101 | testNextHandler("a"), 102 | testNextHandler("b"), 103 | testNextHandler("c"), 104 | ) 105 | assert.Nil(t, c.Next()) 106 | assert.Equal(t, "", res.Body.String()) 107 | 108 | c, res = testNewContext( 109 | testNextHandler("a"), 110 | testAbortHandler("b"), 111 | testNormalHandler("c"), 112 | ) 113 | assert.Nil(t, c.Next()) 114 | assert.Equal(t, "", res.Body.String()) 115 | 116 | c, res = testNewContext( 117 | testNextHandler("a"), 118 | testErrorHandler("b"), 119 | testNormalHandler("c"), 120 | ) 121 | err := c.Next() 122 | if assert.NotNil(t, err) { 123 | assert.Equal(t, "error:b", err.Error()) 124 | } 125 | assert.Equal(t, "", res.Body.String()) 126 | } 127 | 128 | func testNewContext(handlers ...Handler) (*Context, *httptest.ResponseRecorder) { 129 | res := httptest.NewRecorder() 130 | req, _ := http.NewRequest("GET", "http://127.0.0.1/users", nil) 131 | c := &Context{} 132 | c.init(res, req) 133 | c.handlers = handlers 134 | return c, res 135 | } 136 | 137 | func testNextHandler(tag string) Handler { 138 | return func(c *Context) error { 139 | fmt.Fprintf(c.Response, "<%v>", tag) 140 | err := c.Next() 141 | fmt.Fprintf(c.Response, "", tag) 142 | return err 143 | } 144 | } 145 | 146 | func testAbortHandler(tag string) Handler { 147 | return func(c *Context) error { 148 | fmt.Fprintf(c.Response, "<%v/>", tag) 149 | c.Abort() 150 | return nil 151 | } 152 | } 153 | 154 | func testErrorHandler(tag string) Handler { 155 | return func(c *Context) error { 156 | fmt.Fprintf(c.Response, "<%v/>", tag) 157 | return errors.New("error:" + tag) 158 | } 159 | } 160 | 161 | func testNormalHandler(tag string) Handler { 162 | return func(c *Context) error { 163 | fmt.Fprintf(c.Response, "<%v/>", tag) 164 | return nil 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /cors/handler.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package cors provides a handler for handling CORS. 6 | package cors 7 | 8 | import ( 9 | "net/http" 10 | "strconv" 11 | "strings" 12 | "time" 13 | 14 | "github.com/go-ozzo/ozzo-routing/v2" 15 | ) 16 | 17 | const ( 18 | headerOrigin = "Origin" 19 | 20 | headerRequestMethod = "Access-Control-Request-Method" 21 | headerRequestHeaders = "Access-Control-Request-Headers" 22 | 23 | headerAllowOrigin = "Access-Control-Allow-Origin" 24 | headerAllowCredentials = "Access-Control-Allow-Credentials" 25 | headerAllowHeaders = "Access-Control-Allow-Headers" 26 | headerAllowMethods = "Access-Control-Allow-Methods" 27 | headerExposeHeaders = "Access-Control-Expose-Headers" 28 | headerMaxAge = "Access-Control-Max-Age" 29 | ) 30 | 31 | // Options specifies how the CORS handler should respond with appropriate CORS headers. 32 | type Options struct { 33 | // the allowed origins (separated by commas). Use an asterisk (*) to indicate allowing all origins, "null" to indicate disallowing any. 34 | AllowOrigins string 35 | // whether the response to request can be exposed when the omit credentials flag is unset, or whether the actual request can include user credentials. 36 | AllowCredentials bool 37 | // the HTTP methods (separated by commas) that can be used during the actual request. Use an asterisk (*) to indicate allowing any method. 38 | AllowMethods string 39 | // the HTTP headers (separated by commas) that can be used during the actual request. Use an asterisk (*) to indicate allowing any header. 40 | AllowHeaders string 41 | // the HTTP headers (separated by commas) that are safe to expose to the API of a CORS API specification 42 | ExposeHeaders string 43 | // Max amount of seconds that the results of a preflight request can be cached in a preflight result cache. 44 | MaxAge time.Duration 45 | 46 | allowOriginMap map[string]bool 47 | allowMethodMap map[string]bool 48 | allowHeaderMap map[string]bool 49 | } 50 | 51 | // AllowAll is the option that allows all origins, headers, and methods. 52 | var AllowAll = Options{ 53 | AllowOrigins: "*", 54 | AllowHeaders: "*", 55 | AllowMethods: "*", 56 | } 57 | 58 | // Handler creates a routing handler that adds appropriate CORS headers according to the specified options and the request. 59 | func Handler(opts Options) routing.Handler { 60 | 61 | opts.init() 62 | 63 | return func(c *routing.Context) (err error) { 64 | origin := c.Request.Header.Get(headerOrigin) 65 | if origin == "" { 66 | // the request is outside the scope of CORS 67 | return 68 | } 69 | if c.Request.Method == "OPTIONS" { 70 | // a preflight request 71 | method := c.Request.Header.Get(headerRequestMethod) 72 | if method == "" { 73 | // the request is outside the scope of CORS 74 | return 75 | } 76 | headers := c.Request.Header.Get(headerRequestHeaders) 77 | opts.setPreflightHeaders(origin, method, headers, c.Response.Header()) 78 | c.Abort() 79 | return 80 | } 81 | opts.setActualHeaders(origin, c.Response.Header()) 82 | return 83 | } 84 | } 85 | 86 | func (o *Options) init() { 87 | o.allowHeaderMap = buildAllowMap(o.AllowHeaders, false) 88 | o.allowMethodMap = buildAllowMap(o.AllowMethods, true) 89 | o.allowOriginMap = buildAllowMap(o.AllowOrigins, true) 90 | } 91 | 92 | func (o *Options) isOriginAllowed(origin string) bool { 93 | if o.AllowOrigins == "null" { 94 | return false 95 | } 96 | return o.AllowOrigins == "*" || o.allowOriginMap[origin] 97 | } 98 | 99 | func (o *Options) setActualHeaders(origin string, headers http.Header) { 100 | if !o.isOriginAllowed(origin) { 101 | return 102 | } 103 | 104 | o.setOriginHeader(origin, headers) 105 | 106 | if o.ExposeHeaders != "" { 107 | headers.Set(headerExposeHeaders, o.ExposeHeaders) 108 | } 109 | } 110 | 111 | func (o *Options) setPreflightHeaders(origin, method, reqHeaders string, headers http.Header) { 112 | allowed, allowedHeaders := o.isPreflightAllowed(origin, method, reqHeaders) 113 | if !allowed { 114 | return 115 | } 116 | 117 | o.setOriginHeader(origin, headers) 118 | 119 | if o.MaxAge > time.Duration(0) { 120 | headers.Set(headerMaxAge, strconv.FormatInt(int64(o.MaxAge/time.Second), 10)) 121 | } 122 | 123 | if o.AllowMethods == "*" { 124 | headers.Set(headerAllowMethods, method) 125 | } else if o.allowMethodMap[method] { 126 | headers.Set(headerAllowMethods, o.AllowMethods) 127 | } 128 | 129 | if allowedHeaders != "" { 130 | headers.Set(headerAllowHeaders, reqHeaders) 131 | } 132 | } 133 | 134 | func (o *Options) isPreflightAllowed(origin, method, reqHeaders string) (allowed bool, allowedHeaders string) { 135 | if !o.isOriginAllowed(origin) { 136 | return 137 | } 138 | if o.AllowMethods != "*" && !o.allowMethodMap[method] { 139 | return 140 | } 141 | if o.AllowHeaders == "*" || reqHeaders == "" { 142 | return true, reqHeaders 143 | } 144 | 145 | headers := []string{} 146 | for _, header := range strings.Split(reqHeaders, ",") { 147 | header = strings.TrimSpace(header) 148 | if o.allowHeaderMap[strings.ToUpper(header)] { 149 | headers = append(headers, header) 150 | } 151 | } 152 | if len(headers) > 0 { 153 | return true, strings.Join(headers, ",") 154 | } 155 | return 156 | } 157 | 158 | func (o *Options) setOriginHeader(origin string, headers http.Header) { 159 | if o.AllowCredentials { 160 | headers.Set(headerAllowOrigin, origin) 161 | headers.Set(headerAllowCredentials, "true") 162 | } else { 163 | if o.AllowOrigins == "*" { 164 | headers.Set(headerAllowOrigin, "*") 165 | } else { 166 | headers.Set(headerAllowOrigin, origin) 167 | } 168 | } 169 | } 170 | 171 | func buildAllowMap(s string, caseSensitive bool) map[string]bool { 172 | m := make(map[string]bool) 173 | if len(s) > 0 { 174 | for _, p := range strings.Split(s, ",") { 175 | p = strings.TrimSpace(p) 176 | if caseSensitive { 177 | m[p] = true 178 | } else { 179 | m[strings.ToUpper(p)] = true 180 | } 181 | } 182 | } 183 | return m 184 | } 185 | -------------------------------------------------------------------------------- /cors/handler_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package cors 6 | 7 | import ( 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | 12 | "time" 13 | 14 | "github.com/go-ozzo/ozzo-routing/v2" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | func TestBuildAllowMap(t *testing.T) { 19 | m := buildAllowMap("", false) 20 | assert.Equal(t, 0, len(m)) 21 | 22 | m = buildAllowMap("", true) 23 | assert.Equal(t, 0, len(m)) 24 | 25 | m = buildAllowMap("GET , put", false) 26 | assert.Equal(t, 2, len(m)) 27 | assert.True(t, m["GET"]) 28 | assert.True(t, m["PUT"]) 29 | assert.False(t, m["put"]) 30 | 31 | m = buildAllowMap("GET , put", true) 32 | assert.Equal(t, 2, len(m)) 33 | assert.True(t, m["GET"]) 34 | assert.False(t, m["PUT"]) 35 | assert.True(t, m["put"]) 36 | } 37 | 38 | func TestOptionsInit(t *testing.T) { 39 | opts := &Options{ 40 | AllowHeaders: "Accept, Accept-Language", 41 | AllowMethods: "PATCH, PUT", 42 | AllowOrigins: "https://example.com", 43 | } 44 | opts.init() 45 | assert.Equal(t, 2, len(opts.allowHeaderMap)) 46 | assert.Equal(t, 2, len(opts.allowMethodMap)) 47 | assert.Equal(t, 1, len(opts.allowOriginMap)) 48 | } 49 | 50 | func TestOptionsIsOriginAllowed(t *testing.T) { 51 | tests := []struct { 52 | id string 53 | allowed string 54 | origin string 55 | result bool 56 | }{ 57 | {"t1", "*", "http://example.com", true}, 58 | {"t2", "null", "http://example.com", false}, 59 | {"t3", "http://foo.com", "http://example.com", false}, 60 | {"t4", "http://example.com", "http://example.com", true}, 61 | } 62 | 63 | for _, test := range tests { 64 | opts := &Options{AllowOrigins: test.allowed} 65 | opts.init() 66 | assert.Equal(t, test.result, opts.isOriginAllowed(test.origin), test.id) 67 | } 68 | } 69 | 70 | func TestOptionsSetOriginHeaders(t *testing.T) { 71 | headers := http.Header{} 72 | opts := &Options{ 73 | AllowOrigins: "https://example.com, https://foo.com", 74 | AllowCredentials: false, 75 | } 76 | opts.setOriginHeader("https://example.com", headers) 77 | assert.Equal(t, "https://example.com", headers.Get(headerAllowOrigin)) 78 | assert.Equal(t, "", headers.Get(headerAllowCredentials)) 79 | 80 | headers = http.Header{} 81 | opts = &Options{ 82 | AllowOrigins: "*", 83 | AllowCredentials: false, 84 | } 85 | opts.setOriginHeader("https://example.com", headers) 86 | assert.Equal(t, "*", headers.Get(headerAllowOrigin)) 87 | assert.Equal(t, "", headers.Get(headerAllowCredentials)) 88 | 89 | headers = http.Header{} 90 | opts = &Options{ 91 | AllowOrigins: "https://example.com, https://foo.com", 92 | AllowCredentials: true, 93 | } 94 | opts.setOriginHeader("https://example.com", headers) 95 | assert.Equal(t, "https://example.com", headers.Get(headerAllowOrigin)) 96 | assert.Equal(t, "true", headers.Get(headerAllowCredentials)) 97 | 98 | headers = http.Header{} 99 | opts = &Options{ 100 | AllowOrigins: "*", 101 | AllowCredentials: true, 102 | } 103 | opts.setOriginHeader("https://example.com", headers) 104 | assert.Equal(t, "https://example.com", headers.Get(headerAllowOrigin)) 105 | assert.Equal(t, "true", headers.Get(headerAllowCredentials)) 106 | } 107 | 108 | func TestOptionsSetActualHeaders(t *testing.T) { 109 | headers := http.Header{} 110 | opts := &Options{ 111 | AllowOrigins: "https://example.com, https://foo.com", 112 | AllowCredentials: false, 113 | ExposeHeaders: "X-Ping, X-Pong", 114 | } 115 | opts.init() 116 | opts.setActualHeaders("https://example.com", headers) 117 | assert.Equal(t, "https://example.com", headers.Get(headerAllowOrigin)) 118 | assert.Equal(t, "X-Ping, X-Pong", headers.Get(headerExposeHeaders)) 119 | 120 | opts.ExposeHeaders = "" 121 | headers = http.Header{} 122 | opts.setActualHeaders("https://example.com", headers) 123 | assert.Equal(t, "https://example.com", headers.Get(headerAllowOrigin)) 124 | assert.Equal(t, "", headers.Get(headerExposeHeaders)) 125 | 126 | headers = http.Header{} 127 | opts.setActualHeaders("https://bar.com", headers) 128 | assert.Equal(t, "", headers.Get(headerAllowOrigin)) 129 | } 130 | 131 | func TestOptionsIsPreflightAllowed(t *testing.T) { 132 | opts := &Options{ 133 | AllowOrigins: "https://example.com, https://foo.com", 134 | AllowMethods: "PUT, PATCH", 135 | AllowCredentials: false, 136 | ExposeHeaders: "X-Ping, X-Pong", 137 | } 138 | opts.init() 139 | allowed, headers := opts.isPreflightAllowed("https://foo.com", "PUT", "") 140 | assert.True(t, allowed) 141 | assert.Equal(t, "", headers) 142 | 143 | opts = &Options{ 144 | AllowOrigins: "https://example.com, https://foo.com", 145 | AllowMethods: "PUT, PATCH", 146 | } 147 | opts.init() 148 | allowed, headers = opts.isPreflightAllowed("https://foo.com", "DELETE", "") 149 | assert.False(t, allowed) 150 | assert.Equal(t, "", headers) 151 | 152 | opts = &Options{ 153 | AllowOrigins: "https://example.com, https://foo.com", 154 | AllowMethods: "PUT, PATCH", 155 | AllowHeaders: "X-Ping, X-Pong", 156 | } 157 | opts.init() 158 | allowed, headers = opts.isPreflightAllowed("https://foo.com", "PUT", "X-Unknown") 159 | assert.False(t, allowed) 160 | assert.Equal(t, "", headers) 161 | } 162 | 163 | func TestOptionsSetPreflightHeaders(t *testing.T) { 164 | headers := http.Header{} 165 | opts := &Options{ 166 | AllowOrigins: "https://example.com, https://foo.com", 167 | AllowMethods: "PUT, PATCH", 168 | AllowHeaders: "X-Ping, X-Pong", 169 | AllowCredentials: false, 170 | ExposeHeaders: "X-Ping, X-Pong", 171 | MaxAge: time.Duration(100) * time.Second, 172 | } 173 | opts.init() 174 | opts.setPreflightHeaders("https://bar.com", "PUT", "", headers) 175 | assert.Zero(t, len(headers)) 176 | 177 | headers = http.Header{} 178 | opts.setPreflightHeaders("https://foo.com", "PUT", "X-Pong", headers) 179 | assert.Equal(t, "https://foo.com", headers.Get(headerAllowOrigin)) 180 | assert.Equal(t, "PUT, PATCH", headers.Get(headerAllowMethods)) 181 | assert.Equal(t, "100", headers.Get(headerMaxAge)) 182 | assert.Equal(t, "X-Pong", headers.Get(headerAllowHeaders)) 183 | 184 | headers = http.Header{} 185 | opts = &Options{ 186 | AllowOrigins: "*", 187 | AllowMethods: "*", 188 | AllowHeaders: "*", 189 | } 190 | opts.init() 191 | opts.setPreflightHeaders("https://bar.com", "PUT", "X-Pong", headers) 192 | assert.Equal(t, "*", headers.Get(headerAllowOrigin)) 193 | assert.Equal(t, "PUT", headers.Get(headerAllowMethods)) 194 | assert.Equal(t, "X-Pong", headers.Get(headerAllowHeaders)) 195 | } 196 | 197 | func TestHandlers(t *testing.T) { 198 | h := Handler(Options{ 199 | AllowOrigins: "https://example.com, https://foo.com", 200 | AllowMethods: "PUT, PATCH", 201 | }) 202 | res := httptest.NewRecorder() 203 | req, _ := http.NewRequest("OPTIONS", "/users/", nil) 204 | req.Header.Set("Origin", "https://example.com") 205 | req.Header.Set("Access-Control-Request-Method", "PATCH") 206 | c := routing.NewContext(res, req) 207 | assert.Nil(t, h(c)) 208 | assert.Equal(t, "https://example.com", res.Header().Get(headerAllowOrigin)) 209 | 210 | res = httptest.NewRecorder() 211 | req, _ = http.NewRequest("PATCH", "/users/", nil) 212 | req.Header.Set("Origin", "https://example.com") 213 | c = routing.NewContext(res, req) 214 | assert.Nil(t, h(c)) 215 | assert.Equal(t, "https://example.com", res.Header().Get(headerAllowOrigin)) 216 | 217 | res = httptest.NewRecorder() 218 | req, _ = http.NewRequest("PATCH", "/users/", nil) 219 | c = routing.NewContext(res, req) 220 | assert.Nil(t, h(c)) 221 | assert.Equal(t, "", res.Header().Get(headerAllowOrigin)) 222 | 223 | res = httptest.NewRecorder() 224 | req, _ = http.NewRequest("OPTIONS", "/users/", nil) 225 | req.Header.Set("Origin", "https://example.com") 226 | c = routing.NewContext(res, req) 227 | assert.Nil(t, h(c)) 228 | assert.Equal(t, "", res.Header().Get(headerAllowOrigin)) 229 | } 230 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package routing 6 | 7 | import "net/http" 8 | 9 | // HTTPError represents an HTTP error with HTTP status code and error message 10 | type HTTPError interface { 11 | error 12 | // StatusCode returns the HTTP status code of the error 13 | StatusCode() int 14 | } 15 | 16 | // Error contains the error information reported by calling Context.Error(). 17 | type httpError struct { 18 | Status int `json:"status" xml:"status"` 19 | Message string `json:"message" xml:"message"` 20 | } 21 | 22 | // NewHTTPError creates a new HttpError instance. 23 | // If the error message is not given, http.StatusText() will be called 24 | // to generate the message based on the status code. 25 | func NewHTTPError(status int, message ...string) HTTPError { 26 | if len(message) > 0 { 27 | return &httpError{status, message[0]} 28 | } 29 | return &httpError{status, http.StatusText(status)} 30 | } 31 | 32 | // Error returns the error message. 33 | func (e *httpError) Error() string { 34 | return e.Message 35 | } 36 | 37 | // StatusCode returns the HTTP status code. 38 | func (e *httpError) StatusCode() int { 39 | return e.Status 40 | } 41 | -------------------------------------------------------------------------------- /error_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package routing 6 | 7 | import ( 8 | "encoding/json" 9 | "net/http" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestNewHttpError(t *testing.T) { 16 | e := NewHTTPError(http.StatusNotFound) 17 | assert.Equal(t, http.StatusNotFound, e.StatusCode()) 18 | assert.Equal(t, http.StatusText(http.StatusNotFound), e.Error()) 19 | 20 | e = NewHTTPError(http.StatusNotFound, "abc") 21 | assert.Equal(t, http.StatusNotFound, e.StatusCode()) 22 | assert.Equal(t, "abc", e.Error()) 23 | 24 | s, _ := json.Marshal(e) 25 | assert.Equal(t, `{"status":404,"message":"abc"}`, string(s)) 26 | } 27 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package routing_test 2 | 3 | import ( 4 | "github.com/go-ozzo/ozzo-routing/v2" 5 | "github.com/go-ozzo/ozzo-routing/v2/access" 6 | "github.com/go-ozzo/ozzo-routing/v2/content" 7 | "github.com/go-ozzo/ozzo-routing/v2/fault" 8 | "github.com/go-ozzo/ozzo-routing/v2/file" 9 | "github.com/go-ozzo/ozzo-routing/v2/slash" 10 | "log" 11 | "net/http" 12 | ) 13 | 14 | func Example() { 15 | router := routing.New() 16 | 17 | router.Use( 18 | // all these handlers are shared by every route 19 | access.Logger(log.Printf), 20 | slash.Remover(http.StatusMovedPermanently), 21 | fault.Recovery(log.Printf), 22 | ) 23 | 24 | // serve RESTful APIs 25 | api := router.Group("/api") 26 | api.Use( 27 | // these handlers are shared by the routes in the api group only 28 | content.TypeNegotiator(content.JSON, content.XML), 29 | ) 30 | api.Get("/users", func(c *routing.Context) error { 31 | return c.Write("user list") 32 | }) 33 | api.Post("/users", func(c *routing.Context) error { 34 | return c.Write("create a new user") 35 | }) 36 | api.Put(`/users/`, func(c *routing.Context) error { 37 | return c.Write("update user " + c.Param("id")) 38 | }) 39 | 40 | // serve index file 41 | router.Get("/", file.Content("ui/index.html")) 42 | // serve files under the "ui" subdirectory 43 | router.Get("/*", file.Server(file.PathMap{ 44 | "/": "/ui/", 45 | })) 46 | 47 | http.Handle("/", router) 48 | http.ListenAndServe(":8080", nil) 49 | } 50 | -------------------------------------------------------------------------------- /fault/error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package fault provides a panic and error handler for the ozzo routing package. 6 | package fault 7 | 8 | import ( 9 | "net/http" 10 | 11 | "github.com/go-ozzo/ozzo-routing/v2" 12 | ) 13 | 14 | // ErrorHandler returns a handler that handles errors returned by the handlers following this one. 15 | // If the error implements routing.HTTPError, the handler will set the HTTP status code accordingly. 16 | // Otherwise the HTTP status is set as http.StatusInternalServerError. The handler will also write the error 17 | // as the response body. 18 | // 19 | // A log function can be provided to log a message whenever an error is handled. If nil, no message will be logged. 20 | // 21 | // An optional error conversion function can also be provided to convert an error into a normalized one 22 | // before sending it to the response. 23 | // 24 | // import ( 25 | // "log" 26 | // "github.com/go-ozzo/ozzo-routing/v2" 27 | // "github.com/go-ozzo/ozzo-routing/v2/fault" 28 | // ) 29 | // 30 | // r := routing.New() 31 | // r.Use(fault.ErrorHandler(log.Printf)) 32 | // r.Use(fault.PanicHandler(log.Printf)) 33 | func ErrorHandler(logf LogFunc, errorf ...ConvertErrorFunc) routing.Handler { 34 | return func(c *routing.Context) error { 35 | err := c.Next() 36 | if err == nil { 37 | return nil 38 | } 39 | 40 | if logf != nil { 41 | logf("%v", err) 42 | } 43 | 44 | if len(errorf) > 0 { 45 | err = errorf[0](c, err) 46 | } 47 | 48 | writeError(c, err) 49 | c.Abort() 50 | 51 | return nil 52 | } 53 | } 54 | 55 | // writeError writes the error to the response. 56 | // If the error implements HTTPError, it will set the HTTP status as the result of the StatusCode() call of the error. 57 | // Otherwise, the HTTP status will be set as http.StatusInternalServerError. 58 | func writeError(c *routing.Context, err error) { 59 | if httpError, ok := err.(routing.HTTPError); ok { 60 | c.Response.WriteHeader(httpError.StatusCode()) 61 | } else { 62 | c.Response.WriteHeader(http.StatusInternalServerError) 63 | } 64 | c.Write(err) 65 | } 66 | -------------------------------------------------------------------------------- /fault/error_test.go: -------------------------------------------------------------------------------- 1 | package fault 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/go-ozzo/ozzo-routing/v2" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestErrorHandler(t *testing.T) { 15 | var buf bytes.Buffer 16 | h := ErrorHandler(getLogger(&buf)) 17 | 18 | res := httptest.NewRecorder() 19 | req, _ := http.NewRequest("GET", "/users/", nil) 20 | c := routing.NewContext(res, req, h, handler1, handler2) 21 | assert.Nil(t, c.Next()) 22 | assert.Equal(t, http.StatusInternalServerError, res.Code) 23 | assert.Equal(t, "abc", res.Body.String()) 24 | assert.Equal(t, "abc", buf.String()) 25 | 26 | buf.Reset() 27 | res = httptest.NewRecorder() 28 | req, _ = http.NewRequest("GET", "/users/", nil) 29 | c = routing.NewContext(res, req, h, handler2) 30 | assert.Nil(t, c.Next()) 31 | assert.Equal(t, http.StatusOK, res.Code) 32 | assert.Equal(t, "test", res.Body.String()) 33 | assert.Equal(t, "", buf.String()) 34 | 35 | buf.Reset() 36 | h = ErrorHandler(getLogger(&buf), convertError) 37 | res = httptest.NewRecorder() 38 | req, _ = http.NewRequest("GET", "/users/", nil) 39 | c = routing.NewContext(res, req, h, handler1, handler2) 40 | assert.Nil(t, c.Next()) 41 | assert.Equal(t, http.StatusInternalServerError, res.Code) 42 | assert.Equal(t, "123", res.Body.String()) 43 | assert.Equal(t, "abc", buf.String()) 44 | 45 | buf.Reset() 46 | h = ErrorHandler(nil) 47 | res = httptest.NewRecorder() 48 | req, _ = http.NewRequest("GET", "/users/", nil) 49 | c = routing.NewContext(res, req, h, handler1, handler2) 50 | assert.Nil(t, c.Next()) 51 | assert.Equal(t, http.StatusInternalServerError, res.Code) 52 | assert.Equal(t, "abc", res.Body.String()) 53 | assert.Equal(t, "", buf.String()) 54 | } 55 | 56 | func Test_writeError(t *testing.T) { 57 | res := httptest.NewRecorder() 58 | req, _ := http.NewRequest("GET", "/users/", nil) 59 | c := routing.NewContext(res, req) 60 | writeError(c, errors.New("abc")) 61 | assert.Equal(t, http.StatusInternalServerError, res.Code) 62 | assert.Equal(t, "abc", res.Body.String()) 63 | 64 | res = httptest.NewRecorder() 65 | req, _ = http.NewRequest("GET", "/users/", nil) 66 | c = routing.NewContext(res, req) 67 | writeError(c, routing.NewHTTPError(http.StatusNotFound, "xyz")) 68 | assert.Equal(t, http.StatusNotFound, res.Code) 69 | assert.Equal(t, "xyz", res.Body.String()) 70 | } 71 | 72 | func convertError(c *routing.Context, err error) error { 73 | return errors.New("123") 74 | } 75 | -------------------------------------------------------------------------------- /fault/panic.go: -------------------------------------------------------------------------------- 1 | package fault 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "runtime" 7 | 8 | "github.com/go-ozzo/ozzo-routing/v2" 9 | ) 10 | 11 | // PanicHandler returns a handler that recovers from panics happened in the handlers following this one. 12 | // When a panic is recovered, it will be converted into an error and returned to the parent handlers. 13 | // 14 | // A log function can be provided to log the panic call stack information. If the log function is nil, 15 | // no message will be logged. 16 | // 17 | // import ( 18 | // "log" 19 | // "github.com/go-ozzo/ozzo-routing/v2" 20 | // "github.com/go-ozzo/ozzo-routing/v2/fault" 21 | // ) 22 | // 23 | // r := routing.New() 24 | // r.Use(fault.ErrorHandler(log.Printf)) 25 | // r.Use(fault.PanicHandler(log.Printf)) 26 | func PanicHandler(logf LogFunc) routing.Handler { 27 | return func(c *routing.Context) (err error) { 28 | defer func() { 29 | if e := recover(); e != nil { 30 | if logf != nil { 31 | logf("recovered from panic:%v", getCallStack(4)) 32 | } 33 | var ok bool 34 | if err, ok = e.(error); !ok { 35 | err = fmt.Errorf("%v", e) 36 | } 37 | } 38 | }() 39 | 40 | return c.Next() 41 | } 42 | } 43 | 44 | // getCallStack returns the current call stack information as a string. 45 | // The skip parameter specifies how many top frames should be skipped. 46 | func getCallStack(skip int) string { 47 | buf := new(bytes.Buffer) 48 | for i := skip; ; i++ { 49 | _, file, line, ok := runtime.Caller(i) 50 | if !ok { 51 | break 52 | } 53 | fmt.Fprintf(buf, "\n%s:%d", file, line) 54 | } 55 | return buf.String() 56 | } 57 | -------------------------------------------------------------------------------- /fault/panic_test.go: -------------------------------------------------------------------------------- 1 | package fault 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | routing "github.com/go-ozzo/ozzo-routing/v2" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestPanicHandler(t *testing.T) { 14 | var buf bytes.Buffer 15 | h := PanicHandler(getLogger(&buf)) 16 | 17 | res := httptest.NewRecorder() 18 | req, _ := http.NewRequest("GET", "/users/", nil) 19 | c := routing.NewContext(res, req, h, handler3, handler2) 20 | err := c.Next() 21 | if assert.NotNil(t, err) { 22 | assert.Equal(t, "xyz", err.Error()) 23 | } 24 | assert.NotEqual(t, "", buf.String()) 25 | 26 | buf.Reset() 27 | res = httptest.NewRecorder() 28 | req, _ = http.NewRequest("GET", "/users/", nil) 29 | c = routing.NewContext(res, req, h, handler2) 30 | assert.Nil(t, c.Next()) 31 | assert.Equal(t, "", buf.String()) 32 | 33 | buf.Reset() 34 | h2 := ErrorHandler(getLogger(&buf)) 35 | res = httptest.NewRecorder() 36 | req, _ = http.NewRequest("GET", "/users/", nil) 37 | c = routing.NewContext(res, req, h2, h, handler3, handler2) 38 | assert.Nil(t, c.Next()) 39 | assert.Equal(t, http.StatusInternalServerError, res.Code) 40 | assert.Equal(t, "xyz", res.Body.String()) 41 | assert.Contains(t, buf.String(), "panic_test.go") 42 | assert.Contains(t, buf.String(), "xyz") 43 | } 44 | -------------------------------------------------------------------------------- /fault/recovery.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package fault provides a panic and error handler for the ozzo routing package. 6 | package fault 7 | 8 | import "github.com/go-ozzo/ozzo-routing/v2" 9 | 10 | type ( 11 | // LogFunc logs a message using the given format and optional arguments. 12 | // The usage of format and arguments is similar to that for fmt.Printf(). 13 | // LogFunc should be thread safe. 14 | LogFunc func(format string, a ...interface{}) 15 | 16 | // ConvertErrorFunc converts an error into a different format so that it is more appropriate for rendering purpose. 17 | ConvertErrorFunc func(*routing.Context, error) error 18 | ) 19 | 20 | // Recovery returns a handler that handles both panics and errors occurred while servicing an HTTP request. 21 | // Recovery can be considered as a combination of ErrorHandler and PanicHandler. 22 | // 23 | // The handler will recover from panics and render the recovered error or the error returned by a handler. 24 | // If the error implements routing.HTTPError, the handler will set the HTTP status code accordingly. 25 | // Otherwise the HTTP status is set as http.StatusInternalServerError. The handler will also write the error 26 | // as the response body. 27 | // 28 | // A log function can be provided to log a message whenever an error is handled. If nil, no message will be logged. 29 | // 30 | // An optional error conversion function can also be provided to convert an error into a normalized one 31 | // before sending it to the response. 32 | // 33 | // import ( 34 | // "log" 35 | // "github.com/go-ozzo/ozzo-routing/v2" 36 | // "github.com/go-ozzo/ozzo-routing/v2/fault" 37 | // ) 38 | // 39 | // r := routing.New() 40 | // r.Use(fault.Recovery(log.Printf)) 41 | func Recovery(logf LogFunc, errorf ...ConvertErrorFunc) routing.Handler { 42 | handlePanic := PanicHandler(logf) 43 | return func(c *routing.Context) error { 44 | if err := handlePanic(c); err != nil { 45 | if logf != nil { 46 | logf("%v", err) 47 | } 48 | if len(errorf) > 0 { 49 | err = errorf[0](c, err) 50 | } 51 | writeError(c, err) 52 | c.Abort() 53 | } 54 | return nil 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /fault/recovery_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package fault 6 | 7 | import ( 8 | "bytes" 9 | "errors" 10 | "fmt" 11 | "net/http" 12 | "net/http/httptest" 13 | "testing" 14 | 15 | "github.com/go-ozzo/ozzo-routing/v2" 16 | "github.com/stretchr/testify/assert" 17 | ) 18 | 19 | func TestRecovery(t *testing.T) { 20 | var buf bytes.Buffer 21 | h := Recovery(getLogger(&buf)) 22 | 23 | res := httptest.NewRecorder() 24 | req, _ := http.NewRequest("GET", "/users/", nil) 25 | c := routing.NewContext(res, req, h, handler1, handler2) 26 | assert.Nil(t, c.Next()) 27 | assert.Equal(t, http.StatusInternalServerError, res.Code) 28 | assert.Equal(t, "abc", res.Body.String()) 29 | assert.Equal(t, "abc", buf.String()) 30 | 31 | buf.Reset() 32 | res = httptest.NewRecorder() 33 | req, _ = http.NewRequest("GET", "/users/", nil) 34 | c = routing.NewContext(res, req, h, handler2) 35 | assert.Nil(t, c.Next()) 36 | assert.Equal(t, http.StatusOK, res.Code) 37 | assert.Equal(t, "test", res.Body.String()) 38 | assert.Equal(t, "", buf.String()) 39 | 40 | buf.Reset() 41 | res = httptest.NewRecorder() 42 | req, _ = http.NewRequest("GET", "/users/", nil) 43 | c = routing.NewContext(res, req, h, handler3, handler2) 44 | assert.Nil(t, c.Next()) 45 | assert.Equal(t, http.StatusInternalServerError, res.Code) 46 | assert.Equal(t, "xyz", res.Body.String()) 47 | assert.Contains(t, buf.String(), "recovery_test.go") 48 | assert.Contains(t, buf.String(), "xyz") 49 | 50 | buf.Reset() 51 | res = httptest.NewRecorder() 52 | req, _ = http.NewRequest("GET", "/users/", nil) 53 | c = routing.NewContext(res, req, h, handler4, handler2) 54 | assert.Nil(t, c.Next()) 55 | assert.Equal(t, http.StatusBadRequest, res.Code) 56 | assert.Equal(t, "123", res.Body.String()) 57 | assert.Contains(t, buf.String(), "recovery_test.go") 58 | assert.Contains(t, buf.String(), "123") 59 | 60 | buf.Reset() 61 | h = Recovery(getLogger(&buf), convertError) 62 | res = httptest.NewRecorder() 63 | req, _ = http.NewRequest("GET", "/users/", nil) 64 | c = routing.NewContext(res, req, h, handler3, handler2) 65 | assert.Nil(t, c.Next()) 66 | assert.Equal(t, http.StatusInternalServerError, res.Code) 67 | assert.Equal(t, "123", res.Body.String()) 68 | assert.Contains(t, buf.String(), "recovery_test.go") 69 | assert.Contains(t, buf.String(), "xyz") 70 | 71 | buf.Reset() 72 | h = Recovery(getLogger(&buf), convertError) 73 | res = httptest.NewRecorder() 74 | req, _ = http.NewRequest("GET", "/users/", nil) 75 | c = routing.NewContext(res, req, h, handler1, handler2) 76 | assert.Nil(t, c.Next()) 77 | assert.Equal(t, http.StatusInternalServerError, res.Code) 78 | assert.Equal(t, "123", res.Body.String()) 79 | assert.Equal(t, "abc", buf.String()) 80 | } 81 | 82 | func getLogger(buf *bytes.Buffer) LogFunc { 83 | return func(format string, a ...interface{}) { 84 | fmt.Fprintf(buf, format, a...) 85 | } 86 | } 87 | 88 | func handler1(c *routing.Context) error { 89 | return errors.New("abc") 90 | } 91 | 92 | func handler2(c *routing.Context) error { 93 | c.Write("test") 94 | return nil 95 | } 96 | 97 | func handler3(c *routing.Context) error { 98 | panic("xyz") 99 | } 100 | 101 | func handler4(c *routing.Context) error { 102 | panic(routing.NewHTTPError(http.StatusBadRequest, "123")) 103 | } 104 | -------------------------------------------------------------------------------- /file/server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package file provides handlers that serve static files for the ozzo routing package. 6 | package file 7 | 8 | import ( 9 | "net/http" 10 | "os" 11 | "path/filepath" 12 | "sort" 13 | "strings" 14 | 15 | "github.com/go-ozzo/ozzo-routing/v2" 16 | ) 17 | 18 | // ServerOptions defines the possible options for the Server handler. 19 | type ServerOptions struct { 20 | // The path that all files to be served should be located within. The path map passed to the Server method 21 | // are all relative to this path. This property can be specified as an absolute file path or a path relative 22 | // to the current working path. If not set, this property defaults to the current working path. 23 | RootPath string 24 | // The file (e.g. index.html) to be served when the current request corresponds to a directory. 25 | // If not set, the handler will return a 404 HTTP error when the request corresponds to a directory. 26 | // This should only be a file name without the directory part. 27 | IndexFile string 28 | // The file to be served when no file or directory matches the current request. 29 | // If not set, the handler will return a 404 HTTP error when no file/directory matches the request. 30 | // The path of this file is relative to RootPath 31 | CatchAllFile string 32 | // A function that checks if the requested file path is allowed. If allowed, the function 33 | // may do additional work such as setting Expires HTTP header. 34 | // The function should return a boolean indicating whether the file should be served or not. 35 | // If false, a 404 HTTP error will be returned by the handler. 36 | Allow func(*routing.Context, string) bool 37 | } 38 | 39 | // PathMap specifies the mapping between URL paths (keys) and file paths (keys). 40 | // The file paths are relative to Options.RootPath 41 | type PathMap map[string]string 42 | 43 | // RootPath stores the current working path 44 | var RootPath string 45 | 46 | func init() { 47 | RootPath, _ = os.Getwd() 48 | } 49 | 50 | // Server returns a handler that serves the files as the response content. 51 | // The files being served are determined using the current URL path and the specified path map. 52 | // For example, if the path map is {"/css": "/www/css", "/js": "/www/js"} and the current URL path 53 | // "/css/main.css", the file "/www/css/main.css" will be served. 54 | // If a URL path matches multiple prefixes in the path map, the most specific prefix will take precedence. 55 | // For example, if the path map contains both "/css" and "/css/img", and the URL path is "/css/img/logo.gif", 56 | // then the path mapped by "/css/img" will be used. 57 | // 58 | // import ( 59 | // "log" 60 | // "github.com/go-ozzo/ozzo-routing/v2" 61 | // "github.com/go-ozzo/ozzo-routing/v2/file" 62 | // ) 63 | // 64 | // r := routing.New() 65 | // r.Get("/*", file.Server(file.PathMap{ 66 | // "/css": "/ui/dist/css", 67 | // "/js": "/ui/dist/js", 68 | // })) 69 | func Server(pathMap PathMap, opts ...ServerOptions) routing.Handler { 70 | var options ServerOptions 71 | if len(opts) > 0 { 72 | options = opts[0] 73 | } 74 | if !filepath.IsAbs(options.RootPath) { 75 | options.RootPath = filepath.Join(RootPath, options.RootPath) 76 | } 77 | from, to := parsePathMap(pathMap) 78 | 79 | // security measure: limit the files within options.RootPath 80 | dir := http.Dir(options.RootPath) 81 | 82 | return func(c *routing.Context) error { 83 | if c.Request.Method != "GET" && c.Request.Method != "HEAD" { 84 | return routing.NewHTTPError(http.StatusMethodNotAllowed) 85 | } 86 | path, found := matchPath(c.Request.URL.Path, from, to) 87 | if !found || options.Allow != nil && !options.Allow(c, path) { 88 | return routing.NewHTTPError(http.StatusNotFound) 89 | } 90 | 91 | var ( 92 | file http.File 93 | fstat os.FileInfo 94 | err error 95 | ) 96 | 97 | if file, err = dir.Open(path); err != nil { 98 | if options.CatchAllFile != "" { 99 | return serveFile(c, dir, options.CatchAllFile) 100 | } 101 | return routing.NewHTTPError(http.StatusNotFound, err.Error()) 102 | } 103 | defer file.Close() 104 | 105 | if fstat, err = file.Stat(); err != nil { 106 | return routing.NewHTTPError(http.StatusNotFound, err.Error()) 107 | } 108 | 109 | if fstat.IsDir() { 110 | if options.IndexFile == "" { 111 | return routing.NewHTTPError(http.StatusNotFound) 112 | } 113 | return serveFile(c, dir, filepath.Join(path, options.IndexFile)) 114 | } 115 | 116 | c.Response.Header().Del("Content-Type") 117 | http.ServeContent(c.Response, c.Request, path, fstat.ModTime(), file) 118 | return nil 119 | } 120 | } 121 | 122 | func serveFile(c *routing.Context, dir http.Dir, path string) error { 123 | file, err := dir.Open(path) 124 | if err != nil { 125 | return routing.NewHTTPError(http.StatusNotFound, err.Error()) 126 | } 127 | defer file.Close() 128 | fstat, err := file.Stat() 129 | if err != nil { 130 | return routing.NewHTTPError(http.StatusNotFound, err.Error()) 131 | } else if fstat.IsDir() { 132 | return routing.NewHTTPError(http.StatusNotFound) 133 | } 134 | c.Response.Header().Del("Content-Type") 135 | http.ServeContent(c.Response, c.Request, path, fstat.ModTime(), file) 136 | return nil 137 | } 138 | 139 | // Content returns a handler that serves the content of the specified file as the response. 140 | // The file to be served can be specified as an absolute file path or a path relative to RootPath (which 141 | // defaults to the current working path). 142 | // If the specified file does not exist, the handler will pass the control to the next available handler. 143 | func Content(path string) routing.Handler { 144 | if !filepath.IsAbs(path) { 145 | path = filepath.Join(RootPath, path) 146 | } 147 | return func(c *routing.Context) error { 148 | if c.Request.Method != "GET" && c.Request.Method != "HEAD" { 149 | return routing.NewHTTPError(http.StatusMethodNotAllowed) 150 | } 151 | file, err := os.Open(path) 152 | if err != nil { 153 | return routing.NewHTTPError(http.StatusNotFound, err.Error()) 154 | } 155 | defer file.Close() 156 | fstat, err := file.Stat() 157 | if err != nil { 158 | return routing.NewHTTPError(http.StatusNotFound, err.Error()) 159 | } else if fstat.IsDir() { 160 | return routing.NewHTTPError(http.StatusNotFound) 161 | } 162 | c.Response.Header().Del("Content-Type") 163 | http.ServeContent(c.Response, c.Request, path, fstat.ModTime(), file) 164 | return nil 165 | } 166 | } 167 | 168 | func parsePathMap(pathMap PathMap) (from, to []string) { 169 | from = make([]string, len(pathMap)) 170 | to = make([]string, len(pathMap)) 171 | n := 0 172 | for i := range pathMap { 173 | from[n] = i 174 | n++ 175 | } 176 | sort.Strings(from) 177 | for i, s := range from { 178 | to[i] = pathMap[s] 179 | } 180 | return 181 | } 182 | 183 | func matchPath(path string, from, to []string) (string, bool) { 184 | for i := len(from) - 1; i >= 0; i-- { 185 | prefix := from[i] 186 | if strings.HasPrefix(path, prefix) { 187 | return to[i] + path[len(prefix):], true 188 | } 189 | } 190 | return "", false 191 | } 192 | -------------------------------------------------------------------------------- /file/server_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package file 6 | 7 | import ( 8 | "net/http" 9 | "net/http/httptest" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/go-ozzo/ozzo-routing/v2" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func TestParsePathMap(t *testing.T) { 18 | tests := []struct { 19 | id string 20 | pathMap PathMap 21 | from, to string 22 | }{ 23 | {"t1", PathMap{}, "", ""}, 24 | {"t2", PathMap{"/": ""}, "/", ""}, 25 | {"t3", PathMap{"/": "ui/dist"}, "/", "ui/dist"}, 26 | {"t4", PathMap{"/abc/123": "ui123/abc", "/abc": "/ui/abc", "/abc/xyz": "/xyzui/abc"}, "/abc,/abc/123,/abc/xyz", "/ui/abc,ui123/abc,/xyzui/abc"}, 27 | } 28 | for _, test := range tests { 29 | af, at := parsePathMap(test.pathMap) 30 | assert.Equal(t, test.from, strings.Join(af, ","), test.id) 31 | assert.Equal(t, test.to, strings.Join(at, ","), test.id) 32 | } 33 | } 34 | 35 | func TestMatchPath(t *testing.T) { 36 | tests := []struct { 37 | id string 38 | from, to []string 39 | url, path string 40 | found bool 41 | }{ 42 | {"t1", []string{}, []string{}, "", "", false}, 43 | 44 | {"t2.1", []string{"/"}, []string{"/www"}, "", "", false}, 45 | {"t2.2", []string{"/"}, []string{"/www"}, "/", "/www", true}, 46 | {"t2.3", []string{"/"}, []string{"/www"}, "/index", "/wwwindex", true}, 47 | {"t2.4", []string{"/"}, []string{"/www/"}, "/index", "/www/index", true}, 48 | {"t2.5", []string{"/"}, []string{"/www/"}, "/index/", "/www/index/", true}, 49 | {"t2.6", []string{"/"}, []string{"/www/"}, "index", "", false}, 50 | {"t2.7", []string{""}, []string{""}, "/", "/", true}, 51 | {"t2.7", []string{""}, []string{""}, "/index.html", "/index.html", true}, 52 | 53 | {"t3.1", []string{"/", "/css", "/js"}, []string{"/www/others", "/www/ui/css", "/www/ui/js"}, "", "", false}, 54 | {"t3.2", []string{"/", "/css", "/js"}, []string{"/www/others", "/www/ui/css", "/www/ui/js"}, "/", "/www/others", true}, 55 | {"t3.3", []string{"/", "/css", "/js"}, []string{"/www/others", "/www/ui/css", "/www/ui/js"}, "/css", "/www/ui/css", true}, 56 | {"t3.4", []string{"/", "/css", "/js"}, []string{"/www/others", "/www/ui/css", "/www/ui/js"}, "/abc", "/www/othersabc", true}, 57 | {"t3.5", []string{"/", "/css", "/js"}, []string{"/www/others", "/www/ui/css", "/www/ui/js"}, "/css2", "/www/ui/css2", true}, 58 | 59 | {"t4.1", []string{"/css/abc", "/css"}, []string{"/www/abc", "/www/css"}, "/css/abc", "/www/css/abc", true}, 60 | } 61 | for _, test := range tests { 62 | path, found := matchPath(test.url, test.from, test.to) 63 | assert.Equal(t, test.found, found, test.id) 64 | if found { 65 | assert.Equal(t, test.path, path, test.id) 66 | } 67 | } 68 | } 69 | 70 | func TestContent(t *testing.T) { 71 | h := Content("testdata/index.html") 72 | req, _ := http.NewRequest("GET", "/index.html", nil) 73 | res := httptest.NewRecorder() 74 | c := routing.NewContext(res, req) 75 | err := h(c) 76 | assert.Nil(t, err) 77 | assert.Equal(t, "hello\n", res.Body.String()) 78 | 79 | h = Content("testdata/index.html") 80 | req, _ = http.NewRequest("POST", "/index.html", nil) 81 | res = httptest.NewRecorder() 82 | c = routing.NewContext(res, req) 83 | err = h(c) 84 | if assert.NotNil(t, err) { 85 | assert.Equal(t, http.StatusMethodNotAllowed, err.(routing.HTTPError).StatusCode()) 86 | } 87 | 88 | h = Content("testdata/index.go") 89 | req, _ = http.NewRequest("GET", "/index.html", nil) 90 | res = httptest.NewRecorder() 91 | c = routing.NewContext(res, req) 92 | err = h(c) 93 | if assert.NotNil(t, err) { 94 | assert.Equal(t, http.StatusNotFound, err.(routing.HTTPError).StatusCode()) 95 | } 96 | 97 | h = Content("testdata/css") 98 | req, _ = http.NewRequest("GET", "/index.html", nil) 99 | res = httptest.NewRecorder() 100 | c = routing.NewContext(res, req) 101 | err = h(c) 102 | if assert.NotNil(t, err) { 103 | assert.Equal(t, http.StatusNotFound, err.(routing.HTTPError).StatusCode()) 104 | } 105 | } 106 | 107 | func TestServer(t *testing.T) { 108 | h := Server(PathMap{"/css": "/testdata/css"}) 109 | tests := []struct { 110 | id string 111 | method, url string 112 | status int 113 | body string 114 | }{ 115 | {"t1", "GET", "/css/main.css", 0, "body {}\n"}, 116 | {"t2", "HEAD", "/css/main.css", 0, ""}, 117 | {"t3", "GET", "/css/main2.css", http.StatusNotFound, ""}, 118 | {"t4", "POST", "/css/main.css", http.StatusMethodNotAllowed, ""}, 119 | {"t5", "GET", "/css", http.StatusNotFound, ""}, 120 | } 121 | 122 | for _, test := range tests { 123 | req, _ := http.NewRequest(test.method, test.url, nil) 124 | res := httptest.NewRecorder() 125 | c := routing.NewContext(res, req) 126 | err := h(c) 127 | if test.status == 0 { 128 | assert.Nil(t, err, test.id) 129 | assert.Equal(t, test.body, res.Body.String(), test.id) 130 | } else { 131 | if assert.NotNil(t, err, test.id) { 132 | assert.Equal(t, test.status, err.(routing.HTTPError).StatusCode(), test.id) 133 | } 134 | } 135 | } 136 | 137 | h = Server(PathMap{"/css": "/testdata/css"}, ServerOptions{ 138 | IndexFile: "index.html", 139 | Allow: func(c *routing.Context, path string) bool { 140 | return path != "/testdata/css/main.css" 141 | }, 142 | }) 143 | 144 | req, _ := http.NewRequest("GET", "/css/main.css", nil) 145 | res := httptest.NewRecorder() 146 | c := routing.NewContext(res, req) 147 | err := h(c) 148 | assert.NotNil(t, err) 149 | 150 | req, _ = http.NewRequest("GET", "/css", nil) 151 | res = httptest.NewRecorder() 152 | c = routing.NewContext(res, req) 153 | err = h(c) 154 | assert.Nil(t, err) 155 | assert.Equal(t, "css.html\n", res.Body.String()) 156 | 157 | { 158 | // with CatchAll option 159 | h = Server(PathMap{"/css": "/testdata/css"}, ServerOptions{ 160 | IndexFile: "index.html", 161 | CatchAllFile: "testdata/index.html", 162 | Allow: func(c *routing.Context, path string) bool { 163 | return path != "/testdata/css/main.css" 164 | }, 165 | }) 166 | 167 | req, _ := http.NewRequest("GET", "/css/main.css", nil) 168 | res := httptest.NewRecorder() 169 | c := routing.NewContext(res, req) 170 | err := h(c) 171 | assert.NotNil(t, err) 172 | 173 | req, _ = http.NewRequest("GET", "/css", nil) 174 | res = httptest.NewRecorder() 175 | c = routing.NewContext(res, req) 176 | err = h(c) 177 | assert.Nil(t, err) 178 | assert.Equal(t, "css.html\n", res.Body.String()) 179 | 180 | req, _ = http.NewRequest("GET", "/css2", nil) 181 | res = httptest.NewRecorder() 182 | c = routing.NewContext(res, req) 183 | err = h(c) 184 | assert.Nil(t, err) 185 | assert.Equal(t, "hello\n", res.Body.String()) 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /file/testdata/css/index.html: -------------------------------------------------------------------------------- 1 | css.html 2 | -------------------------------------------------------------------------------- /file/testdata/css/main.css: -------------------------------------------------------------------------------- 1 | body {} 2 | -------------------------------------------------------------------------------- /file/testdata/index.html: -------------------------------------------------------------------------------- 1 | hello 2 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-ozzo/ozzo-routing/v2 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/golang-jwt/jwt v3.2.2+incompatible 7 | github.com/golang/gddo v0.0.0-20190904175337-72a348e765d2 8 | github.com/google/go-cmp v0.3.1 // indirect 9 | github.com/stretchr/testify v1.4.0 10 | ) 11 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= 4 | github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= 5 | github.com/golang/gddo v0.0.0-20190904175337-72a348e765d2 h1:xisWqjiKEff2B0KfFYGpCqc3M3zdTz+OHQHRc09FeYk= 6 | github.com/golang/gddo v0.0.0-20190904175337-72a348e765d2/go.mod h1:xEhNfoBDX1hzLm2Nf80qUvZ2sVwoMZ8d6IE2SrsQfh4= 7 | github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= 8 | github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= 9 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 10 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 11 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 12 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= 13 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 14 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 15 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 16 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 17 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 18 | -------------------------------------------------------------------------------- /graceful.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package routing 6 | 7 | import ( 8 | "context" 9 | "net/http" 10 | "os" 11 | "os/signal" 12 | "syscall" 13 | "time" 14 | ) 15 | 16 | // GracefulShutdown shuts down the given HTTP server gracefully when receiving an os.Interrupt or syscall.SIGTERM signal. 17 | // It will wait for the specified timeout to stop hanging HTTP handlers. 18 | func GracefulShutdown(hs *http.Server, timeout time.Duration, logFunc func(format string, args ...interface{})) { 19 | stop := make(chan os.Signal, 1) 20 | 21 | signal.Notify(stop, os.Interrupt, syscall.SIGTERM) 22 | 23 | <-stop 24 | 25 | ctx, cancel := context.WithTimeout(context.Background(), timeout) 26 | defer cancel() 27 | 28 | logFunc("shutting down server with %s timeout", timeout) 29 | 30 | if err := hs.Shutdown(ctx); err != nil { 31 | logFunc("error while shutting down server: %v", err) 32 | } else { 33 | logFunc("server was shut down gracefully") 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /group.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package routing 6 | 7 | import "strings" 8 | 9 | // RouteGroup represents a group of routes that share the same path prefix. 10 | type RouteGroup struct { 11 | prefix string 12 | router *Router 13 | handlers []Handler 14 | } 15 | 16 | // newRouteGroup creates a new RouteGroup with the given path prefix, router, and handlers. 17 | func newRouteGroup(prefix string, router *Router, handlers []Handler) *RouteGroup { 18 | return &RouteGroup{ 19 | prefix: prefix, 20 | router: router, 21 | handlers: handlers, 22 | } 23 | } 24 | 25 | // Get adds a GET route to the router with the given route path and handlers. 26 | func (rg *RouteGroup) Get(path string, handlers ...Handler) *Route { 27 | return rg.add("GET", path, handlers) 28 | } 29 | 30 | // Post adds a POST route to the router with the given route path and handlers. 31 | func (rg *RouteGroup) Post(path string, handlers ...Handler) *Route { 32 | return rg.add("POST", path, handlers) 33 | } 34 | 35 | // Put adds a PUT route to the router with the given route path and handlers. 36 | func (rg *RouteGroup) Put(path string, handlers ...Handler) *Route { 37 | return rg.add("PUT", path, handlers) 38 | } 39 | 40 | // Patch adds a PATCH route to the router with the given route path and handlers. 41 | func (rg *RouteGroup) Patch(path string, handlers ...Handler) *Route { 42 | return rg.add("PATCH", path, handlers) 43 | } 44 | 45 | // Delete adds a DELETE route to the router with the given route path and handlers. 46 | func (rg *RouteGroup) Delete(path string, handlers ...Handler) *Route { 47 | return rg.add("DELETE", path, handlers) 48 | } 49 | 50 | // Connect adds a CONNECT route to the router with the given route path and handlers. 51 | func (rg *RouteGroup) Connect(path string, handlers ...Handler) *Route { 52 | return rg.add("CONNECT", path, handlers) 53 | } 54 | 55 | // Head adds a HEAD route to the router with the given route path and handlers. 56 | func (rg *RouteGroup) Head(path string, handlers ...Handler) *Route { 57 | return rg.add("HEAD", path, handlers) 58 | } 59 | 60 | // Options adds an OPTIONS route to the router with the given route path and handlers. 61 | func (rg *RouteGroup) Options(path string, handlers ...Handler) *Route { 62 | return rg.add("OPTIONS", path, handlers) 63 | } 64 | 65 | // Trace adds a TRACE route to the router with the given route path and handlers. 66 | func (rg *RouteGroup) Trace(path string, handlers ...Handler) *Route { 67 | return rg.add("TRACE", path, handlers) 68 | } 69 | 70 | // Any adds a route with the given route, handlers, and the HTTP methods as listed in routing.Methods. 71 | func (rg *RouteGroup) Any(path string, handlers ...Handler) *Route { 72 | return rg.To(strings.Join(Methods, ","), path, handlers...) 73 | } 74 | 75 | // To adds a route to the router with the given HTTP methods, route path, and handlers. 76 | // Multiple HTTP methods should be separated by commas (without any surrounding spaces). 77 | func (rg *RouteGroup) To(methods, path string, handlers ...Handler) *Route { 78 | mm := strings.Split(methods, ",") 79 | if len(mm) == 1 { 80 | return rg.add(methods, path, handlers) 81 | } 82 | 83 | r := rg.newRoute(methods, path) 84 | for _, method := range mm { 85 | r.routes = append(r.routes, rg.add(method, path, handlers)) 86 | } 87 | return r 88 | } 89 | 90 | // Group creates a RouteGroup with the given route path prefix and handlers. 91 | // The new group will combine the existing path prefix with the new one. 92 | // If no handler is provided, the new group will inherit the handlers registered 93 | // with the current group. 94 | func (rg *RouteGroup) Group(prefix string, handlers ...Handler) *RouteGroup { 95 | if len(handlers) == 0 { 96 | handlers = make([]Handler, len(rg.handlers)) 97 | copy(handlers, rg.handlers) 98 | } 99 | return newRouteGroup(rg.prefix+prefix, rg.router, handlers) 100 | } 101 | 102 | // Use registers one or multiple handlers to the current route group. 103 | // These handlers will be shared by all routes belong to this group and its subgroups. 104 | func (rg *RouteGroup) Use(handlers ...Handler) { 105 | rg.handlers = append(rg.handlers, handlers...) 106 | } 107 | 108 | func (rg *RouteGroup) add(method, path string, handlers []Handler) *Route { 109 | r := rg.newRoute(method, path) 110 | rg.router.addRoute(r, combineHandlers(rg.handlers, handlers)) 111 | return r 112 | } 113 | 114 | // newRoute creates a new Route with the given route path and route group. 115 | func (rg *RouteGroup) newRoute(method, path string) *Route { 116 | return &Route{ 117 | group: rg, 118 | method: method, 119 | path: path, 120 | template: buildURLTemplate(rg.prefix + path), 121 | } 122 | } 123 | 124 | // combineHandlers merges two lists of handlers into a new list. 125 | func combineHandlers(h1 []Handler, h2 []Handler) []Handler { 126 | hh := make([]Handler, len(h1)+len(h2)) 127 | copy(hh, h1) 128 | copy(hh[len(h1):], h2) 129 | return hh 130 | } 131 | 132 | // buildURLTemplate converts a route pattern into a URL template by removing regular expressions in parameter tokens. 133 | func buildURLTemplate(path string) string { 134 | path = strings.TrimRight(path, "*") 135 | template, start, end := "", -1, -1 136 | for i := 0; i < len(path); i++ { 137 | if path[i] == '<' && start < 0 { 138 | start = i 139 | } else if path[i] == '>' && start >= 0 { 140 | name := path[start+1 : i] 141 | for j := start + 1; j < i; j++ { 142 | if path[j] == ':' { 143 | name = path[start+1 : j] 144 | break 145 | } 146 | } 147 | template += path[end+1:start] + "<" + name + ">" 148 | end = i 149 | start = -1 150 | } 151 | } 152 | if end < 0 { 153 | template = path 154 | } else if end < len(path)-1 { 155 | template += path[end+1:] 156 | } 157 | return template 158 | } 159 | -------------------------------------------------------------------------------- /group_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package routing 6 | 7 | import ( 8 | "bytes" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestRouteGroupTo(t *testing.T) { 15 | router := New() 16 | for _, method := range Methods { 17 | store := newMockStore() 18 | router.stores[method] = store 19 | } 20 | group := newRouteGroup("/admin", router, nil) 21 | 22 | group.Any("/users") 23 | for _, method := range Methods { 24 | assert.Equal(t, 1, router.stores[method].(*mockStore).count, "router.stores["+method+"].count@1 =") 25 | } 26 | 27 | group.To("GET", "/articles") 28 | assert.Equal(t, 2, router.stores["GET"].(*mockStore).count, "router.stores[GET].count@2 =") 29 | assert.Equal(t, 1, router.stores["POST"].(*mockStore).count, "router.stores[POST].count@2 =") 30 | 31 | group.To("GET,POST", "/comments") 32 | assert.Equal(t, 3, router.stores["GET"].(*mockStore).count, "router.stores[GET].count@3 =") 33 | assert.Equal(t, 2, router.stores["POST"].(*mockStore).count, "router.stores[POST].count@3 =") 34 | } 35 | 36 | func TestRouteGroupMethods(t *testing.T) { 37 | router := New() 38 | for _, method := range Methods { 39 | store := newMockStore() 40 | router.stores[method] = store 41 | assert.Equal(t, 0, store.count, "router.stores["+method+"].count =") 42 | } 43 | group := newRouteGroup("/admin", router, nil) 44 | 45 | group.Get("/users") 46 | assert.Equal(t, 1, router.stores["GET"].(*mockStore).count, "router.stores[GET].count =") 47 | group.Post("/users") 48 | assert.Equal(t, 1, router.stores["POST"].(*mockStore).count, "router.stores[POST].count =") 49 | group.Patch("/users") 50 | assert.Equal(t, 1, router.stores["PATCH"].(*mockStore).count, "router.stores[PATCH].count =") 51 | group.Put("/users") 52 | assert.Equal(t, 1, router.stores["PUT"].(*mockStore).count, "router.stores[PUT].count =") 53 | group.Delete("/users") 54 | assert.Equal(t, 1, router.stores["DELETE"].(*mockStore).count, "router.stores[DELETE].count =") 55 | group.Connect("/users") 56 | assert.Equal(t, 1, router.stores["CONNECT"].(*mockStore).count, "router.stores[CONNECT].count =") 57 | group.Head("/users") 58 | assert.Equal(t, 1, router.stores["HEAD"].(*mockStore).count, "router.stores[HEAD].count =") 59 | group.Options("/users") 60 | assert.Equal(t, 1, router.stores["OPTIONS"].(*mockStore).count, "router.stores[OPTIONS].count =") 61 | group.Trace("/users") 62 | assert.Equal(t, 1, router.stores["TRACE"].(*mockStore).count, "router.stores[TRACE].count =") 63 | } 64 | 65 | func TestRouteGroupGroup(t *testing.T) { 66 | group := newRouteGroup("/admin", New(), nil) 67 | g1 := group.Group("/users") 68 | assert.Equal(t, "/admin/users", g1.prefix, "g1.prefix =") 69 | assert.Equal(t, 0, len(g1.handlers), "len(g1.handlers) =") 70 | var buf bytes.Buffer 71 | g2 := group.Group("", newHandler("1", &buf), newHandler("2", &buf)) 72 | assert.Equal(t, "/admin", g2.prefix, "g2.prefix =") 73 | assert.Equal(t, 2, len(g2.handlers), "len(g2.handlers) =") 74 | 75 | group2 := newRouteGroup("/admin", New(), []Handler{newHandler("1", &buf), newHandler("2", &buf)}) 76 | g3 := group2.Group("/users") 77 | assert.Equal(t, "/admin/users", g3.prefix, "g3.prefix =") 78 | assert.Equal(t, 2, len(g3.handlers), "len(g3.handlers) =") 79 | g4 := group2.Group("", newHandler("3", &buf)) 80 | assert.Equal(t, "/admin", g4.prefix, "g4.prefix =") 81 | assert.Equal(t, 1, len(g4.handlers), "len(g4.handlers) =") 82 | } 83 | 84 | func TestRouteGroupUse(t *testing.T) { 85 | var buf bytes.Buffer 86 | group := newRouteGroup("/admin", New(), nil) 87 | group.Use(newHandler("1", &buf), newHandler("2", &buf)) 88 | assert.Equal(t, 2, len(group.handlers), "len(group.handlers) =") 89 | 90 | group2 := newRouteGroup("/admin", New(), []Handler{newHandler("1", &buf), newHandler("2", &buf)}) 91 | group2.Use(newHandler("3", &buf)) 92 | assert.Equal(t, 3, len(group2.handlers), "len(group2.handlers) =") 93 | } 94 | -------------------------------------------------------------------------------- /reader.go: -------------------------------------------------------------------------------- 1 | package routing 2 | 3 | import ( 4 | "encoding" 5 | "encoding/json" 6 | "encoding/xml" 7 | "errors" 8 | "net/http" 9 | "reflect" 10 | "strconv" 11 | ) 12 | 13 | // MIME types used when doing request data reading and response data writing. 14 | const ( 15 | MIME_JSON = "application/json" 16 | MIME_XML = "application/xml" 17 | MIME_XML2 = "text/xml" 18 | MIME_HTML = "text/html" 19 | MIME_FORM = "application/x-www-form-urlencoded" 20 | MIME_MULTIPART_FORM = "multipart/form-data" 21 | ) 22 | 23 | var ( 24 | textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() 25 | ) 26 | 27 | // DataReader is used by Context.Read() to read data from an HTTP request. 28 | type DataReader interface { 29 | // Read reads from the given HTTP request and populate the specified data. 30 | Read(*http.Request, interface{}) error 31 | } 32 | 33 | var ( 34 | // DataReaders lists all supported content types and the corresponding data readers. 35 | // Context.Read() will choose a matching reader from this list according to the "Content-Type" 36 | // header from the current request. 37 | // You may modify this variable to add new supported content types. 38 | DataReaders = map[string]DataReader{ 39 | MIME_FORM: &FormDataReader{}, 40 | MIME_MULTIPART_FORM: &FormDataReader{}, 41 | MIME_JSON: &JSONDataReader{}, 42 | MIME_XML: &XMLDataReader{}, 43 | MIME_XML2: &XMLDataReader{}, 44 | } 45 | // DefaultFormDataReader is the reader used when there is no matching reader in DataReaders 46 | // or if the current request is a GET request. 47 | DefaultFormDataReader DataReader = &FormDataReader{} 48 | ) 49 | 50 | // JSONDataReader reads the request body as JSON-formatted data. 51 | type JSONDataReader struct{} 52 | 53 | func (r *JSONDataReader) Read(req *http.Request, data interface{}) error { 54 | return json.NewDecoder(req.Body).Decode(data) 55 | } 56 | 57 | // XMLDataReader reads the request body as XML-formatted data. 58 | type XMLDataReader struct{} 59 | 60 | func (r *XMLDataReader) Read(req *http.Request, data interface{}) error { 61 | return xml.NewDecoder(req.Body).Decode(data) 62 | } 63 | 64 | // FormDataReader reads the query parameters and request body as form data. 65 | type FormDataReader struct{} 66 | 67 | func (r *FormDataReader) Read(req *http.Request, data interface{}) error { 68 | // Do not check return result. Otherwise GET request will cause problem. 69 | req.ParseMultipartForm(32 << 20) 70 | return ReadFormData(req.Form, data) 71 | } 72 | 73 | const formTag = "form" 74 | 75 | // ReadFormData populates the data variable with the data from the given form values. 76 | func ReadFormData(form map[string][]string, data interface{}) error { 77 | rv := reflect.ValueOf(data) 78 | if rv.Kind() != reflect.Ptr || rv.IsNil() { 79 | return errors.New("data must be a pointer") 80 | } 81 | rv = indirect(rv) 82 | if rv.Kind() != reflect.Struct { 83 | return errors.New("data must be a pointer to a struct") 84 | } 85 | 86 | return readForm(form, "", rv) 87 | } 88 | 89 | func readForm(form map[string][]string, prefix string, rv reflect.Value) error { 90 | rv = indirect(rv) 91 | rt := rv.Type() 92 | n := rt.NumField() 93 | for i := 0; i < n; i++ { 94 | field := rt.Field(i) 95 | tag := field.Tag.Get(formTag) 96 | 97 | // only handle anonymous or exported fields 98 | if !field.Anonymous && field.PkgPath != "" || tag == "-" { 99 | continue 100 | } 101 | 102 | ft := field.Type 103 | if ft.Kind() == reflect.Ptr { 104 | ft = ft.Elem() 105 | } 106 | 107 | name := tag 108 | if name == "" && !field.Anonymous { 109 | name = field.Name 110 | } 111 | if name != "" && prefix != "" { 112 | name = prefix + "." + name 113 | } 114 | 115 | // check if type implements a known type, like encoding.TextUnmarshaler 116 | if ok, err := readFormFieldKnownType(form, name, rv.Field(i)); err != nil { 117 | return err 118 | } else if ok { 119 | continue 120 | } 121 | 122 | if ft.Kind() != reflect.Struct { 123 | if err := readFormField(form, name, rv.Field(i)); err != nil { 124 | return err 125 | } 126 | continue 127 | } 128 | 129 | if name == "" { 130 | name = prefix 131 | } 132 | if err := readForm(form, name, rv.Field(i)); err != nil { 133 | return err 134 | } 135 | } 136 | return nil 137 | } 138 | 139 | func readFormFieldKnownType(form map[string][]string, name string, rv reflect.Value) (bool, error) { 140 | value, ok := form[name] 141 | if !ok { 142 | return false, nil 143 | } 144 | rv = indirect(rv) 145 | rt := rv.Type() 146 | 147 | // check if type implements encoding.TextUnmarshaler 148 | if rt.Implements(textUnmarshalerType) { 149 | return true, rv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value[0])) 150 | } else if reflect.PtrTo(rt).Implements(textUnmarshalerType) { 151 | return true, rv.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value[0])) 152 | } 153 | return false, nil 154 | } 155 | 156 | func readFormField(form map[string][]string, name string, rv reflect.Value) error { 157 | value, ok := form[name] 158 | if !ok { 159 | return nil 160 | } 161 | rv = indirect(rv) 162 | if rv.Kind() != reflect.Slice { 163 | return setFormFieldValue(rv, value[0]) 164 | } 165 | 166 | n := len(value) 167 | slice := reflect.MakeSlice(rv.Type(), n, n) 168 | for i := 0; i < n; i++ { 169 | if err := setFormFieldValue(slice.Index(i), value[i]); err != nil { 170 | return err 171 | } 172 | } 173 | rv.Set(slice) 174 | return nil 175 | } 176 | 177 | func setFormFieldValue(rv reflect.Value, value string) error { 178 | switch rv.Kind() { 179 | case reflect.Bool: 180 | if value == "" { 181 | value = "false" 182 | } 183 | v, err := strconv.ParseBool(value) 184 | if err == nil { 185 | rv.SetBool(v) 186 | } 187 | return err 188 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 189 | if value == "" { 190 | value = "0" 191 | } 192 | v, err := strconv.ParseInt(value, 10, 64) 193 | if err == nil { 194 | rv.SetInt(v) 195 | } 196 | return err 197 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 198 | if value == "" { 199 | value = "0" 200 | } 201 | v, err := strconv.ParseUint(value, 10, 64) 202 | if err == nil { 203 | rv.SetUint(v) 204 | } 205 | return err 206 | case reflect.Float32, reflect.Float64: 207 | if value == "" { 208 | value = "0" 209 | } 210 | v, err := strconv.ParseFloat(value, 64) 211 | if err == nil { 212 | rv.SetFloat(v) 213 | } 214 | return err 215 | case reflect.String: 216 | rv.SetString(value) 217 | return nil 218 | default: 219 | return errors.New("Unknown type: " + rv.Kind().String()) 220 | } 221 | } 222 | 223 | // indirect dereferences pointers and returns the actual value it points to. 224 | // If a pointer is nil, it will be initialized with a new value. 225 | func indirect(v reflect.Value) reflect.Value { 226 | for v.Kind() == reflect.Ptr { 227 | if v.IsNil() { 228 | v.Set(reflect.New(v.Type().Elem())) 229 | } 230 | v = v.Elem() 231 | } 232 | return v 233 | } 234 | -------------------------------------------------------------------------------- /reader_test.go: -------------------------------------------------------------------------------- 1 | package routing 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type FA struct { 12 | A1 string 13 | A2 int 14 | } 15 | 16 | type FB struct { 17 | B1 string 18 | B2 bool 19 | B3 float64 20 | } 21 | 22 | func TestReadForm(t *testing.T) { 23 | var a struct { 24 | X1 string `form:"x1"` 25 | FA 26 | X2 int 27 | B *FB 28 | FB `form:"c"` 29 | E FB `form:"e"` 30 | c int 31 | D []int 32 | } 33 | values := map[string][]string{ 34 | "x1": {"abc", "123"}, 35 | "A1": {"a1"}, 36 | "x2": {"1", "2"}, 37 | "B.B1": {"b1", "b2"}, 38 | "B.B2": {"true"}, 39 | "B.B3": {"1.23"}, 40 | "c.B1": {"fb1", "fb2"}, 41 | "e.B1": {"fe1", "fe2"}, 42 | "c": {"100"}, 43 | "D": {"100", "200", "300"}, 44 | } 45 | err := ReadFormData(values, &a) 46 | assert.Nil(t, err) 47 | assert.Equal(t, "abc", a.X1) 48 | assert.Equal(t, "a1", a.A1) 49 | assert.Equal(t, 0, a.X2) 50 | assert.Equal(t, "b1", a.B.B1) 51 | assert.True(t, a.B.B2) 52 | assert.Equal(t, 1.23, a.B.B3) 53 | assert.Equal(t, "fb1", a.B1) 54 | assert.Equal(t, "fe1", a.E.B1) 55 | assert.Equal(t, 0, a.c) 56 | assert.Equal(t, []int{100, 200, 300}, a.D) 57 | } 58 | 59 | func TestDefaultDataReader(t *testing.T) { 60 | tests := []struct { 61 | tag string 62 | header string 63 | method, URL string 64 | body string 65 | }{ 66 | {"t1", "", "GET", "/test?A1=abc&A2=100", ""}, 67 | {"t2", "", "POST", "/test?A1=abc&A2=100", ""}, 68 | {"t3", "application/x-www-form-urlencoded", "POST", "/test", "A1=abc&A2=100"}, 69 | {"t4", "application/json", "POST", "/test", `{"A1":"abc","A2":100}`}, 70 | {"t5", "application/xml", "POST", "/test", `abc100`}, 71 | } 72 | 73 | expected := FA{ 74 | A1: "abc", 75 | A2: 100, 76 | } 77 | for _, test := range tests { 78 | var data FA 79 | req, _ := http.NewRequest(test.method, test.URL, bytes.NewBufferString(test.body)) 80 | req.Header.Set("Content-Type", test.header) 81 | c := NewContext(nil, req) 82 | err := c.Read(&data) 83 | assert.Nil(t, err, test.tag) 84 | assert.Equal(t, expected, data, test.tag) 85 | } 86 | } 87 | 88 | type TU struct { 89 | UValue string 90 | } 91 | 92 | func (tu *TU) UnmarshalText(text []byte) error { 93 | tu.UValue = "TU_" + string(text[:]) 94 | return nil 95 | } 96 | 97 | func TestTextUnmarshaler(t *testing.T) { 98 | var a struct { 99 | ATU TU `form:"atu"` 100 | NTU string `form:"ntu"` 101 | } 102 | values := map[string][]string{ 103 | "atu": {"ORIGINAL"}, 104 | "ntu": {"ORIGINAL"}, 105 | } 106 | err := ReadFormData(values, &a) 107 | assert.Nil(t, err) 108 | assert.Equal(t, "TU_ORIGINAL", a.ATU.UValue) 109 | assert.Equal(t, "ORIGINAL", a.NTU) 110 | } 111 | -------------------------------------------------------------------------------- /route.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package routing 6 | 7 | import ( 8 | "fmt" 9 | "net/url" 10 | "strings" 11 | ) 12 | 13 | // Route represents a URL path pattern that can be used to match requested URLs. 14 | type Route struct { 15 | group *RouteGroup 16 | method, path string 17 | name, template string 18 | tags []interface{} 19 | routes []*Route 20 | } 21 | 22 | // Name sets the name of the route. 23 | // This method will update the registration of the route in the router as well. 24 | func (r *Route) Name(name string) *Route { 25 | r.name = name 26 | r.group.router.namedRoutes[name] = r 27 | return r 28 | } 29 | 30 | // Tag associates some custom data with the route. 31 | func (r *Route) Tag(value interface{}) *Route { 32 | if len(r.routes) > 0 { 33 | // this route is a composite one (a path with multiple methods) 34 | for _, route := range r.routes { 35 | route.Tag(value) 36 | } 37 | return r 38 | } 39 | if r.tags == nil { 40 | r.tags = []interface{}{} 41 | } 42 | r.tags = append(r.tags, value) 43 | return r 44 | } 45 | 46 | // Method returns the HTTP method that this route is associated with. 47 | func (r *Route) Method() string { 48 | return r.method 49 | } 50 | 51 | // Path returns the request path that this route should match. 52 | func (r *Route) Path() string { 53 | return r.group.prefix + r.path 54 | } 55 | 56 | // Tags returns all custom data associated with the route. 57 | func (r *Route) Tags() []interface{} { 58 | return r.tags 59 | } 60 | 61 | // Get adds the route to the router using the GET HTTP method. 62 | func (r *Route) Get(handlers ...Handler) *Route { 63 | return r.group.add("GET", r.path, handlers) 64 | } 65 | 66 | // Post adds the route to the router using the POST HTTP method. 67 | func (r *Route) Post(handlers ...Handler) *Route { 68 | return r.group.add("POST", r.path, handlers) 69 | } 70 | 71 | // Put adds the route to the router using the PUT HTTP method. 72 | func (r *Route) Put(handlers ...Handler) *Route { 73 | return r.group.add("PUT", r.path, handlers) 74 | } 75 | 76 | // Patch adds the route to the router using the PATCH HTTP method. 77 | func (r *Route) Patch(handlers ...Handler) *Route { 78 | return r.group.add("PATCH", r.path, handlers) 79 | } 80 | 81 | // Delete adds the route to the router using the DELETE HTTP method. 82 | func (r *Route) Delete(handlers ...Handler) *Route { 83 | return r.group.add("DELETE", r.path, handlers) 84 | } 85 | 86 | // Connect adds the route to the router using the CONNECT HTTP method. 87 | func (r *Route) Connect(handlers ...Handler) *Route { 88 | return r.group.add("CONNECT", r.path, handlers) 89 | } 90 | 91 | // Head adds the route to the router using the HEAD HTTP method. 92 | func (r *Route) Head(handlers ...Handler) *Route { 93 | return r.group.add("HEAD", r.path, handlers) 94 | } 95 | 96 | // Options adds the route to the router using the OPTIONS HTTP method. 97 | func (r *Route) Options(handlers ...Handler) *Route { 98 | return r.group.add("OPTIONS", r.path, handlers) 99 | } 100 | 101 | // Trace adds the route to the router using the TRACE HTTP method. 102 | func (r *Route) Trace(handlers ...Handler) *Route { 103 | return r.group.add("TRACE", r.path, handlers) 104 | } 105 | 106 | // To adds the route to the router with the given HTTP methods and handlers. 107 | // Multiple HTTP methods should be separated by commas (without any surrounding spaces). 108 | func (r *Route) To(methods string, handlers ...Handler) *Route { 109 | return r.group.To(methods, r.path, handlers...) 110 | } 111 | 112 | // URL creates a URL using the current route and the given parameters. 113 | // The parameters should be given in the sequence of name1, value1, name2, value2, and so on. 114 | // If a parameter in the route is not provided a value, the parameter token will remain in the resulting URL. 115 | // The method will perform URL encoding for all given parameter values. 116 | func (r *Route) URL(pairs ...interface{}) (s string) { 117 | s = r.template 118 | for i := 0; i < len(pairs); i++ { 119 | name := fmt.Sprintf("<%v>", pairs[i]) 120 | value := "" 121 | if i < len(pairs)-1 { 122 | value = url.QueryEscape(fmt.Sprint(pairs[i+1])) 123 | } 124 | s = strings.Replace(s, name, value, -1) 125 | } 126 | return 127 | } 128 | 129 | // String returns the string representation of the route. 130 | func (r *Route) String() string { 131 | return r.method + " " + r.group.prefix + r.path 132 | } 133 | -------------------------------------------------------------------------------- /route_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package routing 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | type mockStore struct { 16 | *store 17 | data map[string]interface{} 18 | } 19 | 20 | func newMockStore() *mockStore { 21 | return &mockStore{newStore(), make(map[string]interface{})} 22 | } 23 | 24 | func (s *mockStore) Add(key string, data interface{}) int { 25 | for _, handler := range data.([]Handler) { 26 | handler(nil) 27 | } 28 | return s.store.Add(key, data) 29 | } 30 | 31 | func TestRouteNew(t *testing.T) { 32 | router := New() 33 | group := newRouteGroup("/admin", router, nil) 34 | 35 | r1 := group.newRoute("GET", "/users").Get() 36 | assert.Equal(t, "", r1.name, "route.name =") 37 | assert.Equal(t, "/users", r1.path, "route.path =") 38 | assert.Equal(t, 1, len(router.Routes())) 39 | 40 | r2 := group.newRoute("GET", "/users//*").Post() 41 | assert.Equal(t, "", r2.name, "route.name =") 42 | assert.Equal(t, "/users//*", r2.path, "route.path =") 43 | assert.Equal(t, "/admin/users//", r2.template, "route.template =") 44 | assert.Equal(t, 2, len(router.Routes())) 45 | } 46 | 47 | func TestRouteName(t *testing.T) { 48 | router := New() 49 | group := newRouteGroup("/admin", router, nil) 50 | 51 | r1 := group.newRoute("GET", "/users") 52 | assert.Equal(t, "", r1.name, "route.name =") 53 | r1.Name("user") 54 | assert.Equal(t, "user", r1.name, "route.name =") 55 | _, exists := router.namedRoutes[r1.name] 56 | assert.True(t, exists) 57 | } 58 | 59 | func TestRouteURL(t *testing.T) { 60 | router := New() 61 | group := newRouteGroup("/admin", router, nil) 62 | r := group.newRoute("GET", "/users///*") 63 | assert.Equal(t, "/admin/users/123/address/", r.URL("id", 123, "action", "address")) 64 | assert.Equal(t, "/admin/users/123//", r.URL("id", 123)) 65 | assert.Equal(t, "/admin/users/123//", r.URL("id", 123, "action")) 66 | assert.Equal(t, "/admin/users/123/profile/", r.URL("id", 123, "action", "profile", "")) 67 | assert.Equal(t, "/admin/users/123/profile/", r.URL("id", 123, "action", "profile", "", "xyz/abc")) 68 | assert.Equal(t, "/admin/users/123/a%2C%3C%3E%3F%23/", r.URL("id", 123, "action", "a,<>?#")) 69 | } 70 | 71 | func newHandler(tag string, buf *bytes.Buffer) Handler { 72 | return func(*Context) error { 73 | fmt.Fprintf(buf, tag) 74 | return nil 75 | } 76 | } 77 | 78 | func TestRouteAdd(t *testing.T) { 79 | store := newMockStore() 80 | router := New() 81 | router.stores["GET"] = store 82 | assert.Equal(t, 0, store.count, "router.stores[GET].count =") 83 | 84 | var buf bytes.Buffer 85 | 86 | group := newRouteGroup("/admin", router, []Handler{newHandler("1.", &buf), newHandler("2.", &buf)}) 87 | group.newRoute("GET", "/users").Get(newHandler("3.", &buf), newHandler("4.", &buf)) 88 | assert.Equal(t, "1.2.3.4.", buf.String(), "buf@1 =") 89 | 90 | buf.Reset() 91 | group = newRouteGroup("/admin", router, []Handler{}) 92 | group.newRoute("GET", "/users").Get(newHandler("3.", &buf), newHandler("4.", &buf)) 93 | assert.Equal(t, "3.4.", buf.String(), "buf@2 =") 94 | 95 | buf.Reset() 96 | group = newRouteGroup("/admin", router, []Handler{newHandler("1.", &buf), newHandler("2.", &buf)}) 97 | group.newRoute("GET", "/users").Get() 98 | assert.Equal(t, "1.2.", buf.String(), "buf@3 =") 99 | } 100 | 101 | func TestRouteTag(t *testing.T) { 102 | router := New() 103 | router.Get("/posts").Tag("posts") 104 | router.Any("/users").Tag("users") 105 | router.To("PUT,PATCH", "/comments").Tag("comments") 106 | router.Get("/orders").Tag("GET orders").Post().Tag("POST orders") 107 | routes := router.Routes() 108 | for _, route := range routes { 109 | if !assert.True(t, len(route.Tags()) > 0, route.method+" "+route.path+" should have a tag") { 110 | continue 111 | } 112 | tag := route.Tags()[0].(string) 113 | switch route.path { 114 | case "/posts": 115 | assert.Equal(t, "posts", tag) 116 | case "/users": 117 | assert.Equal(t, "users", tag) 118 | case "/comments": 119 | assert.Equal(t, "comments", tag) 120 | case "/orders": 121 | if route.method == "GET" { 122 | assert.Equal(t, "GET orders", tag) 123 | } else { 124 | assert.Equal(t, "POST orders", tag) 125 | } 126 | } 127 | } 128 | } 129 | 130 | func TestRouteMethods(t *testing.T) { 131 | router := New() 132 | for _, method := range Methods { 133 | store := newMockStore() 134 | router.stores[method] = store 135 | assert.Equal(t, 0, store.count, "router.stores["+method+"].count =") 136 | } 137 | group := newRouteGroup("/admin", router, nil) 138 | 139 | group.newRoute("GET", "/users").Get() 140 | assert.Equal(t, 1, router.stores["GET"].(*mockStore).count, "router.stores[GET].count =") 141 | group.newRoute("GET", "/users").Post() 142 | assert.Equal(t, 1, router.stores["POST"].(*mockStore).count, "router.stores[POST].count =") 143 | group.newRoute("GET", "/users").Patch() 144 | assert.Equal(t, 1, router.stores["PATCH"].(*mockStore).count, "router.stores[PATCH].count =") 145 | group.newRoute("GET", "/users").Put() 146 | assert.Equal(t, 1, router.stores["PUT"].(*mockStore).count, "router.stores[PUT].count =") 147 | group.newRoute("GET", "/users").Delete() 148 | assert.Equal(t, 1, router.stores["DELETE"].(*mockStore).count, "router.stores[DELETE].count =") 149 | group.newRoute("GET", "/users").Connect() 150 | assert.Equal(t, 1, router.stores["CONNECT"].(*mockStore).count, "router.stores[CONNECT].count =") 151 | group.newRoute("GET", "/users").Head() 152 | assert.Equal(t, 1, router.stores["HEAD"].(*mockStore).count, "router.stores[HEAD].count =") 153 | group.newRoute("GET", "/users").Options() 154 | assert.Equal(t, 1, router.stores["OPTIONS"].(*mockStore).count, "router.stores[OPTIONS].count =") 155 | group.newRoute("GET", "/users").Trace() 156 | assert.Equal(t, 1, router.stores["TRACE"].(*mockStore).count, "router.stores[TRACE].count =") 157 | 158 | group.newRoute("GET", "/posts").To("GET,POST") 159 | assert.Equal(t, 2, router.stores["GET"].(*mockStore).count, "router.stores[GET].count =") 160 | assert.Equal(t, 2, router.stores["POST"].(*mockStore).count, "router.stores[POST].count =") 161 | assert.Equal(t, 1, router.stores["PUT"].(*mockStore).count, "router.stores[PUT].count =") 162 | 163 | group.newRoute("GET", "/posts").To("POST") 164 | assert.Equal(t, 2, router.stores["GET"].(*mockStore).count, "router.stores[GET].count =") 165 | assert.Equal(t, 3, router.stores["POST"].(*mockStore).count, "router.stores[POST].count =") 166 | assert.Equal(t, 1, router.stores["PUT"].(*mockStore).count, "router.stores[PUT].count =") 167 | } 168 | 169 | func TestBuildURLTemplate(t *testing.T) { 170 | tests := []struct { 171 | path, expected string 172 | }{ 173 | {"", ""}, 174 | {"/users", "/users"}, 175 | {"", ""}, 176 | {"", "/users/"}, 178 | {"/users/", "/users/"}, 179 | {"/users/<:\\d+>", "/users/<>"}, 180 | {"/users//xyz", "/users//xyz"}, 181 | {"/users//xyz", "/users//xyz"}, 182 | {"/users//", "/users//"}, 183 | {"/users///", "/users///"}, 184 | {"/users/", "/users/"}, 185 | {"/users//", "/users//"}, 186 | } 187 | for _, test := range tests { 188 | actual := buildURLTemplate(test.path) 189 | assert.Equal(t, test.expected, actual, "buildURLTemplate("+test.path+") =") 190 | } 191 | } 192 | 193 | func TestRouteString(t *testing.T) { 194 | router := New() 195 | router.Get("/users/") 196 | router.To("GET,POST", "/users//profile") 197 | group := router.Group("/admin") 198 | group.Post("/users") 199 | s := "" 200 | for _, route := range router.Routes() { 201 | s += fmt.Sprintln(route) 202 | } 203 | 204 | assert.Equal(t, `GET /users/ 205 | GET /users//profile 206 | POST /users//profile 207 | POST /admin/users 208 | `, s) 209 | } 210 | -------------------------------------------------------------------------------- /router.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package routing provides high performance and powerful HTTP routing capabilities. 6 | package routing 7 | 8 | import ( 9 | "net/http" 10 | "net/url" 11 | "sort" 12 | "strings" 13 | "sync" 14 | ) 15 | 16 | type ( 17 | // Handler is the function for handling HTTP requests. 18 | Handler func(*Context) error 19 | 20 | // Router manages routes and dispatches HTTP requests to the handlers of the matching routes. 21 | Router struct { 22 | RouteGroup 23 | IgnoreTrailingSlash bool // whether to ignore trailing slashes in the end of the request URL 24 | UseEscapedPath bool // whether to use encoded URL instead of decoded URL to match routes 25 | pool sync.Pool 26 | routes []*Route 27 | namedRoutes map[string]*Route 28 | stores map[string]routeStore 29 | maxParams int 30 | notFound []Handler 31 | notFoundHandlers []Handler 32 | } 33 | 34 | // routeStore stores route paths and the corresponding handlers. 35 | routeStore interface { 36 | Add(key string, data interface{}) int 37 | Get(key string, pvalues []string) (data interface{}, pnames []string) 38 | String() string 39 | } 40 | ) 41 | 42 | // Methods lists all supported HTTP methods by Router. 43 | var Methods = []string{ 44 | "CONNECT", 45 | "DELETE", 46 | "GET", 47 | "HEAD", 48 | "OPTIONS", 49 | "PATCH", 50 | "POST", 51 | "PUT", 52 | "TRACE", 53 | } 54 | 55 | // New creates a new Router object. 56 | func New() *Router { 57 | r := &Router{ 58 | namedRoutes: make(map[string]*Route), 59 | stores: make(map[string]routeStore), 60 | } 61 | r.RouteGroup = *newRouteGroup("", r, make([]Handler, 0)) 62 | r.NotFound(MethodNotAllowedHandler, NotFoundHandler) 63 | r.pool.New = func() interface{} { 64 | return &Context{ 65 | pvalues: make([]string, r.maxParams), 66 | router: r, 67 | } 68 | } 69 | return r 70 | } 71 | 72 | // ServeHTTP handles the HTTP request. 73 | // It is required by http.Handler 74 | func (r *Router) ServeHTTP(res http.ResponseWriter, req *http.Request) { 75 | c := r.pool.Get().(*Context) 76 | c.init(res, req) 77 | if r.UseEscapedPath { 78 | c.handlers, c.pnames = r.find(req.Method, r.normalizeRequestPath(req.URL.EscapedPath()), c.pvalues) 79 | for i, v := range c.pvalues { 80 | c.pvalues[i], _ = url.QueryUnescape(v) 81 | } 82 | } else { 83 | c.handlers, c.pnames = r.find(req.Method, r.normalizeRequestPath(req.URL.Path), c.pvalues) 84 | } 85 | if err := c.Next(); err != nil { 86 | r.handleError(c, err) 87 | } 88 | r.pool.Put(c) 89 | } 90 | 91 | // Route returns the named route. 92 | // Nil is returned if the named route cannot be found. 93 | func (r *Router) Route(name string) *Route { 94 | return r.namedRoutes[name] 95 | } 96 | 97 | // Routes returns all routes managed by the router. 98 | func (r *Router) Routes() []*Route { 99 | return r.routes 100 | } 101 | 102 | // Use appends the specified handlers to the router and shares them with all routes. 103 | func (r *Router) Use(handlers ...Handler) { 104 | r.RouteGroup.Use(handlers...) 105 | r.notFoundHandlers = combineHandlers(r.handlers, r.notFound) 106 | } 107 | 108 | // NotFound specifies the handlers that should be invoked when the router cannot find any route matching a request. 109 | // Note that the handlers registered via Use will be invoked first in this case. 110 | func (r *Router) NotFound(handlers ...Handler) { 111 | r.notFound = handlers 112 | r.notFoundHandlers = combineHandlers(r.handlers, r.notFound) 113 | } 114 | 115 | // Find determines the handlers and parameters to use for a specified method and path. 116 | func (r *Router) Find(method, path string) (handlers []Handler, params map[string]string) { 117 | pvalues := make([]string, r.maxParams) 118 | handlers, pnames := r.find(method, path, pvalues) 119 | params = make(map[string]string, len(pnames)) 120 | for i, n := range pnames { 121 | params[n] = pvalues[i] 122 | } 123 | return handlers, params 124 | } 125 | 126 | // handleError is the error handler for handling any unhandled errors. 127 | func (r *Router) handleError(c *Context, err error) { 128 | if httpError, ok := err.(HTTPError); ok { 129 | http.Error(c.Response, httpError.Error(), httpError.StatusCode()) 130 | } else { 131 | http.Error(c.Response, err.Error(), http.StatusInternalServerError) 132 | } 133 | } 134 | 135 | func (r *Router) addRoute(route *Route, handlers []Handler) { 136 | path := route.group.prefix + route.path 137 | 138 | r.routes = append(r.routes, route) 139 | 140 | store := r.stores[route.method] 141 | if store == nil { 142 | store = newStore() 143 | r.stores[route.method] = store 144 | } 145 | 146 | // an asterisk at the end matches any number of characters 147 | if strings.HasSuffix(path, "*") { 148 | path = path[:len(path)-1] + "<:.*>" 149 | } 150 | 151 | if n := store.Add(path, handlers); n > r.maxParams { 152 | r.maxParams = n 153 | } 154 | } 155 | 156 | func (r *Router) find(method, path string, pvalues []string) (handlers []Handler, pnames []string) { 157 | var hh interface{} 158 | if store := r.stores[method]; store != nil { 159 | hh, pnames = store.Get(path, pvalues) 160 | } 161 | if hh != nil { 162 | return hh.([]Handler), pnames 163 | } 164 | return r.notFoundHandlers, pnames 165 | } 166 | 167 | func (r *Router) findAllowedMethods(path string) map[string]bool { 168 | methods := make(map[string]bool) 169 | pvalues := make([]string, r.maxParams) 170 | for m, store := range r.stores { 171 | if handlers, _ := store.Get(path, pvalues); handlers != nil { 172 | methods[m] = true 173 | } 174 | } 175 | return methods 176 | } 177 | 178 | func (r *Router) normalizeRequestPath(path string) string { 179 | if r.IgnoreTrailingSlash && len(path) > 1 && path[len(path)-1] == '/' { 180 | for i := len(path) - 2; i > 0; i-- { 181 | if path[i] != '/' { 182 | return path[0 : i+1] 183 | } 184 | } 185 | return path[0:1] 186 | } 187 | return path 188 | } 189 | 190 | // NotFoundHandler returns a 404 HTTP error indicating a request has no matching route. 191 | func NotFoundHandler(*Context) error { 192 | return NewHTTPError(http.StatusNotFound) 193 | } 194 | 195 | // MethodNotAllowedHandler handles the situation when a request has matching route without matching HTTP method. 196 | // In this case, the handler will respond with an Allow HTTP header listing the allowed HTTP methods. 197 | // Otherwise, the handler will do nothing and let the next handler (usually a NotFoundHandler) to handle the problem. 198 | func MethodNotAllowedHandler(c *Context) error { 199 | methods := c.Router().findAllowedMethods(c.Request.URL.Path) 200 | if len(methods) == 0 { 201 | return nil 202 | } 203 | methods["OPTIONS"] = true 204 | ms := make([]string, len(methods)) 205 | i := 0 206 | for method := range methods { 207 | ms[i] = method 208 | i++ 209 | } 210 | sort.Strings(ms) 211 | c.Response.Header().Set("Allow", strings.Join(ms, ", ")) 212 | if c.Request.Method != "OPTIONS" { 213 | c.Response.WriteHeader(http.StatusMethodNotAllowed) 214 | } 215 | c.Abort() 216 | return nil 217 | } 218 | 219 | // HTTPHandlerFunc adapts a http.HandlerFunc into a routing.Handler. 220 | func HTTPHandlerFunc(h http.HandlerFunc) Handler { 221 | return func(c *Context) error { 222 | h(c.Response, c.Request) 223 | return nil 224 | } 225 | } 226 | 227 | // HTTPHandler adapts a http.Handler into a routing.Handler. 228 | func HTTPHandler(h http.Handler) Handler { 229 | return func(c *Context) error { 230 | h.ServeHTTP(c.Response, c.Request) 231 | return nil 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /router_test.go: -------------------------------------------------------------------------------- 1 | package routing 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/stretchr/testify/assert" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | ) 11 | 12 | func TestRouterNotFound(t *testing.T) { 13 | r := New() 14 | h := func(c *Context) error { 15 | fmt.Fprint(c.Response, "ok") 16 | return nil 17 | } 18 | r.Get("/users", h) 19 | r.Post("/users", h) 20 | r.NotFound(MethodNotAllowedHandler, NotFoundHandler) 21 | 22 | res := httptest.NewRecorder() 23 | req, _ := http.NewRequest("GET", "/users", nil) 24 | r.ServeHTTP(res, req) 25 | assert.Equal(t, "ok", res.Body.String(), "response body") 26 | assert.Equal(t, http.StatusOK, res.Code, "HTTP status code") 27 | 28 | res = httptest.NewRecorder() 29 | req, _ = http.NewRequest("PUT", "/users", nil) 30 | r.ServeHTTP(res, req) 31 | assert.Equal(t, "GET, OPTIONS, POST", res.Header().Get("Allow"), "Allow header") 32 | assert.Equal(t, http.StatusMethodNotAllowed, res.Code, "HTTP status code") 33 | 34 | res = httptest.NewRecorder() 35 | req, _ = http.NewRequest("OPTIONS", "/users", nil) 36 | r.ServeHTTP(res, req) 37 | assert.Equal(t, "GET, OPTIONS, POST", res.Header().Get("Allow"), "Allow header") 38 | assert.Equal(t, http.StatusOK, res.Code, "HTTP status code") 39 | 40 | res = httptest.NewRecorder() 41 | req, _ = http.NewRequest("GET", "/users/", nil) 42 | r.ServeHTTP(res, req) 43 | assert.Equal(t, "", res.Header().Get("Allow"), "Allow header") 44 | assert.Equal(t, http.StatusNotFound, res.Code, "HTTP status code") 45 | 46 | r.IgnoreTrailingSlash = true 47 | res = httptest.NewRecorder() 48 | req, _ = http.NewRequest("GET", "/users/", nil) 49 | r.ServeHTTP(res, req) 50 | assert.Equal(t, "ok", res.Body.String(), "response body") 51 | assert.Equal(t, http.StatusOK, res.Code, "HTTP status code") 52 | } 53 | 54 | func TestRouterUse(t *testing.T) { 55 | r := New() 56 | assert.Equal(t, 2, len(r.notFoundHandlers)) 57 | r.Use(NotFoundHandler) 58 | assert.Equal(t, 3, len(r.notFoundHandlers)) 59 | } 60 | 61 | func TestRouterRoute(t *testing.T) { 62 | r := New() 63 | r.Get("/users").Name("users") 64 | assert.NotNil(t, r.Route("users")) 65 | assert.Nil(t, r.Route("users2")) 66 | } 67 | 68 | func TestRouterAdd(t *testing.T) { 69 | r := New() 70 | assert.Equal(t, 0, r.maxParams) 71 | r.add("GET", "/users/", nil) 72 | assert.Equal(t, 1, r.maxParams) 73 | } 74 | 75 | func TestRouterFind(t *testing.T) { 76 | r := New() 77 | r.add("GET", "/users/", []Handler{NotFoundHandler}) 78 | handlers, params := r.Find("GET", "/users/1") 79 | assert.Equal(t, 1, len(handlers)) 80 | if assert.Equal(t, 1, len(params)) { 81 | assert.Equal(t, "1", params["id"]) 82 | } 83 | } 84 | 85 | func TestRouterNormalizeRequestPath(t *testing.T) { 86 | tests := []struct { 87 | path string 88 | expected string 89 | }{ 90 | {"/", "/"}, 91 | {"/users", "/users"}, 92 | {"/users/", "/users"}, 93 | {"/users//", "/users"}, 94 | {"///", "/"}, 95 | } 96 | r := New() 97 | r.IgnoreTrailingSlash = true 98 | for _, test := range tests { 99 | result := r.normalizeRequestPath(test.path) 100 | assert.Equal(t, test.expected, result) 101 | } 102 | } 103 | 104 | func TestRouterHandleError(t *testing.T) { 105 | r := New() 106 | res := httptest.NewRecorder() 107 | c := &Context{Response: res} 108 | r.handleError(c, errors.New("abc")) 109 | assert.Equal(t, http.StatusInternalServerError, res.Code) 110 | 111 | res = httptest.NewRecorder() 112 | c = &Context{Response: res} 113 | r.handleError(c, NewHTTPError(http.StatusNotFound)) 114 | assert.Equal(t, http.StatusNotFound, res.Code) 115 | } 116 | 117 | func TestHTTPHandler(t *testing.T) { 118 | res := httptest.NewRecorder() 119 | req, _ := http.NewRequest("GET", "/users/", nil) 120 | c := NewContext(res, req) 121 | 122 | h1 := HTTPHandlerFunc(http.NotFound) 123 | assert.Nil(t, h1(c)) 124 | assert.Equal(t, http.StatusNotFound, res.Code) 125 | 126 | res = httptest.NewRecorder() 127 | req, _ = http.NewRequest("GET", "/users/", nil) 128 | c = NewContext(res, req) 129 | h2 := HTTPHandler(http.NotFoundHandler()) 130 | assert.Nil(t, h2(c)) 131 | assert.Equal(t, http.StatusNotFound, res.Code) 132 | } 133 | -------------------------------------------------------------------------------- /slash/remover.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package slash provides a trailing slash remover handler for the ozzo routing package. 6 | package slash 7 | 8 | import ( 9 | "net/http" 10 | "strings" 11 | 12 | "github.com/go-ozzo/ozzo-routing/v2" 13 | ) 14 | 15 | // Remover returns a handler that removes the trailing slash (if any) from the requested URL. 16 | // The handler will redirect the browser to the new URL without the trailing slash. 17 | // The status parameter should be either http.StatusMovedPermanently (301) or http.StatusFound (302), which is to 18 | // be used for redirecting GET requests. For other requests, the status code will be http.StatusTemporaryRedirect (307). 19 | // If the original URL has no trailing slash, the handler will do nothing. For example, 20 | // 21 | // import ( 22 | // "net/http" 23 | // "github.com/go-ozzo/ozzo-routing/v2" 24 | // "github.com/go-ozzo/ozzo-routing/v2/slash" 25 | // ) 26 | // 27 | // r := routing.New() 28 | // r.Use(slash.Remover(http.StatusMovedPermanently)) 29 | // 30 | // Note that Remover relies on HTTP redirection to remove the trailing slashes. 31 | // If you do not want redirection, please set `Router.IgnoreTrailingSlash` to be true without using Remover. 32 | func Remover(status int) routing.Handler { 33 | return func(c *routing.Context) error { 34 | if c.Request.URL.Path != "/" && strings.HasSuffix(c.Request.URL.Path, "/") { 35 | if c.Request.Method != "GET" { 36 | status = http.StatusTemporaryRedirect 37 | } 38 | http.Redirect(c.Response, c.Request, strings.TrimRight(c.Request.URL.Path, "/"), status) 39 | c.Abort() 40 | } 41 | return nil 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /slash/remover_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package slash 6 | 7 | import ( 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | 12 | "github.com/go-ozzo/ozzo-routing/v2" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestRemover(t *testing.T) { 17 | h := Remover(http.StatusMovedPermanently) 18 | res := httptest.NewRecorder() 19 | req, _ := http.NewRequest("GET", "/users/", nil) 20 | c := routing.NewContext(res, req) 21 | err := h(c) 22 | assert.Nil(t, err, "return value is nil") 23 | assert.Equal(t, http.StatusMovedPermanently, res.Code) 24 | assert.Equal(t, "/users", res.Header().Get("Location")) 25 | 26 | res = httptest.NewRecorder() 27 | req, _ = http.NewRequest("GET", "/", nil) 28 | c = routing.NewContext(res, req) 29 | err = h(c) 30 | assert.Nil(t, err, "return value is nil") 31 | assert.Equal(t, http.StatusOK, res.Code) 32 | assert.Equal(t, "", res.Header().Get("Location")) 33 | 34 | res = httptest.NewRecorder() 35 | req, _ = http.NewRequest("GET", "/users", nil) 36 | c = routing.NewContext(res, req) 37 | err = h(c) 38 | assert.Nil(t, err, "return value is nil") 39 | assert.Equal(t, http.StatusOK, res.Code) 40 | assert.Equal(t, "", res.Header().Get("Location")) 41 | 42 | res = httptest.NewRecorder() 43 | req, _ = http.NewRequest("POST", "/users/", nil) 44 | c = routing.NewContext(res, req) 45 | err = h(c) 46 | assert.Nil(t, err, "return value is nil") 47 | assert.Equal(t, http.StatusTemporaryRedirect, res.Code) 48 | assert.Equal(t, "/users", res.Header().Get("Location")) 49 | } 50 | -------------------------------------------------------------------------------- /store.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package routing 6 | 7 | import ( 8 | "fmt" 9 | "math" 10 | "regexp" 11 | "strings" 12 | ) 13 | 14 | // store is a radix tree that supports storing data with parametric keys and retrieving them back with concrete keys. 15 | // When retrieving a data item with a concrete key, the matching parameter names and values will be returned as well. 16 | // A parametric key is a string containing tokens in the format of "", "", or "<:pattern>". 17 | // Each token represents a single parameter. 18 | type store struct { 19 | root *node // the root node of the radix tree 20 | count int // the number of data nodes in the tree 21 | } 22 | 23 | // newStore creates a new store. 24 | func newStore() *store { 25 | return &store{ 26 | root: &node{ 27 | static: true, 28 | children: make([]*node, 256), 29 | pchildren: make([]*node, 0), 30 | pindex: -1, 31 | pnames: []string{}, 32 | }, 33 | } 34 | } 35 | 36 | // Add adds a new data item with the given parametric key. 37 | // The number of parameters in the key is returned. 38 | func (s *store) Add(key string, data interface{}) int { 39 | s.count++ 40 | return s.root.add(key, data, s.count) 41 | } 42 | 43 | // Get returns the data item matching the given concrete key. 44 | // If the data item was added to the store with a parametric key before, the matching 45 | // parameter names and values will be returned as well. 46 | func (s *store) Get(path string, pvalues []string) (data interface{}, pnames []string) { 47 | data, pnames, _ = s.root.get(path, pvalues) 48 | return 49 | } 50 | 51 | // String dumps the radix tree kept in the store as a string. 52 | func (s *store) String() string { 53 | return s.root.print(0) 54 | } 55 | 56 | // node represents a radix trie node 57 | type node struct { 58 | static bool // whether the node is a static node or param node 59 | 60 | key string // the key identifying this node 61 | data interface{} // the data associated with this node. nil if not a data node. 62 | 63 | order int // the order at which the data was added. used to be pick the first one when matching multiple 64 | minOrder int // minimum order among all the child nodes and this node 65 | 66 | children []*node // child static nodes, indexed by the first byte of each child key 67 | pchildren []*node // child param nodes 68 | 69 | regex *regexp.Regexp // regular expression for a param node containing regular expression key 70 | pindex int // the parameter index, meaningful only for param node 71 | pnames []string // the parameter names collected from the root till this node 72 | } 73 | 74 | // add adds a new data item to the tree rooted at the current node. 75 | // The number of parameters in the key is returned. 76 | func (n *node) add(key string, data interface{}, order int) int { 77 | matched := 0 78 | 79 | // find the common prefix 80 | for ; matched < len(key) && matched < len(n.key); matched++ { 81 | if key[matched] != n.key[matched] { 82 | break 83 | } 84 | } 85 | 86 | if matched == len(n.key) { 87 | if matched == len(key) { 88 | // the node key is the same as the key: make the current node as data node 89 | // if the node is already a data node, ignore the new data since we only care the first matched node 90 | if n.data == nil { 91 | n.data = data 92 | n.order = order 93 | } 94 | return n.pindex + 1 95 | } 96 | 97 | // the node key is a prefix of the key: create a child node 98 | newKey := key[matched:] 99 | 100 | // try adding to a static child 101 | if child := n.children[newKey[0]]; child != nil { 102 | if pn := child.add(newKey, data, order); pn >= 0 { 103 | return pn 104 | } 105 | } 106 | // try adding to a param child 107 | for _, child := range n.pchildren { 108 | if pn := child.add(newKey, data, order); pn >= 0 { 109 | return pn 110 | } 111 | } 112 | 113 | return n.addChild(newKey, data, order) 114 | } 115 | 116 | if matched == 0 || !n.static { 117 | // no common prefix, or partial common prefix with a non-static node: should skip this node 118 | return -1 119 | } 120 | 121 | // the node key shares a partial prefix with the key: split the node key 122 | n1 := &node{ 123 | static: true, 124 | key: n.key[matched:], 125 | data: n.data, 126 | order: n.order, 127 | minOrder: n.minOrder, 128 | pchildren: n.pchildren, 129 | children: n.children, 130 | pindex: n.pindex, 131 | pnames: n.pnames, 132 | } 133 | 134 | n.key = key[0:matched] 135 | n.data = nil 136 | n.pchildren = make([]*node, 0) 137 | n.children = make([]*node, 256) 138 | n.children[n1.key[0]] = n1 139 | 140 | return n.add(key, data, order) 141 | } 142 | 143 | // addChild creates static and param nodes to store the given data 144 | func (n *node) addChild(key string, data interface{}, order int) int { 145 | // find the first occurrence of a param token 146 | p0, p1 := -1, -1 147 | for i := 0; i < len(key); i++ { 148 | if p0 < 0 && key[i] == '<' { 149 | p0 = i 150 | } 151 | if p0 >= 0 && key[i] == '>' { 152 | p1 = i 153 | break 154 | } 155 | } 156 | 157 | if p0 > 0 && p1 > 0 || p1 < 0 { 158 | // param token occurs after a static string, or no param token: create a static node 159 | child := &node{ 160 | static: true, 161 | key: key, 162 | minOrder: order, 163 | children: make([]*node, 256), 164 | pchildren: make([]*node, 0), 165 | pindex: n.pindex, 166 | pnames: n.pnames, 167 | } 168 | n.children[key[0]] = child 169 | if p1 > 0 { 170 | // param token occurs after a static string 171 | child.key = key[:p0] 172 | n = child 173 | } else { 174 | // no param token: done adding the child 175 | child.data = data 176 | child.order = order 177 | return child.pindex + 1 178 | } 179 | } 180 | 181 | // add param node 182 | child := &node{ 183 | static: false, 184 | key: key[p0 : p1+1], 185 | minOrder: order, 186 | children: make([]*node, 256), 187 | pchildren: make([]*node, 0), 188 | pindex: n.pindex, 189 | pnames: n.pnames, 190 | } 191 | pattern := "" 192 | pname := key[p0+1 : p1] 193 | for i := p0 + 1; i < p1; i++ { 194 | if key[i] == ':' { 195 | pname = key[p0+1 : i] 196 | pattern = key[i+1 : p1] 197 | break 198 | } 199 | } 200 | if pattern != "" { 201 | // the param token contains a regular expression 202 | child.regex = regexp.MustCompile("^" + pattern) 203 | } 204 | pnames := make([]string, len(n.pnames)+1) 205 | copy(pnames, n.pnames) 206 | pnames[len(n.pnames)] = pname 207 | child.pnames = pnames 208 | child.pindex = len(pnames) - 1 209 | n.pchildren = append(n.pchildren, child) 210 | 211 | if p1 == len(key)-1 { 212 | // the param token is at the end of the key 213 | child.data = data 214 | child.order = order 215 | return child.pindex + 1 216 | } 217 | 218 | // process the rest of the key 219 | return child.addChild(key[p1+1:], data, order) 220 | } 221 | 222 | // get returns the data item with the key matching the tree rooted at the current node 223 | func (n *node) get(key string, pvalues []string) (data interface{}, pnames []string, order int) { 224 | order = math.MaxInt32 225 | 226 | repeat: 227 | if n.static { 228 | // check if the node key is a prefix of the given key 229 | // a slightly optimized version of strings.HasPrefix 230 | nkl := len(n.key) 231 | if nkl > len(key) { 232 | return 233 | } 234 | for i := nkl - 1; i >= 0; i-- { 235 | if n.key[i] != key[i] { 236 | return 237 | } 238 | } 239 | key = key[nkl:] 240 | } else if n.regex != nil { 241 | // param node with regular expression 242 | if n.regex.String() == "^.*" { 243 | pvalues[n.pindex] = key 244 | key = "" 245 | } else if match := n.regex.FindStringIndex(key); match != nil { 246 | pvalues[n.pindex] = key[0:match[1]] 247 | key = key[match[1]:] 248 | } else { 249 | return 250 | } 251 | } else { 252 | // param node matching non-"/" characters 253 | i, kl := 0, len(key) 254 | for ; i < kl; i++ { 255 | if key[i] == '/' { 256 | pvalues[n.pindex] = key[0:i] 257 | key = key[i:] 258 | break 259 | } 260 | } 261 | if i == kl { 262 | pvalues[n.pindex] = key 263 | key = "" 264 | } 265 | } 266 | 267 | if len(key) > 0 { 268 | // find a static child that can match the rest of the key 269 | if child := n.children[key[0]]; child != nil { 270 | if len(n.pchildren) == 0 { 271 | // use goto to avoid recursion when no param children 272 | n = child 273 | goto repeat 274 | } 275 | data, pnames, order = child.get(key, pvalues) 276 | } 277 | } else if n.data != nil { 278 | // do not return yet: a param node may match an empty string with smaller order 279 | data, pnames, order = n.data, n.pnames, n.order 280 | } 281 | 282 | // try matching param children 283 | tvalues := pvalues 284 | allocated := false 285 | for _, child := range n.pchildren { 286 | if child.minOrder >= order { 287 | continue 288 | } 289 | if data != nil && !allocated { 290 | tvalues = make([]string, len(pvalues)) 291 | allocated = true 292 | } 293 | if d, p, s := child.get(key, tvalues); d != nil && s < order { 294 | if allocated { 295 | for i := child.pindex; i < len(p); i++ { 296 | pvalues[i] = tvalues[i] 297 | } 298 | } 299 | data, pnames, order = d, p, s 300 | } 301 | } 302 | 303 | return 304 | } 305 | 306 | func (n *node) print(level int) string { 307 | r := fmt.Sprintf("%v{key: %v, regex: %v, data: %v, order: %v, minOrder: %v, pindex: %v, pnames: %v}\n", strings.Repeat(" ", level<<2), n.key, n.regex, n.data, n.order, n.minOrder, n.pindex, n.pnames) 308 | for _, child := range n.children { 309 | if child != nil { 310 | r += child.print(level + 1) 311 | } 312 | } 313 | for _, child := range n.pchildren { 314 | r += child.print(level + 1) 315 | } 316 | return r 317 | } 318 | -------------------------------------------------------------------------------- /store_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Qiang Xue. All rights reserved. 2 | // Use of this source code is governed by a MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package routing 6 | 7 | import ( 8 | "fmt" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | type storeTestEntry struct { 15 | key, data string 16 | params int 17 | } 18 | 19 | func TestStoreAdd(t *testing.T) { 20 | tests := []struct { 21 | id string 22 | entries []storeTestEntry 23 | expected string 24 | }{ 25 | { 26 | "all static", 27 | []storeTestEntry{ 28 | {"/gopher/bumper.png", "1", 0}, 29 | {"/gopher/bumper192x108.png", "2", 0}, 30 | {"/gopher/doc.png", "3", 0}, 31 | {"/gopher/bumper320x180.png", "4", 0}, 32 | {"/gopher/docpage.png", "5", 0}, 33 | {"/gopher/doc.png", "6", 0}, 34 | {"/gopher/doc", "7", 0}, 35 | }, 36 | `{key: , regex: , data: , order: 0, minOrder: 0, pindex: -1, pnames: []} 37 | {key: /gopher/, regex: , data: , order: 1, minOrder: 1, pindex: -1, pnames: []} 38 | {key: bumper, regex: , data: , order: 1, minOrder: 1, pindex: -1, pnames: []} 39 | {key: .png, regex: , data: 1, order: 1, minOrder: 1, pindex: -1, pnames: []} 40 | {key: 192x108.png, regex: , data: 2, order: 2, minOrder: 2, pindex: -1, pnames: []} 41 | {key: 320x180.png, regex: , data: 4, order: 4, minOrder: 4, pindex: -1, pnames: []} 42 | {key: doc, regex: , data: 7, order: 7, minOrder: 3, pindex: -1, pnames: []} 43 | {key: .png, regex: , data: 3, order: 3, minOrder: 3, pindex: -1, pnames: []} 44 | {key: page.png, regex: , data: 5, order: 5, minOrder: 5, pindex: -1, pnames: []} 45 | `, 46 | }, 47 | { 48 | "parametric", 49 | []storeTestEntry{ 50 | {"/users/", "11", 1}, 51 | {"/users//profile", "12", 1}, 52 | {"/users///address", "13", 2}, 53 | {"/users//age", "14", 1}, 54 | {"/users//", "15", 2}, 55 | }, 56 | `{key: , regex: , data: , order: 0, minOrder: 0, pindex: -1, pnames: []} 57 | {key: /users/, regex: , data: , order: 0, minOrder: 1, pindex: -1, pnames: []} 58 | {key: , regex: , data: 11, order: 1, minOrder: 1, pindex: 0, pnames: [id]} 59 | {key: /, regex: , data: , order: 2, minOrder: 2, pindex: 0, pnames: [id]} 60 | {key: age, regex: , data: 14, order: 4, minOrder: 4, pindex: 0, pnames: [id]} 61 | {key: profile, regex: , data: 12, order: 2, minOrder: 2, pindex: 0, pnames: [id]} 62 | {key: , regex: ^\d+, data: 15, order: 5, minOrder: 3, pindex: 1, pnames: [id accnt]} 63 | {key: /address, regex: , data: 13, order: 3, minOrder: 3, pindex: 1, pnames: [id accnt]} 64 | `, 65 | }, 66 | { 67 | "corner cases", 68 | []storeTestEntry{ 69 | {"/users//test/", "101", 2}, 70 | {"/users/abc//", "102", 2}, 71 | {"", "103", 0}, 72 | }, 73 | `{key: , regex: , data: 103, order: 3, minOrder: 0, pindex: -1, pnames: []} 74 | {key: /users/, regex: , data: , order: 0, minOrder: 1, pindex: -1, pnames: []} 75 | {key: abc/, regex: , data: , order: 0, minOrder: 2, pindex: -1, pnames: []} 76 | {key: , regex: , data: , order: 0, minOrder: 2, pindex: 0, pnames: [id]} 77 | {key: /, regex: , data: , order: 0, minOrder: 2, pindex: 0, pnames: [id]} 78 | {key: , regex: , data: 102, order: 2, minOrder: 2, pindex: 1, pnames: [id name]} 79 | {key: , regex: , data: , order: 0, minOrder: 1, pindex: 0, pnames: [id]} 80 | {key: /test/, regex: , data: , order: 0, minOrder: 1, pindex: 0, pnames: [id]} 81 | {key: , regex: , data: 101, order: 1, minOrder: 1, pindex: 1, pnames: [id name]} 82 | `, 83 | }, 84 | } 85 | for _, test := range tests { 86 | h := newStore() 87 | for _, entry := range test.entries { 88 | n := h.Add(entry.key, entry.data) 89 | assert.Equal(t, entry.params, n, test.id+" > "+entry.key+" > param count =") 90 | } 91 | assert.Equal(t, test.expected, h.String(), test.id+" > store.String() =") 92 | } 93 | } 94 | 95 | func TestStoreGet(t *testing.T) { 96 | pairs := []struct { 97 | key, value string 98 | }{ 99 | {"/gopher/bumper.png", "1"}, 100 | {"/gopher/bumper192x108.png", "2"}, 101 | {"/gopher/doc.png", "3"}, 102 | {"/gopher/bumper320x180.png", "4"}, 103 | {"/gopher/docpage.png", "5"}, 104 | {"/gopher/doc.png", "6"}, 105 | {"/gopher/doc", "7"}, 106 | {"/users/", "8"}, 107 | {"/users//profile", "9"}, 108 | {"/users///address", "10"}, 109 | {"/users//age", "11"}, 110 | {"/users//", "12"}, 111 | {"/users//test/", "13"}, 112 | {"/users/abc//", "14"}, 113 | {"", "15"}, 114 | {"/all/<:.*>", "16"}, 115 | } 116 | h := newStore() 117 | maxParams := 0 118 | for _, pair := range pairs { 119 | n := h.Add(pair.key, pair.value) 120 | if n > maxParams { 121 | maxParams = n 122 | } 123 | } 124 | assert.Equal(t, 2, maxParams, "param count = ") 125 | 126 | tests := []struct { 127 | key string 128 | value interface{} 129 | params string 130 | }{ 131 | {"/gopher/bumper.png", "1", ""}, 132 | {"/gopher/bumper192x108.png", "2", ""}, 133 | {"/gopher/doc.png", "3", ""}, 134 | {"/gopher/bumper320x180.png", "4", ""}, 135 | {"/gopher/docpage.png", "5", ""}, 136 | {"/gopher/doc.png", "3", ""}, 137 | {"/gopher/doc", "7", ""}, 138 | {"/users/abc", "8", "id:abc,"}, 139 | {"/users/abc/profile", "9", "id:abc,"}, 140 | {"/users/abc/123/address", "10", "id:abc,accnt:123,"}, 141 | {"/users/abcd/age", "11", "id:abcd,"}, 142 | {"/users/abc/123", "12", "id:abc,accnt:123,"}, 143 | {"/users/abc/test/123", "13", "id:abc,name:123,"}, 144 | {"/users/abc/xyz/123", "14", "id:xyz,name:123,"}, 145 | {"", "15", ""}, 146 | {"/g", nil, ""}, 147 | {"/all", nil, ""}, 148 | {"/all/", "16", ":,"}, 149 | {"/all/abc", "16", ":abc,"}, 150 | {"/users/abc/xyz", nil, ""}, 151 | } 152 | pvalues := make([]string, maxParams) 153 | for _, test := range tests { 154 | data, pnames := h.Get(test.key, pvalues) 155 | assert.Equal(t, test.value, data, "store.Get("+test.key+") =") 156 | params := "" 157 | if len(pnames) > 0 { 158 | for i, name := range pnames { 159 | params += fmt.Sprintf("%v:%v,", name, pvalues[i]) 160 | } 161 | } 162 | assert.Equal(t, test.params, params, "store.Get("+test.key+").params =") 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /writer.go: -------------------------------------------------------------------------------- 1 | package routing 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | ) 7 | 8 | // DataWriter is used by Context.Write() to write arbitrary data into an HTTP response. 9 | type DataWriter interface { 10 | // SetHeader sets necessary response headers. 11 | SetHeader(http.ResponseWriter) 12 | // Write writes the given data into the response. 13 | Write(http.ResponseWriter, interface{}) error 14 | } 15 | 16 | // DefaultDataWriter writes the given data in an HTTP response. 17 | // If the data is neither string nor byte array, it will use fmt.Fprint() to write it into the response. 18 | var DefaultDataWriter DataWriter = &dataWriter{} 19 | 20 | type dataWriter struct{} 21 | 22 | func (w *dataWriter) SetHeader(res http.ResponseWriter) {} 23 | 24 | func (w *dataWriter) Write(res http.ResponseWriter, data interface{}) error { 25 | var bytes []byte 26 | switch data.(type) { 27 | case []byte: 28 | bytes = data.([]byte) 29 | case string: 30 | bytes = []byte(data.(string)) 31 | default: 32 | if data != nil { 33 | _, err := fmt.Fprint(res, data) 34 | return err 35 | } 36 | } 37 | _, err := res.Write(bytes) 38 | return err 39 | } 40 | -------------------------------------------------------------------------------- /writer_test.go: -------------------------------------------------------------------------------- 1 | package routing 2 | 3 | import ( 4 | "net/http/httptest" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestDefaultDataWriter(t *testing.T) { 11 | res := httptest.NewRecorder() 12 | err := DefaultDataWriter.Write(res, "abc") 13 | assert.Nil(t, err) 14 | assert.Equal(t, "abc", res.Body.String()) 15 | 16 | res = httptest.NewRecorder() 17 | err = DefaultDataWriter.Write(res, []byte("abc")) 18 | assert.Nil(t, err) 19 | assert.Equal(t, "abc", res.Body.String()) 20 | 21 | res = httptest.NewRecorder() 22 | err = DefaultDataWriter.Write(res, 123) 23 | assert.Nil(t, err) 24 | assert.Equal(t, "123", res.Body.String()) 25 | 26 | res = httptest.NewRecorder() 27 | err = DefaultDataWriter.Write(res, nil) 28 | assert.Nil(t, err) 29 | assert.Equal(t, "", res.Body.String()) 30 | 31 | res = httptest.NewRecorder() 32 | c := &Context{} 33 | c.init(res, nil) 34 | assert.Nil(t, c.Write("abc")) 35 | assert.Equal(t, "abc", res.Body.String()) 36 | } 37 | --------------------------------------------------------------------------------