├── LICENSE ├── README.md ├── bench_test.go ├── defaults.go ├── defaults_appengine.go ├── globalmux.go ├── globalmux_17.go ├── globalmux_old.go ├── handler_17.go ├── handler_19.go ├── handler_new.go ├── handler_old.go ├── kami.go ├── kami_17.go ├── kami_17_test.go ├── kami_old_test.go ├── middleware.go ├── middleware_17.go ├── middleware_19.go ├── middleware_new.go ├── middleware_old.go ├── middleware_test.go ├── mux.go ├── mux_17.go ├── mux_test.go ├── params.go ├── params_test.go ├── serve.go ├── serve_appengine.go └── treemux ├── LICENSE ├── README.md ├── router.go ├── tree.go └── tree_test.go /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Gregory Roseberry (greg@toki.waseda.jp) 2 | Uses code from Goji: Copyright (c) 2014, 2015 Carl Jackson (carl@avtok.com) 3 | 4 | MIT License 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy of 7 | this software and associated documentation files (the "Software"), to deal in 8 | the Software without restriction, including without limitation the rights to 9 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 10 | the Software, and to permit persons to whom the Software is furnished to do so, 11 | subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 18 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 19 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 20 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 21 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## kami [![GoDoc](https://godoc.org/github.com/guregu/kami?status.svg)](https://godoc.org/github.com/guregu/kami) [![CircleCI](https://circleci.com/gh/guregu/kami.svg?style=svg)](https://circleci.com/gh/guregu/kami) 2 | `import "github.com/guregu/kami"` [or](http://gopkg.in) `import "gopkg.in/guregu/kami.v2"` 3 | 4 | kami (神) is a tiny web framework using [context](https://blog.golang.org/context) for request context and [httptreemux](https://github.com/dimfeld/httptreemux) for routing. It includes a simple system for running hierarchical middleware before and after requests, in addition to log and panic hooks. Graceful restart via einhorn is also supported. 5 | 6 | kami is designed to be used as central registration point for your routes, middleware, and context "god object". You are encouraged to use the global functions, but kami supports multiple muxes with `kami.New()`. 7 | 8 | You are free to mount `kami.Handler()` wherever you wish, but a helpful `kami.Serve()` function is provided. 9 | 10 | Here is a [presentation about the birth of kami](http://go-talks.appspot.com/github.com/guregu/slides/kami/kami.slide), explaining some of the design choices. 11 | 12 | Both `context` and `x/net/context` are supported. 13 | 14 | ### Example 15 | A contrived example using kami and context to localize greetings. 16 | 17 | [Skip :fast_forward:](#usage) 18 | 19 | 20 | ```go 21 | // Our webserver 22 | package main 23 | 24 | import ( 25 | "fmt" 26 | "net/http" 27 | "context" 28 | 29 | "github.com/guregu/kami" 30 | 31 | "github.com/my-github/greeting" // see package greeting below 32 | ) 33 | 34 | func greet(ctx context.Context, w http.ResponseWriter, r *http.Request) { 35 | hello := greeting.FromContext(ctx) 36 | name := kami.Param(ctx, "name") 37 | fmt.Fprintf(w, "%s, %s!", hello, name) 38 | } 39 | 40 | func main() { 41 | ctx := context.Background() 42 | ctx = greeting.WithContext(ctx, "Hello") // set default greeting 43 | kami.Context = ctx // set our "god context", the base context for all requests 44 | 45 | kami.Use("/hello/", greeting.Guess) // use this middleware for paths under /hello/ 46 | kami.Get("/hello/:name", greet) // add a GET handler with a parameter in the URL 47 | kami.Serve() // gracefully serve with support for einhorn and systemd 48 | } 49 | ``` 50 | 51 | ```go 52 | // Package greeting stores greeting settings in context. 53 | package greeting 54 | 55 | import ( 56 | "net/http" 57 | "context" 58 | 59 | "golang.org/x/text/language" 60 | ) 61 | 62 | // For more information about context and why we're doing this, 63 | // see https://blog.golang.org/context 64 | type ctxkey int 65 | 66 | var key ctxkey = 0 67 | 68 | var greetings = map[language.Tag]string{ 69 | language.AmericanEnglish: "Yo", 70 | language.Japanese: "こんにちは", 71 | } 72 | 73 | // Guess is kami middleware that examines Accept-Language and sets 74 | // the greeting to a better one if possible. 75 | func Guess(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 76 | if tag, _, err := language.ParseAcceptLanguage(r.Header.Get("Accept-Language")); err == nil { 77 | for _, t := range tag { 78 | if g, ok := greetings[t]; ok { 79 | ctx = WithContext(ctx, g) 80 | return ctx 81 | } 82 | } 83 | } 84 | return ctx 85 | } 86 | 87 | // WithContext returns a new context with the given greeting. 88 | func WithContext(ctx context.Context, greeting string) context.Context { 89 | return context.WithValue(ctx, key, greeting) 90 | } 91 | 92 | // FromContext retrieves the greeting from this context, 93 | // or returns an empty string if missing. 94 | func FromContext(ctx context.Context) string { 95 | hello, _ := ctx.Value(key).(string) 96 | return hello 97 | } 98 | ``` 99 | 100 | ### Usage 101 | 102 | * Set up routes using `kami.Get("/path", handler)`, `kami.Post(...)`, etc. You can use named parameters or wildcards in URLs like `/hello/:name/edit` or `/files/*path`, and access them using the context kami gives you: `kami.Param(ctx, "name")`. See the [routing rules](https://github.com/dimfeld/httptreemux#routing-rules) and [routing priority](https://github.com/dimfeld/httptreemux#routing-priority). The following kinds of handlers are accepted: 103 | * types that implement `kami.ContextHandler` 104 | * `func(context.Context, http.ResponseWriter, *http.Request)` 105 | * types that implement `http.Handler` 106 | * `func(http.ResponseWriter, *http.Request)` 107 | * All contexts that kami uses are descended from `kami.Context`: this is the "god object" and the namesake of this project. By default, this is `context.Background()`, but feel free to replace it with a pre-initialized context suitable for your application. 108 | * Builds targeting Google App Engine will automatically wrap the "god object" Context with App Engine's per-request Context. 109 | * Add middleware with `kami.Use("/path", kami.Middleware)`. Middleware runs before requests and can stop them early. More on middleware below. 110 | * Add afterware with `kami.After("/path", kami.Afterware)`. Afterware runs after requests. 111 | * Set `kami.Cancel` to `true` to automatically cancel all request's contexts after the request is finished. Unlike the standard library, kami does not cancel contexts by default. 112 | * You can provide a panic handler by setting `kami.PanicHandler`. When the panic handler is called, you can access the panic error with `kami.Exception(ctx)`. 113 | * You can also provide a `kami.LogHandler` that will wrap every request. `kami.LogHandler` has a different function signature, taking a WriterProxy that has access to the response status code, etc. 114 | * Use `kami.Serve()` to gracefully serve your application, or mount `kami.Handler()` somewhere convenient. 115 | 116 | ### Middleware 117 | ```go 118 | type Middleware func(context.Context, http.ResponseWriter, *http.Request) context.Context 119 | ``` 120 | Middleware differs from a HandlerType in that it returns a new context. You can take advantage of this to build your context by registering middleware at the approriate paths. As a special case, you may return **nil** to halt execution of the middleware chain. 121 | 122 | Middleware is hierarchical. For example, a request for `/hello/greg` will run middleware registered under the following paths, in order: 123 | 124 | 1. `/` 125 | 2. `/hello/` 126 | 3. `/hello/greg` 127 | 128 | Within a path, middleware is run in the order of registration. 129 | 130 | ```go 131 | func init() { 132 | kami.Use("/", Login) 133 | kami.Use("/private/", LoginRequired) 134 | } 135 | 136 | // Login returns a new context with the appropiate user object inside 137 | func Login(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 138 | if u, err := user.GetByToken(ctx, r.FormValue("auth_token")); err == nil { 139 | ctx = user.NewContext(ctx, u) 140 | } 141 | return ctx 142 | } 143 | 144 | // LoginRequired stops the request if we don't have a user object 145 | func LoginRequired(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 146 | if _, ok := user.FromContext(ctx); !ok { 147 | w.WriteHeader(http.StatusForbidden) 148 | // ... render 503 Forbidden page 149 | return nil 150 | } 151 | return ctx 152 | } 153 | ``` 154 | 155 | #### Named parameters, wildcards, and middleware 156 | 157 | Named parameters and wildcards in middleware are supported now. Middleware registered under a path with a wildcard will run **after** all hierarchical middleware. 158 | 159 | ```go 160 | kami.Use("/user/:id/edit", CheckAdminPermissions) // Matches only /user/:id/edit 161 | kami.Use("/user/:id/edit/*", CheckAdminPermissions) // Matches all inheriting paths, behaves like non-parameterized paths 162 | ``` 163 | 164 | #### Vanilla net/http middleware 165 | 166 | kami can use vanilla http middleware as well. `kami.Use` accepts functions in the form of `func(next http.Handler) http.Handler`. Be advised that kami will run such middleware in sequence, not in a chain. This means that standard loggers and panic handlers won't work as you expect. You should use `kami.LogHandler` and `kami.PanicHandler` instead. 167 | 168 | The following example uses [goji/httpauth](https://github.com/goji/httpauth) to add HTTP Basic Authentication to paths under `/secret/`. 169 | 170 | ```go 171 | import ( 172 | "github.com/goji/httpauth" 173 | "github.com/guregu/kami" 174 | ) 175 | 176 | func main() { 177 | kami.Use("/secret/", httpauth.SimpleBasicAuth("username", "password")) 178 | kami.Get("/secret/message", secretMessageHandler) 179 | kami.Serve() 180 | } 181 | ``` 182 | 183 | #### Afterware 184 | 185 | ```go 186 | type Afterware func(context.Context, mutil.WriterProxy, *http.Request) context.Context 187 | ``` 188 | 189 | ```go 190 | func init() { 191 | kami.After("/", cleanup) 192 | } 193 | ``` 194 | 195 | Running after the request handler, afterware is useful for cleaning up. Afterware is like a mirror image of middleware. Afterware also runs hierarchically, but in the reverse order of middleware. Wildcards are evaluated **before** hierarchical afterware. 196 | 197 | For example, a request for `/hello/greg` will run afterware registered under the following paths: 198 | 199 | 1. `/hello/greg` 200 | 2. `/hello/` 201 | 3. `/` 202 | 203 | This gives afterware under specific paths the ability to use resources that may be closed by `/`. 204 | 205 | Unlike middleware, afterware returning **nil** will not stop the remaining afterware from being evaluated. 206 | 207 | `kami.After("/path", afterware)` supports many different types of functions, see the docs for `kami.AfterwareType` for more details. 208 | 209 | ### Independent stacks with `*kami.Mux` 210 | 211 | kami was originally designed to be the "glue" between multiple packages in a complex web application. The global functions and `kami.Context` are an easy way for your packages to work together. However, if you would like to use kami as an embedded server within another app, serve two separate kami stacks on different ports, or otherwise would like to have an non-global version of kami, `kami.New()` may come in handy. 212 | 213 | Calling `kami.New()` returns a fresh `*kami.Mux`, a completely independent kami stack. Changes to `kami.Context`, paths registered with `kami.Get()` et al, and global middleware registered with `kami.Use()` will not affect a `*kami.Mux`. 214 | 215 | Instead, with `mux := kami.New()` you can change `mux.Context`, call `mux.Use()`, `mux.Get()`, `mux.NotFound()`, etc. 216 | 217 | `*kami.Mux` implements `http.Handler`, so you may use it however you'd like! 218 | 219 | ```go 220 | // package admin is an admin panel web server plugin 221 | package admin 222 | 223 | import ( 224 | "net/http" 225 | "github.com/guregu/kami" 226 | ) 227 | 228 | // automatically mount our secret admin stuff 229 | func init() { 230 | mux := kami.New() 231 | mux.Context = adminContext 232 | mux.Use("/", authorize) 233 | mux.Get("/admin/memstats", memoryStats) 234 | mux.Post("/admin/die", shutdown) 235 | // ... 236 | http.Handle("/admin/", mux) 237 | } 238 | ``` 239 | 240 | ### License 241 | 242 | MIT 243 | 244 | ### Acknowledgements 245 | 246 | * [httptreemux](https://github.com/dimfeld/httptreemux): router 247 | * [Goji](https://github.com/zenazn/goji): graceful, WriterProxy 248 | -------------------------------------------------------------------------------- /bench_test.go: -------------------------------------------------------------------------------- 1 | package kami_test 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "golang.org/x/net/context" 9 | 10 | "github.com/guregu/kami" 11 | ) 12 | 13 | func routeBench(b *testing.B, route string) { 14 | kami.Reset() 15 | kami.Use("/Z/", noopMW) 16 | kami.After("/Z/", noopMW) 17 | kami.Get(route, noop) 18 | req, _ := http.NewRequest("GET", route, nil) 19 | b.ResetTimer() 20 | for n := 0; n < b.N; n++ { 21 | resp := httptest.NewRecorder() 22 | kami.Handler().ServeHTTP(resp, req) 23 | } 24 | } 25 | 26 | func BenchmarkShortRoute(b *testing.B) { 27 | routeBench(b, "/hello") 28 | } 29 | 30 | func BenchmarkLongRoute(b *testing.B) { 31 | routeBench(b, "/aaaaaaaaaaaa/") 32 | } 33 | 34 | func BenchmarkDeepRoute(b *testing.B) { 35 | routeBench(b, "/a/b/c/d/e/f/g") 36 | } 37 | 38 | func BenchmarkDeepRouteUnicode(b *testing.B) { 39 | routeBench(b, "/ä/蜂/海/🐶/神/🍺/🍻") 40 | } 41 | 42 | func BenchmarkSuperDeepRoute(b *testing.B) { 43 | routeBench(b, "/a/b/c/d/e/f/g/h/i/l/k/l/m/n/o/p/q/r/hello world") 44 | } 45 | 46 | // Param benchmarks test accessing URL params 47 | 48 | func BenchmarkParameter(b *testing.B) { 49 | kami.Reset() 50 | kami.Get("/hello/:name", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 51 | kami.Param(ctx, "name") 52 | }) 53 | req, _ := http.NewRequest("GET", "/hello/bob", nil) 54 | b.ResetTimer() 55 | for n := 0; n < b.N; n++ { 56 | resp := httptest.NewRecorder() 57 | kami.Handler().ServeHTTP(resp, req) 58 | } 59 | } 60 | 61 | func BenchmarkParameter5(b *testing.B) { 62 | kami.Reset() 63 | kami.Get("/:a/:b/:c/:d/:e", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 64 | for _, v := range []string{"a", "b", "c", "d", "e"} { 65 | kami.Param(ctx, v) 66 | } 67 | }) 68 | req, _ := http.NewRequest("GET", "/a/b/c/d/e", nil) 69 | b.ResetTimer() 70 | for n := 0; n < b.N; n++ { 71 | resp := httptest.NewRecorder() 72 | kami.Handler().ServeHTTP(resp, req) 73 | } 74 | } 75 | 76 | // Middleware tests setting and using values with middleware 77 | // These test the speed of kami's middleware engine AND using 78 | // x/net/context to store values, so it could be a somewhat 79 | // realitic idea of what using kami would be like. 80 | 81 | func BenchmarkMiddleware(b *testing.B) { 82 | kami.Reset() 83 | kami.Use("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 84 | return context.WithValue(ctx, "test", "ok") 85 | }) 86 | kami.Get("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 87 | if ctx.Value("test") != "ok" { 88 | w.WriteHeader(http.StatusServiceUnavailable) 89 | } 90 | }) 91 | req, _ := http.NewRequest("GET", "/test", nil) 92 | b.ResetTimer() 93 | for n := 0; n < b.N; n++ { 94 | resp := httptest.NewRecorder() 95 | kami.Handler().ServeHTTP(resp, req) 96 | } 97 | } 98 | 99 | func BenchmarkMiddleware5(b *testing.B) { 100 | kami.Reset() 101 | numbers := []int{1, 2, 3, 4, 5} 102 | for _, n := range numbers { 103 | n := n // wtf 104 | kami.Use("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 105 | return context.WithValue(ctx, n, n) 106 | }) 107 | } 108 | kami.Get("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 109 | for _, n := range numbers { 110 | if ctx.Value(n) != n { 111 | w.WriteHeader(http.StatusServiceUnavailable) 112 | return 113 | } 114 | } 115 | }) 116 | req, _ := http.NewRequest("GET", "/test", nil) 117 | b.ResetTimer() 118 | for n := 0; n < b.N; n++ { 119 | resp := httptest.NewRecorder() 120 | kami.Handler().ServeHTTP(resp, req) 121 | } 122 | } 123 | 124 | func BenchmarkMiddleware1Afterware1(b *testing.B) { 125 | kami.Reset() 126 | numbers := []int{1} 127 | for _, n := range numbers { 128 | n := n // wtf 129 | kami.Use("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 130 | return context.WithValue(ctx, n, n) 131 | }) 132 | } 133 | kami.After("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 134 | for _, n := range numbers { 135 | if ctx.Value(n) != n { 136 | panic(n) 137 | } 138 | } 139 | return ctx 140 | }) 141 | kami.Get("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 142 | // ... 143 | }) 144 | req, _ := http.NewRequest("GET", "/test", nil) 145 | b.ResetTimer() 146 | for n := 0; n < b.N; n++ { 147 | resp := httptest.NewRecorder() 148 | kami.Handler().ServeHTTP(resp, req) 149 | } 150 | } 151 | 152 | func BenchmarkMiddleware5Afterware1(b *testing.B) { 153 | kami.Reset() 154 | numbers := []int{1, 2, 3, 4, 5} 155 | for _, n := range numbers { 156 | n := n // wtf 157 | kami.Use("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 158 | return context.WithValue(ctx, n, n) 159 | }) 160 | } 161 | kami.After("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 162 | for _, n := range numbers { 163 | if ctx.Value(n) != n { 164 | panic(n) 165 | } 166 | } 167 | return ctx 168 | }) 169 | kami.Get("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 170 | for _, n := range numbers { 171 | if ctx.Value(n) != n { 172 | w.WriteHeader(http.StatusServiceUnavailable) 173 | return 174 | } 175 | } 176 | }) 177 | req, _ := http.NewRequest("GET", "/test", nil) 178 | b.ResetTimer() 179 | for n := 0; n < b.N; n++ { 180 | resp := httptest.NewRecorder() 181 | kami.Handler().ServeHTTP(resp, req) 182 | } 183 | } 184 | 185 | // This tests just the URL walking middleware engine. 186 | func BenchmarkMiddlewareAfterwareMiss(b *testing.B) { 187 | kami.Reset() 188 | kami.Use("/dog/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 189 | return nil 190 | }) 191 | kami.After("/dog/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 192 | return nil 193 | }) 194 | kami.Get("/a/bbb/cc/d/e", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 195 | w.WriteHeader(http.StatusOK) 196 | }) 197 | req, _ := http.NewRequest("GET", "/a/bbb/cc/d/e", nil) 198 | b.ResetTimer() 199 | for n := 0; n < b.N; n++ { 200 | resp := httptest.NewRecorder() 201 | kami.Handler().ServeHTTP(resp, req) 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /defaults.go: -------------------------------------------------------------------------------- 1 | // +build !appengine,!appenginevm 2 | 3 | package kami 4 | 5 | import ( 6 | "net/http" 7 | 8 | "golang.org/x/net/context" 9 | ) 10 | 11 | func defaultContext(ctx context.Context, r *http.Request) context.Context { 12 | return ctx 13 | } 14 | -------------------------------------------------------------------------------- /defaults_appengine.go: -------------------------------------------------------------------------------- 1 | // +build appengine appenginevm 2 | 3 | package kami 4 | 5 | import ( 6 | "net/http" 7 | 8 | "golang.org/x/net/context" 9 | "google.golang.org/appengine" 10 | ) 11 | 12 | func defaultContext(ctx context.Context, r *http.Request) context.Context { 13 | return appengine.WithContext(ctx, r) 14 | } 15 | -------------------------------------------------------------------------------- /globalmux.go: -------------------------------------------------------------------------------- 1 | package kami 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/dimfeld/httptreemux" 7 | ) 8 | 9 | var ( 10 | routes = newRouter() 11 | enable405 = true 12 | ) 13 | 14 | func init() { 15 | // set up the default 404/405 handlers 16 | NotFound(nil) 17 | MethodNotAllowed(nil) 18 | } 19 | 20 | func newRouter() *httptreemux.TreeMux { 21 | r := httptreemux.New() 22 | r.PathSource = httptreemux.URLPath 23 | r.RedirectBehavior = httptreemux.Redirect307 24 | r.RedirectMethodBehavior = map[string]httptreemux.RedirectBehavior{ 25 | "GET": httptreemux.Redirect301, 26 | } 27 | return r 28 | } 29 | 30 | // Handler returns an http.Handler serving registered routes. 31 | func Handler() http.Handler { 32 | return routes 33 | } 34 | 35 | // Handle registers an arbitrary method handler under the given path. 36 | func Handle(method, path string, handler HandlerType) { 37 | routes.Handle(method, path, bless(wrap(handler))) 38 | } 39 | 40 | // Get registers a GET handler under the given path. 41 | func Get(path string, handler HandlerType) { 42 | Handle("GET", path, handler) 43 | } 44 | 45 | // Post registers a POST handler under the given path. 46 | func Post(path string, handler HandlerType) { 47 | Handle("POST", path, handler) 48 | } 49 | 50 | // Put registers a PUT handler under the given path. 51 | func Put(path string, handler HandlerType) { 52 | Handle("PUT", path, handler) 53 | } 54 | 55 | // Patch registers a PATCH handler under the given path. 56 | func Patch(path string, handler HandlerType) { 57 | Handle("PATCH", path, handler) 58 | } 59 | 60 | // Head registers a HEAD handler under the given path. 61 | func Head(path string, handler HandlerType) { 62 | Handle("HEAD", path, handler) 63 | } 64 | 65 | // Head registers a OPTIONS handler under the given path. 66 | func Options(path string, handler HandlerType) { 67 | Handle("OPTIONS", path, handler) 68 | } 69 | 70 | // Delete registers a DELETE handler under the given path. 71 | func Delete(path string, handler HandlerType) { 72 | Handle("DELETE", path, handler) 73 | } 74 | 75 | // EnableMethodNotAllowed enables or disables automatic Method Not Allowed handling. 76 | // Note that this is enabled by default. 77 | func EnableMethodNotAllowed(enabled bool) { 78 | enable405 = enabled 79 | } 80 | -------------------------------------------------------------------------------- /globalmux_17.go: -------------------------------------------------------------------------------- 1 | // +build go1.7 2 | 3 | package kami 4 | 5 | import ( 6 | "context" 7 | "net/http" 8 | 9 | "github.com/dimfeld/httptreemux" 10 | "github.com/zenazn/goji/web/mutil" 11 | ) 12 | 13 | var ( 14 | // Context is the root "god object" from which every request's context will derive. 15 | Context = context.Background() 16 | // Cancel will, if true, automatically cancel the context of incoming requests after they finish. 17 | Cancel bool 18 | 19 | // PanicHandler will, if set, be called on panics. 20 | // You can use kami.Exception(ctx) within the panic handler to get panic details. 21 | PanicHandler HandlerType 22 | // LogHandler will, if set, wrap every request and be called at the very end. 23 | LogHandler func(context.Context, mutil.WriterProxy, *http.Request) 24 | ) 25 | 26 | // NotFound registers a special handler for unregistered (404) paths. 27 | // If handle is nil, use the default http.NotFound behavior. 28 | func NotFound(handler HandlerType) { 29 | // set up the default handler if needed 30 | // we need to bless this so middleware will still run for a 404 request 31 | if handler == nil { 32 | handler = HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 33 | http.NotFound(w, r) 34 | }) 35 | } 36 | 37 | h := bless(wrap(handler)) 38 | routes.NotFoundHandler = func(w http.ResponseWriter, r *http.Request) { 39 | h(w, r, nil) 40 | } 41 | } 42 | 43 | // MethodNotAllowed registers a special handler for automatically responding 44 | // to invalid method requests (405). 45 | func MethodNotAllowed(handler HandlerType) { 46 | if handler == nil { 47 | handler = HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 48 | http.Error(w, 49 | http.StatusText(http.StatusMethodNotAllowed), 50 | http.StatusMethodNotAllowed, 51 | ) 52 | }) 53 | } 54 | 55 | h := bless(wrap(handler)) 56 | routes.MethodNotAllowedHandler = func(w http.ResponseWriter, r *http.Request, methods map[string]httptreemux.HandlerFunc) { 57 | if !enable405 { 58 | routes.NotFoundHandler(w, r) 59 | return 60 | } 61 | h(w, r, nil) 62 | } 63 | } 64 | 65 | // bless creates a new kamified handler using the global mux and middleware. 66 | func bless(h ContextHandler) httptreemux.HandlerFunc { 67 | k := kami{ 68 | handler: h, 69 | base: &Context, 70 | autocancel: &Cancel, 71 | middleware: defaultMW, 72 | panicHandler: &PanicHandler, 73 | logHandler: &LogHandler, 74 | } 75 | return k.handle 76 | } 77 | 78 | // Reset changes the root Context to context.Background(). 79 | // It removes every handler and all middleware. 80 | func Reset() { 81 | Context = context.Background() 82 | Cancel = false 83 | PanicHandler = nil 84 | LogHandler = nil 85 | defaultMW = newWares() 86 | routes = newRouter() 87 | NotFound(nil) 88 | MethodNotAllowed(nil) 89 | } 90 | -------------------------------------------------------------------------------- /globalmux_old.go: -------------------------------------------------------------------------------- 1 | // +build !go1.7 2 | 3 | package kami 4 | 5 | import ( 6 | "net/http" 7 | 8 | "github.com/dimfeld/httptreemux" 9 | "github.com/zenazn/goji/web/mutil" 10 | "golang.org/x/net/context" 11 | ) 12 | 13 | var ( 14 | // Context is the root "god object" from which every request's context will derive. 15 | Context = context.Background() 16 | // Cancel will, if true, automatically cancel the context of incoming requests after they finish. 17 | Cancel bool 18 | 19 | // PanicHandler will, if set, be called on panics. 20 | // You can use kami.Exception(ctx) within the panic handler to get panic details. 21 | PanicHandler HandlerType 22 | // LogHandler will, if set, wrap every request and be called at the very end. 23 | LogHandler func(context.Context, mutil.WriterProxy, *http.Request) 24 | ) 25 | 26 | // NotFound registers a special handler for unregistered (404) paths. 27 | // If handle is nil, use the default http.NotFound behavior. 28 | func NotFound(handler HandlerType) { 29 | // set up the default handler if needed 30 | // we need to bless this so middleware will still run for a 404 request 31 | if handler == nil { 32 | handler = HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 33 | http.NotFound(w, r) 34 | }) 35 | } 36 | 37 | h := bless(wrap(handler)) 38 | routes.NotFoundHandler = func(w http.ResponseWriter, r *http.Request) { 39 | h(w, r, nil) 40 | } 41 | } 42 | 43 | // MethodNotAllowed registers a special handler for automatically responding 44 | // to invalid method requests (405). 45 | func MethodNotAllowed(handler HandlerType) { 46 | if handler == nil { 47 | handler = HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 48 | http.Error(w, 49 | http.StatusText(http.StatusMethodNotAllowed), 50 | http.StatusMethodNotAllowed, 51 | ) 52 | }) 53 | } 54 | 55 | h := bless(wrap(handler)) 56 | routes.MethodNotAllowedHandler = func(w http.ResponseWriter, r *http.Request, methods map[string]httptreemux.HandlerFunc) { 57 | if !enable405 { 58 | routes.NotFoundHandler(w, r) 59 | return 60 | } 61 | h(w, r, nil) 62 | } 63 | } 64 | 65 | // bless creates a new kamified handler using the global mux and middleware. 66 | func bless(h ContextHandler) httptreemux.HandlerFunc { 67 | k := kami{ 68 | handler: h, 69 | base: &Context, 70 | autocancel: &Cancel, 71 | middleware: defaultMW, 72 | panicHandler: &PanicHandler, 73 | logHandler: &LogHandler, 74 | } 75 | return k.handle 76 | } 77 | 78 | // Reset changes the root Context to context.Background(). 79 | // It removes every handler and all middleware. 80 | func Reset() { 81 | Context = context.Background() 82 | Cancel = false 83 | PanicHandler = nil 84 | LogHandler = nil 85 | defaultMW = newWares() 86 | routes = newRouter() 87 | NotFound(nil) 88 | MethodNotAllowed(nil) 89 | } 90 | -------------------------------------------------------------------------------- /handler_17.go: -------------------------------------------------------------------------------- 1 | // +build go1.7,!go1.9 2 | 3 | package kami 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "net/http" 9 | 10 | netcontext "golang.org/x/net/context" 11 | ) 12 | 13 | // OldContextHandler is like ContextHandler but uses the old x/net/context. 14 | type OldContextHandler interface { 15 | ServeHTTPContext(netcontext.Context, http.ResponseWriter, *http.Request) 16 | } 17 | 18 | // wrap tries to turn a HandlerType into a ContextHandler 19 | func wrap(h HandlerType) ContextHandler { 20 | switch x := h.(type) { 21 | case ContextHandler: 22 | return x 23 | case func(context.Context, http.ResponseWriter, *http.Request): 24 | return HandlerFunc(x) 25 | case func(netcontext.Context, http.ResponseWriter, *http.Request): 26 | return HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 27 | x(ctx, w, r) 28 | }) 29 | case http.Handler: 30 | return HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 31 | x.ServeHTTP(w, r) 32 | }) 33 | case func(http.ResponseWriter, *http.Request): 34 | return HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 35 | x(w, r) 36 | }) 37 | } 38 | panic(fmt.Errorf("unsupported HandlerType: %T", h)) 39 | } 40 | -------------------------------------------------------------------------------- /handler_19.go: -------------------------------------------------------------------------------- 1 | // +build go1.9 2 | 3 | package kami 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "net/http" 9 | ) 10 | 11 | // wrap tries to turn a HandlerType into a ContextHandler 12 | func wrap(h HandlerType) ContextHandler { 13 | switch x := h.(type) { 14 | case ContextHandler: 15 | return x 16 | case func(context.Context, http.ResponseWriter, *http.Request): 17 | return HandlerFunc(x) 18 | case http.Handler: 19 | return HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 20 | x.ServeHTTP(w, r) 21 | }) 22 | case func(http.ResponseWriter, *http.Request): 23 | return HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 24 | x(w, r) 25 | }) 26 | } 27 | panic(fmt.Errorf("unsupported HandlerType: %T", h)) 28 | } 29 | -------------------------------------------------------------------------------- /handler_new.go: -------------------------------------------------------------------------------- 1 | // +build go1.7 2 | 3 | package kami 4 | 5 | import ( 6 | "context" 7 | "net/http" 8 | ) 9 | 10 | // HandlerType is the type of Handlers and types that kami internally converts to 11 | // ContextHandler. In order to provide an expressive API, this type is an alias for 12 | // interface{} that is named for the purposes of documentation, however only the 13 | // following concrete types are accepted: 14 | // - types that implement http.Handler 15 | // - types that implement ContextHandler 16 | // - func(http.ResponseWriter, *http.Request) 17 | // - func(context.Context, http.ResponseWriter, *http.Request) 18 | type HandlerType interface{} 19 | 20 | // ContextHandler is like http.Handler but supports context. 21 | type ContextHandler interface { 22 | ServeHTTPContext(context.Context, http.ResponseWriter, *http.Request) 23 | } 24 | 25 | // HandlerFunc is like http.HandlerFunc with context. 26 | type HandlerFunc func(context.Context, http.ResponseWriter, *http.Request) 27 | 28 | func (h HandlerFunc) ServeHTTPContext(ctx context.Context, w http.ResponseWriter, r *http.Request) { 29 | h(ctx, w, r) 30 | } 31 | -------------------------------------------------------------------------------- /handler_old.go: -------------------------------------------------------------------------------- 1 | // +build !go1.7 2 | 3 | package kami 4 | 5 | import ( 6 | "fmt" 7 | "net/http" 8 | 9 | "golang.org/x/net/context" 10 | ) 11 | 12 | // HandlerType is the type of Handlers and types that kami internally converts to 13 | // ContextHandler. In order to provide an expressive API, this type is an alias for 14 | // interface{} that is named for the purposes of documentation, however only the 15 | // following concrete types are accepted: 16 | // - types that implement http.Handler 17 | // - types that implement ContextHandler 18 | // - func(http.ResponseWriter, *http.Request) 19 | // - func(context.Context, http.ResponseWriter, *http.Request) 20 | type HandlerType interface{} 21 | 22 | // ContextHandler is like http.Handler but supports context. 23 | type ContextHandler interface { 24 | ServeHTTPContext(context.Context, http.ResponseWriter, *http.Request) 25 | } 26 | 27 | // HandlerFunc is like http.HandlerFunc with context. 28 | type HandlerFunc func(context.Context, http.ResponseWriter, *http.Request) 29 | 30 | func (h HandlerFunc) ServeHTTPContext(ctx context.Context, w http.ResponseWriter, r *http.Request) { 31 | h(ctx, w, r) 32 | } 33 | 34 | // wrap tries to turn a HandlerType into a ContextHandler 35 | func wrap(h HandlerType) ContextHandler { 36 | switch x := h.(type) { 37 | case ContextHandler: 38 | return x 39 | case func(context.Context, http.ResponseWriter, *http.Request): 40 | return HandlerFunc(x) 41 | case http.Handler: 42 | return HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 43 | x.ServeHTTP(w, r) 44 | }) 45 | case func(http.ResponseWriter, *http.Request): 46 | return HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 47 | x(w, r) 48 | }) 49 | } 50 | panic(fmt.Errorf("unsupported HandlerType: %T", h)) 51 | } 52 | -------------------------------------------------------------------------------- /kami.go: -------------------------------------------------------------------------------- 1 | // +build !go1.7 2 | 3 | package kami 4 | 5 | import ( 6 | "net/http" 7 | 8 | "github.com/zenazn/goji/web/mutil" 9 | "golang.org/x/net/context" 10 | ) 11 | 12 | // kami is the heart of the package. 13 | // It wraps a ContextHandler into an httprouter compatible request, 14 | // in order to run all the middleware and other special handlers. 15 | type kami struct { 16 | handler ContextHandler 17 | autocancel *bool 18 | base *context.Context 19 | middleware *wares 20 | panicHandler *HandlerType 21 | logHandler *func(context.Context, mutil.WriterProxy, *http.Request) 22 | } 23 | 24 | func (k kami) handle(w http.ResponseWriter, r *http.Request, params map[string]string) { 25 | var ( 26 | ctx = defaultContext(*k.base, r) 27 | autocancel = *k.autocancel 28 | handler = k.handler 29 | mw = *k.middleware 30 | panicHandler = *k.panicHandler 31 | logHandler = *k.logHandler 32 | ranLogHandler = false // track this in case the log handler blows up 33 | ) 34 | if len(params) > 0 { 35 | ctx = newContextWithParams(ctx, params) 36 | } 37 | 38 | if autocancel { 39 | var cancel context.CancelFunc 40 | ctx, cancel = context.WithCancel(ctx) 41 | defer cancel() 42 | } 43 | 44 | var proxy mutil.WriterProxy 45 | if logHandler != nil || mw.needsWrapper() { 46 | proxy = mutil.WrapWriter(w) 47 | w = proxy 48 | } 49 | 50 | if panicHandler != nil { 51 | defer func() { 52 | if err := recover(); err != nil { 53 | ctx = newContextWithException(ctx, err) 54 | wrap(panicHandler).ServeHTTPContext(ctx, w, r) 55 | 56 | if logHandler != nil && !ranLogHandler { 57 | logHandler(ctx, proxy, r) 58 | // should only happen if header hasn't been written 59 | proxy.WriteHeader(http.StatusInternalServerError) 60 | } 61 | } 62 | }() 63 | } 64 | 65 | ctx, ok := mw.run(ctx, w, r) 66 | if ok { 67 | handler.ServeHTTPContext(ctx, w, r) 68 | } 69 | if proxy != nil { 70 | ctx = mw.after(ctx, proxy, r) 71 | } 72 | 73 | if logHandler != nil { 74 | ranLogHandler = true 75 | logHandler(ctx, proxy, r) 76 | // should only happen if header hasn't been written 77 | proxy.WriteHeader(http.StatusInternalServerError) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /kami_17.go: -------------------------------------------------------------------------------- 1 | // +build go1.7 2 | 3 | package kami 4 | 5 | import ( 6 | "context" 7 | "net/http" 8 | 9 | "github.com/zenazn/goji/web/mutil" 10 | ) 11 | 12 | // kami is the heart of the package. 13 | // It wraps a ContextHandler into an httprouter compatible request, 14 | // in order to run all the middleware and other special handlers. 15 | type kami struct { 16 | handler ContextHandler 17 | autocancel *bool 18 | base *context.Context 19 | middleware *wares 20 | panicHandler *HandlerType 21 | logHandler *func(context.Context, mutil.WriterProxy, *http.Request) 22 | } 23 | 24 | func (k kami) handle(w http.ResponseWriter, r *http.Request, params map[string]string) { 25 | var ( 26 | ctx = defaultContext(*k.base, r) 27 | autocancel = *k.autocancel 28 | handler = k.handler 29 | mw = *k.middleware 30 | panicHandler = *k.panicHandler 31 | logHandler = *k.logHandler 32 | ranLogHandler = false // track this in case the log handler blows up 33 | ) 34 | if len(params) > 0 { 35 | ctx = newContextWithParams(ctx, params) 36 | } 37 | 38 | if autocancel { 39 | var cancel context.CancelFunc 40 | ctx, cancel = context.WithCancel(ctx) 41 | defer cancel() 42 | } 43 | 44 | if ctx != context.Background() { 45 | r = r.WithContext(ctx) 46 | } 47 | 48 | var proxy mutil.WriterProxy 49 | if logHandler != nil || mw.needsWrapper() { 50 | proxy = mutil.WrapWriter(w) 51 | w = proxy 52 | } 53 | 54 | if panicHandler != nil { 55 | defer func() { 56 | if err := recover(); err != nil { 57 | ctx = newContextWithException(ctx, err) 58 | r = r.WithContext(ctx) 59 | wrap(panicHandler).ServeHTTPContext(ctx, w, r) 60 | 61 | if logHandler != nil && !ranLogHandler { 62 | logHandler(ctx, proxy, r) 63 | // should only happen if header hasn't been written 64 | proxy.WriteHeader(http.StatusInternalServerError) 65 | } 66 | } 67 | }() 68 | } 69 | 70 | r, ctx, ok := mw.run(ctx, w, r) 71 | if ok { 72 | handler.ServeHTTPContext(ctx, w, r) 73 | } 74 | if proxy != nil { 75 | r, ctx = mw.after(ctx, proxy, r) 76 | } 77 | 78 | if logHandler != nil { 79 | ranLogHandler = true 80 | logHandler(ctx, proxy, r) 81 | // should only happen if header hasn't been written 82 | proxy.WriteHeader(http.StatusInternalServerError) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /kami_17_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.7 2 | 3 | package kami_test 4 | 5 | import ( 6 | "context" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | "time" 11 | 12 | "github.com/zenazn/goji/web/mutil" 13 | 14 | "github.com/guregu/kami" 15 | ) 16 | 17 | func TestKami(t *testing.T) { 18 | kami.Reset() 19 | kami.Cancel = true 20 | 21 | done := make(chan struct{}) 22 | 23 | expect := func(ctx context.Context, i int) context.Context { 24 | if prev := ctx.Value(i - 1).(int); prev != i-1 { 25 | t.Error("missing", i) 26 | } 27 | if curr := ctx.Value(i); curr != nil { 28 | t.Error("pre-existing", i) 29 | } 30 | return context.WithValue(ctx, i, i) 31 | } 32 | expectEqual := func(one, two context.Context, i int) { 33 | if one != two { 34 | t.Error(i, "mismatched contexes", one, "\n≠\n", two) 35 | } 36 | } 37 | 38 | kami.Use("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 39 | expectEqual(ctx, r.Context(), 1) 40 | ctx = context.WithValue(ctx, 1, 1) 41 | ctx = context.WithValue(ctx, "handler", new(bool)) 42 | ctx = context.WithValue(ctx, "done", new(bool)) 43 | ctx = context.WithValue(ctx, "recovered", new(bool)) 44 | go func() { 45 | <-ctx.Done() 46 | close(done) 47 | }() 48 | return ctx 49 | }) 50 | kami.Use("/a/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 51 | expectEqual(ctx, r.Context(), 2) 52 | ctx = expect(ctx, 2) 53 | return ctx 54 | }) 55 | kami.Use("/a/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 56 | expectEqual(ctx, r.Context(), 3) 57 | ctx = expect(ctx, 3) 58 | return ctx 59 | }) 60 | kami.Use("/a/b", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 61 | expectEqual(ctx, r.Context(), 4) 62 | ctx = expect(ctx, 4) 63 | return ctx 64 | }) 65 | kami.Use("/a/*files", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 66 | expectEqual(ctx, r.Context(), 5) 67 | ctx = expect(ctx, 5) 68 | return ctx 69 | }) 70 | kami.Use("/a/*files", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 71 | expectEqual(ctx, r.Context(), 6) 72 | ctx = expect(ctx, 6) 73 | return ctx 74 | }) 75 | kami.Get("/a/b", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 76 | expectEqual(ctx, r.Context(), 6) 77 | if prev := ctx.Value(6).(int); prev != 6 { 78 | t.Error("handler: missing", 6) 79 | } 80 | *(ctx.Value("handler").(*bool)) = true 81 | 82 | w.WriteHeader(http.StatusTeapot) 83 | }) 84 | kami.After("/a/*files", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 85 | expectEqual(ctx, r.Context(), 8) 86 | ctx = expect(ctx, 8) 87 | if !*(ctx.Value("handler").(*bool)) { 88 | t.Error("ran before handler") 89 | } 90 | return ctx 91 | }) 92 | kami.After("/a/*files", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 93 | expectEqual(ctx, r.Context(), 7) 94 | ctx = expect(ctx, 7) 95 | if !*(ctx.Value("handler").(*bool)) { 96 | t.Error("ran before handler") 97 | } 98 | return ctx 99 | }) 100 | kami.After("/a/b", kami.Afterware(func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 101 | expectEqual(ctx, r.Context(), 9) 102 | ctx = expect(ctx, 9) 103 | return ctx 104 | })) 105 | kami.After("/a/", func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 106 | expectEqual(ctx, r.Context(), 11) 107 | ctx = expect(ctx, 11) 108 | return ctx 109 | }) 110 | kami.After("/a/", func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 111 | expectEqual(ctx, r.Context(), 10) 112 | ctx = expect(ctx, 10) 113 | return ctx 114 | }) 115 | kami.After("/", func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 116 | expectEqual(ctx, r.Context(), 12) 117 | if status := w.Status(); status != http.StatusTeapot { 118 | t.Error("wrong status", status) 119 | } 120 | 121 | ctx = expect(ctx, 12) 122 | *(ctx.Value("done").(*bool)) = true 123 | panic("🍣") 124 | return nil 125 | }) 126 | kami.PanicHandler = func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 127 | expectEqual(ctx, r.Context(), 13) 128 | if got := kami.Exception(ctx); got.(string) != "🍣" { 129 | t.Error("panic handler: expected sushi, got", got) 130 | } 131 | if !*(ctx.Value("done").(*bool)) { 132 | t.Error("didn't finish") 133 | } 134 | *(ctx.Value("recovered").(*bool)) = true 135 | } 136 | kami.LogHandler = func(ctx context.Context, w mutil.WriterProxy, r *http.Request) { 137 | expectEqual(ctx, r.Context(), 14) 138 | if !*(ctx.Value("recovered").(*bool)) { 139 | t.Error("didn't recover") 140 | } 141 | } 142 | 143 | expectResponseCode(t, "GET", "/a/b", http.StatusTeapot) 144 | select { 145 | case <-done: 146 | // ok 147 | case <-time.After(10 * time.Second): 148 | panic("didn't cancel") 149 | } 150 | } 151 | 152 | func TestLoggerAndPanic(t *testing.T) { 153 | kami.Reset() 154 | // test logger with panic 155 | status := 0 156 | kami.LogHandler = func(ctx context.Context, w mutil.WriterProxy, r *http.Request) { 157 | status = w.Status() 158 | } 159 | kami.PanicHandler = kami.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 160 | err := kami.Exception(ctx) 161 | if err != "test panic" { 162 | t.Error("unexpected exception:", err) 163 | } 164 | w.WriteHeader(http.StatusServiceUnavailable) 165 | }) 166 | kami.Post("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 167 | panic("test panic") 168 | }) 169 | kami.Put("/ok", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 170 | w.WriteHeader(http.StatusOK) 171 | }) 172 | 173 | expectResponseCode(t, "POST", "/test", http.StatusServiceUnavailable) 174 | if status != http.StatusServiceUnavailable { 175 | t.Error("log handler received wrong status code", status, "≠", http.StatusServiceUnavailable) 176 | } 177 | 178 | // test loggers without panics 179 | expectResponseCode(t, "PUT", "/ok", http.StatusOK) 180 | if status != http.StatusOK { 181 | t.Error("log handler received wrong status code", status, "≠", http.StatusOK) 182 | } 183 | } 184 | 185 | func TestPanickingLogger(t *testing.T) { 186 | kami.Reset() 187 | kami.LogHandler = func(ctx context.Context, w mutil.WriterProxy, r *http.Request) { 188 | t.Log("log handler") 189 | panic("test panic") 190 | } 191 | kami.PanicHandler = kami.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 192 | t.Log("panic handler") 193 | err := kami.Exception(ctx) 194 | if err != "test panic" { 195 | t.Error("unexpected exception:", err) 196 | } 197 | w.WriteHeader(http.StatusServiceUnavailable) 198 | }) 199 | kami.Options("/test", noop) 200 | 201 | expectResponseCode(t, "OPTIONS", "/test", http.StatusServiceUnavailable) 202 | } 203 | 204 | func TestNotFound(t *testing.T) { 205 | kami.Reset() 206 | kami.Use("/missing/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 207 | return context.WithValue(ctx, "ok", true) 208 | }) 209 | kami.NotFound(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 210 | ok, _ := ctx.Value("ok").(bool) 211 | if !ok { 212 | w.WriteHeader(http.StatusInternalServerError) 213 | return 214 | } 215 | w.WriteHeader(http.StatusTeapot) 216 | }) 217 | 218 | expectResponseCode(t, "GET", "/missing/hello", http.StatusTeapot) 219 | } 220 | 221 | func TestNotFoundDefault(t *testing.T) { 222 | kami.Reset() 223 | 224 | expectResponseCode(t, "GET", "/missing/hello", http.StatusNotFound) 225 | } 226 | 227 | func TestMethodNotAllowed(t *testing.T) { 228 | kami.Reset() 229 | kami.Use("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 230 | return context.WithValue(ctx, "ok", true) 231 | }) 232 | kami.Post("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 233 | panic("test panic") 234 | }) 235 | 236 | kami.MethodNotAllowed(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 237 | ok, _ := ctx.Value("ok").(bool) 238 | if !ok { 239 | w.WriteHeader(http.StatusInternalServerError) 240 | return 241 | } 242 | w.WriteHeader(http.StatusTeapot) 243 | }) 244 | 245 | expectResponseCode(t, "GET", "/test", http.StatusTeapot) 246 | } 247 | 248 | func TestEnableMethodNotAllowed(t *testing.T) { 249 | kami.Reset() 250 | kami.Post("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 251 | panic("test panic") 252 | }) 253 | 254 | // Handling enabled by default 255 | expectResponseCode(t, "GET", "/test", http.StatusMethodNotAllowed) 256 | 257 | // Not found deals with it when handling disabled 258 | kami.EnableMethodNotAllowed(false) 259 | expectResponseCode(t, "GET", "/test", http.StatusNotFound) 260 | 261 | // And MethodNotAllowed status when handling enabled 262 | kami.EnableMethodNotAllowed(true) 263 | expectResponseCode(t, "GET", "/test", http.StatusMethodNotAllowed) 264 | } 265 | 266 | func TestMethodNotAllowedDefault(t *testing.T) { 267 | kami.Reset() 268 | kami.Post("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 269 | panic("test panic") 270 | }) 271 | 272 | expectResponseCode(t, "GET", "/test", http.StatusMethodNotAllowed) 273 | } 274 | 275 | func noop(ctx context.Context, w http.ResponseWriter, r *http.Request) {} 276 | 277 | func noopMW(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 278 | return ctx 279 | } 280 | 281 | func expectResponseCode(t *testing.T, method, path string, expected int) { 282 | resp := httptest.NewRecorder() 283 | req, err := http.NewRequest(method, path, nil) 284 | if err != nil { 285 | t.Fatal(err) 286 | } 287 | 288 | kami.Handler().ServeHTTP(resp, req) 289 | 290 | if resp.Code != expected { 291 | t.Error("should return HTTP", http.StatusText(expected)+":", resp.Code, "≠", expected) 292 | } 293 | } 294 | -------------------------------------------------------------------------------- /kami_old_test.go: -------------------------------------------------------------------------------- 1 | // +build !go1.7 2 | 3 | package kami_test 4 | 5 | import ( 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | "time" 10 | 11 | "github.com/zenazn/goji/web/mutil" 12 | "golang.org/x/net/context" 13 | 14 | "github.com/guregu/kami" 15 | ) 16 | 17 | func TestKami(t *testing.T) { 18 | kami.Reset() 19 | kami.Cancel = true 20 | 21 | done := make(chan struct{}) 22 | 23 | expect := func(ctx context.Context, i int) context.Context { 24 | if prev := ctx.Value(i - 1).(int); prev != i-1 { 25 | t.Error("missing", i) 26 | } 27 | if curr := ctx.Value(i); curr != nil { 28 | t.Error("pre-existing", i) 29 | } 30 | return context.WithValue(ctx, i, i) 31 | } 32 | 33 | kami.Use("/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 34 | ctx = context.WithValue(ctx, 1, 1) 35 | ctx = context.WithValue(ctx, "handler", new(bool)) 36 | ctx = context.WithValue(ctx, "done", new(bool)) 37 | ctx = context.WithValue(ctx, "recovered", new(bool)) 38 | go func() { 39 | <-ctx.Done() 40 | close(done) 41 | }() 42 | return ctx 43 | }) 44 | kami.Use("/a/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 45 | ctx = expect(ctx, 2) 46 | return ctx 47 | }) 48 | kami.Use("/a/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 49 | ctx = expect(ctx, 3) 50 | return ctx 51 | }) 52 | kami.Use("/a/b", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 53 | ctx = expect(ctx, 4) 54 | return ctx 55 | }) 56 | kami.Use("/a/*files", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 57 | ctx = expect(ctx, 5) 58 | return ctx 59 | }) 60 | kami.Use("/a/*files", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 61 | ctx = expect(ctx, 6) 62 | return ctx 63 | }) 64 | kami.Get("/a/b", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 65 | if prev := ctx.Value(6).(int); prev != 6 { 66 | t.Error("handler: missing", 6) 67 | } 68 | *(ctx.Value("handler").(*bool)) = true 69 | 70 | w.WriteHeader(http.StatusTeapot) 71 | }) 72 | kami.After("/a/*files", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 73 | ctx = expect(ctx, 8) 74 | if !*(ctx.Value("handler").(*bool)) { 75 | t.Error("ran before handler") 76 | } 77 | return ctx 78 | }) 79 | kami.After("/a/*files", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 80 | ctx = expect(ctx, 7) 81 | if !*(ctx.Value("handler").(*bool)) { 82 | t.Error("ran before handler") 83 | } 84 | return ctx 85 | }) 86 | kami.After("/a/b", kami.Afterware(func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 87 | ctx = expect(ctx, 9) 88 | return ctx 89 | })) 90 | kami.After("/a/", func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 91 | ctx = expect(ctx, 11) 92 | return ctx 93 | }) 94 | kami.After("/a/", func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 95 | ctx = expect(ctx, 10) 96 | return ctx 97 | }) 98 | kami.After("/", func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 99 | if status := w.Status(); status != http.StatusTeapot { 100 | t.Error("wrong status", status) 101 | } 102 | 103 | ctx = expect(ctx, 12) 104 | *(ctx.Value("done").(*bool)) = true 105 | panic("🍣") 106 | return nil 107 | }) 108 | kami.PanicHandler = func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 109 | if got := kami.Exception(ctx); got.(string) != "🍣" { 110 | t.Error("panic handler: expected sushi, got", got) 111 | } 112 | if !*(ctx.Value("done").(*bool)) { 113 | t.Error("didn't finish") 114 | } 115 | *(ctx.Value("recovered").(*bool)) = true 116 | } 117 | kami.LogHandler = func(ctx context.Context, w mutil.WriterProxy, r *http.Request) { 118 | if !*(ctx.Value("recovered").(*bool)) { 119 | t.Error("didn't recover") 120 | } 121 | } 122 | 123 | expectResponseCode(t, "GET", "/a/b", http.StatusTeapot) 124 | select { 125 | case <-done: 126 | // ok 127 | case <-time.After(10 * time.Second): 128 | panic("didn't cancel") 129 | } 130 | } 131 | 132 | func TestLoggerAndPanic(t *testing.T) { 133 | kami.Reset() 134 | // test logger with panic 135 | status := 0 136 | kami.LogHandler = func(ctx context.Context, w mutil.WriterProxy, r *http.Request) { 137 | status = w.Status() 138 | } 139 | kami.PanicHandler = kami.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 140 | err := kami.Exception(ctx) 141 | if err != "test panic" { 142 | t.Error("unexpected exception:", err) 143 | } 144 | w.WriteHeader(http.StatusServiceUnavailable) 145 | }) 146 | kami.Post("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 147 | panic("test panic") 148 | }) 149 | kami.Put("/ok", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 150 | w.WriteHeader(http.StatusOK) 151 | }) 152 | 153 | expectResponseCode(t, "POST", "/test", http.StatusServiceUnavailable) 154 | if status != http.StatusServiceUnavailable { 155 | t.Error("log handler received wrong status code", status, "≠", http.StatusServiceUnavailable) 156 | } 157 | 158 | // test loggers without panics 159 | expectResponseCode(t, "PUT", "/ok", http.StatusOK) 160 | if status != http.StatusOK { 161 | t.Error("log handler received wrong status code", status, "≠", http.StatusOK) 162 | } 163 | } 164 | 165 | func TestPanickingLogger(t *testing.T) { 166 | kami.Reset() 167 | kami.LogHandler = func(ctx context.Context, w mutil.WriterProxy, r *http.Request) { 168 | t.Log("log handler") 169 | panic("test panic") 170 | } 171 | kami.PanicHandler = kami.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 172 | t.Log("panic handler") 173 | err := kami.Exception(ctx) 174 | if err != "test panic" { 175 | t.Error("unexpected exception:", err) 176 | } 177 | w.WriteHeader(http.StatusServiceUnavailable) 178 | }) 179 | kami.Options("/test", noop) 180 | 181 | expectResponseCode(t, "OPTIONS", "/test", http.StatusServiceUnavailable) 182 | } 183 | 184 | func TestNotFound(t *testing.T) { 185 | kami.Reset() 186 | kami.Use("/missing/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 187 | return context.WithValue(ctx, "ok", true) 188 | }) 189 | kami.NotFound(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 190 | ok, _ := ctx.Value("ok").(bool) 191 | if !ok { 192 | w.WriteHeader(http.StatusInternalServerError) 193 | return 194 | } 195 | w.WriteHeader(http.StatusTeapot) 196 | }) 197 | 198 | expectResponseCode(t, "GET", "/missing/hello", http.StatusTeapot) 199 | } 200 | 201 | func TestNotFoundDefault(t *testing.T) { 202 | kami.Reset() 203 | 204 | expectResponseCode(t, "GET", "/missing/hello", http.StatusNotFound) 205 | } 206 | 207 | func TestMethodNotAllowed(t *testing.T) { 208 | kami.Reset() 209 | kami.Use("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 210 | return context.WithValue(ctx, "ok", true) 211 | }) 212 | kami.Post("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 213 | panic("test panic") 214 | }) 215 | 216 | kami.MethodNotAllowed(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 217 | ok, _ := ctx.Value("ok").(bool) 218 | if !ok { 219 | w.WriteHeader(http.StatusInternalServerError) 220 | return 221 | } 222 | w.WriteHeader(http.StatusTeapot) 223 | }) 224 | 225 | expectResponseCode(t, "GET", "/test", http.StatusTeapot) 226 | } 227 | 228 | func TestEnableMethodNotAllowed(t *testing.T) { 229 | kami.Reset() 230 | kami.Post("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 231 | panic("test panic") 232 | }) 233 | 234 | // Handling enabled by default 235 | expectResponseCode(t, "GET", "/test", http.StatusMethodNotAllowed) 236 | 237 | // Not found deals with it when handling disabled 238 | kami.EnableMethodNotAllowed(false) 239 | expectResponseCode(t, "GET", "/test", http.StatusNotFound) 240 | 241 | // And MethodNotAllowed status when handling enabled 242 | kami.EnableMethodNotAllowed(true) 243 | expectResponseCode(t, "GET", "/test", http.StatusMethodNotAllowed) 244 | } 245 | 246 | func TestMethodNotAllowedDefault(t *testing.T) { 247 | kami.Reset() 248 | kami.Post("/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 249 | panic("test panic") 250 | }) 251 | 252 | expectResponseCode(t, "GET", "/test", http.StatusMethodNotAllowed) 253 | } 254 | 255 | func noop(ctx context.Context, w http.ResponseWriter, r *http.Request) {} 256 | 257 | func noopMW(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 258 | return ctx 259 | } 260 | 261 | func expectResponseCode(t *testing.T, method, path string, expected int) { 262 | resp := httptest.NewRecorder() 263 | req, err := http.NewRequest(method, path, nil) 264 | if err != nil { 265 | t.Fatal(err) 266 | } 267 | 268 | kami.Handler().ServeHTTP(resp, req) 269 | 270 | if resp.Code != expected { 271 | t.Error("should return HTTP", http.StatusText(expected)+":", resp.Code, "≠", expected) 272 | } 273 | } 274 | -------------------------------------------------------------------------------- /middleware.go: -------------------------------------------------------------------------------- 1 | package kami 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/guregu/kami/treemux" 7 | ) 8 | 9 | type wares struct { 10 | middleware map[string][]Middleware 11 | afterware map[string][]Afterware 12 | wildcards *treemux.TreeMux 13 | afterWildcards *treemux.TreeMux 14 | } 15 | 16 | func newWares() *wares { 17 | return new(wares) 18 | } 19 | 20 | // Use registers middleware to run for the given path. 21 | // See the global Use function's documents for information on how middleware works. 22 | func (m *wares) Use(path string, mw MiddlewareType) { 23 | if containsWildcard(path) { 24 | if m.wildcards == nil { 25 | m.wildcards = treemux.New() 26 | } 27 | mw := convert(mw) 28 | iface, _ := m.wildcards.Get(path) 29 | if chain, ok := iface.(*[]Middleware); ok { 30 | *chain = append(*chain, mw) 31 | } else { 32 | chain := []Middleware{mw} 33 | m.wildcards.Set(path, &chain) 34 | } 35 | } else { 36 | if m.middleware == nil { 37 | m.middleware = make(map[string][]Middleware) 38 | } 39 | fn := convert(mw) 40 | chain := m.middleware[path] 41 | chain = append(chain, fn) 42 | m.middleware[path] = chain 43 | } 44 | } 45 | 46 | // After registers middleware to run for the given path after normal middleware added with Use has run. 47 | // See the global After function's documents for information on how middleware works. 48 | func (m *wares) After(path string, afterware AfterwareType) { 49 | aw := convertAW(afterware) 50 | if containsWildcard(path) { 51 | if m.afterWildcards == nil { 52 | m.afterWildcards = treemux.New() 53 | } 54 | iface, _ := m.afterWildcards.Get(path) 55 | if chain, ok := iface.(*[]Afterware); ok { 56 | *chain = append([]Afterware{aw}, *chain...) 57 | } else { 58 | chain := []Afterware{aw} 59 | m.afterWildcards.Set(path, &chain) 60 | } 61 | } else { 62 | if m.afterware == nil { 63 | m.afterware = make(map[string][]Afterware) 64 | } 65 | m.afterware[path] = append([]Afterware{aw}, m.afterware[path]...) 66 | } 67 | } 68 | 69 | var defaultMW = newWares() // for the global router 70 | 71 | // Use registers middleware to run for the given path. 72 | // Middleware will be executed hierarchically, starting with the least specific path. 73 | // Middleware under the same path will be executed in order of registration. 74 | // You may use wildcards in the path. Wildcard middleware will be run last, 75 | // after all hierarchical middleware has run. 76 | // 77 | // Adding middleware is not threadsafe. 78 | // 79 | // WARNING: kami middleware is run in sequence, but standard middleware is chained; 80 | // middleware that expects its code to run after the next handler, such as 81 | // standard loggers and panic handlers, will not work as expected. 82 | // Use kami.LogHandler and kami.PanicHandler instead. 83 | // Standard middleware that does not call the next handler to stop the request is supported. 84 | func Use(path string, mw MiddlewareType) { 85 | defaultMW.Use(path, mw) 86 | } 87 | 88 | // After registers afterware to run after middleware and the request handler has run. 89 | // Afterware is like middleware, but everything is in reverse. 90 | // Afterware will be executed hierarchically, starting with wildcards and then 91 | // the most specific path, ending with /. 92 | // Afterware under the same path will be executed in the opposite order of registration. 93 | func After(path string, aw AfterwareType) { 94 | defaultMW.After(path, aw) 95 | } 96 | 97 | // Middleware run functions are in versioned files. 98 | 99 | func (m *wares) needsWrapper() bool { 100 | return m.afterware != nil || m.afterWildcards != nil 101 | } 102 | 103 | func containsWildcard(path string) bool { 104 | return strings.Contains(path, "/:") || strings.Contains(path, "/*") 105 | } 106 | -------------------------------------------------------------------------------- /middleware_17.go: -------------------------------------------------------------------------------- 1 | // +build go1.7,!go1.9 2 | 3 | package kami 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "net/http" 9 | 10 | "github.com/zenazn/goji/web/mutil" 11 | netcontext "golang.org/x/net/context" 12 | ) 13 | 14 | // convert turns standard http middleware into kami Middleware if needed. 15 | func convert(mw MiddlewareType) Middleware { 16 | switch x := mw.(type) { 17 | case Middleware: 18 | return x 19 | case func(context.Context, http.ResponseWriter, *http.Request) context.Context: 20 | return Middleware(x) 21 | case func(netcontext.Context, http.ResponseWriter, *http.Request) netcontext.Context: 22 | return Middleware(func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 23 | return x(ctx, w, r) 24 | }) 25 | case func(ContextHandler) ContextHandler: 26 | return func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 27 | var dh dummyHandler 28 | x(&dh).ServeHTTPContext(ctx, w, r) 29 | if !dh { 30 | return nil 31 | } 32 | return ctx 33 | } 34 | case func(OldContextHandler) OldContextHandler: 35 | return func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 36 | var dh oldDummyHandler 37 | x(&dh).ServeHTTPContext(ctx, w, r) 38 | if !dh { 39 | return nil 40 | } 41 | return ctx 42 | } 43 | case func(http.Handler) http.Handler: 44 | return func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 45 | var dh dummyHandler 46 | x(&dh).ServeHTTP(w, r) 47 | if !dh { 48 | return nil 49 | } 50 | return ctx 51 | } 52 | case http.Handler: 53 | return Middleware(func(_ context.Context, w http.ResponseWriter, r *http.Request) context.Context { 54 | x.ServeHTTP(w, r) 55 | return r.Context() 56 | }) 57 | case func(w http.ResponseWriter, r *http.Request): 58 | return Middleware(func(_ context.Context, w http.ResponseWriter, r *http.Request) context.Context { 59 | x(w, r) 60 | return r.Context() 61 | }) 62 | case func(w http.ResponseWriter, r *http.Request) context.Context: 63 | return Middleware(func(_ context.Context, w http.ResponseWriter, r *http.Request) context.Context { 64 | return x(w, r) 65 | }) 66 | } 67 | panic(fmt.Errorf("unsupported MiddlewareType: %T", mw)) 68 | } 69 | 70 | // convertAW 71 | func convertAW(aw AfterwareType) Afterware { 72 | switch x := aw.(type) { 73 | case Afterware: 74 | return x 75 | case func(context.Context, mutil.WriterProxy, *http.Request) context.Context: 76 | return Afterware(x) 77 | case func(netcontext.Context, mutil.WriterProxy, *http.Request) netcontext.Context: 78 | return func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 79 | return x(ctx, w, r) 80 | } 81 | case func(context.Context, *http.Request) context.Context: 82 | return func(ctx context.Context, _ mutil.WriterProxy, r *http.Request) context.Context { 83 | return x(ctx, r) 84 | } 85 | case func(netcontext.Context, *http.Request) netcontext.Context: 86 | return func(ctx context.Context, _ mutil.WriterProxy, r *http.Request) context.Context { 87 | return x(ctx, r) 88 | } 89 | case func(context.Context) context.Context: 90 | return func(ctx context.Context, _ mutil.WriterProxy, _ *http.Request) context.Context { 91 | return x(ctx) 92 | } 93 | case func(netcontext.Context) netcontext.Context: 94 | return func(ctx context.Context, _ mutil.WriterProxy, _ *http.Request) context.Context { 95 | return x(ctx) 96 | } 97 | case Middleware: 98 | return func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 99 | return x(ctx, w, r) 100 | } 101 | case func(context.Context, http.ResponseWriter, *http.Request) context.Context: 102 | return func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 103 | return x(ctx, w, r) 104 | } 105 | case func(netcontext.Context, http.ResponseWriter, *http.Request) netcontext.Context: 106 | return func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 107 | return x(ctx, w, r) 108 | } 109 | case func(w http.ResponseWriter, r *http.Request) context.Context: 110 | return Afterware(func(_ context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 111 | return x(w, r) 112 | }) 113 | case func(w mutil.WriterProxy, r *http.Request) context.Context: 114 | return Afterware(func(_ context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 115 | return x(w, r) 116 | }) 117 | case http.Handler: 118 | return Afterware(func(_ context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 119 | x.ServeHTTP(w, r) 120 | return r.Context() 121 | }) 122 | case func(w http.ResponseWriter, r *http.Request): 123 | return Afterware(func(_ context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 124 | x(w, r) 125 | return r.Context() 126 | }) 127 | case func(w mutil.WriterProxy, r *http.Request): 128 | return Afterware(func(_ context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 129 | x(w, r) 130 | return r.Context() 131 | }) 132 | } 133 | panic(fmt.Errorf("unsupported AfterwareType: %T", aw)) 134 | } 135 | 136 | // oldDummyHandler is dummyHandler compatible with the old context type. 137 | type oldDummyHandler bool 138 | 139 | func (dh *oldDummyHandler) ServeHTTP(http.ResponseWriter, *http.Request) { 140 | *dh = true 141 | } 142 | 143 | func (dh *oldDummyHandler) ServeHTTPContext(_ netcontext.Context, _ http.ResponseWriter, _ *http.Request) { 144 | *dh = true 145 | } 146 | -------------------------------------------------------------------------------- /middleware_19.go: -------------------------------------------------------------------------------- 1 | // +build go1.9 2 | 3 | package kami 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "net/http" 9 | 10 | "github.com/zenazn/goji/web/mutil" 11 | ) 12 | 13 | // convert turns standard http middleware into kami Middleware if needed. 14 | func convert(mw MiddlewareType) Middleware { 15 | switch x := mw.(type) { 16 | case Middleware: 17 | return x 18 | case func(context.Context, http.ResponseWriter, *http.Request) context.Context: 19 | return Middleware(x) 20 | case func(ContextHandler) ContextHandler: 21 | return func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 22 | var dh dummyHandler 23 | x(&dh).ServeHTTPContext(ctx, w, r) 24 | if !dh { 25 | return nil 26 | } 27 | return ctx 28 | } 29 | case func(http.Handler) http.Handler: 30 | return func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 31 | var dh dummyHandler 32 | x(&dh).ServeHTTP(w, r) 33 | if !dh { 34 | return nil 35 | } 36 | return ctx 37 | } 38 | case http.Handler: 39 | return Middleware(func(_ context.Context, w http.ResponseWriter, r *http.Request) context.Context { 40 | x.ServeHTTP(w, r) 41 | return r.Context() 42 | }) 43 | case func(w http.ResponseWriter, r *http.Request): 44 | return Middleware(func(_ context.Context, w http.ResponseWriter, r *http.Request) context.Context { 45 | x(w, r) 46 | return r.Context() 47 | }) 48 | case func(w http.ResponseWriter, r *http.Request) context.Context: 49 | return Middleware(func(_ context.Context, w http.ResponseWriter, r *http.Request) context.Context { 50 | return x(w, r) 51 | }) 52 | } 53 | panic(fmt.Errorf("unsupported MiddlewareType: %T", mw)) 54 | } 55 | 56 | // convertAW 57 | func convertAW(aw AfterwareType) Afterware { 58 | switch x := aw.(type) { 59 | case Afterware: 60 | return x 61 | case func(context.Context, mutil.WriterProxy, *http.Request) context.Context: 62 | return Afterware(x) 63 | case func(context.Context, *http.Request) context.Context: 64 | return func(ctx context.Context, _ mutil.WriterProxy, r *http.Request) context.Context { 65 | return x(ctx, r) 66 | } 67 | case func(context.Context) context.Context: 68 | return func(ctx context.Context, _ mutil.WriterProxy, _ *http.Request) context.Context { 69 | return x(ctx) 70 | } 71 | case Middleware: 72 | return func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 73 | return x(ctx, w, r) 74 | } 75 | case func(context.Context, http.ResponseWriter, *http.Request) context.Context: 76 | return func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 77 | return x(ctx, w, r) 78 | } 79 | case func(w http.ResponseWriter, r *http.Request) context.Context: 80 | return Afterware(func(_ context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 81 | return x(w, r) 82 | }) 83 | case func(w mutil.WriterProxy, r *http.Request) context.Context: 84 | return Afterware(func(_ context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 85 | return x(w, r) 86 | }) 87 | case http.Handler: 88 | return Afterware(func(_ context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 89 | x.ServeHTTP(w, r) 90 | return r.Context() 91 | }) 92 | case func(w http.ResponseWriter, r *http.Request): 93 | return Afterware(func(_ context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 94 | x(w, r) 95 | return r.Context() 96 | }) 97 | case func(w mutil.WriterProxy, r *http.Request): 98 | return Afterware(func(_ context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 99 | x(w, r) 100 | return r.Context() 101 | }) 102 | } 103 | panic(fmt.Errorf("unsupported AfterwareType: %T", aw)) 104 | } 105 | -------------------------------------------------------------------------------- /middleware_new.go: -------------------------------------------------------------------------------- 1 | // +build go1.7 2 | 3 | package kami 4 | 5 | import ( 6 | "context" 7 | "net/http" 8 | "unicode/utf8" 9 | 10 | "github.com/zenazn/goji/web/mutil" 11 | ) 12 | 13 | // Middleware is a function that takes the current request context and returns a new request context. 14 | // You can use middleware to build your context before your handler handles a request. 15 | // As a special case, middleware that returns nil will halt middleware and handler execution (LogHandler will still run). 16 | type Middleware func(context.Context, http.ResponseWriter, *http.Request) context.Context 17 | 18 | // MiddlewareType represents types that kami can convert to Middleware. 19 | // kami will try its best to convert standard, non-context middleware. 20 | // See the Use function for important information about how kami middleware is run. 21 | // The following concrete types are accepted: 22 | // - Middleware 23 | // - func(context.Context, http.ResponseWriter, *http.Request) context.Context 24 | // - func(http.ResponseWriter, *http.Request) context.Context 25 | // - func(http.Handler) http.Handler [* see Use docs] 26 | // - func(http.ContextHandler) http.ContextHandler [* see Use docs] 27 | // - http.Handler [read only] 28 | // - func(http.ResponseWriter, *http.Request) [read only] 29 | // The old x/net/context is also supported. 30 | type MiddlewareType interface{} 31 | 32 | // Afterware is a function that will run after middleware and the request. 33 | // Afterware takes the request context and returns a new context, but unlike middleware, 34 | // returning nil won't halt execution of other afterware. 35 | type Afterware func(context.Context, mutil.WriterProxy, *http.Request) context.Context 36 | 37 | // Afterware represents types that kami can convert to Afterware. 38 | // The following concrete types are accepted: 39 | // - Afterware 40 | // - func(context.Context, mutil.WriterProxy, *http.Request) context.Context 41 | // - func(context.Context, http.ResponseWriter, *http.Request) context.Context 42 | // - func(context.Context, *http.Request) context.Context 43 | // - func(context.Context) context.Context 44 | // - Middleware types 45 | // The old x/net/context is also supported. 46 | type AfterwareType interface{} 47 | 48 | // run runs the middleware chain for a particular request. 49 | // run returns false if it should stop early. 50 | func (m *wares) run(ctx context.Context, w http.ResponseWriter, r *http.Request) (*http.Request, context.Context, bool) { 51 | if m.middleware != nil { 52 | // hierarchical middleware 53 | for i, c := range r.URL.Path { 54 | if c == '/' || i == len(r.URL.Path)-1 { 55 | mws, ok := m.middleware[r.URL.Path[:i+1]] 56 | if !ok { 57 | continue 58 | } 59 | for _, mw := range mws { 60 | // return nil context to stop 61 | result := mw(ctx, w, r) 62 | if result == nil { 63 | return r, ctx, false 64 | } 65 | if result != ctx { 66 | r = r.WithContext(result) 67 | } 68 | ctx = result 69 | } 70 | } 71 | } 72 | } 73 | 74 | if m.wildcards != nil { 75 | // wildcard middleware 76 | if wild, params := m.wildcards.Get(r.URL.Path); wild != nil { 77 | if mws, ok := wild.(*[]Middleware); ok { 78 | ctx = mergeParams(ctx, params) 79 | r = r.WithContext(ctx) 80 | for _, mw := range *mws { 81 | result := mw(ctx, w, r) 82 | if result == nil { 83 | return r, ctx, false 84 | } 85 | if result != ctx { 86 | r = r.WithContext(result) 87 | } 88 | ctx = result 89 | } 90 | } 91 | } 92 | } 93 | 94 | return r, ctx, true 95 | } 96 | 97 | // after runs the afterware chain for a particular request. 98 | // after can't stop early 99 | func (m *wares) after(ctx context.Context, w mutil.WriterProxy, r *http.Request) (*http.Request, context.Context) { 100 | if m.afterWildcards != nil { 101 | // wildcard afterware 102 | if wild, params := m.afterWildcards.Get(r.URL.Path); wild != nil { 103 | if aws, ok := wild.(*[]Afterware); ok { 104 | ctx = mergeParams(ctx, params) 105 | r = r.WithContext(ctx) 106 | for _, aw := range *aws { 107 | result := aw(ctx, w, r) 108 | if result != nil { 109 | if result != ctx { 110 | r = r.WithContext(result) 111 | } 112 | ctx = result 113 | } 114 | } 115 | } 116 | } 117 | } 118 | 119 | if m.afterware != nil { 120 | // hierarchical afterware, like middleware in reverse 121 | path := r.URL.Path 122 | for len(path) > 0 { 123 | chr, size := utf8.DecodeLastRuneInString(path) 124 | if chr == '/' || len(path) == len(r.URL.Path) { 125 | for _, aw := range m.afterware[path] { 126 | result := aw(ctx, w, r) 127 | if result != nil { 128 | if result != ctx { 129 | r = r.WithContext(result) 130 | } 131 | ctx = result 132 | } 133 | } 134 | } 135 | path = path[:len(path)-size] 136 | } 137 | } 138 | 139 | return r, ctx 140 | } 141 | 142 | // dummyHandler is used to keep track of whether the next middleware was called or not. 143 | type dummyHandler bool 144 | 145 | func (dh *dummyHandler) ServeHTTP(http.ResponseWriter, *http.Request) { 146 | *dh = true 147 | } 148 | 149 | func (dh *dummyHandler) ServeHTTPContext(_ context.Context, _ http.ResponseWriter, _ *http.Request) { 150 | *dh = true 151 | } 152 | -------------------------------------------------------------------------------- /middleware_old.go: -------------------------------------------------------------------------------- 1 | // +build !go1.7 2 | 3 | package kami 4 | 5 | import ( 6 | "fmt" 7 | "net/http" 8 | "unicode/utf8" 9 | 10 | "github.com/zenazn/goji/web/mutil" 11 | "golang.org/x/net/context" 12 | ) 13 | 14 | // Middleware is a function that takes the current request context and returns a new request context. 15 | // You can use middleware to build your context before your handler handles a request. 16 | // As a special case, middleware that returns nil will halt middleware and handler execution (LogHandler will still run). 17 | type Middleware func(context.Context, http.ResponseWriter, *http.Request) context.Context 18 | 19 | // MiddlewareType represents types that kami can convert to Middleware. 20 | // kami will try its best to convert standard, non-context middleware. 21 | // See the Use function for important information about how kami middleware is run. 22 | // The following concrete types are accepted: 23 | // - Middleware 24 | // - func(context.Context, http.ResponseWriter, *http.Request) context.Context 25 | // - func(http.Handler) http.Handler [* see Use docs] 26 | // - func(http.ContextHandler) http.ContextHandler [* see Use docs] 27 | type MiddlewareType interface{} 28 | 29 | // Afterware is a function that will run after middleware and the request. 30 | // Afterware takes the request context and returns a new context, but unlike middleware, 31 | // returning nil won't halt execution of other afterware. 32 | type Afterware func(context.Context, mutil.WriterProxy, *http.Request) context.Context 33 | 34 | // Afterware represents types that kami can convert to Afterware. 35 | // The following concrete types are accepted: 36 | // - Afterware 37 | // - func(context.Context, mutil.WriterProxy, *http.Request) context.Context 38 | // - func(context.Context, http.ResponseWriter, *http.Request) context.Context 39 | // - func(context.Context, *http.Request) context.Context 40 | // - func(context.Context) context.Context 41 | // - Middleware 42 | type AfterwareType interface{} 43 | 44 | // run runs the middleware chain for a particular request. 45 | // run returns false if it should stop early. 46 | func (m *wares) run(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, bool) { 47 | if m.middleware != nil { 48 | // hierarchical middleware 49 | for i, c := range r.URL.Path { 50 | if c == '/' || i == len(r.URL.Path)-1 { 51 | mws, ok := m.middleware[r.URL.Path[:i+1]] 52 | if !ok { 53 | continue 54 | } 55 | for _, mw := range mws { 56 | // return nil context to stop 57 | result := mw(ctx, w, r) 58 | if result == nil { 59 | return ctx, false 60 | } 61 | ctx = result 62 | } 63 | } 64 | } 65 | } 66 | 67 | if m.wildcards != nil { 68 | // wildcard middleware 69 | if wild, params := m.wildcards.Get(r.URL.Path); wild != nil { 70 | if mws, ok := wild.(*[]Middleware); ok { 71 | ctx = mergeParams(ctx, params) 72 | for _, mw := range *mws { 73 | result := mw(ctx, w, r) 74 | if result == nil { 75 | return ctx, false 76 | } 77 | ctx = result 78 | } 79 | } 80 | } 81 | } 82 | 83 | return ctx, true 84 | } 85 | 86 | // after runs the afterware chain for a particular request. 87 | // after can't stop early 88 | func (m *wares) after(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 89 | if m.afterWildcards != nil { 90 | // wildcard afterware 91 | if wild, params := m.afterWildcards.Get(r.URL.Path); wild != nil { 92 | if aws, ok := wild.(*[]Afterware); ok { 93 | ctx = mergeParams(ctx, params) 94 | for _, aw := range *aws { 95 | result := aw(ctx, w, r) 96 | if result != nil { 97 | ctx = result 98 | } 99 | } 100 | } 101 | } 102 | } 103 | 104 | if m.afterware != nil { 105 | // hierarchical afterware, like middleware in reverse 106 | path := r.URL.Path 107 | for len(path) > 0 { 108 | chr, size := utf8.DecodeLastRuneInString(path) 109 | if chr == '/' || len(path) == len(r.URL.Path) { 110 | for _, aw := range m.afterware[path] { 111 | result := aw(ctx, w, r) 112 | if result != nil { 113 | ctx = result 114 | } 115 | } 116 | } 117 | path = path[:len(path)-size] 118 | } 119 | } 120 | 121 | return ctx 122 | } 123 | 124 | // convert turns standard http middleware into kami Middleware if needed. 125 | func convert(mw MiddlewareType) Middleware { 126 | switch x := mw.(type) { 127 | case Middleware: 128 | return x 129 | case func(context.Context, http.ResponseWriter, *http.Request) context.Context: 130 | return Middleware(x) 131 | case func(ContextHandler) ContextHandler: 132 | return func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 133 | var dh dummyHandler 134 | x(&dh).ServeHTTPContext(ctx, w, r) 135 | if !dh { 136 | return nil 137 | } 138 | return ctx 139 | } 140 | case func(http.Handler) http.Handler: 141 | return func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 142 | var dh dummyHandler 143 | x(&dh).ServeHTTP(w, r) 144 | if !dh { 145 | return nil 146 | } 147 | return ctx 148 | } 149 | } 150 | panic(fmt.Errorf("unsupported MiddlewareType: %T", mw)) 151 | } 152 | 153 | // convertAW 154 | func convertAW(aw AfterwareType) Afterware { 155 | switch x := aw.(type) { 156 | case Afterware: 157 | return x 158 | case func(context.Context, mutil.WriterProxy, *http.Request) context.Context: 159 | return Afterware(x) 160 | case func(context.Context, *http.Request) context.Context: 161 | return func(ctx context.Context, _ mutil.WriterProxy, r *http.Request) context.Context { 162 | return x(ctx, r) 163 | } 164 | case func(context.Context) context.Context: 165 | return func(ctx context.Context, _ mutil.WriterProxy, _ *http.Request) context.Context { 166 | return x(ctx) 167 | } 168 | case Middleware: 169 | return func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 170 | return x(ctx, w, r) 171 | } 172 | case func(context.Context, http.ResponseWriter, *http.Request) context.Context: 173 | return func(ctx context.Context, w mutil.WriterProxy, r *http.Request) context.Context { 174 | return x(ctx, w, r) 175 | } 176 | } 177 | panic(fmt.Errorf("unsupported AfterwareType: %T", aw)) 178 | } 179 | 180 | // dummyHandler is used to keep track of whether the next middleware was called or not. 181 | type dummyHandler bool 182 | 183 | func (dh *dummyHandler) ServeHTTP(http.ResponseWriter, *http.Request) { 184 | *dh = true 185 | } 186 | 187 | func (dh *dummyHandler) ServeHTTPContext(_ context.Context, _ http.ResponseWriter, _ *http.Request) { 188 | *dh = true 189 | } 190 | -------------------------------------------------------------------------------- /middleware_test.go: -------------------------------------------------------------------------------- 1 | package kami_test 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/guregu/kami" 8 | "golang.org/x/net/context" 9 | ) 10 | 11 | func TestWildcardMiddleware(t *testing.T) { 12 | kami.Reset() 13 | kami.Use("/user/:mid/edit", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 14 | if kami.Param(ctx, "mid") == "403" { 15 | w.WriteHeader(http.StatusForbidden) 16 | return nil 17 | } 18 | 19 | return context.WithValue(ctx, "middleware id", kami.Param(ctx, "mid")) 20 | }) 21 | kami.Patch("/user/:id/edit", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 22 | if kami.Param(ctx, "mid") != kami.Param(ctx, "id") { 23 | t.Error("mid != id") 24 | } 25 | 26 | if ctx.Value("middleware id").(string) != kami.Param(ctx, "id") { 27 | t.Error("middleware values not propagating") 28 | } 29 | }) 30 | kami.Head("/user/:id", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 31 | if ctx.Value("middleware id") != nil { 32 | t.Error("wildcard middleware shouldn't have been called") 33 | w.WriteHeader(http.StatusInternalServerError) 34 | return 35 | } 36 | w.WriteHeader(http.StatusOK) 37 | }) 38 | 39 | // normal case 40 | expectResponseCode(t, "PATCH", "/user/42/edit", http.StatusOK) 41 | 42 | // should stop early 43 | expectResponseCode(t, "PATCH", "/user/403/edit", http.StatusForbidden) 44 | 45 | // make sure the middleware isn't over eager 46 | expectResponseCode(t, "HEAD", "/user/403", http.StatusOK) 47 | } 48 | 49 | func TestHierarchicalStop(t *testing.T) { 50 | kami.Reset() 51 | kami.Use("/nope/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 52 | w.WriteHeader(http.StatusForbidden) 53 | return nil 54 | }) 55 | kami.Delete("/nope/test", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 56 | w.WriteHeader(http.StatusOK) 57 | }) 58 | 59 | expectResponseCode(t, "DELETE", "/nope/test", http.StatusForbidden) 60 | } 61 | -------------------------------------------------------------------------------- /mux.go: -------------------------------------------------------------------------------- 1 | // +build !go1.7 2 | 3 | package kami 4 | 5 | import ( 6 | "net/http" 7 | 8 | "github.com/dimfeld/httptreemux" 9 | "github.com/zenazn/goji/web/mutil" 10 | "golang.org/x/net/context" 11 | ) 12 | 13 | // Mux is an independent kami router and middleware stack. Manipulating it is not threadsafe. 14 | type Mux struct { 15 | // Context is the root "god object" for this mux, 16 | // from which every request's context will derive. 17 | Context context.Context 18 | // Cancel will, if true, automatically cancel the context of incoming requests after they finish. 19 | Cancel bool 20 | // PanicHandler will, if set, be called on panics. 21 | // You can use kami.Exception(ctx) within the panic handler to get panic details. 22 | PanicHandler HandlerType 23 | // LogHandler will, if set, wrap every request and be called at the very end. 24 | LogHandler func(context.Context, mutil.WriterProxy, *http.Request) 25 | 26 | routes *httptreemux.TreeMux 27 | enable405 bool 28 | *wares 29 | } 30 | 31 | // New creates a new independent kami router and middleware stack. 32 | // It is totally separate from the global kami.Context and middleware stack. 33 | func New() *Mux { 34 | m := &Mux{ 35 | Context: context.Background(), 36 | routes: newRouter(), 37 | wares: newWares(), 38 | enable405: true, 39 | } 40 | m.NotFound(nil) 41 | m.MethodNotAllowed(nil) 42 | return m 43 | } 44 | 45 | // ServeHTTP handles an HTTP request, running middleware and forwarding the request to the appropriate handler. 46 | // Implements the http.Handler interface for easy composition with other frameworks. 47 | func (m *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { 48 | m.routes.ServeHTTP(w, r) 49 | } 50 | 51 | // Handle registers an arbitrary method handler under the given path. 52 | func (m *Mux) Handle(method, path string, handler HandlerType) { 53 | m.routes.Handle(method, path, m.bless(wrap(handler))) 54 | } 55 | 56 | // Get registers a GET handler under the given path. 57 | func (m *Mux) Get(path string, handler HandlerType) { 58 | m.Handle("GET", path, handler) 59 | } 60 | 61 | // Post registers a POST handler under the given path. 62 | func (m *Mux) Post(path string, handler HandlerType) { 63 | m.Handle("POST", path, handler) 64 | } 65 | 66 | // Put registers a PUT handler under the given path. 67 | func (m *Mux) Put(path string, handler HandlerType) { 68 | m.Handle("PUT", path, handler) 69 | } 70 | 71 | // Patch registers a PATCH handler under the given path. 72 | func (m *Mux) Patch(path string, handler HandlerType) { 73 | m.Handle("PATCH", path, handler) 74 | } 75 | 76 | // Head registers a HEAD handler under the given path. 77 | func (m *Mux) Head(path string, handler HandlerType) { 78 | m.Handle("HEAD", path, handler) 79 | } 80 | 81 | // Options registers a OPTIONS handler under the given path. 82 | func (m *Mux) Options(path string, handler HandlerType) { 83 | m.Handle("OPTIONS", path, handler) 84 | } 85 | 86 | // Delete registers a DELETE handler under the given path. 87 | func (m *Mux) Delete(path string, handler HandlerType) { 88 | m.Handle("DELETE", path, handler) 89 | } 90 | 91 | // NotFound registers a special handler for unregistered (404) paths. 92 | // If handle is nil, use the default http.NotFound behavior. 93 | func (m *Mux) NotFound(handler HandlerType) { 94 | // set up the default handler if needed 95 | // we need to bless this so middleware will still run for a 404 request 96 | if handler == nil { 97 | handler = HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 98 | http.NotFound(w, r) 99 | }) 100 | } 101 | 102 | h := m.bless(wrap(handler)) 103 | m.routes.NotFoundHandler = func(w http.ResponseWriter, r *http.Request) { 104 | h(w, r, nil) 105 | } 106 | } 107 | 108 | // MethodNotAllowed registers a special handler for automatically responding 109 | // to invalid method requests (405). 110 | func (m *Mux) MethodNotAllowed(handler HandlerType) { 111 | if handler == nil { 112 | handler = HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 113 | http.Error(w, 114 | http.StatusText(http.StatusMethodNotAllowed), 115 | http.StatusMethodNotAllowed, 116 | ) 117 | }) 118 | } 119 | 120 | h := m.bless(wrap(handler)) 121 | m.routes.MethodNotAllowedHandler = func(w http.ResponseWriter, r *http.Request, methods map[string]httptreemux.HandlerFunc) { 122 | if !m.enable405 { 123 | m.routes.NotFoundHandler(w, r) 124 | return 125 | } 126 | h(w, r, nil) 127 | } 128 | } 129 | 130 | // EnableMethodNotAllowed enables or disables automatic Method Not Allowed handling. 131 | // Note that this is enabled by default. 132 | func (m *Mux) EnableMethodNotAllowed(enabled bool) { 133 | m.enable405 = enabled 134 | } 135 | 136 | // bless creates a new kamified handler. 137 | func (m *Mux) bless(h ContextHandler) httptreemux.HandlerFunc { 138 | k := kami{ 139 | handler: h, 140 | base: &m.Context, 141 | autocancel: &m.Cancel, 142 | middleware: m.wares, 143 | panicHandler: &m.PanicHandler, 144 | logHandler: &m.LogHandler, 145 | } 146 | return k.handle 147 | } 148 | -------------------------------------------------------------------------------- /mux_17.go: -------------------------------------------------------------------------------- 1 | // +build go1.7 2 | 3 | package kami 4 | 5 | import ( 6 | "context" 7 | "net/http" 8 | 9 | "github.com/dimfeld/httptreemux" 10 | "github.com/zenazn/goji/web/mutil" 11 | ) 12 | 13 | // Mux is an independent kami router and middleware stack. Manipulating it is not threadsafe. 14 | type Mux struct { 15 | // Context is the root "god object" for this mux, 16 | // from which every request's context will derive. 17 | Context context.Context 18 | // Cancel will, if true, automatically cancel the context of incoming requests after they finish. 19 | Cancel bool 20 | // PanicHandler will, if set, be called on panics. 21 | // You can use kami.Exception(ctx) within the panic handler to get panic details. 22 | PanicHandler HandlerType 23 | // LogHandler will, if set, wrap every request and be called at the very end. 24 | LogHandler func(context.Context, mutil.WriterProxy, *http.Request) 25 | 26 | routes *httptreemux.TreeMux 27 | enable405 bool 28 | *wares 29 | } 30 | 31 | // New creates a new independent kami router and middleware stack. 32 | // It is totally separate from the global kami.Context and middleware stack. 33 | func New() *Mux { 34 | m := &Mux{ 35 | Context: context.Background(), 36 | routes: newRouter(), 37 | wares: newWares(), 38 | enable405: true, 39 | } 40 | m.NotFound(nil) 41 | m.MethodNotAllowed(nil) 42 | return m 43 | } 44 | 45 | // ServeHTTP handles an HTTP request, running middleware and forwarding the request to the appropriate handler. 46 | // Implements the http.Handler interface for easy composition with other frameworks. 47 | func (m *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { 48 | m.routes.ServeHTTP(w, r) 49 | } 50 | 51 | // Handle registers an arbitrary method handler under the given path. 52 | func (m *Mux) Handle(method, path string, handler HandlerType) { 53 | m.routes.Handle(method, path, m.bless(wrap(handler))) 54 | } 55 | 56 | // Get registers a GET handler under the given path. 57 | func (m *Mux) Get(path string, handler HandlerType) { 58 | m.Handle("GET", path, handler) 59 | } 60 | 61 | // Post registers a POST handler under the given path. 62 | func (m *Mux) Post(path string, handler HandlerType) { 63 | m.Handle("POST", path, handler) 64 | } 65 | 66 | // Put registers a PUT handler under the given path. 67 | func (m *Mux) Put(path string, handler HandlerType) { 68 | m.Handle("PUT", path, handler) 69 | } 70 | 71 | // Patch registers a PATCH handler under the given path. 72 | func (m *Mux) Patch(path string, handler HandlerType) { 73 | m.Handle("PATCH", path, handler) 74 | } 75 | 76 | // Head registers a HEAD handler under the given path. 77 | func (m *Mux) Head(path string, handler HandlerType) { 78 | m.Handle("HEAD", path, handler) 79 | } 80 | 81 | // Options registers a OPTIONS handler under the given path. 82 | func (m *Mux) Options(path string, handler HandlerType) { 83 | m.Handle("OPTIONS", path, handler) 84 | } 85 | 86 | // Delete registers a DELETE handler under the given path. 87 | func (m *Mux) Delete(path string, handler HandlerType) { 88 | m.Handle("DELETE", path, handler) 89 | } 90 | 91 | // NotFound registers a special handler for unregistered (404) paths. 92 | // If handle is nil, use the default http.NotFound behavior. 93 | func (m *Mux) NotFound(handler HandlerType) { 94 | // set up the default handler if needed 95 | // we need to bless this so middleware will still run for a 404 request 96 | if handler == nil { 97 | handler = HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 98 | http.NotFound(w, r) 99 | }) 100 | } 101 | 102 | h := m.bless(wrap(handler)) 103 | m.routes.NotFoundHandler = func(w http.ResponseWriter, r *http.Request) { 104 | h(w, r, nil) 105 | } 106 | } 107 | 108 | // MethodNotAllowed registers a special handler for automatically responding 109 | // to invalid method requests (405). 110 | func (m *Mux) MethodNotAllowed(handler HandlerType) { 111 | if handler == nil { 112 | handler = HandlerFunc(func(_ context.Context, w http.ResponseWriter, r *http.Request) { 113 | http.Error(w, 114 | http.StatusText(http.StatusMethodNotAllowed), 115 | http.StatusMethodNotAllowed, 116 | ) 117 | }) 118 | } 119 | 120 | h := m.bless(wrap(handler)) 121 | m.routes.MethodNotAllowedHandler = func(w http.ResponseWriter, r *http.Request, methods map[string]httptreemux.HandlerFunc) { 122 | if !m.enable405 { 123 | m.routes.NotFoundHandler(w, r) 124 | return 125 | } 126 | h(w, r, nil) 127 | } 128 | } 129 | 130 | // EnableMethodNotAllowed enables or disables automatic Method Not Allowed handling. 131 | // Note that this is enabled by default. 132 | func (m *Mux) EnableMethodNotAllowed(enabled bool) { 133 | m.enable405 = enabled 134 | } 135 | 136 | // bless creates a new kamified handler. 137 | func (m *Mux) bless(h ContextHandler) httptreemux.HandlerFunc { 138 | k := kami{ 139 | handler: h, 140 | base: &m.Context, 141 | autocancel: &m.Cancel, 142 | middleware: m.wares, 143 | panicHandler: &m.PanicHandler, 144 | logHandler: &m.LogHandler, 145 | } 146 | return k.handle 147 | } 148 | -------------------------------------------------------------------------------- /mux_test.go: -------------------------------------------------------------------------------- 1 | package kami_test 2 | 3 | import ( 4 | "io" 5 | "io/ioutil" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | // "github.com/zenazn/goji/web/mutil" 11 | "golang.org/x/net/context" 12 | 13 | "github.com/guregu/kami" 14 | ) 15 | 16 | // TODO: this mostly a copy/paste of kami_test.go, rewrite it! 17 | func TestKamiMux(t *testing.T) { 18 | mux := kami.New() 19 | 20 | // normal stuff 21 | mux.Use("/mux/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 22 | return context.WithValue(ctx, "test1", "1") 23 | }) 24 | mux.Use("/mux/v2/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 25 | return context.WithValue(ctx, "test2", "2") 26 | }) 27 | mux.Get("/mux/v2/papers/:page", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 28 | page := kami.Param(ctx, "page") 29 | if page == "" { 30 | panic("blank page") 31 | } 32 | io.WriteString(w, page) 33 | 34 | test1 := ctx.Value("test1").(string) 35 | test2 := ctx.Value("test2").(string) 36 | 37 | if test1 != "1" || test2 != "2" { 38 | t.Error("unexpected ctx value:", test1, test2) 39 | } 40 | }) 41 | 42 | // 404 stuff 43 | mux.Use("/mux/missing/", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 44 | return context.WithValue(ctx, "ok", true) 45 | }) 46 | mux.NotFound(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 47 | ok, _ := ctx.Value("ok").(bool) 48 | if !ok { 49 | w.WriteHeader(http.StatusInternalServerError) 50 | return 51 | } 52 | w.WriteHeader(http.StatusTeapot) 53 | }) 54 | 55 | // 405 stuff 56 | mux.Use("/mux/method_not_allowed", func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 57 | return context.WithValue(ctx, "ok", true) 58 | }) 59 | mux.MethodNotAllowed(func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 60 | ok, _ := ctx.Value("ok").(bool) 61 | if !ok { 62 | w.WriteHeader(http.StatusInternalServerError) 63 | return 64 | } 65 | w.WriteHeader(http.StatusTeapot) 66 | }) 67 | mux.Post("/mux/method_not_allowed", func(ctx context.Context, w http.ResponseWriter, r *http.Request) { 68 | panic("test panic") 69 | }) 70 | 71 | stdMux := http.NewServeMux() 72 | stdMux.Handle("/mux/", mux) 73 | 74 | // test normal stuff 75 | resp := httptest.NewRecorder() 76 | req, err := http.NewRequest("GET", "/mux/v2/papers/3", nil) 77 | if err != nil { 78 | t.Fatal(err) 79 | } 80 | 81 | stdMux.ServeHTTP(resp, req) 82 | if resp.Code != http.StatusOK { 83 | t.Error("should return HTTP OK", resp.Code, "≠", http.StatusOK) 84 | } 85 | 86 | data, err := ioutil.ReadAll(resp.Body) 87 | if err != nil { 88 | panic(err) 89 | } 90 | 91 | if string(data) != "3" { 92 | t.Error("expected page 3, got", string(data)) 93 | } 94 | 95 | // test 404 96 | resp = httptest.NewRecorder() 97 | req, err = http.NewRequest("GET", "/mux/missing/hello", nil) 98 | if err != nil { 99 | t.Fatal(err) 100 | } 101 | 102 | stdMux.ServeHTTP(resp, req) 103 | if resp.Code != http.StatusTeapot { 104 | t.Error("should return HTTP Teapot", resp.Code, "≠", http.StatusTeapot) 105 | } 106 | 107 | // test 405 108 | resp = httptest.NewRecorder() 109 | req, err = http.NewRequest("GET", "/mux/method_not_allowed", nil) 110 | if err != nil { 111 | t.Fatal(err) 112 | } 113 | 114 | stdMux.ServeHTTP(resp, req) 115 | if resp.Code != http.StatusTeapot { 116 | t.Error("should return HTTP Teapot", resp.Code, "≠", http.StatusTeapot) 117 | } 118 | 119 | // test EnableMethodNotAllowed method 120 | resp = httptest.NewRecorder() 121 | req, err = http.NewRequest("GET", "/mux/method_not_allowed", nil) 122 | if err != nil { 123 | t.Fatal(err) 124 | } 125 | 126 | // Reset NotFound handler to receive default 404 instead of custom handler 418(Teapot) 127 | mux.NotFound(nil) 128 | 129 | mux.EnableMethodNotAllowed(false) 130 | stdMux.ServeHTTP(resp, req) 131 | if resp.Code != http.StatusNotFound { 132 | t.Error("should return HTTP NotFound", resp.Code, "≠", http.StatusNotFound) 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /params.go: -------------------------------------------------------------------------------- 1 | package kami 2 | 3 | import ( 4 | "golang.org/x/net/context" 5 | ) 6 | 7 | type paramsKey struct{} 8 | type panicKey struct{} 9 | 10 | // Param returns a request path parameter, or a blank string if it doesn't exist. 11 | // For example, with the path /v2/papers/:page 12 | // use kami.Param(ctx, "page") to access the :page variable. 13 | func Param(ctx context.Context, name string) string { 14 | params, ok := ctx.Value(paramsKey{}).(map[string]string) 15 | if !ok { 16 | return "" 17 | } 18 | return params[name] 19 | } 20 | 21 | // SetParam will set the value of a path parameter in a given context. 22 | // This is intended for testing and should not be used otherwise. 23 | func SetParam(ctx context.Context, name string, value string) context.Context { 24 | params, ok := ctx.Value(paramsKey{}).(map[string]string) 25 | if !ok { 26 | params = map[string]string{name: value} 27 | return context.WithValue(ctx, paramsKey{}, params) 28 | } 29 | params[name] = value 30 | return ctx 31 | } 32 | 33 | // Exception gets the "v" in panic(v). The panic details. 34 | // Only PanicHandler will receive a context you can use this with. 35 | func Exception(ctx context.Context) interface{} { 36 | return ctx.Value(panicKey{}) 37 | } 38 | 39 | func newContextWithParams(ctx context.Context, params map[string]string) context.Context { 40 | return context.WithValue(ctx, paramsKey{}, params) 41 | } 42 | 43 | func mergeParams(ctx context.Context, params map[string]string) context.Context { 44 | current, _ := ctx.Value(paramsKey{}).(map[string]string) 45 | if current == nil { 46 | return context.WithValue(ctx, paramsKey{}, params) 47 | } 48 | 49 | for k, v := range params { 50 | current[k] = v 51 | } 52 | return ctx 53 | } 54 | 55 | func newContextWithException(ctx context.Context, exception interface{}) context.Context { 56 | return context.WithValue(ctx, panicKey{}, exception) 57 | } 58 | -------------------------------------------------------------------------------- /params_test.go: -------------------------------------------------------------------------------- 1 | package kami_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/guregu/kami" 7 | "golang.org/x/net/context" 8 | ) 9 | 10 | func TestParams(t *testing.T) { 11 | ctx := context.Background() 12 | if result := kami.Param(ctx, "test"); result != "" { 13 | t.Error("expected blank, got", result) 14 | } 15 | ctx = kami.SetParam(ctx, "test", "abc") 16 | if result := kami.Param(ctx, "test"); result != "abc" { 17 | t.Error("expected abc, got", result) 18 | } 19 | ctx = kami.SetParam(ctx, "test", "overwritten") 20 | if result := kami.Param(ctx, "test"); result != "overwritten" { 21 | t.Error("expected overwritten, got", result) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /serve.go: -------------------------------------------------------------------------------- 1 | // +build !appengine 2 | 3 | package kami 4 | 5 | import ( 6 | "crypto/tls" 7 | "flag" 8 | "log" 9 | "net" 10 | "net/http" 11 | "time" 12 | 13 | "github.com/zenazn/goji/bind" 14 | "github.com/zenazn/goji/graceful" 15 | ) 16 | 17 | func init() { 18 | bind.WithFlag() 19 | graceful.DoubleKickWindow(2 * time.Second) 20 | } 21 | 22 | // Serve starts kami with reasonable defaults. 23 | // The bind address can be changed by setting the GOJI_BIND environment variable, or 24 | // by setting the "bind" command line flag. 25 | // Serve detects einhorn and systemd for you. 26 | // It works exactly like zenazn/goji. 27 | func Serve() { 28 | if !flag.Parsed() { 29 | flag.Parse() 30 | } 31 | 32 | serveListener(Handler(), bind.Default()) 33 | } 34 | 35 | // ServeTLS is like Serve, but enables TLS using the given config. 36 | func ServeTLS(config *tls.Config) { 37 | if !flag.Parsed() { 38 | flag.Parse() 39 | } 40 | 41 | serveListener(Handler(), tls.NewListener(bind.Default(), config)) 42 | } 43 | 44 | // ServeListener is like Serve, but runs kami on top of an arbitrary net.Listener. 45 | func ServeListener(listener net.Listener) { 46 | serveListener(Handler(), listener) 47 | } 48 | 49 | // Serve starts serving this mux with reasonable defaults. 50 | // The bind address can be changed by setting the GOJI_BIND environment variable, or 51 | // by setting the "--bind" command line flag. 52 | // Serve detects einhorn and systemd for you. 53 | // It works exactly like zenazn/goji. Only one mux may be served at a time. 54 | func (m *Mux) Serve() { 55 | if !flag.Parsed() { 56 | flag.Parse() 57 | } 58 | 59 | serveListener(m, bind.Default()) 60 | } 61 | 62 | // ServeTLS is like Serve, but enables TLS using the given config. 63 | func (m *Mux) ServeTLS(config *tls.Config) { 64 | if !flag.Parsed() { 65 | flag.Parse() 66 | } 67 | 68 | serveListener(m, tls.NewListener(bind.Default(), config)) 69 | } 70 | 71 | // ServeListener is like Serve, but runs kami on top of an arbitrary net.Listener. 72 | func (m *Mux) ServeListener(listener net.Listener) { 73 | serveListener(m, listener) 74 | } 75 | 76 | // ServeListener is like Serve, but runs kami on top of an arbitrary net.Listener. 77 | func serveListener(h http.Handler, listener net.Listener) { 78 | // Install our handler at the root of the standard net/http default mux. 79 | // This allows packages like expvar to continue working as expected. 80 | http.Handle("/", h) 81 | 82 | log.Println("Starting kami on", listener.Addr()) 83 | 84 | graceful.HandleSignals() 85 | bind.Ready() 86 | graceful.PreHook(func() { log.Printf("kami received signal, gracefully stopping") }) 87 | graceful.PostHook(func() { log.Printf("kami stopped") }) 88 | 89 | err := graceful.Serve(listener, http.DefaultServeMux) 90 | 91 | if err != nil { 92 | log.Fatal(err) 93 | } 94 | 95 | graceful.Wait() 96 | } 97 | -------------------------------------------------------------------------------- /serve_appengine.go: -------------------------------------------------------------------------------- 1 | // +build appengine 2 | 3 | package kami 4 | 5 | import ( 6 | "net/http" 7 | ) 8 | 9 | // Serve starts kami with reasonable defaults. 10 | func Serve() { 11 | http.Handle("/", Handler()) 12 | } 13 | -------------------------------------------------------------------------------- /treemux/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Daniel Imfeld 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /treemux/README.md: -------------------------------------------------------------------------------- 1 | treemux [![GoDoc](http://godoc.org/github.com/dimfeld/httptreemux?status.png)](http://godoc.org/github.com/guregu/kami/treemux) 2 | =========== 3 | 4 | Generic router ripped from [httptreemux](https://github.com/dimfeld/httptreemux). 5 | -------------------------------------------------------------------------------- /treemux/router.go: -------------------------------------------------------------------------------- 1 | // Package treemux is a generic treemux ripped from httptreemux. 2 | package treemux 3 | 4 | import ( 5 | "fmt" 6 | ) 7 | 8 | type TreeMux struct { 9 | root *node 10 | } 11 | 12 | func (t *TreeMux) Dump() string { 13 | return t.root.dumpTree("", "") 14 | } 15 | 16 | func (t *TreeMux) Set(path string, v interface{}) { 17 | if path[0] != '/' { 18 | panic(fmt.Sprintf("Path %s must start with slash", path)) 19 | } 20 | 21 | node := t.root.addPath(path[1:], nil) 22 | node.setValue(v) 23 | } 24 | 25 | func (t *TreeMux) Get(path string) (interface{}, map[string]string) { 26 | n, params := t.root.search(path[1:]) 27 | if n == nil { 28 | return nil, nil 29 | } 30 | 31 | var paramMap map[string]string 32 | if len(params) != 0 { 33 | if len(params) != len(n.leafWildcardNames) { 34 | // Need better behavior here. Should this be a panic? 35 | panic(fmt.Sprintf("treemux parameter list length mismatch: %v, %v", 36 | params, n.leafWildcardNames)) 37 | } 38 | 39 | paramMap = make(map[string]string) 40 | numParams := len(params) 41 | for index := 0; index < numParams; index++ { 42 | paramMap[n.leafWildcardNames[numParams-index-1]] = params[index] 43 | } 44 | } 45 | 46 | return n.leafValue, paramMap 47 | } 48 | 49 | func New() *TreeMux { 50 | root := &node{path: "/"} 51 | return &TreeMux{ 52 | root: root, 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /treemux/tree.go: -------------------------------------------------------------------------------- 1 | package treemux 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | "strings" 7 | ) 8 | 9 | type node struct { 10 | path string 11 | 12 | priority int 13 | 14 | // The list of static children to check. 15 | staticIndices []byte 16 | staticChild []*node 17 | 18 | // If none of the above match, check the wildcard children 19 | wildcardChild *node 20 | 21 | // If none of the above match, then we use the catch-all, if applicable. 22 | catchAllChild *node 23 | 24 | // Data for the node is below. 25 | 26 | addSlash bool 27 | isCatchAll bool 28 | // If this node is the end of the URL, then call the handler, if applicable. 29 | leafValue interface{} 30 | 31 | // The names of the parameters to apply. 32 | leafWildcardNames []string 33 | } 34 | 35 | func (n *node) sortStaticChild(i int) { 36 | for i > 0 && n.staticChild[i].priority > n.staticChild[i-1].priority { 37 | n.staticChild[i], n.staticChild[i-1] = n.staticChild[i-1], n.staticChild[i] 38 | n.staticIndices[i], n.staticIndices[i-1] = n.staticIndices[i-1], n.staticIndices[i] 39 | i -= 1 40 | } 41 | } 42 | 43 | func (n *node) setValue(v interface{}) { 44 | if n.leafValue != nil { 45 | panic(fmt.Errorf("treemux: duplicate value for path %s", n.path)) 46 | } 47 | n.leafValue = v 48 | } 49 | 50 | func (n *node) addPath(path string, wildcards []string) *node { 51 | leaf := len(path) == 0 52 | if leaf { 53 | if wildcards != nil { 54 | // Make sure the current wildcards are the same as the old ones. 55 | // If not then we have an ambiguous path. 56 | if n.leafWildcardNames != nil { 57 | if len(n.leafWildcardNames) != len(wildcards) { 58 | // This should never happen. 59 | panic("Reached leaf node with differing wildcard array length. Please report this as a bug.") 60 | } 61 | 62 | for i := 0; i < len(wildcards); i++ { 63 | if n.leafWildcardNames[i] != wildcards[i] { 64 | panic(fmt.Sprintf("Wildcards %v are ambiguous with wildcards %v", 65 | n.leafWildcardNames, wildcards)) 66 | } 67 | } 68 | } else { 69 | // No wildcards yet, so just add the existing set. 70 | n.leafWildcardNames = wildcards 71 | } 72 | } 73 | 74 | return n 75 | } 76 | 77 | c := path[0] 78 | nextSlash := strings.Index(path, "/") 79 | var thisToken string 80 | var tokenEnd int 81 | 82 | if c == '/' { 83 | thisToken = "/" 84 | tokenEnd = 1 85 | } else if nextSlash == -1 { 86 | thisToken = path 87 | tokenEnd = len(path) 88 | } else { 89 | thisToken = path[0:nextSlash] 90 | tokenEnd = nextSlash 91 | } 92 | remainingPath := path[tokenEnd:] 93 | 94 | if c == '*' { 95 | // Token starts with a *, so it's a catch-all 96 | thisToken = thisToken[1:] 97 | if n.catchAllChild == nil { 98 | n.catchAllChild = &node{path: thisToken, isCatchAll: true} 99 | } 100 | 101 | if path[1:] != n.catchAllChild.path { 102 | panic(fmt.Sprintf("Catch-all name in %s doesn't match %s", 103 | path, n.catchAllChild.path)) 104 | } 105 | 106 | if nextSlash != -1 { 107 | panic("/ after catch-all found in " + path) 108 | } 109 | 110 | if wildcards == nil { 111 | wildcards = []string{thisToken} 112 | } else { 113 | wildcards = append(wildcards, thisToken) 114 | } 115 | n.catchAllChild.leafWildcardNames = wildcards 116 | 117 | return n.catchAllChild 118 | } else if c == ':' { 119 | // Token starts with a : 120 | thisToken = thisToken[1:] 121 | 122 | if wildcards == nil { 123 | wildcards = []string{thisToken} 124 | } else { 125 | wildcards = append(wildcards, thisToken) 126 | } 127 | 128 | if n.wildcardChild == nil { 129 | n.wildcardChild = &node{path: "wildcard"} 130 | } 131 | 132 | return n.wildcardChild.addPath(remainingPath, wildcards) 133 | 134 | } else { 135 | if strings.ContainsAny(thisToken, ":*") { 136 | panic("* or : in middle of path component " + path) 137 | } 138 | 139 | // Do we have an existing node that starts with the same letter? 140 | for i, index := range n.staticIndices { 141 | if c == index { 142 | // Yes. Split it based on the common prefix of the existing 143 | // node and the new one. 144 | child, prefixSplit := n.splitCommonPrefix(i, thisToken) 145 | child.priority++ 146 | n.sortStaticChild(i) 147 | return child.addPath(path[prefixSplit:], wildcards) 148 | } 149 | } 150 | 151 | // No existing node starting with this letter, so create it. 152 | child := &node{path: thisToken} 153 | 154 | if n.staticIndices == nil { 155 | n.staticIndices = []byte{c} 156 | n.staticChild = []*node{child} 157 | } else { 158 | n.staticIndices = append(n.staticIndices, c) 159 | n.staticChild = append(n.staticChild, child) 160 | } 161 | return child.addPath(remainingPath, wildcards) 162 | } 163 | } 164 | 165 | func (n *node) splitCommonPrefix(existingNodeIndex int, path string) (*node, int) { 166 | childNode := n.staticChild[existingNodeIndex] 167 | 168 | if strings.HasPrefix(path, childNode.path) { 169 | // No split needs to be done. Rather, the new path shares the entire 170 | // prefix with the existing node, so the new node is just a child of 171 | // the existing one. Or the new path is the same as the existing path, 172 | // which means that we just move on to the next token. Either way, 173 | // this return accomplishes that 174 | return childNode, len(childNode.path) 175 | } 176 | 177 | var i int 178 | // Find the length of the common prefix of the child node and the new path. 179 | for i = range childNode.path { 180 | if i == len(path) { 181 | break 182 | } 183 | if path[i] != childNode.path[i] { 184 | break 185 | } 186 | } 187 | 188 | commonPrefix := path[0:i] 189 | childNode.path = childNode.path[i:] 190 | 191 | // Create a new intermediary node in the place of the existing node, with 192 | // the existing node as a child. 193 | newNode := &node{ 194 | path: commonPrefix, 195 | priority: childNode.priority, 196 | // Index is the first letter of the non-common part of the path. 197 | staticIndices: []byte{childNode.path[0]}, 198 | staticChild: []*node{childNode}, 199 | } 200 | n.staticChild[existingNodeIndex] = newNode 201 | 202 | return newNode, i 203 | } 204 | 205 | func (n *node) search(path string) (found *node, params []string) { 206 | // if test != nil { 207 | // test.Logf("Searching for %s in %s", path, n.dumpTree("", "")) 208 | // } 209 | pathLen := len(path) 210 | if pathLen == 0 { 211 | if n.leafValue == nil { 212 | return nil, nil 213 | } else { 214 | return n, nil 215 | } 216 | } 217 | 218 | // First see if this matches a static token. 219 | firstChar := path[0] 220 | for i, staticIndex := range n.staticIndices { 221 | if staticIndex == firstChar { 222 | child := n.staticChild[i] 223 | childPathLen := len(child.path) 224 | if pathLen >= childPathLen && child.path == path[:childPathLen] { 225 | nextPath := path[childPathLen:] 226 | found, params = child.search(nextPath) 227 | } 228 | break 229 | } 230 | } 231 | 232 | if found != nil { 233 | return 234 | } 235 | 236 | if n.wildcardChild != nil { 237 | // Didn't find a static token, so check for a wildcard. 238 | nextSlash := 0 239 | for nextSlash < pathLen && path[nextSlash] != '/' { 240 | nextSlash++ 241 | } 242 | 243 | thisToken := path[0:nextSlash] 244 | nextToken := path[nextSlash:] 245 | 246 | if len(thisToken) > 0 { // Don't match on empty tokens. 247 | found, params = n.wildcardChild.search(nextToken) 248 | if found != nil { 249 | unescaped, err := url.QueryUnescape(thisToken) 250 | if err != nil { 251 | unescaped = thisToken 252 | } 253 | 254 | if params == nil { 255 | params = []string{unescaped} 256 | } else { 257 | params = append(params, unescaped) 258 | } 259 | 260 | return 261 | } 262 | } 263 | } 264 | 265 | catchAllChild := n.catchAllChild 266 | if catchAllChild != nil { 267 | // Hit the catchall, so just assign the whole remaining path. 268 | unescaped, err := url.QueryUnescape(path) 269 | if err != nil { 270 | unescaped = path 271 | } 272 | 273 | return catchAllChild, []string{unescaped} 274 | } 275 | 276 | return nil, nil 277 | } 278 | 279 | func (n *node) dumpTree(prefix, nodeType string) string { 280 | line := fmt.Sprintf("%s %02d %s%s [%d] %v wildcards %v\n", prefix, n.priority, nodeType, n.path, 281 | len(n.staticChild), n.leafValue, n.leafWildcardNames) 282 | prefix += " " 283 | for _, node := range n.staticChild { 284 | line += node.dumpTree(prefix, "") 285 | } 286 | if n.wildcardChild != nil { 287 | line += n.wildcardChild.dumpTree(prefix, ":") 288 | } 289 | if n.catchAllChild != nil { 290 | line += n.catchAllChild.dumpTree(prefix, "*") 291 | } 292 | return line 293 | } 294 | -------------------------------------------------------------------------------- /treemux/tree_test.go: -------------------------------------------------------------------------------- 1 | package treemux 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func dummyHandler(w http.ResponseWriter, r *http.Request, urlParams map[string]string) { 10 | 11 | } 12 | 13 | func addPath(t *testing.T, tree *node, path string) { 14 | t.Logf("Adding path %s", path) 15 | n := tree.addPath(path[1:], nil) 16 | handler := func(w http.ResponseWriter, r *http.Request, urlParams map[string]string) { 17 | urlParams["path"] = path 18 | } 19 | n.setValue(handler) 20 | } 21 | 22 | var test *testing.T 23 | 24 | func testPath(t *testing.T, tree *node, path string, expectPath string, expectedParams map[string]string) { 25 | if t.Failed() { 26 | t.Log(tree.dumpTree("", " ")) 27 | t.FailNow() 28 | } 29 | 30 | expectCatchAll := strings.Contains(expectPath, "/*") 31 | 32 | t.Log("Testing", path) 33 | n, paramList := tree.search(path[1:]) 34 | if expectPath != "" && n == nil { 35 | t.Errorf("No match for %s, expected %s", path, expectPath) 36 | return 37 | } else if expectPath == "" && n != nil { 38 | t.Errorf("Expected no match for %s but got %v with params %v", path, n, expectedParams) 39 | t.Error("Node and subtree was\n" + n.dumpTree("", " ")) 40 | return 41 | } 42 | 43 | if n == nil { 44 | return 45 | } 46 | 47 | if expectCatchAll != n.isCatchAll { 48 | t.Errorf("For path %s expectCatchAll %v but saw %v", path, expectCatchAll, n.isCatchAll) 49 | } 50 | 51 | handler, ok := n.leafValue.(func(http.ResponseWriter, *http.Request, map[string]string)) 52 | if !ok { 53 | t.Errorf("Path %s returned node without handler", path) 54 | t.Error("Node and subtree was\n" + n.dumpTree("", " ")) 55 | return 56 | } 57 | 58 | pathMap := make(map[string]string) 59 | handler(nil, nil, pathMap) 60 | matchedPath := pathMap["path"] 61 | 62 | if matchedPath != expectPath { 63 | t.Errorf("Path %s matched %s, expected %s", path, matchedPath, expectPath) 64 | t.Error("Node and subtree was\n" + n.dumpTree("", " ")) 65 | } 66 | 67 | if expectedParams == nil { 68 | if len(paramList) != 0 { 69 | t.Errorf("Path %p expected no parameters, saw %v", path, paramList) 70 | } 71 | } else { 72 | if len(paramList) != len(n.leafWildcardNames) { 73 | t.Errorf("Got %d params back but node specifies %d", 74 | len(paramList), len(n.leafWildcardNames)) 75 | } 76 | 77 | params := map[string]string{} 78 | for i := 0; i < len(paramList); i++ { 79 | params[n.leafWildcardNames[len(paramList)-i-1]] = paramList[i] 80 | } 81 | t.Log("\tGot params", params) 82 | 83 | for key, val := range expectedParams { 84 | sawVal, ok := params[key] 85 | if !ok { 86 | t.Errorf("Path %s matched without key %s", path, key) 87 | } else if sawVal != val { 88 | t.Errorf("Path %s expected param %s to be %s, saw %s", path, key, val, sawVal) 89 | } 90 | 91 | delete(params, key) 92 | } 93 | 94 | for key, val := range params { 95 | t.Errorf("Path %s returned unexpected param %s=%s", path, key, val) 96 | } 97 | } 98 | 99 | } 100 | 101 | func checkHandlerNodes(t *testing.T, n *node) { 102 | hasHandlers := n.leafValue != nil 103 | hasWildcards := len(n.leafWildcardNames) != 0 104 | 105 | if hasWildcards && !hasHandlers { 106 | t.Errorf("Node %s has wildcards without handlers", n.path) 107 | } 108 | } 109 | 110 | func TestTree(t *testing.T) { 111 | test = t 112 | tree := &node{path: "/"} 113 | 114 | addPath(t, tree, "/") 115 | addPath(t, tree, "/i") 116 | addPath(t, tree, "/i/:aaa") 117 | addPath(t, tree, "/images") 118 | addPath(t, tree, "/images/abc.jpg") 119 | addPath(t, tree, "/images/:imgname") 120 | addPath(t, tree, "/images/*path") 121 | addPath(t, tree, "/ima") 122 | addPath(t, tree, "/ima/:par") 123 | addPath(t, tree, "/images1") 124 | addPath(t, tree, "/images2") 125 | addPath(t, tree, "/apples") 126 | addPath(t, tree, "/app/les") 127 | addPath(t, tree, "/apples1") 128 | addPath(t, tree, "/appeasement") 129 | addPath(t, tree, "/appealing") 130 | addPath(t, tree, "/date/:year/:month") 131 | addPath(t, tree, "/date/:year/month") 132 | addPath(t, tree, "/date/:year/:month/abc") 133 | addPath(t, tree, "/date/:year/:month/:post") 134 | addPath(t, tree, "/date/:year/:month/*post") 135 | addPath(t, tree, "/:page") 136 | addPath(t, tree, "/:page/:index") 137 | addPath(t, tree, "/post/:post/page/:page") 138 | addPath(t, tree, "/plaster") 139 | addPath(t, tree, "/users/:pk/:related") 140 | addPath(t, tree, "/users/:id/updatePassword") 141 | addPath(t, tree, "/:something/abc") 142 | addPath(t, tree, "/:something/def") 143 | 144 | testPath(t, tree, "/users/abc/updatePassword", "/users/:id/updatePassword", 145 | map[string]string{"id": "abc"}) 146 | testPath(t, tree, "/users/all/something", "/users/:pk/:related", 147 | map[string]string{"pk": "all", "related": "something"}) 148 | 149 | testPath(t, tree, "/aaa/abc", "/:something/abc", 150 | map[string]string{"something": "aaa"}) 151 | testPath(t, tree, "/aaa/def", "/:something/def", 152 | map[string]string{"something": "aaa"}) 153 | 154 | testPath(t, tree, "/paper", "/:page", 155 | map[string]string{"page": "paper"}) 156 | 157 | testPath(t, tree, "/", "/", nil) 158 | testPath(t, tree, "/i", "/i", nil) 159 | testPath(t, tree, "/images", "/images", nil) 160 | testPath(t, tree, "/images/abc.jpg", "/images/abc.jpg", nil) 161 | testPath(t, tree, "/images/something", "/images/:imgname", 162 | map[string]string{"imgname": "something"}) 163 | testPath(t, tree, "/images/long/path", "/images/*path", 164 | map[string]string{"path": "long/path"}) 165 | testPath(t, tree, "/images/even/longer/path", "/images/*path", 166 | map[string]string{"path": "even/longer/path"}) 167 | testPath(t, tree, "/ima", "/ima", nil) 168 | testPath(t, tree, "/apples", "/apples", nil) 169 | testPath(t, tree, "/app/les", "/app/les", nil) 170 | testPath(t, tree, "/abc", "/:page", 171 | map[string]string{"page": "abc"}) 172 | testPath(t, tree, "/abc/100", "/:page/:index", 173 | map[string]string{"page": "abc", "index": "100"}) 174 | testPath(t, tree, "/post/a/page/2", "/post/:post/page/:page", 175 | map[string]string{"post": "a", "page": "2"}) 176 | testPath(t, tree, "/date/2014/5", "/date/:year/:month", 177 | map[string]string{"year": "2014", "month": "5"}) 178 | testPath(t, tree, "/date/2014/month", "/date/:year/month", 179 | map[string]string{"year": "2014"}) 180 | testPath(t, tree, "/date/2014/5/abc", "/date/:year/:month/abc", 181 | map[string]string{"year": "2014", "month": "5"}) 182 | testPath(t, tree, "/date/2014/5/def", "/date/:year/:month/:post", 183 | map[string]string{"year": "2014", "month": "5", "post": "def"}) 184 | testPath(t, tree, "/date/2014/5/def/hij", "/date/:year/:month/*post", 185 | map[string]string{"year": "2014", "month": "5", "post": "def/hij"}) 186 | testPath(t, tree, "/date/2014/5/def/hij/", "/date/:year/:month/*post", 187 | map[string]string{"year": "2014", "month": "5", "post": "def/hij/"}) 188 | 189 | testPath(t, tree, "/date/2014/ab%2f", "/date/:year/:month", 190 | map[string]string{"year": "2014", "month": "ab/"}) 191 | testPath(t, tree, "/post/ab%2fdef/page/2%2f", "/post/:post/page/:page", 192 | map[string]string{"post": "ab/def", "page": "2/"}) 193 | 194 | testPath(t, tree, "/ima/bcd/fgh", "", nil) 195 | testPath(t, tree, "/date/2014//month", "", nil) 196 | testPath(t, tree, "/date/2014/05/", "", nil) // Empty catchall should not match 197 | testPath(t, tree, "/post//abc/page/2", "", nil) 198 | testPath(t, tree, "/post/abc//page/2", "", nil) 199 | testPath(t, tree, "/post/abc/page//2", "", nil) 200 | testPath(t, tree, "//post/abc/page/2", "", nil) 201 | testPath(t, tree, "//post//abc//page//2", "", nil) 202 | 203 | t.Log("Test retrieval of duplicate paths") 204 | params := make(map[string]string) 205 | p := "date/:year/:month/abc" 206 | n := tree.addPath(p, nil) 207 | if n == nil { 208 | t.Errorf("Duplicate add of %s didn't return a node", p) 209 | } else { 210 | handler, ok := n.leafValue.(func(http.ResponseWriter, *http.Request, map[string]string)) 211 | matchPath := "" 212 | if ok { 213 | handler(nil, nil, params) 214 | matchPath = params["path"] 215 | } 216 | 217 | if len(matchPath) < 2 || matchPath[1:] != p { 218 | t.Errorf("Duplicate add of %s returned node for %s\n%s", p, matchPath, 219 | n.dumpTree("", " ")) 220 | 221 | } 222 | } 223 | 224 | checkHandlerNodes(t, tree) 225 | 226 | t.Log(tree.dumpTree("", " ")) 227 | test = nil 228 | } 229 | 230 | func TestPanics(t *testing.T) { 231 | sawPanic := false 232 | 233 | panicHandler := func() { 234 | if err := recover(); err != nil { 235 | sawPanic = true 236 | } 237 | } 238 | 239 | addPathPanic := func(p ...string) { 240 | sawPanic = false 241 | defer panicHandler() 242 | tree := &node{path: "/"} 243 | for _, path := range p { 244 | tree.addPath(path, nil) 245 | } 246 | } 247 | 248 | addPathPanic("abc/*path/") 249 | if !sawPanic { 250 | t.Error("Expected panic with slash after catch-all") 251 | } 252 | 253 | addPathPanic("abc/*path/def") 254 | if !sawPanic { 255 | t.Error("Expected panic with path segment after catch-all") 256 | } 257 | 258 | addPathPanic("abc/*path", "abc/*paths") 259 | if !sawPanic { 260 | t.Error("Expected panic when adding conflicting catch-alls") 261 | } 262 | 263 | func() { 264 | sawPanic = false 265 | defer panicHandler() 266 | tree := &node{path: "/"} 267 | tree.setValue(dummyHandler) 268 | tree.setValue(dummyHandler) 269 | }() 270 | if !sawPanic { 271 | t.Error("Expected panic when adding a duplicate handler for a pattern") 272 | } 273 | 274 | addPathPanic("abc/ab:cd") 275 | if !sawPanic { 276 | t.Error("Expected panic with : in middle of path segment") 277 | } 278 | 279 | addPathPanic("abc/ab", "abc/ab:cd") 280 | if !sawPanic { 281 | t.Error("Expected panic with : in middle of path segment with existing path") 282 | } 283 | 284 | addPathPanic("abc/ab*cd") 285 | if !sawPanic { 286 | t.Error("Expected panic with * in middle of path segment") 287 | } 288 | 289 | addPathPanic("abc/ab", "abc/ab*cd") 290 | if !sawPanic { 291 | t.Error("Expected panic with * in middle of path segment with existing path") 292 | } 293 | 294 | twoPathPanic := func(first, second string) { 295 | addPathPanic(first, second) 296 | if !sawPanic { 297 | t.Errorf("Expected panic with ambiguous wildcards on paths %s and %s", first, second) 298 | } 299 | } 300 | 301 | twoPathPanic("abc/:ab/def/:cd", "abc/:ad/def/:cd") 302 | twoPathPanic("abc/:ab/def/:cd", "abc/:ab/def/:ef") 303 | twoPathPanic(":abc", ":def") 304 | twoPathPanic(":abc/ggg", ":def/ggg") 305 | } 306 | 307 | func BenchmarkTreeNullRequest(b *testing.B) { 308 | b.ReportAllocs() 309 | tree := &node{path: "/"} 310 | 311 | b.ResetTimer() 312 | for i := 0; i < b.N; i++ { 313 | tree.search("") 314 | } 315 | } 316 | 317 | func BenchmarkTreeOneStatic(b *testing.B) { 318 | b.ReportAllocs() 319 | tree := &node{path: "/"} 320 | tree.addPath("abc", nil) 321 | 322 | b.ResetTimer() 323 | for i := 0; i < b.N; i++ { 324 | tree.search("abc") 325 | } 326 | } 327 | 328 | func BenchmarkTreeOneParam(b *testing.B) { 329 | b.ReportAllocs() 330 | tree := &node{path: "/"} 331 | tree.addPath(":abc", nil) 332 | 333 | b.ResetTimer() 334 | for i := 0; i < b.N; i++ { 335 | tree.search("abc") 336 | } 337 | } 338 | --------------------------------------------------------------------------------