├── .gitignore ├── .travis.yml ├── examples ├── 2-advanced │ ├── oauth2-client-id.gif │ ├── config.json │ ├── static │ │ ├── error.tpl.html │ │ ├── landing-page.tpl.html │ │ ├── home.tpl.html │ │ └── main.js │ ├── README.md │ └── main.go ├── 0-helloworld │ └── main.go └── 1-simple │ ├── README.md │ └── main.go ├── go.mod ├── fs_test.go ├── fs.go ├── LICENSE ├── chain ├── benchmark_reflect_call_test.go ├── codegen_test.go ├── naming_test.go ├── ordinal.go ├── util.go ├── example_test.go ├── codegen.go ├── naming.go ├── chain_test.go └── chain.go ├── gzip_test.go ├── response_writer.go ├── go.sum ├── gzip.go ├── example_test.go ├── wrap.go ├── errors.go ├── logger.go ├── benchmark_simple_handler_test.go ├── logger_test.go ├── doc.go ├── router_test.go ├── README.md └── router.go /.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.19 4 | - master 5 | notifications: 6 | email: 7 | on_failure: change 8 | -------------------------------------------------------------------------------- /examples/2-advanced/oauth2-client-id.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/augustoroman/sandwich/HEAD/examples/2-advanced/oauth2-client-id.gif -------------------------------------------------------------------------------- /examples/2-advanced/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "host": "localhost:5000", 3 | "port": 5000, 4 | "cookie-secret": "123456789012345678901234", 5 | "oauth2-client-id": "see README.md", 6 | "oauth2-client-secret": "see README.md" 7 | } 8 | -------------------------------------------------------------------------------- /examples/2-advanced/static/error.tpl.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Failed 5 | 6 | 7 | OMG, something went horribly wrong.

8 | If this were real, this would be a beautifully-styled fail page.

9 | Your error is: 10 |

({{.Error.Code}}) - {{.Error.ClientMsg}}
11 | 12 | 13 | -------------------------------------------------------------------------------- /examples/0-helloworld/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net/http" 7 | 8 | "github.com/augustoroman/sandwich" 9 | ) 10 | 11 | func main() { 12 | mux := sandwich.TheUsual() 13 | mux.Get("/", func(w http.ResponseWriter) { 14 | fmt.Fprintf(w, "Hello world!") 15 | }) 16 | if err := http.ListenAndServe(":8080", mux); err != nil { 17 | log.Fatal(err) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /examples/1-simple/README.md: -------------------------------------------------------------------------------- 1 | # Basic sandwich usage 2 | 3 | This example demonstrates most of the basic features of sandwich, including: 4 | 5 | * Providing user types to the middleware chain 6 | * Adding middleware handlers to the stack. 7 | * Writing handlers that provide request-scoped values. 8 | * Writing handlers using injected values. 9 | * Using the default sandwich logging system. 10 | * Using the default sandwich error system. 11 | 12 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/augustoroman/sandwich 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/bradrydzewski/go.auth v0.0.0-20130828171325-d0051b5cc538 7 | github.com/stretchr/testify v1.7.0 8 | ) 9 | 10 | require ( 11 | github.com/davecgh/go-spew v1.1.1 // indirect 12 | github.com/dchest/authcookie v0.0.0-20190824115100-f900d2294c8e // indirect 13 | github.com/pmezard/go-difflib v1.0.0 // indirect 14 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /fs_test.go: -------------------------------------------------------------------------------- 1 | package sandwich 2 | 3 | import ( 4 | "embed" 5 | "io/fs" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | //go:embed examples 14 | var examples embed.FS 15 | 16 | func TestServeFS(t *testing.T) { 17 | helloworld := ServeFS(examples, "examples/0-helloworld", "path") 18 | w := httptest.NewRecorder() 19 | helloworld(w, httptest.NewRequest("", "/foo", nil), Params{"path": "main.go"}) 20 | contents, err := fs.ReadFile(examples, "examples/0-helloworld/main.go") 21 | require.NoError(t, err) 22 | assert.Equal(t, string(contents), w.Body.String()) 23 | } 24 | -------------------------------------------------------------------------------- /fs.go: -------------------------------------------------------------------------------- 1 | package sandwich 2 | 3 | import ( 4 | "io/fs" 5 | "net/http" 6 | ) 7 | 8 | // ServeFS is a simple helper that will serve static files from an fs.FS 9 | // filesystem. It allows serving files identified by a sandwich path parameter 10 | // out of a subdirectory of the filesystem. This is especially useful when 11 | // embedding static files: 12 | // 13 | // //go:embed server_files 14 | // var all_files embed.FS 15 | // 16 | // mux.Get("/css/:path*", sandwich.ServeFS(all_files, "static/css", "path")) 17 | // mux.Get("/js/:path*", sandwich.ServeFS(all_files, "dist/js", "path")) 18 | // mux.Get("/i/:path*", sandwich.ServeFS(all_files, "static/images", "path")) 19 | func ServeFS( 20 | f fs.FS, 21 | fsRoot string, 22 | pathParam string, 23 | ) func(w http.ResponseWriter, r *http.Request, p Params) { 24 | sub, err := fs.Sub(f, fsRoot) 25 | if err != nil { 26 | panic(err) 27 | } 28 | handler := http.FileServer(http.FS(sub)) 29 | return func(w http.ResponseWriter, r *http.Request, p Params) { 30 | r.URL.Path = p[pathParam] 31 | handler.ServeHTTP(w, r) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /examples/2-advanced/static/landing-page.tpl.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Welcome to a sample todo app 5 | 13 | 14 | 15 | Please sign in to start making your todo list: 16 | . 29 | 30 | 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Augusto Roman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /chain/benchmark_reflect_call_test.go: -------------------------------------------------------------------------------- 1 | package chain 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func twoArg(a, b string) string { 9 | return a 10 | } 11 | 12 | func manyArg(a, b, c, d, e, f, g string) (s, t, u, v, w, y, z string) { 13 | return "s", "t", "u", "v", "w", "y", "z" 14 | } 15 | 16 | func BenchmarkReflectCall_TwoArgs(b *testing.B) { 17 | arg1, arg2 := reflect.ValueOf("foo"), reflect.ValueOf("bar") 18 | fn := reflect.ValueOf(twoArg) 19 | args := []reflect.Value{arg1, arg2} 20 | for i := 0; i < b.N; i++ { 21 | fn.Call(args) 22 | } 23 | } 24 | func BenchmarkDirectCall_TwoArgs(b *testing.B) { 25 | arg1, arg2 := "foo", "bar" 26 | for i := 0; i < b.N; i++ { 27 | twoArg(arg1, arg2) 28 | } 29 | } 30 | 31 | func BenchmarkReflectCall_ManyArgs(b *testing.B) { 32 | arg := reflect.ValueOf("arg") 33 | fn := reflect.ValueOf(manyArg) 34 | args := []reflect.Value{arg, arg, arg, arg, arg, arg, arg} 35 | for i := 0; i < b.N; i++ { 36 | fn.Call(args) 37 | } 38 | } 39 | func BenchmarkDirectCall_ManyArgs(bb *testing.B) { 40 | a, b, c, d, e, f, g := "a", "b", "c", "d", "e", "f", "g" 41 | for i := 0; i < bb.N; i++ { 42 | manyArg(a, b, c, d, e, f, g) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /examples/2-advanced/static/home.tpl.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | TODO List 5 | 19 | 20 | 21 | Hello {{.User.Name}} 22 |
23 | [ 24 | Normal handlers | 25 | Sign out 26 | ] 27 |

28 | 29 | 30 | Here's your TODO list: 31 | 36 |
37 | Add a new task: 38 |
39 | Description: 40 | 41 |
42 | 43 |
44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /examples/2-advanced/README.md: -------------------------------------------------------------------------------- 1 | # Advanced sandwich usage: Simple TODO app. 2 | 3 | This example demonstrates more advanced features of sandwich, including: 4 | 5 | * Providing interface types to the middleware chain.
6 | TaskDb is the interface provided to the handlers, the actual value injected 7 | in main() is a taskDbImpl. 8 | * Using 3rd party middleware (go.auth, go.rice) 9 | * Using a 3rd party router (gorilla/mux) 10 | * Using multiple error handlers, and custom error handlers.
11 | Most web servers will want to server a custom HTML error page for user-facing 12 | error pages. An example of that is included here. For AJAX calls, however, 13 | we don't want to serve HTML. Instead, we always respond with JSON using the sandwich.. With 14 | sandwich, we the errors returned from handlers are agnostic. Instead, the 15 | error handler decides what format to respond in. 16 | * Early exit of the middleware chain via the sandwich.Done error
17 | See `CheckForFakeLogin()` for usage. 18 | 19 | ## Google Login (Oauth2 Authentication) 20 | 21 | In order to use the Google login, you need an Oauth2 client ID & client secret. 22 | See this animation for an example of getting these values: 23 | 24 | ![oauth2 client id demo](oauth2-client-id.gif) 25 | 26 | More documentation is available at https://developers.google.com/identity/protocols/OAuth2WebServer 27 | -------------------------------------------------------------------------------- /chain/codegen_test.go: -------------------------------------------------------------------------------- 1 | package chain 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "regexp" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func normalizeWhitespace(s string) string { 12 | return regexp.MustCompile(`\s+`).ReplaceAllLiteralString(strings.TrimSpace(s), " ") 13 | } 14 | 15 | type TestDb struct{} 16 | 17 | func (t *TestDb) Validate(s string) {} 18 | 19 | func TestCodeGen(t *testing.T) { 20 | var buf bytes.Buffer 21 | 22 | type User struct{} 23 | type Http struct{} 24 | 25 | New(). 26 | Arg((*http.ResponseWriter)(nil)). 27 | Set(""). 28 | Set(int64(0)). 29 | Set(int(1)). 30 | Set((*User)(nil)). 31 | Set(User{}). 32 | Set(Http{}). 33 | Set(&TestDb{}). 34 | Then((*TestDb).Validate). 35 | Then(a, b, c). 36 | Code("foo", "chain", &buf) 37 | 38 | const expected = `func foo( 39 | str string, 40 | i64 int64, 41 | i int, 42 | pUser *User, 43 | user User, 44 | chain_Http Http, 45 | pTestDb *TestDb, 46 | ) func( 47 | rw http.ResponseWriter, 48 | ) { 49 | return func( 50 | rw http.ResponseWriter, 51 | ) { 52 | (*TestDb).Validate(pTestDb, str) 53 | 54 | str = a() 55 | 56 | str, i = b(str) 57 | 58 | c(str, i) 59 | 60 | } 61 | }` 62 | if normalizeWhitespace(buf.String()) != normalizeWhitespace(expected) { 63 | t.Errorf("Wrong code generated: %s\nExp: %q\nGot: %q", buf.String(), 64 | normalizeWhitespace(expected), normalizeWhitespace(buf.String())) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /gzip_test.go: -------------------------------------------------------------------------------- 1 | package sandwich 2 | 3 | import ( 4 | "compress/gzip" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | ) 11 | 12 | func TestGzip(t *testing.T) { 13 | greet := func(w http.ResponseWriter, r *http.Request) { 14 | fmt.Fprintf(w, "Hi there!") 15 | } 16 | handler := BuildYourOwn() 17 | handler.Use(Gzip) 18 | handler.Any("/", greet) 19 | 20 | resp := httptest.NewRecorder() 21 | req, _ := http.NewRequest("GET", "/", nil) 22 | req.Header.Add(headerAcceptEncoding, "gzip") 23 | handler.ServeHTTP(resp, req) 24 | 25 | if resp.Header().Get(headerContentEncoding) != "gzip" { 26 | t.Errorf("Not gzip'd? Content-encoding: %q", resp.Header()) 27 | } 28 | 29 | if resp.Header().Get(headerContentLength) != "" { 30 | t.Errorf("Not supposed to include content-length: %q", resp.Header()) 31 | } 32 | 33 | r, err := gzip.NewReader(resp.Body) 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | defer r.Close() 38 | if body, err := io.ReadAll(r); err != nil { 39 | t.Fatal(err) 40 | } else if string(body) != "Hi there!" { 41 | t.Errorf("Wrong response: %q", string(body)) 42 | } 43 | 44 | // Also, test without the accept header and make sure it's NOT gzip'd. 45 | resp = httptest.NewRecorder() 46 | req, _ = http.NewRequest("GET", "/", nil) 47 | handler.ServeHTTP(resp, req) 48 | if resp.Header().Get(headerContentEncoding) == "gzip" { 49 | t.Errorf("Unexpectedly gzip'd: Content-encoding: %q", resp.Header()) 50 | } 51 | if resp.Body.String() != "Hi there!" { 52 | t.Errorf("Wrong response: %q", resp.Body.String()) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /response_writer.go: -------------------------------------------------------------------------------- 1 | package sandwich 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | ) 9 | 10 | // WrapResponseWriter creates a ResponseWriter and returns it as both an 11 | // http.ResponseWriter and a *ResponseWriter. The double return is redundant 12 | // for native Go code, but is a necessary hint to the dependency injection. 13 | func WrapResponseWriter(w http.ResponseWriter) (http.ResponseWriter, *ResponseWriter) { 14 | rw := &ResponseWriter{w, 0, 0} 15 | return rw, rw 16 | } 17 | 18 | // ResponseWriter wraps http.ResponseWriter to add tracking of the response size 19 | // and response code. 20 | type ResponseWriter struct { 21 | http.ResponseWriter 22 | Size int // The size of the response written so far, in bytes. 23 | Code int // The status code of the response, or 0 if not written yet. 24 | } 25 | 26 | func (w *ResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 27 | hijacker, ok := w.ResponseWriter.(http.Hijacker) 28 | if !ok { 29 | return nil, nil, fmt.Errorf("the ResponseWriter doesn't support the Hijacker interface") 30 | } 31 | return hijacker.Hijack() 32 | } 33 | 34 | func (w *ResponseWriter) Flush() { 35 | flusher, ok := w.ResponseWriter.(http.Flusher) 36 | if ok { 37 | flusher.Flush() 38 | } 39 | } 40 | 41 | func (w *ResponseWriter) WriteHeader(code int) { 42 | if w.Code == 0 { 43 | w.Code = code 44 | } 45 | w.ResponseWriter.WriteHeader(code) 46 | } 47 | 48 | func (w *ResponseWriter) Write(p []byte) (int, error) { 49 | if w.Code == 0 { 50 | w.Code = 200 51 | } 52 | n, err := w.ResponseWriter.Write(p) 53 | w.Size += n 54 | return n, err 55 | } 56 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/bradrydzewski/go.auth v0.0.0-20130828171325-d0051b5cc538 h1:xdNrK3humN0dHUoArlWMrkDUBcxtDHlJ7196exYTbsI= 2 | github.com/bradrydzewski/go.auth v0.0.0-20130828171325-d0051b5cc538/go.mod h1:uPPs9JS166ZU50k/tsHQ5Dux4qOBjsJyD2zrU0WaUbk= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/dchest/authcookie v0.0.0-20190824115100-f900d2294c8e h1:xizeG5ksKSdyNaom2//2Bow4hLWqXkCql36nrL9iEUI= 7 | github.com/dchest/authcookie v0.0.0-20190824115100-f900d2294c8e/go.mod h1:x7AK2h2QzaXVEFi1tbMYMDuvHcCEr1QdMDrg3hkW24Q= 8 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 9 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 10 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 11 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 12 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 13 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 14 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 15 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 16 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= 17 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 18 | -------------------------------------------------------------------------------- /gzip.go: -------------------------------------------------------------------------------- 1 | package sandwich 2 | 3 | import ( 4 | "compress/gzip" 5 | "net/http" 6 | "strings" 7 | ) 8 | 9 | const ( 10 | headerAcceptEncoding = "Accept-Encoding" 11 | headerContentEncoding = "Content-Encoding" 12 | headerContentLength = "Content-Length" 13 | headerContentType = "Content-Type" 14 | headerVary = "Vary" 15 | ) 16 | 17 | // Gzip wraps a sandwich.Middleware to add gzip compression to the output for 18 | // all subsequent handlers. 19 | // 20 | // For example, to gzip everything you could use: 21 | // 22 | // router.Use(sandwich.Gzip) 23 | // ...use as normal... 24 | // 25 | // Or, to gzip just a particular route you could do: 26 | // 27 | // router.Get("/foo/bar", sandwich.Gzip, MyHandleFooBar) 28 | // 29 | // Note that this does NOT auto-detect the content and disable compression for 30 | // already-compressed data (e.g. jpg images). 31 | var Gzip = Wrap{provideGZipWriter, (*gZipWriter).Flush} 32 | 33 | func provideGZipWriter(w http.ResponseWriter, r *http.Request) (http.ResponseWriter, *gZipWriter) { 34 | if !strings.Contains(r.Header.Get(headerAcceptEncoding), "gzip") { 35 | return w, nil 36 | } 37 | headers := w.Header() 38 | headers.Set(headerContentEncoding, "gzip") 39 | headers.Set(headerVary, headerAcceptEncoding) 40 | 41 | wr := &gZipWriter{w, gzip.NewWriter(w)} 42 | return wr, wr 43 | } 44 | 45 | type gZipWriter struct { 46 | http.ResponseWriter 47 | w *gzip.Writer 48 | } 49 | 50 | func (g *gZipWriter) Write(p []byte) (int, error) { 51 | if len(g.Header().Get(headerContentType)) == 0 { 52 | g.Header().Set(headerContentType, http.DetectContentType(p)) 53 | } 54 | return g.w.Write(p) 55 | } 56 | 57 | func (g *gZipWriter) Flush() { 58 | g.Header().Del(headerContentLength) 59 | g.w.Close() 60 | } 61 | -------------------------------------------------------------------------------- /chain/naming_test.go: -------------------------------------------------------------------------------- 1 | package chain 2 | 3 | import ( 4 | "net/http" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestNameMapper(t *testing.T) { 12 | var n nameMapper 13 | 14 | assert.Equal(t, "i64", n.For(reflect.TypeOf(int64(0))), "int64") 15 | assert.Equal(t, "i", n.For(reflect.TypeOf(int(0))), "int") 16 | assert.Equal(t, "u", n.For(reflect.TypeOf(uint(0))), "uint") 17 | assert.Equal(t, "f32", n.For(reflect.TypeOf(float32(0))), "float32") 18 | assert.Equal(t, "f64", n.For(reflect.TypeOf(float64(0))), "float64") 19 | assert.Equal(t, "flag", n.For(reflect.TypeOf(false)), "bool") 20 | 21 | assert.Equal(t, "rw", n.For(reflect.TypeOf((*http.ResponseWriter)(nil)).Elem()), "http.ResponseWriter") 22 | assert.Equal(t, "req", n.For(reflect.TypeOf((*http.Request)(nil))), "*http.Request") 23 | 24 | type Req struct{} 25 | assert.Equal(t, "chain_Req", n.For(reflect.TypeOf(Req{})), "chain.Req (could've been req, but that's taken)") 26 | assert.Equal(t, "pReq", n.For(reflect.TypeOf(&Req{})), "*chain.Req") 27 | assert.Equal(t, "pppReq", n.For(reflect.TypeOf((***Req)(nil))), "***chain.Req") 28 | 29 | var c map[string]struct { 30 | A []byte 31 | B chan bool 32 | } 33 | assert.Equal(t, "map_string_struct_A_uint8_B_chan_bool", n.For(reflect.TypeOf(c)), 34 | "crazy inlined struct") 35 | 36 | var d map[string]struct { 37 | A_uint8_B chan bool 38 | } 39 | assert.Equal(t, "__var12__", n.For(reflect.TypeOf(d)), 40 | "inlined struct with var name that conflicts") 41 | 42 | assert.Equal(t, "u8", n.For(reflect.TypeOf(byte(0))), "single byte") 43 | assert.Equal(t, "sliceOfUint8", n.For(reflect.TypeOf([]byte{})), "byte slice") 44 | assert.Equal(t, "sliceOfInt", n.For(reflect.TypeOf([]int{})), "int slice") 45 | } 46 | -------------------------------------------------------------------------------- /examples/2-advanced/static/main.js: -------------------------------------------------------------------------------- 1 | function addTask() { 2 | var el = document.querySelector('input[name=desc]'); 3 | var data = { desc: el.value }; 4 | el.value = ""; 5 | var xhr = { 6 | url: 'api/task', 7 | method: 'POST', 8 | credentials: 'include', 9 | body: JSON.stringify(data), 10 | cache: 'no-cache', 11 | }; 12 | fetch(xhr.url, xhr).then(response => response.json()).then(data => { 13 | if (data.error) { 14 | setNotice('#FEE', 'Failed: ' + data.error); 15 | } else if (data.task) { 16 | var li = html('
  • '); 17 | li.id = 'task' + data.task.id; 18 | li.innerText = data.task.desc; // use innerText so it's escaped! 19 | document.querySelector('#tasks').appendChild(li); 20 | setNotice('#EFE', 'Added task ' + data.task.desc); 21 | } else { 22 | setNotice('#FEE', 'Server error'); 23 | console.error('Bad response: ', data); 24 | } 25 | }) 26 | 27 | return false; 28 | } 29 | 30 | function toggle(el) { 31 | var id = el.id.substr(4); // skip the "task" prefix 32 | var xhr = { 33 | url: 'api/task/' + id, 34 | method: 'POST', 35 | credentials: 'include', 36 | body: '{"toggle":true,"id":"' + id + '"}', 37 | cache: 'no-cache', 38 | }; 39 | fetch(xhr.url, xhr).then(response => response.json()).then(data => { 40 | if (data.error) { 41 | setNotice('#FEE', 'Failed: ' + data.error); 42 | } else if (data.task) { 43 | el.classList.toggle('done', data.task.done); 44 | setNotice('#EFE', 'Updated!') 45 | } else { 46 | setNotice('#FEE', 'Server error'); 47 | console.error('Bad response: ', data); 48 | } 49 | }); 50 | } 51 | 52 | // Updates the notice div. 53 | function setNotice(col, msg) { 54 | var el = document.querySelector('#notice'); 55 | el.innerText = msg; 56 | el.style.background = col; 57 | } 58 | 59 | // Creates an html element from a string. 60 | function html(content) { 61 | var div = document.createElement('div'); 62 | div.innerHTML = content; 63 | return div.firstChild; 64 | } 65 | -------------------------------------------------------------------------------- /chain/ordinal.go: -------------------------------------------------------------------------------- 1 | package chain 2 | 3 | // adapted from https://github.com/martinusso/inflect/blob/master/ordinal.go#L38 4 | 5 | // The MIT License (MIT) 6 | // 7 | // Copyright (c) 2016 Breno Martinusso 8 | // 9 | // Permission is hereby granted, free of charge, to any person obtaining a copy 10 | // of this software and associated documentation files (the "Software"), to deal 11 | // in the Software without restriction, including without limitation the rights 12 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | // copies of the Software, and to permit persons to whom the Software is 14 | // furnished to do so, subject to the following conditions: 15 | // 16 | // The above copyright notice and this permission notice shall be included in all 17 | // copies or substantial portions of the Software. 18 | // 19 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | // SOFTWARE. 26 | 27 | import ( 28 | "fmt" 29 | "math" 30 | ) 31 | 32 | const ( 33 | st = "st" 34 | nd = "nd" 35 | rd = "rd" 36 | th = "th" 37 | ) 38 | 39 | // Ordinal returns the ordinal suffix that should be added to a number to denote the position in an ordered sequence such as 1st, 2nd, 3rd, 4th... 40 | func ordinal(number int) string { 41 | switch abs(number) % 100 { 42 | case 11, 12, 13: 43 | return th 44 | default: 45 | switch abs(number) % 10 { 46 | case 1: 47 | return st 48 | case 2: 49 | return nd 50 | case 3: 51 | return rd 52 | } 53 | } 54 | return th 55 | } 56 | 57 | func abs(number int) int { 58 | return int(math.Abs(float64(number))) 59 | } 60 | 61 | // Ordinalize turns a number into an ordinal string 62 | func ordinalize(number int) string { 63 | ordinal := ordinal(number) 64 | return fmt.Sprintf("%d%s", number, ordinal) 65 | } 66 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package sandwich_test 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/fs" 7 | "net/http" 8 | 9 | "github.com/augustoroman/sandwich" 10 | ) 11 | 12 | type UserID string 13 | 14 | type User struct{} 15 | type UserDB interface { 16 | Get(UserID) (*User, error) 17 | New(*User) (UserID, error) 18 | Del(UserID) error 19 | List() ([]*User, error) 20 | } 21 | 22 | func ExampleRouter() { 23 | var db UserDB // = NewUserDB 24 | 25 | root := sandwich.TheUsual() 26 | root.SetAs(db, &db) 27 | 28 | api := root.SubRouter("/api") 29 | api.OnErr(sandwich.HandleErrorJson) 30 | 31 | apiUsers := api.SubRouter("/users") 32 | apiUsers.Get("/:uid", UserIDFromParam, UserDB.Get, SendUser) 33 | apiUsers.Delete("/:uid", UserIDFromParam, UserDB.Del) 34 | apiUsers.Get("/", UserDB.List, SendUserList) 35 | apiUsers.Post("/", UserFromCreateRequest, UserDB.New, SendUserID) 36 | 37 | var staticFS fs.FS 38 | root.Get("/home/", GetLoggedInUser, UserDB.Get, Home) 39 | root.Get("/:path", http.FileServer(http.FS(staticFS))) 40 | 41 | // Output: 42 | } 43 | 44 | func GetLoggedInUser(r *http.Request) (UserID, error) { 45 | token := r.Header.Get("user-token") 46 | uid := UserID(token) // decode the token to get the user info 47 | if uid == "" { 48 | return "", sandwich.Error{Code: http.StatusUnauthorized, ClientMsg: "invalid user token"} 49 | } 50 | return uid, nil 51 | } 52 | 53 | func Home(w http.ResponseWriter, u *User) { 54 | fmt.Fprintf(w, "Hello %v", u) 55 | } 56 | 57 | func UserIDFromParam(p sandwich.Params) (UserID, error) { 58 | uid := UserID(p["id"]) 59 | if uid == "" { 60 | return "", sandwich.Error{Code: 400, ClientMsg: "Missing UID param", LogMsg: "Request missing UID param"} 61 | } 62 | return "", nil 63 | } 64 | 65 | func UserFromCreateRequest(r *http.Request) (*User, error) { 66 | u := &User{} 67 | return u, json.NewDecoder(r.Body).Decode(u) 68 | } 69 | 70 | func SendUserList(w http.ResponseWriter, users []*User) error { return SendJson(w, users) } 71 | func SendUser(w http.ResponseWriter, user *User) error { return SendJson(w, user) } 72 | func SendUserID(w http.ResponseWriter, id UserID) error { return SendJson(w, id) } 73 | 74 | func SendJson(w http.ResponseWriter, val interface{}) error { 75 | w.Header().Set("Content-Type", "application/json") 76 | return json.NewEncoder(w).Encode(val) 77 | } 78 | -------------------------------------------------------------------------------- /wrap.go: -------------------------------------------------------------------------------- 1 | package sandwich 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/augustoroman/sandwich/chain" 7 | ) 8 | 9 | // ChainMutation is a special type that allows modifying the chain directly when 10 | // added to a router. This allows advanced usage and should generally not be 11 | // used unless you know what you're doing. In particular, don't add `Arg`s to 12 | // the chain, that will break the router. 13 | type ChainMutation interface { 14 | // Modify the provided chain and return the modified chain. 15 | Apply(c chain.Func) chain.Func 16 | } 17 | 18 | // Wrap provides a mechanism to add two handlers: one that runs during the 19 | // normal course of middleware handling (Before) and one that is defer'd and 20 | // runs after the main set of middleware has executed (After). The defer'd 21 | // handler may accept the `error` type and handle or ignore errors as desired. 22 | // 23 | // This is generally useful for specifying operations that need to run before 24 | // and after subsequent middleware, such as timing, logging, or 25 | // allocation/cleanup operations. 26 | type Wrap struct { 27 | // `Before` is run in the normal course of middleware evaluation. Any returned 28 | // types from this will be available to the defer'd After handler. If Before 29 | // itself returns an error, After will not run. 30 | Before any 31 | // `After` is defer`d and run after the normal course of middleware has 32 | // completed, in reverse order of any registered `defer` handlers. Defer`d 33 | // handlers will always be executed if `Before` was executed, even in the case 34 | // of errors. The `After` handler may accept the `error` type -- that will be 35 | // nil unless a subsequent handler has returned an error. 36 | After any 37 | } 38 | 39 | // Apply modifies the chain to add Before and After. 40 | func (w Wrap) Apply(c chain.Func) chain.Func { 41 | return c.Then(toHandlerFunc(w.Before)).Defer(toHandlerFunc(w.After)) 42 | } 43 | 44 | func apply(c chain.Func, handlers ...any) chain.Func { 45 | for _, h := range handlers { 46 | if mod, ok := h.(ChainMutation); ok { 47 | c = mod.Apply(c) 48 | } else { 49 | c = c.Then(toHandlerFunc(h)) 50 | } 51 | } 52 | return c 53 | } 54 | 55 | func toHandlerFunc(h any) any { 56 | if handlerInterface, ok := h.(http.Handler); ok { 57 | return handlerInterface.ServeHTTP 58 | } 59 | return h 60 | } 61 | -------------------------------------------------------------------------------- /chain/util.go: -------------------------------------------------------------------------------- 1 | package chain 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "runtime" 7 | "sort" 8 | ) 9 | 10 | // TODO(aroman) Replace calls with an explicit error type 11 | func panicf(msgfmt string, args ...interface{}) { 12 | panic(fmt.Errorf(msgfmt, args...)) 13 | } 14 | 15 | func valueOfFunction(handler interface{}) (FuncInfo, error) { 16 | if handler == nil { 17 | return FuncInfo{}, fmt.Errorf("should be a function, handler is ") 18 | } 19 | val := reflect.ValueOf(handler) 20 | if !val.IsValid() || val.Kind() != reflect.Func { 21 | return FuncInfo{}, fmt.Errorf("should be a function, handler is %s", val.Type()) 22 | } 23 | info := runtime.FuncForPC(val.Pointer()) 24 | file, line := info.FileLine(val.Pointer()) 25 | return FuncInfo{info.Name(), file, line, val}, nil 26 | } 27 | 28 | func checkCanCall(available map[reflect.Type]bool, fn FuncInfo) error { 29 | fn_typ := fn.Func.Type() 30 | for i := 0; i < fn_typ.NumIn(); i++ { 31 | t := fn_typ.In(i) 32 | if available[t] { 33 | continue 34 | } 35 | 36 | // Un-oh, not available. Let's see what we can do to make a helpful 37 | // error message. 38 | provided := []string{} 39 | candidates := []string{} 40 | for typ := range available { 41 | provided = append(provided, typ.String()) 42 | if t.Kind() == reflect.Interface && typ.Implements(t) { 43 | candidates = append(candidates, typ.String()) 44 | } 45 | } 46 | sort.Strings(provided) 47 | 48 | suggestion := "" 49 | if len(candidates) == 0 && t.Kind() == reflect.Interface { 50 | suggestion = fmt.Sprintf(" Type %s is an interface, but not "+ 51 | "implemented by any of the provided types.", t) 52 | } else if len(candidates) == 1 { 53 | suggestion = fmt.Sprintf(" Type %s is an interface that is "+ 54 | "implemented by the provided type %s. Did you mean to use "+ 55 | "'.SetAs(val, (*%s)(nil))' instead of '.Set(val)'?", 56 | t, candidates[0], strip("main", t)) 57 | } else if len(candidates) > 1 { 58 | suggestion = fmt.Sprintf(" Type %s is an interface that is implemented "+ 59 | "by %d provided types: %s. If you meant to use one of those, use "+ 60 | "'.SetAs(val, (*someInterface)(nil))' to explicitly assign "+ 61 | "to that type.", 62 | t, len(candidates), candidates) 63 | } 64 | 65 | return fmt.Errorf("can't be called: type %s required for %s arg "+ 66 | "of %s (%s) has not been provided. Types that have been provided: %s. %s", 67 | t, ordinalize(i+1), fn.Name, fn_typ, provided, suggestion) 68 | } 69 | return nil 70 | } 71 | -------------------------------------------------------------------------------- /chain/example_test.go: -------------------------------------------------------------------------------- 1 | package chain_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "time" 7 | 8 | "github.com/augustoroman/sandwich/chain" 9 | ) 10 | 11 | func ExampleFunc() { 12 | example := chain.Func{}. 13 | // Indicate that the chain will will receive a time.Duration as the first 14 | // arg when it's executed. 15 | Arg(time.Duration(0)). 16 | // When the chain is executed, it will first call time.Now which takes no 17 | // arguments but will return a time.Time value that will be available to 18 | // later calls. 19 | Then(time.Now). 20 | // Next, time.Sleep will be invoked, which requires a time.Duration 21 | // parameter. That's available since it's provided as an input to the chain. 22 | Then(time.Sleep). 23 | // Next, time.Since will be invoked, which requires a time.Time value that 24 | // was provided by the earlier time.Now call. It will return a time.Duration 25 | // value that will overwrite the input the chain. 26 | Then(time.Since). 27 | // Finally, we'll print out the stored time.Duration value. 28 | Then(func(dt time.Duration) { 29 | // Round to the nearest 10ms -- this makes the test not-flaky since the 30 | // sleep duration will not have been exact. 31 | dt = dt.Truncate(10 * time.Millisecond) 32 | fmt.Println("elapsed:", dt) 33 | }) 34 | 35 | example.MustRun(time.Duration(30 * time.Millisecond)) 36 | 37 | // Print the equivalent code: 38 | fmt.Println("Generated code is:") 39 | example.Code("example", "main", os.Stdout) 40 | 41 | // Output: 42 | // elapsed: 30ms 43 | // Generated code is: 44 | // func example( 45 | // ) func( 46 | // duration time.Duration, 47 | // ) { 48 | // return func( 49 | // duration time.Duration, 50 | // ) { 51 | // var time_Time time.Time 52 | // time_Time = time.Now() 53 | // 54 | // time.Sleep(duration) 55 | // 56 | // duration = time.Since(time_Time) 57 | // 58 | // chain_test.ExampleFunc.func1(duration) 59 | // 60 | // } 61 | // } 62 | } 63 | 64 | func ExampleFunc_file() { 65 | // Chains can be used to do file operations! 66 | 67 | writeToFile := chain.Func{}. 68 | Arg(""). // filename 69 | Arg([]byte(nil)). // data 70 | Then(os.Create). 71 | Then((*os.File).Write). 72 | Then((*os.File).Close) 73 | 74 | // This never fails -- any errors in creating the file or writing to it will 75 | // be handled by the default error handler that logs a message, but the `Run` 76 | // itself doesn't fail unless the args are incorrect. 77 | writeToFile.MustRun("test.txt", []byte("the data")) 78 | 79 | content, err := os.ReadFile("test.txt") 80 | panicOnErr(err) 81 | fmt.Printf("test.txt: %s\n", content) 82 | 83 | panicOnErr(os.Remove("test.txt")) 84 | 85 | // Output: 86 | // test.txt: the data 87 | } 88 | 89 | func panicOnErr(err error) { 90 | if err != nil { 91 | panic(err) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /chain/codegen.go: -------------------------------------------------------------------------------- 1 | package chain 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "path/filepath" 7 | "reflect" 8 | "runtime" 9 | "strings" 10 | ) 11 | 12 | // Code writes the Go code for the current chain out to w assuming it lives in 13 | // package "pkg" with the specified handler function name. 14 | func (c Func) Code(name, pkg string, w io.Writer) { 15 | vars := &nameMapper{} 16 | 17 | for _, s := range c.steps { 18 | vars.Reserve(s.valTyp.Name()) 19 | vars.Reserve(filepath.Base(s.valTyp.PkgPath())) 20 | } 21 | 22 | fmt.Fprintf(w, "func %s(\n", name) 23 | for _, s := range c.steps { 24 | switch s.typ { 25 | case tVALUE: 26 | fmt.Fprintf(w, "\t%s %s,\n", vars.For(s.valTyp), strip(pkg, s.valTyp)) 27 | } 28 | } 29 | fmt.Fprintf(w, ") func(\n") 30 | for _, s := range c.steps { 31 | if s.typ == tARG { 32 | fmt.Fprintf(w, "\t%s %s,\n", vars.For(s.valTyp), strip(pkg, s.valTyp)) 33 | } 34 | } 35 | fmt.Fprintf(w, ") {\n") 36 | 37 | fmt.Fprintf(w, "\treturn func(\n") 38 | for _, s := range c.steps { 39 | if s.typ == tARG { 40 | fmt.Fprintf(w, "\t\t%s %s,\n", vars.For(s.valTyp), strip(pkg, s.valTyp)) 41 | } 42 | } 43 | fmt.Fprintf(w, "\t) {\n") 44 | 45 | errHandler := step{tERROR_HANDLER, reflect.ValueOf(DefaultErrorHandler), nil} 46 | for _, s := range c.steps { 47 | if s.typ == tARG || s.typ == tVALUE { 48 | continue 49 | } 50 | 51 | if s.typ == tERROR_HANDLER { 52 | errHandler = s 53 | continue 54 | } 55 | 56 | for i := 0; i < s.valTyp.NumOut(); i++ { 57 | t := s.valTyp.Out(i) 58 | if !vars.Has(t) { 59 | fmt.Fprintf(w, "\t\tvar %s %s\n", vars.For(t), strip(pkg, t)) 60 | } 61 | } 62 | 63 | if s.typ == tPOST_HANDLER { 64 | fmt.Fprintf(w, "\t\tdefer func() {\n\t") 65 | } 66 | 67 | name, inVars, outVars, returnsError := getArgNames(pkg, vars, s.val) 68 | 69 | fmt.Fprintf(w, "\t\t") 70 | if len(outVars) > 0 { 71 | fmt.Fprintf(w, "%s = ", strings.Join(outVars, ", ")) 72 | } 73 | fmt.Fprintf(w, "%s(%s)\n", name, strings.Join(inVars, ", ")) 74 | 75 | if returnsError { 76 | name, inVars, _, _ := getArgNames(pkg, vars, errHandler.val) 77 | fmt.Fprintf(w, "\t\tif err != nil {\n") 78 | fmt.Fprintf(w, "\t\t\t%s(%s)\n", name, strings.Join(inVars, ", ")) 79 | fmt.Fprintf(w, "\t\t\treturn\n") 80 | fmt.Fprintf(w, "\t\t}\n") 81 | } 82 | 83 | if s.typ == tPOST_HANDLER { 84 | fmt.Fprintf(w, "\t\t}()\n") 85 | } 86 | fmt.Fprintf(w, "\n") 87 | } 88 | fmt.Fprintf(w, "\t}\n") 89 | fmt.Fprintf(w, "}\n") 90 | } 91 | 92 | func strip(pkg string, t reflect.Type) string { 93 | return stripStr(pkg, t.String()) 94 | } 95 | func stripStr(pkg, s string) string { 96 | pos := strings.IndexFunc(s, func(r rune) bool { return r != '*' }) 97 | s = s[:pos] + strings.TrimPrefix(s[pos:], pkg+".") 98 | return s 99 | } 100 | 101 | func getArgNames(pkg string, vars *nameMapper, v reflect.Value) (name string, in, out []string, returnsError bool) { 102 | name = runtime.FuncForPC(v.Pointer()).Name() 103 | name = filepath.Base(name) 104 | name = strings.TrimPrefix(name, pkg+".") 105 | 106 | if pos := strings.Index(name, ".(*"); pos > 0 { 107 | pkgName := name[:pos+1] 108 | pkgName = strings.TrimPrefix(pkgName, pkg+".") 109 | name = "(*" + pkgName + name[pos+3:] 110 | } 111 | 112 | t := v.Type() 113 | out = make([]string, t.NumOut()) 114 | for i := 0; i < t.NumOut(); i++ { 115 | out[i] = vars.For(t.Out(i)) 116 | if t.Out(i) == errorType { 117 | returnsError = true 118 | } 119 | } 120 | in = make([]string, t.NumIn()) 121 | for i := 0; i < t.NumIn(); i++ { 122 | in[i] = vars.For(t.In(i)) 123 | } 124 | return name, in, out, returnsError 125 | } 126 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package sandwich 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | ) 8 | 9 | // Error is an error implementation that provides the ability to specify three 10 | // things to the sandwich error handler: 11 | // - The HTTP status code that should be used in the response. 12 | // - The client-facing message that should be sent. Typically this is a 13 | // sanitized error message, such as "Internal Server Error". 14 | // - Internal debugging detail including a log message and the underlying 15 | // error that should be included in the server logs. 16 | // 17 | // Note that Cause may be nil. 18 | // 19 | // The sandwich standard Error handlers (HandleError and HandleErrorJson) will 20 | // respect these Errors and respond with the appropriate status code and client 21 | // message. Additionally, the sandwich standard log handling will log LogMsg. 22 | type Error struct { 23 | Code int 24 | ClientMsg string 25 | LogMsg string 26 | Cause error 27 | } 28 | 29 | func (e Error) Error() string { 30 | msg := fmt.Sprintf("(%d) %s", e.Code, e.LogMsg) 31 | if e.LogMsg == "" { 32 | msg += e.ClientMsg 33 | } 34 | if e.Cause != nil { 35 | msg += ": " + e.Cause.Error() 36 | } 37 | return msg 38 | } 39 | 40 | // LogIfMsg will set the Error field on the LogEntry if the Error's LogMsg 41 | // field has something. 42 | func (e Error) LogIfMsg(l *LogEntry) { 43 | if e.LogMsg != "" { 44 | l.Error = e 45 | } 46 | } 47 | 48 | // Done is a sentinel error value that can be used to interrupt the middleware 49 | // chain without triggering the default error handling. HandleError will not 50 | // attempt to write any status code or client message, nor will it add the error 51 | // to the log. 52 | var Done = errors.New("") 53 | 54 | // ToError will convert a generic non-nil error to an explicit sandwich.Error 55 | // type. If err is already a sandwich.Error, it will be returned. Otherwise, a 56 | // generic 500 Error (internal server error) will be initialized and returned. 57 | // Note that if err is nil, it will still return a generic 500 Error. 58 | func ToError(err error) Error { 59 | var e Error 60 | if errors.As(err, &e) { 61 | if e.Code == 0 { 62 | e.Code = 500 63 | } 64 | if e.ClientMsg == "" { 65 | e.ClientMsg = http.StatusText(e.Code) 66 | } 67 | return e 68 | } 69 | return Error{ 70 | Code: http.StatusInternalServerError, 71 | LogMsg: "Failure", 72 | Cause: err, 73 | ClientMsg: http.StatusText(http.StatusInternalServerError), 74 | } 75 | } 76 | 77 | // HandleError is the default error handler included in sandwich.TheUsual. 78 | // If the error is a sandwich.Error, it responds with the specified status code 79 | // and client message. Otherwise, it responds with a 500. In both cases, the 80 | // underlying error is added to the request log. 81 | // 82 | // If the error is sandwich.Done, HandleError does nothing. 83 | func HandleError(w http.ResponseWriter, r *http.Request, l *LogEntry, err error) { 84 | if err == Done { 85 | return 86 | } 87 | e := ToError(err) 88 | e.LogIfMsg(l) 89 | http.Error(w, e.ClientMsg, e.Code) 90 | } 91 | 92 | // HandleErrorJson is identical to HandleError except that it responds to the 93 | // client as JSON instead of plain text. Again, detailed error info is added 94 | // to the request log. 95 | // 96 | // If the error is sandwich.Done, HandleErrorJson does nothing. 97 | func HandleErrorJson(w http.ResponseWriter, r *http.Request, l *LogEntry, err error) { 98 | if err == Done { 99 | return 100 | } 101 | e := ToError(err) 102 | e.LogIfMsg(l) 103 | w.Header().Set("Content-Type", "application/json") 104 | w.WriteHeader(e.Code) 105 | fmt.Fprintf(w, `{"error":%q}`, e.ClientMsg) 106 | } 107 | -------------------------------------------------------------------------------- /logger.go: -------------------------------------------------------------------------------- 1 | package sandwich 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "os" 8 | "sort" 9 | "strings" 10 | "time" 11 | ) 12 | 13 | // Injected for testing 14 | var time_Now = time.Now 15 | var os_Stderr io.Writer = os.Stderr 16 | 17 | // LogEntry is the information tracked on a per-request basis for the sandwich 18 | // Logger. All fields other than Note are automatically filled in. The Note 19 | // field is a generic key-value string map for adding additional per-request 20 | // metadata to the logs. You can take *sandwich.LogEntry to your functions to 21 | // add fields to Note. 22 | // 23 | // For example: 24 | // 25 | // func MyAuthCheck(r *http.Request, e *sandwich.LogEntry) (User, error) { 26 | // user, err := decodeAuthCookie(r) 27 | // if user != nil { 28 | // e.Note["user"] = user.Id() // indicate which user is auth'd 29 | // } 30 | // return user, err 31 | // } 32 | type LogEntry struct { 33 | RemoteIp string 34 | Start time.Time 35 | Request *http.Request 36 | StatusCode int 37 | ResponseSize int 38 | Elapsed time.Duration 39 | Error error 40 | Note map[string]string 41 | // set to true to suppress logging this request 42 | Quiet bool 43 | } 44 | 45 | // NoLog is a middleware function that suppresses log output for this request. 46 | // For example: 47 | // 48 | // // suppress logging of the favicon request to reduce log spam. 49 | // router.Get("/favicon.ico", sandwich.NoLog, staticHandler) 50 | // 51 | // This depends on WriteLog respecting the Quiet flag, which the default 52 | // implementation does. 53 | func NoLog(e *LogEntry) { e.Quiet = true } 54 | 55 | // LogRequests is a middleware wrap that creates a log entry during middleware 56 | // processing and then commits the log entry after the middleware has executed. 57 | var LogRequests = Wrap{NewLogEntry, (*LogEntry).Commit} 58 | 59 | // NewLogEntry creates a *LogEntry and initializes it with basic request 60 | // information. 61 | func NewLogEntry(r *http.Request) *LogEntry { 62 | return &LogEntry{ 63 | RemoteIp: remoteIp(r), 64 | Start: time_Now(), 65 | Request: r, 66 | Note: map[string]string{}, 67 | } 68 | } 69 | 70 | // Commit fills in the remaining *LogEntry fields and writes the entry out. 71 | func (entry *LogEntry) Commit(w *ResponseWriter) { 72 | entry.Elapsed = time_Now().Sub(entry.Start) 73 | entry.ResponseSize = w.Size 74 | entry.StatusCode = w.Code 75 | WriteLog(*entry) 76 | } 77 | 78 | // Some nice escape codes 79 | const ( 80 | _GREEN = "\033[32m" 81 | _YELLOW = "\033[33m" 82 | _RESET = "\033[0m" 83 | _RED = "\033[91m" 84 | ) 85 | 86 | // WriteLog is called to actually write a LogEntry out to the log. By default, 87 | // it writes to stderr and colors normal requests green, slow requests yellow, 88 | // and errors red. You can replace the function to adjust the formatting or use 89 | // whatever logging library you like. 90 | var WriteLog = func(e LogEntry) { 91 | if e.Quiet { 92 | return 93 | } 94 | col, reset := logColors(e) 95 | fmt.Fprintf(os_Stderr, "%s%s %s \"%s %s\" (%d %dB %s) %s%s\n", 96 | col, 97 | e.Start.Format(time.RFC3339), e.RemoteIp, 98 | e.Request.Method, e.Request.RequestURI, 99 | e.StatusCode, e.ResponseSize, e.Elapsed, 100 | e.NotesAndError(), 101 | reset) 102 | } 103 | 104 | // NotesAndError formats the Note values and error (if any) for logging. 105 | func (l LogEntry) NotesAndError() string { 106 | pairs := make([]string, len(l.Note)) 107 | for k, v := range l.Note { 108 | pairs = append(pairs, fmt.Sprintf("%s=%q", k, v)) 109 | } 110 | sort.Strings(pairs) 111 | msg := strings.Join(pairs, " ") 112 | if l.Error != nil { 113 | msg += "\n ERROR: " + l.Error.Error() 114 | } 115 | return msg 116 | } 117 | 118 | func logColors(e LogEntry) (start, reset string) { 119 | col, reset := _GREEN, _RESET 120 | if e.Elapsed > 30*time.Millisecond { 121 | col = _YELLOW 122 | } 123 | if e.StatusCode >= 400 || e.Error != nil { 124 | col, reset = _RED, _RESET // high-intensity red + reset 125 | } 126 | return col, reset 127 | } 128 | 129 | // remoteIp extracts the remote IP from the request. Adapted from code in 130 | // Martini: 131 | // 132 | // https://github.com/go-martini/martini/blob/1d33529c15f19/logger.go#L14..L20 133 | func remoteIp(r *http.Request) string { 134 | if addr := r.Header.Get("X-Real-IP"); addr != "" { 135 | return addr 136 | } else if addr := r.Header.Get("X-Forwarded-For"); addr != "" { 137 | return addr 138 | } 139 | return r.RemoteAddr 140 | } 141 | -------------------------------------------------------------------------------- /benchmark_simple_handler_test.go: -------------------------------------------------------------------------------- 1 | package sandwich 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | ) 10 | 11 | // Sample data to JSON-encode for benchmarking. 12 | var userInfo = struct { 13 | Id uint64 `json:"id"` 14 | Name string `json:"name"` 15 | Email string `json:"email"` 16 | AvatarUrl string `json:"avatar_url"` 17 | }{ 18 | 12345467, "John Doe", "john.doe@example.com", "https://www.example.com/users/john.doe/image", 19 | } 20 | 21 | func write204(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } 22 | func hello(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("Hello there!")) } 23 | func sendjson(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(userInfo) } 24 | 25 | func addTestRoutes(mux Router) { 26 | mux.Get("/204", write204) 27 | mux.Get("/hello", hello) 28 | mux.Get("/jsonuser", sendjson) 29 | 30 | mux.Get("/long/1/2/3/4/5/6/7/8/9/xyz/204", write204) 31 | mux.Get("/long/1/2/3/4/5/6/7/8/9/xyz/hello", hello) 32 | mux.Get("/long/1/2/3/4/5/6/7/8/9/xyz/jsonuser", sendjson) 33 | 34 | mux.Get("/1param/:var/204", write204) 35 | mux.Get("/1param/:var/hello", hello) 36 | mux.Get("/1param/:var/jsonuser", sendjson) 37 | 38 | mux.Get("/manyparams/:var/:x/:y/:z/:a/:b/:c/204", write204) 39 | mux.Get("/manyparams/:var/:x/:y/:z/:a/:b/:c/hello", hello) 40 | mux.Get("/manyparams/:var/:x/:y/:z/:a/:b/:c/jsonuser", sendjson) 41 | 42 | mux.Get("/greedy/:var*/204", write204) 43 | mux.Get("/greedy/:var*/hello", hello) 44 | mux.Get("/greedy/:var*/jsonuser", sendjson) 45 | } 46 | 47 | var usualRouter = func() Router { 48 | mux := TheUsual() 49 | mux.Use(NoLog) 50 | addTestRoutes(mux) 51 | return mux 52 | }() 53 | var bareRouter = func() Router { 54 | mux := BuildYourOwn() 55 | addTestRoutes(mux) 56 | return mux 57 | }() 58 | 59 | func bench(N int, route string, mux Router) { 60 | req := httptest.NewRequest("GET", route, nil) 61 | for i := 0; i < N; i++ { 62 | w := httptest.NewRecorder() 63 | mux.ServeHTTP(w, req) 64 | } 65 | } 66 | 67 | func S(s ...string) []string { return s } 68 | 69 | func runBenches(b *testing.B, mux Router, routes, endpoints []string) { 70 | for _, ep := range endpoints { 71 | for _, route := range routes { 72 | path := route + "/" + ep 73 | b.Run(ep+"::"+path, func(b *testing.B) { 74 | bench(b.N, path, mux) 75 | }) 76 | } 77 | } 78 | } 79 | 80 | // Just to shorten the following benchmark functions: 81 | type Handler = http.HandlerFunc 82 | 83 | func BenchmarkUsual(b *testing.B) { 84 | runBenches(b, usualRouter, 85 | S("", "/long/1/2/3/4/5/6/7/8/9/xyz", "/1param/foo", "/manyparams/foo/x/y/z/a/b/c", "/greedy/short", "/greedy/x/y/z/a/b/c"), 86 | S("204", "hello", "jsonuser"), 87 | ) 88 | } 89 | func BenchmarkBare(b *testing.B) { 90 | runBenches(b, bareRouter, 91 | S("", "/long/1/2/3/4/5/6/7/8/9/xyz", "/1param/foo", "/manyparams/foo/x/y/z/a/b/c", "/greedy/short", "/greedy/x/y/z/a/b/c"), 92 | S("204", "hello", "jsonuser"), 93 | ) 94 | } 95 | 96 | func BenchmarkCalls(b *testing.B) { 97 | for i := 1; i < 20; i += 2 { 98 | b.Run(fmt.Sprintf("%02d", i), func(b *testing.B) { 99 | var calls []any 100 | for j := 0; j < i; j++ { 101 | calls = append(calls, hello) 102 | } 103 | mux := BuildYourOwn() 104 | mux.Get("/", calls...) 105 | b.ResetTimer() 106 | bench(b.N, "/", mux) 107 | }) 108 | } 109 | } 110 | 111 | // func Benchmark_Usual_Short_Write204(b *testing.B) { bench(b.N, "/204", usualRouter) } 112 | // func Benchmark_Usual_Short_Hello(b *testing.B) { bench(b.N, "/hello", usualRouter) } 113 | // func Benchmark_Usual_Short_SnedJson(b *testing.B) { bench(b.N, "/204", usualRouter) } 114 | 115 | // // func Benchmark_Hello_RawHTTP(b *testing.B) { bench(b.N, Handler(hello)) } 116 | // // func Benchmark_Hello_Dynamic_Bare(b *testing.B) { bench(b.N, makeHello_Bare()) } 117 | // // func Benchmark_Hello_Dynamic_TheUsual(b *testing.B) { bench(b.N, makeHello_TheUsual()) } 118 | 119 | // // func Benchmark_Write204_RawHTTP(b *testing.B) { bench(b.N, Handler(write204)) } 120 | 121 | // // func Benchmark_Write204_Dynamic_Bare(b *testing.B) { bench(b.N, makeWrite204_Bare()) } 122 | // // func Benchmark_Write204_Dynamic_TheUsual(b *testing.B) { bench(b.N, makeWrite204_TheUsual()) } 123 | 124 | // // func Benchmark_SendJson_RawHTTP(b *testing.B) { bench(b.N, Handler(sendjson)) } 125 | // // func Benchmark_SendJson_Dynamic_Bare(b *testing.B) { bench(b.N, makeSendJson_Bare()) } 126 | // // func Benchmark_SendJson_Dynamic_TheUsual(b *testing.B) { bench(b.N, makeSendJson_TheUsual()) } 127 | -------------------------------------------------------------------------------- /logger_test.go: -------------------------------------------------------------------------------- 1 | package sandwich 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "net/http" 7 | "net/http/httptest" 8 | "os" 9 | "strings" 10 | "testing" 11 | "time" 12 | 13 | "github.com/augustoroman/sandwich/chain" 14 | ) 15 | 16 | type fakeClock struct { 17 | now time.Time 18 | advance time.Duration 19 | } 20 | 21 | func (f *fakeClock) Now() time.Time { 22 | now := f.now 23 | f.now = now.Add(f.advance) 24 | return now 25 | } 26 | 27 | func (f *fakeClock) Sleep(dt time.Duration) { 28 | f.now = f.now.Add(dt) 29 | } 30 | 31 | func validateLogMessage(t *testing.T, logs, expectedColor, expectedMsg string) { 32 | logs = strings.TrimSpace(logs) 33 | 34 | if !strings.HasPrefix(logs, expectedColor) { 35 | t.Errorf("Expected color prefix of %q: %q", expectedColor, logs) 36 | } else { 37 | logs = strings.TrimPrefix(logs, expectedColor) 38 | } 39 | if !strings.HasSuffix(logs, _RESET) { 40 | t.Errorf("Expected reset suffix: %q", logs) 41 | } else { 42 | logs = strings.TrimSuffix(logs, _RESET) 43 | } 44 | logs = strings.TrimSpace(logs) 45 | expectedMsg = strings.TrimSpace(expectedMsg) 46 | if logs != expectedMsg { 47 | t.Errorf("Wrong log message:\nExp: %q\nGot: %q", expectedMsg, logs) 48 | } 49 | } 50 | 51 | func TestLogger(t *testing.T) { 52 | // Restore the world from insanity when we're done: 53 | orig := WriteLog 54 | defer func() { time_Now = time.Now; os_Stderr = os.Stderr; WriteLog = orig }() 55 | 56 | // Setup our fake world. 57 | var logBuf bytes.Buffer 58 | os_Stderr = &logBuf 59 | clk := &fakeClock{time.Date(2001, 2, 3, 4, 5, 6, 7, time.UTC), 13 * time.Millisecond} 60 | time_Now = clk.Now 61 | 62 | // Useful handlers: 63 | sendMsg := func(w http.ResponseWriter) { _, _ = w.Write([]byte("Hi there")) } 64 | slowSendMsg := func(w http.ResponseWriter) { clk.Sleep(100 * time.Millisecond); sendMsg(w) } 65 | fail := func() error { return errors.New("It went horribly wrong") } 66 | slowFail := func() error { clk.Sleep(time.Second); return fail() } 67 | panics := func(w http.ResponseWriter) { sendMsg(w); panic("oops") } 68 | addsNote := func(w http.ResponseWriter, e *LogEntry) { e.Note["a"] = "x"; e.Note["b"] = "y"; sendMsg(w) } 69 | 70 | var resp *httptest.ResponseRecorder 71 | var req *http.Request 72 | 73 | mux := TheUsual() 74 | 75 | // Test a normal response: 76 | logBuf.Reset() 77 | resp = httptest.NewRecorder() 78 | req, _ = http.NewRequest("GET", "/", nil) 79 | req.RequestURI = req.URL.String() 80 | req.Header.Add("X-Real-IP", "123.456.789.0") 81 | mux.Get("/", addsNote) 82 | mux.ServeHTTP(resp, req) 83 | validateLogMessage(t, logBuf.String(), _GREEN, 84 | `2001-02-03T04:05:06Z 123.456.789.0 "GET /" (200 8B 13ms) a="x" b="y"`) 85 | 86 | // Test a slow response: 87 | logBuf.Reset() 88 | resp = httptest.NewRecorder() 89 | req, _ = http.NewRequest("POST", "/slow", nil) 90 | req.RequestURI = req.URL.String() 91 | req.Header.Add("X-Forwarded-For", "") 92 | mux.Post("/slow", slowSendMsg) 93 | mux.ServeHTTP(resp, req) 94 | validateLogMessage(t, logBuf.String(), _YELLOW, 95 | `2001-02-03T04:05:06Z "POST /slow" (200 8B 113ms)`) 96 | 97 | // Test a failed response: 98 | logBuf.Reset() 99 | resp = httptest.NewRecorder() 100 | req, _ = http.NewRequest("BOO!", "/fail", nil) 101 | req.RequestURI = req.URL.String() 102 | req.RemoteAddr = "[::1]:56596" 103 | mux.On("BOO!", "/fail", fail) 104 | mux.ServeHTTP(resp, req) 105 | validateLogMessage(t, logBuf.String(), _RED, 106 | `2001-02-03T04:05:06Z [::1]:56596 "BOO! /fail" (500 22B 13ms) `+"\n"+ 107 | ` ERROR: (500) Failure: It went horribly wrong`) 108 | 109 | // Test a slow failed response (should still be red): 110 | logBuf.Reset() 111 | resp = httptest.NewRecorder() 112 | req, _ = http.NewRequest("PUT", "/slowfail", nil) 113 | req.RequestURI = req.URL.String() 114 | req.RemoteAddr = "[::1]:56596" 115 | req.Header.Add("X-Forwarded-For", "") 116 | req.Header.Add("X-Real-IP", "123.456.789.0") // takes precedence 117 | mux.Put("/slowfail", slowFail) 118 | mux.ServeHTTP(resp, req) 119 | validateLogMessage(t, logBuf.String(), _RED, 120 | `2001-02-03T04:05:06Z 123.456.789.0 "PUT /slowfail" (500 22B 1.013s) `+"\n"+ 121 | ` ERROR: (500) Failure: It went horribly wrong`) 122 | 123 | // Test a suppressed log. 124 | logBuf.Reset() 125 | resp = httptest.NewRecorder() 126 | req, _ = http.NewRequest("GET", "/nolog", nil) 127 | mux.Get("/nolog", NoLog, addsNote) 128 | mux.ServeHTTP(resp, req) 129 | if logBuf.String() != "" { 130 | t.Errorf("Expected no log output, but got [%s]", logBuf.String()) 131 | } 132 | 133 | // Test that a panic should be recorded. 134 | var log LogEntry 135 | WriteLog = func(e LogEntry) { log = e } 136 | resp = httptest.NewRecorder() 137 | req, _ = http.NewRequest("PUT", "/panic", nil) 138 | req.RequestURI = req.URL.String() 139 | req.RemoteAddr = "" 140 | mux.Put("/panic", panics) 141 | mux.ServeHTTP(resp, req) 142 | 143 | if err, ok := ToError(log.Error).Cause.(chain.PanicError); !ok { 144 | t.Errorf("log error should be a panic, but is: %#v", log.Error) 145 | } else if msg := err.Error(); !strings.Contains(msg, `Panic executing middleware`) { 146 | t.Errorf("Bad err message: %s", err) 147 | } else if !strings.Contains(msg, `oops`) { 148 | t.Errorf("Bad err message: %s", err) 149 | } 150 | 151 | if resp.Body.String() != "Hi thereInternal Server Error\n" { 152 | t.Errorf("Incorrect client response: %q", resp.Body.String()) 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /examples/1-simple/main.go: -------------------------------------------------------------------------------- 1 | // 1-simple is a demo webserver for the sandwich middleware package 2 | // demonstrating basic usage. 3 | // 4 | // This example demonstrates most of the basic features of sandwich, including: 5 | // 6 | // * Providing user types to the middleware chain 7 | // * Adding middleware handlers to the stack. 8 | // * Writing handlers that provide request-scoped values. 9 | // * Writing handlers using injected values. 10 | // * Using the default sandwich logging system. 11 | // * Using the default sandwich error system. 12 | package main 13 | 14 | import ( 15 | "fmt" 16 | "log" 17 | "net/http" 18 | "time" 19 | 20 | "github.com/augustoroman/sandwich" 21 | ) 22 | 23 | // Interface for abstracting out the user database. 24 | type UserDb interface { 25 | Lookup(id string) (User, error) 26 | } 27 | type User struct{ Id, Name, Email string } 28 | 29 | func main() { 30 | // To reduce log spam, we'll just put this here, not using any framework. 31 | http.Handle("/favicon.ico", http.NotFoundHandler()) 32 | 33 | // Setup connections to the databases. 34 | udb := userDb{{"bob", "Bob", "bob@example.com"}, {"alice", "Alice", "alice@example.com"}} 35 | 36 | // Create a typical sandwich middleware with logging and error-handling. 37 | mux := sandwich.TheUsual() 38 | // Inject config and user database; now available to all handlers. 39 | mux.SetAs(udb, (*UserDb)(nil)) 40 | // In this example, we'll always check to see if the user is logged in. 41 | // If so, we'll add the user ID to the log entries. 42 | mux.Use(ParseUserIfLoggedIn) 43 | 44 | // If the user is logged in, they'll get a personalized landing page. 45 | // Otherwise, they'll get a generic landing page. 46 | mux.Get("/", ShowLandingPage) 47 | mux.Post("/login", Login) 48 | 49 | // Some pages are only allowed if the user is logged in. 50 | mux.Get("/user/profile", FailIfNotAuthenticated, ShowUserProfile) 51 | // If you have multiple pages that require authentication, you could do: 52 | // authed := mw.Then(FailIfNotAuthenticated) 53 | // http.Handle("/user/profile", authed.Then(ShowUserProfile)) 54 | // http.Handle("/user/...", authed.Then(...)) 55 | // http.Handle("/user/...", authed.Then(...)) 56 | 57 | log.Println("Serving on http://localhost:8080/") 58 | if err := http.ListenAndServe(":8080", mux); err != nil { 59 | log.Fatal("Can't start webserver:", err) 60 | } 61 | } 62 | 63 | // The actual user DB implementation. 64 | type userDb []User 65 | 66 | func (udb userDb) Lookup(id string) (User, error) { 67 | for _, u := range udb { 68 | if id == u.Id { 69 | return u, nil 70 | } 71 | } 72 | return User{}, fmt.Errorf("no such user %q", id) 73 | } 74 | 75 | func ShowLandingPage(w http.ResponseWriter, u *User) { 76 | fmt.Fprintln(w, "") 77 | if u == nil { 78 | fmt.Fprint(w, "Hello unknown person!") 79 | fmt.Fprintf(w, " [profile will fail and log an error]") 80 | } else { 81 | fmt.Fprintf(w, "Welcome back, %s!", u.Name) 82 | fmt.Fprintf(w, " [profile]") 83 | } 84 | fmt.Fprintln(w, `


    85 | Login
    86 |
    87 |
    88 | 89 |
    90 |
    91 | Try logging in with:
      92 |
    • "alice" will authenticate to Alice 93 |
    • "bob" will authenticate to Bob but panic during request handling 94 |
    • any other string for a non-authenticated user. 95 |
    96 | `) 97 | } 98 | 99 | func ShowUserProfile(w http.ResponseWriter, u User) { 100 | fmt.Fprintln(w, "") 101 | fmt.Fprintf(w, "Id: %s
    Name: %s
    Email:%s", u.Id, u.Name, u.Email) 102 | 103 | // Show an example of user-code panicking in a handler. 104 | if u.Id == "bob" { 105 | panic("oops") 106 | } 107 | } 108 | 109 | func Login(w http.ResponseWriter, r *http.Request, udb UserDb, e *sandwich.LogEntry) { 110 | u, err := udb.Lookup(r.FormValue("id")) 111 | if err != nil { 112 | log.Printf("No such user id: %q", r.FormValue("id")) 113 | http.SetCookie(w, &http.Cookie{Name: "auth", Value: "", Expires: time.Now()}) 114 | // Redirect to / 115 | fmt.Fprintf(w, ` 116 | `) 117 | return 118 | } 119 | 120 | e.Note["userId"] = u.Id 121 | http.SetCookie(w, &http.Cookie{ 122 | Name: "auth", 123 | Value: u.Id, // Encrypt cookie here, maybe include the whole user struct. 124 | Expires: time.Now().Add(time.Hour), 125 | MaxAge: int(time.Hour / time.Second), 126 | HttpOnly: true, 127 | }) 128 | // Redirect to /user/profile 129 | fmt.Fprintf(w, ` 130 | `) 131 | } 132 | 133 | func FailIfNotAuthenticated(u *User) (User, error) { 134 | if u == nil { 135 | return User{}, sandwich.Error{ 136 | Code: http.StatusUnauthorized, 137 | ClientMsg: "Not logged in", 138 | LogMsg: "Unauthorized access attempt", 139 | } 140 | } 141 | return *u, nil 142 | } 143 | 144 | func getAndParseCookie(r *http.Request) (string, error) { 145 | c, err := r.Cookie("auth") 146 | if err != nil { 147 | return "", err 148 | } 149 | userid := c.Value // Decrypt cookie here, maybe getting a whole user struct. 150 | return userid, nil 151 | } 152 | 153 | func ParseUserIfLoggedIn(r *http.Request, udb UserDb, e *sandwich.LogEntry) (*User, error) { 154 | if user_id, err := getAndParseCookie(r); err != nil { 155 | return nil, nil // not logged in or expired or corrupt. Ignore cookie. 156 | } else if user, err := udb.Lookup(user_id); err != nil { 157 | log.Printf("No such user: %q", user_id) 158 | return nil, nil // no such user 159 | } else { 160 | e.Note["userId"] = user.Id 161 | return &user, nil // Hello logged-in user! 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /chain/naming.go: -------------------------------------------------------------------------------- 1 | package chain 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "path/filepath" 7 | "reflect" 8 | "regexp" 9 | "strings" 10 | "unicode" 11 | ) 12 | 13 | type nameMapper struct { 14 | typToName map[reflect.Type]string 15 | used map[string]bool 16 | } 17 | 18 | func (n *nameMapper) Has(t reflect.Type) bool { 19 | _, exists := n.typToName[t] 20 | return exists 21 | } 22 | 23 | // Don't ever use these as variable names: keywords + primitive type names var 24 | var disallowed = map[string]bool{ 25 | // keywords 26 | "break": true, "default": true, "func": true, "interface": true, 27 | "select": true, "case": true, "defer": true, "go": true, "map": true, 28 | "struct": true, "chan": true, "else": true, "goto": true, "package": true, 29 | "switch": true, "const": true, "fallthrough": true, "if": true, 30 | "range": true, "type": true, "continue": true, "for": true, "import": true, 31 | "return": true, "var": true, 32 | // pre-declared identifiers: 33 | "bool": true, "byte": true, "complex64": true, "complex128": true, 34 | "error": true, "float32": true, "float64": true, "int": true, "int8": true, 35 | "int16": true, "int32": true, "int64": true, "rune": true, "string": true, 36 | "uint": true, "uint8": true, "uint16": true, "uint32": true, "uint64": true, 37 | "uintptr": true, "true": true, "false": true, "iota": true, "nil": true, 38 | "append": true, "cap": true, "close": true, "complex": true, "copy": true, 39 | "delete": true, "imag": true, "len": true, "make": true, "new": true, 40 | "panic": true, "print": true, "println": true, "real": true, 41 | "recover": true, 42 | } 43 | 44 | func (n *nameMapper) Reserve(names ...string) { 45 | if n.typToName == nil { 46 | n.typToName = map[reflect.Type]string{} 47 | n.used = map[string]bool{} 48 | } 49 | for _, name := range names { 50 | n.used[name] = true 51 | } 52 | } 53 | 54 | func (n *nameMapper) For(t reflect.Type) string { 55 | if name, exists := n.typToName[t]; exists { 56 | return name 57 | } 58 | if n.typToName == nil { 59 | n.Reserve() 60 | } 61 | for _, name := range n.options(t) { 62 | if !disallowed[name] && !n.used[name] { 63 | n.used[name] = true 64 | n.typToName[t] = name 65 | return name 66 | } 67 | } 68 | // This should never happen: The final option should be a completely unique 69 | // name using the full package name and type name. 70 | panic(fmt.Errorf("Could not come up with a unique variable name for %s. "+ 71 | "Used names are %v.\nTyp2Name: %q", t, n.used, n.typToName)) 72 | } 73 | 74 | func extractCaps(s string) string { 75 | caps := "" 76 | for i, r := range s { 77 | if r == '_' { 78 | return "" // If there are any underscores in the name, give up on extracting caps. 79 | } 80 | if !(unicode.IsLetter(r) || unicode.IsNumber(r)) { 81 | continue 82 | } 83 | if i == 0 || unicode.IsUpper(r) || unicode.IsNumber(r) { 84 | caps += string(r) 85 | } 86 | } 87 | if len(caps) == 0 { 88 | caps += string(s[0]) 89 | } 90 | return strings.ToLower(caps) 91 | } 92 | 93 | func upperFirstLetter(s string) string { 94 | for i, r := range s { 95 | return string(unicode.ToUpper(r)) + s[i+1:] 96 | } 97 | return "" 98 | } 99 | 100 | func lowerFirstLetter(s string) string { 101 | for i, r := range s { 102 | return string(unicode.ToLower(r)) + s[i+1:] 103 | } 104 | return "" 105 | } 106 | 107 | func ptrPrefix(t reflect.Type) (string, reflect.Type) { 108 | var s = "" 109 | for t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice { 110 | if t.Kind() == reflect.Ptr { 111 | s = "p" + s 112 | } else if t.Kind() == reflect.Slice { 113 | s = "sliceOf" + s 114 | } 115 | t = t.Elem() 116 | } 117 | return s, t 118 | } 119 | func assemble(prefix, name string) string { 120 | if prefix == "" { 121 | return name 122 | } 123 | return prefix + upperFirstLetter(name) 124 | } 125 | 126 | func pkgNamePrefix(pkg string) string { 127 | if pkg == "" || pkg == "." { 128 | return "" 129 | } 130 | pkg = strings.ToLower(strings.Replace(pkg, ".", "_", -1)) 131 | pkg = strings.Trim(pkg, "_") 132 | return pkg + "_" 133 | } 134 | 135 | func cleanTypeName(name string) string { 136 | name = strings.Replace(name, "chan ", "chan_", -1) 137 | name = strings.Replace(name, " ", "_", -1) 138 | name = strings.Map(func(r rune) rune { 139 | if !unicode.IsLetter(r) && !unicode.IsNumber(r) { 140 | return '_' 141 | } 142 | return r 143 | }, name) 144 | name = regexp.MustCompile("_{2,}").ReplaceAllLiteralString(name, "_") 145 | name = strings.Trim(name, "_") 146 | return name 147 | } 148 | 149 | func (n nameMapper) options(t reflect.Type) []string { 150 | options := wellKnownTypesAndCommonNames[t] 151 | prefix, t := ptrPrefix(t) 152 | short_pkg := pkgNamePrefix(filepath.Base(t.PkgPath())) 153 | full_pkg := pkgNamePrefix(t.PkgPath()) 154 | 155 | name := cleanTypeName(t.Name()) 156 | if name == "" { 157 | name = cleanTypeName(t.String()) 158 | short_pkg = "" 159 | full_pkg = "" 160 | } 161 | if name != "" { 162 | lname := lowerFirstLetter(name) 163 | options = append(options, assemble(prefix, lname)) 164 | options = append(options, assemble(short_pkg+prefix, lname)) 165 | options = append(options, assemble(prefix, lname)) 166 | caps := extractCaps(name) 167 | if caps != "" { 168 | options = append(options, assemble(prefix, caps)) 169 | options = append(options, assemble(prefix, string(caps[0]))) 170 | } 171 | 172 | options = append(options, assemble(full_pkg+prefix, lname)) 173 | } 174 | // uniqueVarNameForType should always make something completely unique, but 175 | // it's a bit verbose. 176 | // options = append(options, uniqueVarNameForType(t)) 177 | // As a completely paranoid option, this should absolutely positively make 178 | // a unique var name: 179 | options = append(options, fmt.Sprintf("__var%d__", len(n.used))) 180 | return options 181 | } 182 | 183 | var wellKnownTypesAndCommonNames = map[reflect.Type][]string{ 184 | reflect.TypeOf((*http.ResponseWriter)(nil)).Elem(): {"rw", "w"}, 185 | reflect.TypeOf((*http.Request)(nil)): {"req", "r"}, 186 | reflect.TypeOf(""): {"str"}, 187 | reflect.TypeOf(false): {"flag"}, 188 | errorType: {"err"}, 189 | } 190 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package sandwich is a middleware framework for go that lets you write 2 | // testable web servers. 3 | // 4 | // Sandwich allows writing robust middleware handlers that are easily tested: 5 | // - Avoid globals, instead propagate per-request state automatically from 6 | // one handler to the next. 7 | // - Write your handlers to accept the parameters they need rather than 8 | // type-asserting from an untyped per-request context. 9 | // - Abort request handling by returning an error. 10 | // 11 | // Sandwich is provides a basic PAT-style router. 12 | // 13 | // # Example 14 | // 15 | // Here's a simple complete program using sandwich: 16 | // 17 | // package main 18 | // 19 | // import ( 20 | // "fmt" 21 | // "log" 22 | // "net/http" 23 | // 24 | // "github.com/augustoroman/sandwich" 25 | // ) 26 | // 27 | // func main() { 28 | // mux := sandwich.TheUsual() 29 | // mux.Get("/", func(w http.ResponseWriter) { 30 | // fmt.Fprintf(w, "Hello world!") 31 | // }) 32 | // if err := http.ListenAndServe(":6060", mux); err != nil { 33 | // log.Fatal(err) 34 | // } 35 | // } 36 | // 37 | // # Providing 38 | // 39 | // Sandwich automatically calls your middleware with the necessary arguments to 40 | // run them based on the types they require. These types can be provided by 41 | // previous middleware or directly during the initial setup. 42 | // 43 | // For example, you can use this to provide your database to all handlers: 44 | // 45 | // func main() { 46 | // db_conn := ConnectToDatabase(...) 47 | // mux := sandwich.TheUsual() 48 | // mux.Set(db_conn) 49 | // mux.Get("/", Home) 50 | // } 51 | // 52 | // func Home(w http.ResponseWriter, r *http.Request, db_conn *Database) { 53 | // // process the request here, using the provided db_conn 54 | // } 55 | // 56 | // Set(...) and SetAs(...) are excellent alternatives to using global values, 57 | // plus they keep your functions easy to test! 58 | // 59 | // # Handlers 60 | // 61 | // In many cases you want to initialize a value based on the request, for 62 | // example extracting the user login: 63 | // 64 | // func main() { 65 | // mux := sandwich.TheUsual() 66 | // mux.Get("/", ParseUserCookie, SayHi) 67 | // } 68 | // // You can write & test exactly this signature: 69 | // func ParseUserCookie(r *http.Request) (User, error) { ... } 70 | // // Then write your handler assuming User is available: 71 | // func SayHi(w http.ResponseWriter, u User) { 72 | // fmt.Fprintf(w, "Hello %s", u.Name) 73 | // } 74 | // 75 | // This starts to show off the real power of sandwich. For each request, the 76 | // following occurs: 77 | // - First ParseUserCookie is called. If it returns a non-nil error, 78 | // sandwich's HandleError is called the request is aborted. If the error 79 | // is nil, processing continues. 80 | // - Next SayHi is called with the User value returned from ParseUserCookie. 81 | // 82 | // This allows you to write small, independently testable functions and let 83 | // sandwich chain them together for you. Sandwich works hard to ensure that you 84 | // don't get annoying run-time errors: it's structured such that it must always 85 | // be possible to call your functions when the middleware is initialized rather 86 | // than when the http handler is being executed, so you don't get surprised 87 | // while your server is running. 88 | // 89 | // # Error Handlers 90 | // 91 | // When a handler returns an error, sandwich aborts the middleware chain and 92 | // looks for the most recently registered error handler and calls that. Error 93 | // handlers may accept any types that have been provided so far in the 94 | // middleware stack as well as the error type. They must not have any return 95 | // values. 96 | // 97 | // # Wrapping Handlers 98 | // 99 | // Sandwich also allows registering handlers to run during AND after the 100 | // middleware (and error handling) stack has completed. This is especially 101 | // useful for handles such as logging or gzip wrappers. Once the before handle 102 | // is run, the 'after' handlers are queued to run and will be run regardless of 103 | // whether an error aborts any subsequent middleware handlers. 104 | // 105 | // Typically this is done with the first function creating and initializing some 106 | // state to pass to the deferred handler. For example, the logging handlers are: 107 | // 108 | // // StartLog creates a *LogEntry and initializes it with basic request 109 | // // information. 110 | // func NewLogEntry(r *http.Request) *LogEntry { 111 | // return &LogEntry{Start: time.Now(), ...} 112 | // } 113 | // 114 | // // Commit fills in the remaining *LogEntry fields and writes the entry out. 115 | // func (entry *LogEntry) Commit(w *ResponseWriter) { 116 | // entry.Elapsed = time.Since(entry.Start) 117 | // ... 118 | // WriteLog(*entry) 119 | // } 120 | // 121 | // and are added to the chain using: 122 | // 123 | // var LogRequests = Wrap{NewLogEntry, (*LogEntry).Commit} 124 | // 125 | // In this case, the `Wrap` executes NewLogEntry during middleware processing 126 | // that returns a *LogEntry which is provided to downstream handlers, including 127 | // the deferred Commit handler -- in this case a method expression 128 | // (https://golang.org/ref/spec#Method_expressions) that takes the *LogEntry as 129 | // its value receiver. 130 | // 131 | // # Providing Interfaces 132 | // 133 | // Unfortunately, providing interfaces is a little tricky. Since interfaces in 134 | // Go are only used for static typing, the encapsulation isn't passed to 135 | // functions that accept interface{}, like Set(). 136 | // 137 | // This means that if you have an interface and a concrete implementation, such 138 | // as: 139 | // 140 | // type UserDatabase interface{ 141 | // GetUserProfile(u User) (Profile, error) 142 | // } 143 | // type userDbImpl struct { ... } 144 | // func (u *userDbImpl) GetUserProfile(u User) (Profile, error) { ... } 145 | // 146 | // You cannot provide this to handlers directly via the Set() call. 147 | // 148 | // udb := &userDbImpl{...} 149 | // // DOESN'T WORK: this will provide *userDbImpl, not UserDatabase 150 | // mux.Set(udb) 151 | // // STILL DOESN'T WORK 152 | // mux.Set((UserDatabase)(udb)) 153 | // // *STILL* DOESN'T WORK 154 | // udb_iface := UserDatabase(udb) 155 | // mux.Set(&udb_iface) 156 | // 157 | // Instead, you have to either use SetAs() or a dedicated middleware function: 158 | // 159 | // udb := &userDbImpl{...} 160 | // mux.SetAs(udb, (*UserDatabase)(nil)) // either use SetAs() with a pointer to the interface 161 | // mux.Use(func() UserDatabase { return udb }) // or add a handler that returns the interface 162 | // 163 | // It's a bit silly, but that's how it is. 164 | package sandwich 165 | -------------------------------------------------------------------------------- /router_test.go: -------------------------------------------------------------------------------- 1 | package sandwich 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | // Shorthand for keeping tests concise below 16 | type M = Params 17 | 18 | func TestMuxRegisterAndMatch(t *testing.T) { 19 | const REGISTRATION_ERROR = "•ERR:" 20 | fail := func(reason string) string { 21 | return REGISTRATION_ERROR + reason 22 | } 23 | split := func(combined_pattern string) (pattern, errmsg string) { 24 | pos := strings.Index(combined_pattern, REGISTRATION_ERROR) 25 | if pos == -1 { 26 | return combined_pattern, "" 27 | } 28 | return combined_pattern[:pos], combined_pattern[pos+len(REGISTRATION_ERROR):] 29 | } 30 | patterns := []string{ 31 | "/", 32 | "/a", 33 | "/a" + fail("repeated entry"), 34 | "/a/:x/:x" + fail("repeated param name"), 35 | "/a/", 36 | "/a/b", 37 | "/a/b/c", 38 | "/a/b/c" + fail("repeated entry"), 39 | "/a/b/c/d/e", // NOTE: /a/b/c/d not registered 40 | "/a/:x/c", 41 | "/a/:x/c" + fail("repeated entry"), 42 | "/a/:y/c" + fail("ambiguous param var"), 43 | "/a/:y/c2", 44 | "/a/:m*", 45 | "/a/:m*/", 46 | "/b/:a*/x", 47 | "/b/:b*/y", 48 | "/b/:b*/x" + fail("ambiguous greedy pattern"), 49 | "/c/:x/y", 50 | "/c/:x*/y" + fail("ambiguous param (greedy or not)"), 51 | "/:m*/b/c", 52 | "/:m*/:x/c", 53 | "/:m*/:x*/c" + fail("multiple greedy patterns"), 54 | "/:x*/b/c" + fail("ambiguous greedy var"), 55 | "/x/:x*/y/:y/z/:z*/blah" + fail("multiple greedy patterns"), 56 | 57 | // literal colon in static URL 58 | "/a/::x", 59 | "/a/::x/c", 60 | } 61 | 62 | var m mux 63 | 64 | for _, combo_pattern := range patterns { 65 | pattern, errmsg := split(combo_pattern) 66 | err := m.Register(pattern, noopHandler(pattern)) 67 | if errmsg == "" { 68 | require.NoError(t, err) 69 | } else { 70 | require.Error(t, err, "Pattern %#q should have failed: %s", pattern, errmsg) 71 | } 72 | } 73 | 74 | // priority: 75 | // - static routes 76 | // - explicit parameter 77 | // - greedy parameter 78 | 79 | testCases := []struct { 80 | uri string 81 | expectedHandler noopHandler 82 | expectedParams M 83 | }{ 84 | {"/", "/", M{}}, 85 | {"/a", "/a", M{}}, 86 | {"/a/", "/a/", M{}}, 87 | {"/a/b", "/a/b", M{}}, 88 | {"/a/b/c", "/a/b/c", M{}}, 89 | {"/a/b/c/d/e", "/a/b/c/d/e", M{}}, 90 | {"/a/b/c/d", "/a/:m*", M{"m": "b/c/d"}}, 91 | 92 | {"/a/foobar/c", "/a/:x/c", M{"x": "foobar"}}, 93 | {"/a/foobar/c2", "/a/:y/c2", M{"y": "foobar"}}, 94 | 95 | {"/a/foobar/blah", "/a/:m*", M{"m": "foobar/blah"}}, 96 | {"/a/foobar/blah/", "/a/:m*/", M{"m": "foobar/blah"}}, 97 | 98 | {"/b/mm/nn/", "", nil}, 99 | {"/b/mm/nn/x", "/b/:a*/x", M{"a": "mm/nn"}}, 100 | {"/b/mm/nn/y", "/b/:b*/y", M{"b": "mm/nn"}}, 101 | 102 | {"/c/x/y", "/c/:x/y", M{"x": "x"}}, 103 | 104 | {"/b/x/y/b/c", "/:m*/b/c", M{"m": "b/x/y"}}, 105 | {"/b/x/y/bo/c", "/:m*/:x/c", M{"m": "b/x/y", "x": "bo"}}, 106 | } 107 | 108 | for _, test := range testCases { 109 | t.Run(fmt.Sprintf("%s -> %s", test.uri, test.expectedHandler), func(t *testing.T) { 110 | if test.expectedHandler == "" { 111 | t.Logf("Testing input uri %#q --> should not match any pattern", 112 | test.uri) 113 | } else { 114 | t.Logf("Testing input uri %#q --> should match pattern %#q", 115 | test.uri, test.expectedHandler) 116 | } 117 | params := Params{} 118 | selected := m.Match(test.uri, params) 119 | if test.expectedHandler == "" { 120 | assert.Nil(t, selected, "should not match any pattern") 121 | assert.Empty(t, params) 122 | } else { 123 | require.NotNil(t, selected) 124 | assert.Equal(t, test.expectedHandler, selected) 125 | assert.Equal(t, test.expectedParams, params) 126 | } 127 | }) 128 | } 129 | } 130 | 131 | type noopHandler string 132 | 133 | func (h noopHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, p Params) {} 134 | 135 | func TestRouter(t *testing.T) { 136 | r := TheUsual() 137 | 138 | type UserID string 139 | type User string 140 | type UserDB map[UserID]User 141 | 142 | theUserDB := UserDB{"1": "bob", "2": "alice"} 143 | r.Set(theUserDB) 144 | 145 | loadUser := func(db UserDB, p Params) (User, UserID, error) { 146 | if uid := UserID(p["userID"]); uid == "" { 147 | return "", "", Error{Code: 400, ClientMsg: "Must specify user ID"} 148 | } else if u := db[uid]; u == "" { 149 | return "", uid, Error{Code: 404, ClientMsg: "No such user"} 150 | } else { 151 | return u, uid, nil 152 | } 153 | } 154 | newUserFromRequest := func(r *http.Request) (UserID, User, error) { 155 | uid := UserID(r.FormValue("uid")) 156 | user := User(r.FormValue("name")) 157 | if uid == "" { 158 | return "", "", errors.New("missing user id") 159 | } else if user == "" { 160 | return "", "", errors.New("missing user info") 161 | } 162 | return uid, user, nil 163 | } 164 | 165 | r.Get("/user/:userID", loadUser, 166 | func(w http.ResponseWriter, u User) { 167 | fmt.Fprintf(w, "Hi user %#q", u) 168 | }, 169 | ) 170 | r.Post("/user/", newUserFromRequest, 171 | func(db UserDB, uid UserID, u User) { db[uid] = u }, 172 | func(w http.ResponseWriter, uid UserID, u User) { 173 | fmt.Fprintf(w, "Made user %#q = %#q", uid, u) 174 | }, 175 | ) 176 | r.Any("/user/:userID/:cmd*", loadUser, 177 | func(w http.ResponseWriter, r *http.Request, p Params, u User) { 178 | fmt.Fprintf(w, "Doing %#q (%s) to user %#q", r.Method, p["cmd"], u) 179 | }, 180 | ) 181 | 182 | w := httptest.NewRecorder() 183 | r.ServeHTTP(w, httptest.NewRequest("GET", "/user/1", nil)) 184 | assert.Equal(t, http.StatusOK, w.Result().StatusCode) 185 | assert.Equal(t, "Hi user `bob`", w.Body.String()) 186 | 187 | w = httptest.NewRecorder() 188 | r.ServeHTTP(w, httptest.NewRequest("GET", "/user/2", nil)) 189 | assert.Equal(t, http.StatusOK, w.Result().StatusCode) 190 | assert.Equal(t, "Hi user `alice`", w.Body.String()) 191 | 192 | w = httptest.NewRecorder() 193 | r.ServeHTTP(w, httptest.NewRequest("GET", "/user/3", nil)) 194 | assert.Equal(t, http.StatusNotFound, w.Result().StatusCode) 195 | assert.Equal(t, "No such user\n", w.Body.String()) 196 | 197 | w = httptest.NewRecorder() 198 | r.ServeHTTP(w, httptest.NewRequest("POST", "/user/?uid=3&name=sid", nil)) 199 | assert.Equal(t, http.StatusOK, w.Result().StatusCode, "Response: %s", w.Body.String()) 200 | 201 | w = httptest.NewRecorder() 202 | r.ServeHTTP(w, httptest.NewRequest("GET", "/user/3", nil)) 203 | assert.Equal(t, http.StatusOK, w.Result().StatusCode) 204 | assert.Equal(t, "Hi user `sid`", w.Body.String()) 205 | 206 | w = httptest.NewRecorder() 207 | r.ServeHTTP(w, httptest.NewRequest("EXPLODE", "/user/3/boom", nil)) 208 | assert.Equal(t, http.StatusOK, w.Result().StatusCode) 209 | assert.Equal(t, "Doing `EXPLODE` (boom) to user `sid`", w.Body.String()) 210 | } 211 | 212 | // func TestNodeMatch(t *testing.T) { 213 | // testCases := []struct { 214 | // path, pattern string 215 | // matches bool 216 | // expectedParams M 217 | // }{ 218 | // // static paths 219 | // {"/a/b/c", "/a/b/c", true, M{}}, 220 | // {"/a/b/c", "/x/b/c", false, nil}, 221 | // {"/a/b/c", "/a/b", false, nil}, 222 | // {"/a/b/c", "/a/b/c/d", false, nil}, 223 | 224 | // {"/a/b/c", "/a/b/:last", true, M{"last": "c"}}, 225 | // {"/a/b/c", "/a/:mid/c", true, M{"mid": "b"}}, 226 | // {"/a/b/c", "/:first/:mid/c", true, M{"first": "a", "mid": "b"}}, 227 | // {"/a/b/c", "/:first/:mid/:last", true, M{"first": "a", "mid": "b", "last": "c"}}, 228 | // {"/a/b/c", "/:first/:mid/:last/:missing", false, nil}, 229 | // {"/a/b/c", "/:first/:mid/:last/x", false, nil}, 230 | 231 | // {"/a/b/c/d/e/f/g", "/:path*", true, M{"path": "a/b/c/d/e/f/g"}}, 232 | // {"/a/b/c/d/e/f/g", "/a/:path*", true, M{"path": "b/c/d/e/f/g"}}, 233 | // {"/a/b/c/d/e/f/g", "/a/:path*/g", true, M{"path": "b/c/d/e/f"}}, 234 | // {"/a/b/c/d/e/f/g", "/a/:path*/f/g", true, M{"path": "b/c/d/e"}}, 235 | // {"/a/b/c/d/e/f/g", "/a/:path*/:last", true, M{"path": "b/c/d/e/f", "last": "g"}}, 236 | // {"/a/b/c/d/e/f/g", "/a/:first/:mid*/:last", true, M{"first": "b", "mid": "c/d/e/f", "last": "g"}}, 237 | // {"/a/b/c/d/e/f/g", "/a/:first*/:mid/:last", true, M{"first": "b/c/d/e", "mid": "f", "last": "g"}}, 238 | // {"/a/b/c/d/e/f/g", "/a/:first/:mid/:last*", true, M{"first": "b", "mid": "c", "last": "d/e/f/g"}}, 239 | // } 240 | 241 | // for i, test := range testCases { 242 | // t.Run(fmt.Sprintf("%d:%s", i, test.pattern), func(t *testing.T) { 243 | // t.Logf("Test %d: Path %#q Pattern %#q Should match: %v", 244 | // i, test.path, test.pattern, test.matches) 245 | // root, err := makeMatchNodes(test.pattern) 246 | // require.NoError(t, err) 247 | // pathSegments := strings.Split(test.path, "/") 248 | // params := Params{} 249 | // match := root.match(pathSegments[1:], params) 250 | // if test.matches { 251 | // require.NotNil(t, match) 252 | // assert.Equal(t, test.expectedParams, params) 253 | // } else { 254 | // assert.Nil(t, match) 255 | // } 256 | // }) 257 | // } 258 | // } 259 | -------------------------------------------------------------------------------- /examples/2-advanced/main.go: -------------------------------------------------------------------------------- 1 | // 2-advanced is a demo webserver for the sandwich middleware package 2 | // demonstrating advanced usage. 3 | // 4 | // This provides a sample, multi-user TODO list application that allows users 5 | // to sign in via a Google account or sign in using fake credentials. 6 | // This example demonstrates more advanced features of sandwich, including: 7 | // 8 | // - Providing interface types to the middleware chain. 9 | // TaskDb is the interface provided to the handlers, the actual value injected 10 | // in main() is a taskDbImpl. 11 | // - Using 3rd party middleware (go.auth, go.rice) 12 | // - Using a 3rd party router (gorilla/mux) 13 | // - Using multiple error handlers, and custom error handlers. 14 | // Most web servers will want to server a custom HTML error page for user-facing 15 | // error pages. An example of that is included here. For AJAX calls, however, 16 | // .... 17 | // - Early exit of the middleware chain via the sandwich.Done error 18 | // - Auto-generating handler code 19 | package main 20 | 21 | import ( 22 | "embed" 23 | "encoding/json" 24 | "fmt" 25 | "html/template" 26 | "io/fs" 27 | "log" 28 | "net/http" 29 | "os" 30 | "time" 31 | 32 | "github.com/augustoroman/sandwich" 33 | auth "github.com/bradrydzewski/go.auth" 34 | ) 35 | 36 | //go:embed static 37 | var static embed.FS 38 | 39 | func main() { 40 | // Read in configuration: 41 | var config struct { 42 | Host string `json:"host"` 43 | Port int `json:"port"` 44 | CookieSecret string `json:"cookie-secret"` 45 | ClientId string `json:"oauth2-client-id"` 46 | ClientSecret string `json:"oauth2-client-secret"` 47 | } 48 | failOnError(readJsonFile("config.json", &config)) 49 | 50 | // Setup Oauth login framework: 51 | auth.Config.LoginRedirect = "/auth/login" // send user here to login 52 | auth.Config.LoginSuccessRedirect = "/" // send user here post-login 53 | auth.Config.CookieSecure = false // for local-testing only 54 | auth.Config.CookieSecret = []byte(config.CookieSecret) 55 | // This must match the authorized URLs entered in the google cloud api console. 56 | redirectUrl := fmt.Sprintf("http://%s/auth/google/callback", config.Host) 57 | authHandler := auth.Google(config.ClientId, config.ClientSecret, redirectUrl) 58 | 59 | // Setup task database: 60 | taskDb := taskDbImpl{} 61 | 62 | // Load our templates. 63 | tpl := template.Must(template.ParseFS(static, "static/*.tpl.html")) 64 | 65 | // Start setting up our server: 66 | mux := sandwich.TheUsual() 67 | mux.Use(ParseUserCookie, LogUser) 68 | mux.SetAs(taskDb, (*TaskDb)(nil)) 69 | mux.Set(tpl) 70 | mux.OnErr(CustomErrorPage) 71 | 72 | // Don't log these requests since we don't have a favicon, it's just a 73 | // bunch of 404 spam. 74 | mux.Get("/favicon.ico", sandwich.NoLog, NotFound) 75 | 76 | // When login is called, we'll FIRST call our very own CheckForFakeLogin 77 | // handler. If we detect the fake login form params, we'll process that 78 | // and then abort the middleware chain. 79 | // However, if we don't have fake parameters, we'll continue on and let 80 | // the authHandler take care of things. 81 | mux.Any("/auth/login", CheckForFakeLogin, authHandler) 82 | mux.Any("/auth/google/callback", authHandler) 83 | // Note that we can use auth.DeleteUserCookie directly. 84 | mux.Any("/auth/logout", auth.DeleteUserCookie, 85 | http.RedirectHandler("/", http.StatusTemporaryRedirect)) 86 | 87 | // Static file handling. The s.Then(...) wrapper isn't strictly necessary, 88 | // but it gives us logging (and potentially gzip or other middleware). 89 | // static := http.StripPrefix("/static", http.FileServer(http.FS( 90 | // mustSubFS(static, "static")))) 91 | mux.Get("/static/:path*", sandwich.ServeFS(static, "static", "path")) 92 | 93 | // OK, here are the core handlers: 94 | mux.Get("/", Home) 95 | // All API calls will use the api middleware that responds with JSON for 96 | // errors and requires users to be logged in. 97 | api := mux.SubRouter("/api/") 98 | api.OnErr(sandwich.HandleErrorJson) 99 | api.Use(RequireLoggedIn) 100 | api.Post("/task", TaskFromAddRequest, TaskDb.Add, SendTaskAsJson) 101 | api.Post("/task/:id", TaskOpFromUpdateRequest, UpdateTask) 102 | 103 | // Catch all remaining URLs and respond with not-found errors. We 104 | // explicitly use the error-return mechanism so that we get the JSON 105 | // response under /api/ and normal HTML responses elsewhere. 106 | api.Any("/:*", NotFound) 107 | mux.Any("/:*", NotFound) 108 | 109 | // Otherwise, start serving! 110 | addr := fmt.Sprintf("localhost:%d", config.Port) 111 | log.Printf("Server listening on http://%s", addr) 112 | failOnError(http.ListenAndServe(addr, mux)) 113 | } 114 | 115 | // ============================================================================ 116 | // Database 117 | 118 | type UserId string 119 | type TaskDb interface { 120 | List(UserId) ([]Task, error) 121 | Add(UserId, *Task) error 122 | Update(UserId, Task) error 123 | } 124 | type Task struct { 125 | Id string `json:"id"` 126 | Desc string `json:"desc"` 127 | Done bool `json:"done"` 128 | } 129 | 130 | type taskDbImpl map[UserId][]Task 131 | 132 | func (db taskDbImpl) List(u UserId) ([]Task, error) { return db[u], nil } 133 | func (db taskDbImpl) Add(u UserId, t *Task) error { 134 | t.Id = fmt.Sprint(time.Now().UnixNano()) 135 | db[u] = append(db[u], *t) 136 | return nil 137 | } 138 | func (db taskDbImpl) Update(u UserId, t Task) error { 139 | tasks := db[u] 140 | for i, task := range tasks { 141 | if task.Id == t.Id { 142 | tasks[i] = t 143 | return nil 144 | } 145 | } 146 | return sandwich.Error{ 147 | Code: http.StatusBadRequest, 148 | ClientMsg: "No such task", 149 | Cause: fmt.Errorf("No such task: %q", t.Id), 150 | } 151 | } 152 | 153 | // ============================================================================ 154 | // Core Handlers 155 | 156 | func Home( 157 | w http.ResponseWriter, 158 | r *http.Request, 159 | uid UserId, 160 | u auth.User, 161 | db TaskDb, 162 | tpl *template.Template, 163 | ) error { 164 | if u == nil { 165 | // tpl := template.Must(template.New("").Parse( 166 | // rice.MustFindBox("static").MustString("landing-page.tpl.html"))) 167 | return tpl.ExecuteTemplate(w, "landing-page.tpl.html", nil) 168 | } 169 | tasks, err := db.List(uid) 170 | if err != nil { 171 | return err 172 | } 173 | // tpl := template.Must(template.New("").Parse( 174 | // rice.MustFindBox("static").MustString("home.tpl.html"))) 175 | return tpl.ExecuteTemplate(w, "home.tpl.html", map[string]interface{}{ 176 | "User": u, 177 | "Tasks": tasks, 178 | }) 179 | } 180 | 181 | func TaskFromAddRequest(r *http.Request) (*Task, error) { 182 | var t Task 183 | if err := json.NewDecoder(r.Body).Decode(&t); err != nil { 184 | return nil, sandwich.Error{Code: http.StatusBadRequest, Cause: err} 185 | } 186 | if t.Desc == "" { 187 | return nil, sandwich.Error{Code: http.StatusBadRequest, 188 | ClientMsg: "Please include a task description"} 189 | } 190 | return &t, nil 191 | } 192 | 193 | func SendTaskAsJson(w http.ResponseWriter, t *Task) error { 194 | return json.NewEncoder(w).Encode(map[string]interface{}{"task": t}) 195 | } 196 | 197 | type TaskOp struct { 198 | Toggle bool 199 | Id string 200 | } 201 | 202 | func TaskOpFromUpdateRequest(r *http.Request) (TaskOp, error) { 203 | var op TaskOp 204 | if err := json.NewDecoder(r.Body).Decode(&op); err != nil { 205 | return op, sandwich.Error{Code: http.StatusBadRequest, Cause: err} 206 | } 207 | if op.Id == "" { 208 | return op, sandwich.Error{ 209 | Code: http.StatusBadRequest, 210 | ClientMsg: "Invalid op: missing task id", 211 | } 212 | } 213 | return op, nil 214 | } 215 | 216 | func UpdateTask(w http.ResponseWriter, r *http.Request, uid UserId, op TaskOp, db TaskDb) error { 217 | tasks, err := db.List(uid) 218 | if err != nil { 219 | return err 220 | } 221 | var t Task 222 | for i := range tasks { 223 | if tasks[i].Id == op.Id { 224 | t = tasks[i] 225 | break 226 | } 227 | } 228 | if t.Id == "" { 229 | return sandwich.Error{ 230 | Code: http.StatusBadRequest, 231 | ClientMsg: "No such task id: " + op.Id, 232 | } 233 | } 234 | 235 | if op.Toggle { 236 | t.Done = !t.Done 237 | } 238 | 239 | if err := db.Update(uid, t); err != nil { 240 | return err 241 | } 242 | return json.NewEncoder(w).Encode(map[string]interface{}{"task": t}) 243 | } 244 | 245 | // This will get called for any error that occurs outside of the API calls. 246 | func CustomErrorPage( 247 | w http.ResponseWriter, 248 | r *http.Request, 249 | err error, 250 | tpl *template.Template, 251 | l *sandwich.LogEntry, 252 | ) { 253 | // Make sure we actually have a real error: 254 | if err == sandwich.Done { 255 | return 256 | } 257 | // Convert the error to a sandwich.Error that has an error code. 258 | e := sandwich.ToError(err) 259 | // Always log the error and error details. 260 | l.Error = e 261 | 262 | w.WriteHeader(e.Code) 263 | err = tpl.ExecuteTemplate(w, "error.tpl.html", map[string]interface{}{ 264 | "Error": e, 265 | }) 266 | 267 | // But... what if our fancy template rendering fails? At this point, we 268 | // fall back to the simplest possible thing: http.Error(...). Maybe it'll 269 | // work, but we'll also log the error so it doesn't disappear. 270 | if err != nil { 271 | // Try putting a typo in the template name above, and you'll see this: 272 | l.Error = fmt.Errorf("Failed to render error page: %v\nTriggering error: %v", 273 | err, e) 274 | http.Error(w, "Internal server error", http.StatusInternalServerError) 275 | } 276 | } 277 | 278 | func CheckForFakeLogin(w http.ResponseWriter, r *http.Request) error { 279 | if r.FormValue("id") == "" { 280 | return nil 281 | } 282 | 283 | user := &auth.GoogleUser{ 284 | UserId: r.FormValue("id"), 285 | UserEmail: r.FormValue("email"), 286 | UserName: r.FormValue("name"), 287 | } 288 | auth.SetUserCookie(w, r, user) 289 | http.Redirect(w, r, "/", http.StatusTemporaryRedirect) 290 | 291 | // Great, everything is handled, so don't continue with the Google auth. 292 | return sandwich.Done 293 | } 294 | 295 | // ============================================================================ 296 | // Basic Handlers 297 | 298 | // You could also use .Then(http.NotFound), but that wouldn't go through the 299 | // error-handlers. The advantage of using the error handlers is that you 300 | // automatically get JSON vs HTML handling. 301 | func NotFound() error { return sandwich.Error{Code: http.StatusNotFound} } 302 | 303 | func RequireLoggedIn(u auth.User) error { 304 | if u == nil { 305 | return sandwich.Error{Code: http.StatusUnauthorized} 306 | } 307 | return nil 308 | } 309 | func ParseUserCookie(r *http.Request) (auth.User, UserId) { 310 | // Ignore errors. If the cookie is invalid or expired or corrupt or 311 | // missing, just consider the user not-logged-in. 312 | u, _ := auth.GetUserCookie(r) 313 | var uid UserId 314 | if u != nil { 315 | uid = UserId(u.Id()) 316 | } 317 | return u, uid 318 | } 319 | 320 | // Adds the current user to the per-request log notes, if logged in. 321 | func LogUser(u auth.User, e *sandwich.LogEntry) { 322 | if u != nil { 323 | e.Note["user"] = u.Email() 324 | e.Note["userId"] = u.Id() 325 | } else { 326 | e.Note["user"] = "" 327 | } 328 | } 329 | 330 | // ============================================================================ 331 | // Simple utilities 332 | 333 | func readJsonFile(filename string, dst interface{}) error { 334 | f, err := os.Open(filename) 335 | if err != nil { 336 | return err 337 | } 338 | defer f.Close() 339 | return json.NewDecoder(f).Decode(dst) 340 | } 341 | 342 | func failOnError(err error) { 343 | if err != nil { 344 | panic(err) 345 | } 346 | } 347 | 348 | func mustSubFS(base fs.FS, dir string) fs.FS { 349 | sub, err := fs.Sub(base, dir) 350 | if err != nil { 351 | panic(err) 352 | } 353 | return sub 354 | } 355 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Sandwich: Delicious HTTP Middleware [![Build Status](https://travis-ci.org/augustoroman/sandwich.svg?branch=master)](https://travis-ci.org/augustoroman/sandwich) [![Coverage](https://gocover.io/_badge/github.com/augustoroman/sandwich?1)](https://gocover.io/github.com/augustoroman/sandwich) [![Go Report Card](https://goreportcard.com/badge/github.com/augustoroman/sandwich)](https://goreportcard.com/report/github.com/augustoroman/sandwich) [![GoDoc](https://pkg.go.dev/badge/github.com/augustoroman/sandwich)](https://pkg.go.dev/github.com/augustoroman/sandwich) 4 | 5 | *Keep pilin' it on!* 6 | 7 | Sandwich is a middleware & routing framework that lets you write your handlers 8 | and middleware the way you want to and it takes care of tracking & validating 9 | dependencies. 10 | 11 | ## Features 12 | 13 | * Keeps middleware and handlers simple and testable. 14 | * Consolidates error handling. 15 | * Ensures that middleware dependencies are safely provided -- avoids unsafe 16 | casting from generic context objects. 17 | * Detects missing dependencies *during route construction* (before the server 18 | starts listening!), not when the route is actually called. 19 | * Provides clear and helpful error messages. 20 | * Compatible with the [http.Handler](https://pkg.go.dev/net/http#Handler) 21 | interface and lots of existing middleware. 22 | * Provides just a touch of magic: enough to make things easier, but not enough 23 | to induce a debugging nightmare. 24 | 25 | ## Getting started 26 | 27 | Here's a very simple example of using sandwich with the standard HTTP stack: 28 | 29 | ```go 30 | package main 31 | 32 | import ( 33 | "fmt" 34 | "log" 35 | "net/http" 36 | 37 | "github.com/augustoroman/sandwich" 38 | ) 39 | 40 | func main() { 41 | // Create a default sandwich middlware stack that includes logging and 42 | // a simple error handler. 43 | mux := sandwich.TheUsual() 44 | mux.Get("/", func(w http.ResponseWriter) { 45 | fmt.Fprintf(w, "Hello world!") 46 | }) 47 | if err := http.ListenAndServe(":6060", mux); err != nil { 48 | log.Fatal(err) 49 | } 50 | } 51 | ``` 52 | 53 | See the [examples directory](examples/) for: 54 | * A [hello-world sample](examples/0-helloworld) 55 | * A [basic usage](examples/1-simple) demo 56 | * A [TODO app](examples/2-advanced) showing advanced usage including custom 57 | error handling, embedded files, code generation, & login and authentication 58 | via oauth. 59 | 60 | ## Usage 61 | 62 | ### Providing 63 | 64 | Sandwich automatically calls your middleware with the necessary arguments to 65 | run them based on the types they require. These types can be provided by 66 | previous middleware or directly during the initial setup. 67 | 68 | For example, you can use this to provide your database to all handlers: 69 | 70 | ```go 71 | func main() { 72 | db_conn := ConnectToDatabase(...) 73 | mux := sandwich.TheUsual() 74 | mux.Set(db_conn) 75 | mux.Get("/", Home) 76 | } 77 | 78 | func Home(w http.ResponseWriter, r *http.Request, db_conn *Database) { 79 | // process the request here, using the provided db_conn 80 | } 81 | ``` 82 | 83 | Set(...) and SetAs(...) are excellent alternatives to using global 84 | values, plus they keep your functions easy to test! 85 | 86 | 87 | ### Handlers 88 | 89 | In many cases you want to initialize a value based on the request, for 90 | example extracting the user login: 91 | 92 | ```go 93 | func main() { 94 | mux := sandwich.TheUsual() 95 | mux.Get("/", ParseUserCookie, SayHi) 96 | } 97 | // You can write & test exactly this signature: 98 | func ParseUserCookie(r *http.Request) (User, error) { ... } 99 | // Then write your handler assuming User is available: 100 | func SayHi(w http.ResponseWriter, u User) { 101 | fmt.Fprintf(w, "Hello %s", u.Name) 102 | } 103 | ``` 104 | 105 | This starts to show off the real power of sandwich. For each request, the 106 | following occurs: 107 | 108 | * First `ParseUserCookie` is called. If it returns a non-nil error, 109 | sandwich's `HandleError` is called and the request is aborted. If the error 110 | is nil, processing continues. 111 | * Next `SayHi` is called with `User` returned from `ParseUserCookie`. 112 | 113 | This allows you to write small, independently testable functions and let 114 | sandwich chain them together for you. Sandwich works hard to ensure that 115 | you don't get annoying run-time errors: it's structured such that it must 116 | always be possible to call your functions when the middleware is initialized 117 | rather than when the http handler is being executed, so you don't get 118 | surprised while your server is running. 119 | 120 | 121 | ### Error Handlers 122 | 123 | When a handler returns an error, sandwich aborts the middleware chain and 124 | looks for the most recently registered error handler and calls that. 125 | Error handlers may accept any types that have been provided so far in the 126 | middleware stack as well as the error type. They must not have any return 127 | values. 128 | 129 | Here's an example of rendering errors with a custom error page: 130 | 131 | ```go 132 | type ErrorPageTemplate *template.Template 133 | func main() { 134 | tpl := template.Must(template.ParseFiles("path/to/my/error_page.tpl")) 135 | mux := sandwich.TheUsual() 136 | mux.Set(ErrorPageTemplate(tpl)) 137 | mux.OnErr(MyErrorHandler) 138 | ... 139 | } 140 | func MyErrorHandler(w http.ResponseWriter, t ErrorPageTemplate, l *sandwich.LogEntry, err error) { 141 | if err == sandwich.Done { // sandwich.Done can be returned to abort middleware. 142 | return // It indicates there was no actual error, so just return. 143 | } 144 | // Unwrap to a sandwich.Error that has Code, ClientMsg, and internal LogMsg. 145 | e := sandwich.ToError(err) 146 | // If there's an internal log message, add it to the request log. 147 | e.LogIfMsg(l) 148 | // Respond with my custom html error page, including the client-facing msg. 149 | w.WriteHeader(e.Code) 150 | t.Execute(w, map[string]string{Msg: e.ClientMsg}) 151 | } 152 | ``` 153 | 154 | Error handlers allow you consolidate the error handling of your web app. You 155 | can customize the error page, assign user-facing error codes, detect and fire 156 | alerts for certain errors, and control which errors get logged -- all in one 157 | place. 158 | 159 | By default, sandwich never sends internal error details to the client and 160 | insteads logs the details. 161 | 162 | ### Wrapping Handlers 163 | 164 | Sandwich also allows registering handlers to run during AND after the middleware 165 | (and error handling) stack has completed. This is especially useful for handles 166 | such as logging or gzip wrappers. Once the before handle is run, the 'after' 167 | handlers are queued to run and will be run regardless of whether an error aborts 168 | any subsequent middleware handlers. 169 | 170 | Typically this is done with the first function creating and initializing some 171 | state to pass to the deferred handler. For example, the logging handlers 172 | are: 173 | 174 | ```go 175 | // NewLogEntry creates a *LogEntry and initializes it with basic request 176 | // information. 177 | func NewLogEntry(r *http.Request) *LogEntry { 178 | return &LogEntry{Start: time.Now(), ...} 179 | } 180 | 181 | // Commit fills in the remaining *LogEntry fields and writes the entry out. 182 | func (entry *LogEntry) Commit(w *ResponseWriter) { 183 | entry.Elapsed = time.Since(entry.Start) 184 | ... 185 | WriteLog(*entry) 186 | } 187 | ``` 188 | 189 | and are added to the chain using: 190 | 191 | ```go 192 | var LogRequests = Wrap{NewLogEntry, (*LogEntry).Commit} 193 | ``` 194 | 195 | In this case, `NewLogEntry` returns a `*LogEntry` that is then provided to 196 | downstream handlers, including the deferred Commit handler -- in this case a 197 | [method expression](https://golang.org/ref/spec#Method_expressions) that takes 198 | the `*LogEntry` as its value receiver. 199 | 200 | 201 | ### Providing Interfaces 202 | 203 | Unfortunately, set interface values is a little tricky. Since interfaces in Go 204 | are only used for static typing, the encapsulation isn't passed to functions 205 | that accept interface{}, like Set(). 206 | 207 | This means that if you have an interface and a concrete implementation, such 208 | as: 209 | 210 | ```go 211 | type UserDatabase interface{ 212 | GetUserProfile(u User) (Profile, error) 213 | } 214 | type userDbImpl struct { ... } 215 | func (u *userDbImpl) GetUserProfile(u User) (Profile, error) { ... } 216 | ``` 217 | 218 | You cannot provide this to handlers directly via the Set() call. 219 | 220 | ```go 221 | udb := &userDbImpl{...} 222 | // DOESN'T WORK: this will provide *userDbImpl, not UserDatabase 223 | mux.Set(udb) 224 | mux.Set((UserDatabase)(udb)) // DOESN'T WORK EITHER 225 | udb_iface := UserDatabase(udb) 226 | mux.Set(&udb_iface) // STILL DOESN'T WORK! 227 | ``` 228 | 229 | Instead, you have to either use SetAs() or a dedicated middleware function that 230 | returns the interface: 231 | 232 | ```go 233 | udb := &userDbImpl{...} 234 | // either use SetAs() with a pointer to the interface 235 | mux.SetAs(udb, (*UserDatabase)(nil)) 236 | // or add a handler that returns the interface 237 | mux.Use(func() UserDatabase { return udb }) 238 | ``` 239 | 240 | It's a bit silly, but there you are. 241 | 242 | 243 | ## FAQ 244 | 245 | Sandwich uses reflection-based dependency-injection to call the middleware 246 | functions with the parameters they need. 247 | 248 | **Q: OMG reflection and dependency-injection, isn't that terrible and slow and 249 | non-idiomatic go?!** 250 | 251 | Whoa, nelly. Let's deal with those one at time, m'kay? 252 | 253 | **Q: Isn't reflection slow?** 254 | 255 | Not compared to everything else a webserver needs to do. 256 | 257 | Yes, sandwich's reflection-based dependency-injection code is slower than 258 | middleware code that directly calls functions, however **the vast majority of 259 | server code (especially during development) is not impacted by time spent 260 | calling a few functions, but rather by HTTP network I/O, request parsing, 261 | database I/O, response marshalling, etc.** 262 | 263 | **Q: Ok, but aren't both reflection and dependency-injection non-idiomatic Go?** 264 | 265 | Sorta. The use of reflection in and of itself isn't non-idiomatic, but the use 266 | of magical dependency injection is: Go eschews magic. 267 | 268 | However, one of the major goals of this library is to allow the HTTP handler 269 | code (and all middleware) to be really clean, idiomatic go functions that are 270 | testable by themselves. The idea is that the magic is small, contained, doesn't 271 | leak, and provides substantial benefit. 272 | 273 | **Q: But wait, don't you get annoying run-time "dependency-not-found" errors 274 | with dependency-injection?** 275 | 276 | While it's true that you can't get the same compile-time checking that you do 277 | with direct-call-based middleware, sandwich works really hard to ensure that you 278 | don't get surprises while running your server. 279 | 280 | At the time each middleware function is added to the stack, the library ensures 281 | that it's dependencies have been explicitly provided. One of the *features* of 282 | sandwich is that you can't arbitrary inject values -- they need to have an 283 | explicit provisioning source. 284 | 285 | **Q: Doesn't the http.Request.Context in go 1.7 solve the middleware dependency 286 | problem?** 287 | 288 | Have a request-scoped context allows you to pass values between middleware 289 | handlers, it's true. However, there's no guarantee that the values are 290 | available, so you get the same run-time bugs that you might get with a naive 291 | dependency-injection framework. In addition, you have to do type-assertions to 292 | get your values, so there's another possible source of bugs. One of the goals 293 | of sandwich is to avoid these two types of bugs. 294 | 295 | **Q: Why do I have to use _two_ functions (before & after) to wrap a request. 296 | Why can't I just have one with a next() function?** 297 | 298 | Many middleware frameworks provide the capability to wrap a request via a next() 299 | function. Sometimes it's part of a context object 300 | ([martini's Context.Next()](https://pkg.go.dev/github.com/go-martini/martini#Context), 301 | [gin's Context.Next()](https://pkg.go.dev/github.com/gin-gonic/gin#Context.Next)) 302 | and sometimes it's directly provided 303 | ([negroni's third handler arg](https://pkg.go.dev/github.com/urfave/negroni#HandlerFunc)). 304 | 305 | While implementing sandwich, I initially included a `next()` function until I 306 | realized it was impossible to validate the dependencies with such a function. 307 | Sandwich guarantees that dependencies can be supplied, and therefore `next()` 308 | had to go. 309 | 310 | Instead, I took a tip from go and instead implemented 311 | [defer](https://pkg.go.dev/github.com/augustoroman/sandwich/chain#Func.Defer). 312 | The wrap interface simply makes it obvious that there's a before and after. 313 | This allows me to keep my dependency guarantee. 314 | 315 | **Q: I don't know, it's still scary and terrible!** 316 | 317 | Don't get scared off. Take a look at the library, try it out, and I hope you 318 | enjoy it. If you don't, there are lots of great alternatives. 319 | -------------------------------------------------------------------------------- /chain/chain_test.go: -------------------------------------------------------------------------------- 1 | package chain 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func New() Func { return Func{} } 15 | 16 | func TestInitialInjection(t *testing.T) { 17 | var args []interface{} 18 | recordArgs := func(a int, b string) { args = append(args, a, b) } 19 | 20 | err := New(). 21 | Arg(0). 22 | Arg(""). 23 | Then(recordArgs). 24 | Set(3). 25 | Set("four"). 26 | Then(recordArgs). 27 | Run(1, "two") 28 | assert.NoError(t, err) 29 | 30 | assert.EqualValues(t, []interface{}{1, "two", 3, "four"}, args) 31 | } 32 | 33 | func TestInitialDeferredInjection(t *testing.T) { 34 | var args []interface{} 35 | recordArgs := func(a int, b string) { args = append(args, a, b) } 36 | 37 | err := New().Arg(0).Arg("").Then(recordArgs).Run(2, "xyz") 38 | assert.NoError(t, err) 39 | 40 | assert.EqualValues(t, []interface{}{2, "xyz"}, args) 41 | } 42 | 43 | func TestDeferredExecutionOrder(t *testing.T) { 44 | var buf bytes.Buffer 45 | say := func(s string) func() { return func() { buf.WriteString(s + ":") } } 46 | err := New(). 47 | Then(say("a"), say("b")). 48 | Defer(say("f")). 49 | Defer(say("e")). 50 | Then(say("c")). 51 | Defer(say("d")). 52 | Run() 53 | assert.NoError(t, err) 54 | assert.Equal(t, "a:b:c:d:e:f:", buf.String()) 55 | } 56 | 57 | func TestDeferredExecutionOrderWithErrors(t *testing.T) { 58 | var buf bytes.Buffer 59 | say := func(s string) func() { return func() { buf.WriteString(s + ":") } } 60 | onErr := func(e error) { buf.WriteString("err[" + e.Error() + "]:") } 61 | fail := func() error { return errors.New("failed") } 62 | err := New(). 63 | Then(say("a"), say("b")). 64 | Defer(say("f")). 65 | OnErr(onErr). 66 | Defer(say("e")). 67 | Then(fail). 68 | Then(say("c")). 69 | Defer(say("d")). 70 | Run() 71 | assert.NoError(t, err) 72 | assert.Equal(t, "a:b:err[failed]:e:f:", buf.String()) 73 | } 74 | 75 | func TestBasicFuncExecution(t *testing.T) { 76 | a, b := 5, "hi" 77 | provide_initial := func() (*int, *string) { return &a, &b } 78 | verify_injected := func(x *int, y *string) { 79 | if *x != 5 { 80 | t.Errorf("Expected *int to be 6, got %d", *x) 81 | } 82 | if *y != "hi" { 83 | t.Errorf("Expected *string to be 'hi', got %s", *y) 84 | } 85 | } 86 | modify_injected := func(x *int, y *string) { *x = 6; *y = "bye" } 87 | 88 | err := New().Then(provide_initial, verify_injected, modify_injected).Run() 89 | assert.NoError(t, err) 90 | 91 | assert.Equal(t, 6, a) 92 | assert.Equal(t, "bye", b) 93 | } 94 | 95 | func TestSimpleErrorHandling(t *testing.T) { 96 | var out string 97 | chain := New().Arg("").Arg(0). 98 | OnErr(func(err error) { out += "First error handler: " + err.Error() }). 99 | Then( 100 | func(val string) (string, error) { 101 | if val == "foo" { 102 | return "bar", nil 103 | } else { 104 | return "", fmt.Errorf("%q is not foo", val) 105 | } 106 | }). 107 | OnErr(func(err error) { out += "Second error handler: " + err.Error() }). 108 | Then( 109 | func(num int) error { 110 | if num != 3 { 111 | return fmt.Errorf("%d is not 3", num) 112 | } 113 | return nil 114 | }) 115 | 116 | out = "" 117 | assert.NoError(t, chain.Run("", 0)) 118 | assert.Equal(t, `First error handler: "" is not foo`, out) 119 | 120 | out = "" 121 | assert.NoError(t, chain.Run("foo", 7)) 122 | assert.Equal(t, `Second error handler: 7 is not 3`, out) 123 | 124 | out = "" 125 | assert.NoError(t, chain.Run("foo", 3)) 126 | assert.Equal(t, ``, out) 127 | } 128 | 129 | func TestMustProvideTypes(t *testing.T) { 130 | assert.Panics(t, func() { New().Then(func(string) {}) }) 131 | assert.NotPanics(t, func() { New().Set("").Then(func(string) {}) }) 132 | 133 | assert.Panics(t, func() { New().Then(func(int, string) {}) }) 134 | assert.NotPanics(t, func() { New().Set("").Set(3).Then(func(int, string) {}) }) 135 | 136 | assert.NotPanics(t, func() { 137 | New().Then( 138 | func() string { return "" }, 139 | func(string) int { return 3 }, 140 | func(int) {}, 141 | ) 142 | }, "Should be OK: Everything is provided by earlier functions.") 143 | assert.NotPanics(t, func() { 144 | New().Then( 145 | func() string { return "" }, 146 | func(string) int { return 3 }, 147 | ).OnErr(func(int, error) {}) 148 | }, "Should be OK: Everything is provided by earlier functions") 149 | 150 | assert.Panics(t, func() { 151 | New().Then( 152 | func() string { return "" }, 153 | func(string) int { return 3 }, 154 | func(bool) {}, 155 | func(int) {}, 156 | ) 157 | }, "Should FAIL: bool isn't provided anywhere") 158 | assert.Panics(t, func() { 159 | New().Then(func() string { return "" }, func(string) int { return 3 }). 160 | OnErr(func(bool, error) {}). 161 | Then(func(int) {}) 162 | }, "Should FAIL: bool isn't provided anywhere (even error handlers need proper provisioning)") 163 | } 164 | 165 | func TestErrorAbortsHandling(t *testing.T) { 166 | var out string 167 | err := New().OnErr(func(err error) { out += "Failed @ " + err.Error() }).Then( 168 | func() error { out += "1 "; return nil }, 169 | func() error { out += "2 "; return fmt.Errorf("2") }, 170 | func() error { out += "3 "; return nil }, 171 | ).Run() 172 | assert.NoError(t, err) 173 | assert.Equal(t, "1 2 Failed @ 2", out) 174 | } 175 | 176 | func a() string { return "hello " } 177 | func b(s string) (string, int) { return s + "world", 42 } 178 | func c(s string, n int) {} 179 | 180 | func TestCatchesPanics(t *testing.T) { 181 | var err error 182 | captureError := func(e error) { err = e } 183 | panics := func() { panic("ahhhh! 🔥") } 184 | 185 | assert.NoError(t, 186 | New().OnErr(captureError).Then(a, b, c).Defer(c).Then(panics).Run()) 187 | 188 | assert.NotNil(t, err) 189 | 190 | e := err.(PanicError) 191 | assert.Equal(t, e.Val, "ahhhh! 🔥") 192 | assert.Equal(t, len(e.MiddlewareStack), 4) // defers haven't run yet. 193 | assert.Contains(t, e.MiddlewareStack[0].Name, "chain.TestCatchesPanics.func2") 194 | assert.Contains(t, e.MiddlewareStack[1].Name, "chain.c") 195 | assert.Contains(t, e.MiddlewareStack[2].Name, "chain.b") 196 | assert.Contains(t, e.MiddlewareStack[3].Name, "chain.a") 197 | 198 | assert.Contains(t, err.Error(), "Panic executing middleware") 199 | assert.Contains(t, err.Error(), "ahhhh! 🔥") 200 | // This is where the panic actually occurred. This will need to be updated if 201 | // this file changes, sadly. 202 | assert.Contains(t, err.Error(), "/home/aroman/code/sandwich/chain/chain_test.go:183") 203 | assert.Contains(t, err.Error(), "func() string") 204 | assert.Contains(t, err.Error(), "func(string) (string, int)") 205 | assert.Contains(t, err.Error(), "func(string, int)") 206 | assert.Contains(t, err.Error(), "chain.a") 207 | assert.Contains(t, err.Error(), "chain.b") 208 | assert.Contains(t, err.Error(), "chain.c") 209 | } 210 | 211 | func TestDefersCanAcceptErrors(t *testing.T) { 212 | var buf bytes.Buffer 213 | onerr := func(err error) { fmt.Fprintf(&buf, "onerr[%v]:", err) } 214 | deferred := func(err error) { fmt.Fprintf(&buf, "defer[%v]:", err) } 215 | fails := func() error { return errors.New("💣") } 216 | 217 | assert.NoError(t, New(). 218 | OnErr(onerr). 219 | Then(a, b, c). 220 | Defer(deferred). 221 | Then(fails). 222 | Run()) 223 | 224 | assert.Equal(t, "onerr[💣]:defer[💣]:", buf.String()) 225 | 226 | // But what if nothing actually fails? Defer's can still accept errors. 227 | buf.Reset() 228 | assert.NoError(t, New(). 229 | OnErr(onerr). 230 | Then(a, b, c). 231 | Defer(deferred). 232 | // With(fails). // no failure! 233 | Run()) 234 | 235 | assert.Equal(t, "defer[]:", buf.String()) 236 | } 237 | 238 | func TestDefaultErrorHandler(t *testing.T) { 239 | var buf bytes.Buffer 240 | onerr := func(err error) { fmt.Fprintf(&buf, "onerr[%v]:", err) } 241 | fails := func() error { return errors.New("☠") } 242 | 243 | // Restore the default error handler when we're done with the test. 244 | defer func(orig interface{}) { DefaultErrorHandler = orig }(DefaultErrorHandler) 245 | DefaultErrorHandler = onerr 246 | 247 | assert.NoError(t, New().Then(fails).Run()) 248 | 249 | assert.Equal(t, "onerr[☠]:", buf.String()) 250 | } 251 | 252 | func TestSetAs_Nil(t *testing.T) { 253 | worked := false 254 | check := func(s fmt.Stringer) { 255 | require.Nil(t, s) 256 | worked = true 257 | } 258 | err := New().SetAs(nil, (*fmt.Stringer)(nil)).Then(check).Run() 259 | require.NoError(t, err) 260 | require.True(t, worked) 261 | } 262 | 263 | func TestProvidingBadValues(t *testing.T) { 264 | assert.Panics(t, func() { New().Set(nil) }) 265 | 266 | // ifacePtr must be a pointer to an interface 267 | assert.Panics(t, func() { New().SetAs(nil, 5) }) 268 | type Struct struct{} 269 | assert.Panics(t, func() { New().SetAs(nil, Struct{}) }) 270 | assert.Panics(t, func() { New().SetAs(nil, &Struct{}) }) 271 | 272 | // SetAs value must actually implement the specified interface 273 | assert.Panics(t, func() { New().SetAs(5, (*fmt.Stringer)(nil)) }) 274 | assert.Panics(t, func() { New().SetAs(Struct{}, (*fmt.Stringer)(nil)) }) 275 | } 276 | 277 | func TestWithBadValues(t *testing.T) { 278 | type Struct struct{} 279 | assert.Panics(t, func() { New().Then(nil) }) 280 | assert.Panics(t, func() { New().Then(5) }) 281 | assert.Panics(t, func() { New().Then(Struct{}) }) 282 | } 283 | 284 | func TestBadErrorHandler(t *testing.T) { 285 | // The error handler must actually be a function 286 | assert.Panics(t, func() { New().OnErr(true) }) 287 | // The error handler may not return any values. 288 | returnsSomething := func(err error) bool { return true } 289 | assert.Panics(t, func() { New().OnErr(returnsSomething) }) 290 | // The error handler can't take args of types that have not yet been 291 | // provided. 292 | takesAString := func(str string, err error) {} 293 | assert.Panics(t, func() { New().OnErr(takesAString) }) 294 | } 295 | 296 | func TestBadDefer(t *testing.T) { 297 | assert.Panics(t, func() { New().Defer(true) }, 298 | "deferred func must actually be a function") 299 | 300 | returnsSomething := func(err error) bool { return true } 301 | assert.Panics(t, func() { New().Defer(returnsSomething) }, 302 | "deferred func may not return any values") 303 | 304 | takesAString := func(str string) {} 305 | assert.Panics(t, func() { New().Defer(takesAString) }, 306 | "deferred func arg types must have already been provided") 307 | } 308 | 309 | func TestInterfaceConversionOnRun(t *testing.T) { 310 | chain := New().Arg((*fmt.Stringer)(nil)) 311 | 312 | assert.Error(t, chain.Run(), "missing arg") 313 | assert.NoError(t, chain.Run(nil), "nil value is ok") 314 | 315 | var stringer Stringer 316 | assert.NoError(t, chain.Run(stringer), "implements stringer") 317 | assert.NoError(t, chain.Run(&stringer), "implements stringer") 318 | 319 | var ptrStringer PtrStringer 320 | assert.Error(t, chain.Run(ptrStringer), "does not implement stringer") 321 | assert.NoError(t, chain.Run(&ptrStringer), "implements stringer") 322 | 323 | var ptrToPtrStringer *PtrStringer 324 | assert.NoError(t, chain.Run(ptrToPtrStringer), "nil implements stringer") 325 | 326 | var nilStringer fmt.Stringer 327 | assert.NoError(t, chain.Run(nilStringer), "nil values are ok") 328 | 329 | chain = New().Arg(0) 330 | assert.Error(t, chain.Run(nil), 331 | "nil values are not ok for non-pointers and non-interfaces") 332 | 333 | type Struct struct{} 334 | chain = New().Arg(&Struct{}) 335 | assert.NoError(t, chain.Run(nil), "nil values are ok for pointers to structs") 336 | assert.NoError(t, chain.Run(&Struct{}), "nil values are ok for pointers to structs") 337 | assert.Error(t, chain.Run(1), "ints don't match struct pointers") 338 | } 339 | 340 | type Stringer struct{} 341 | type PtrStringer struct{} 342 | 343 | func (Stringer) String() string { return "yup" } 344 | func (*PtrStringer) String() string { return "yup" } 345 | 346 | func TestBadRunArgs(t *testing.T) { 347 | chain := New(). 348 | Arg(int(0)). 349 | Arg((*fmt.Stringer)(nil)) 350 | 351 | assert.Error(t, chain.Run(), "not all args are specified") 352 | assert.Error(t, chain.Run(0), "not all args are specified") 353 | assert.NoError(t, chain.Run(0, nil), "all args are specified") 354 | } 355 | 356 | func TestRunArgsMustExactlyMatchSpecifiedArgs(t *testing.T) { 357 | chain := New(). 358 | Arg(int(0)). 359 | Arg(""). 360 | Arg(true) 361 | 362 | // OK 363 | assert.NoError(t, chain.Run(0, "hi", true)) 364 | 365 | // Wrong ordering 366 | assert.EqualError(t, chain.Run(true, "hi", 0), 367 | "bad arg: 1st arg of Run(...) should be a int but is bool") 368 | assert.EqualError(t, chain.Run(0, true, "hi"), 369 | "bad arg: 2nd arg of Run(...) should be a string but is bool") 370 | 371 | // Too many 372 | assert.EqualError(t, chain.Run(0, "hi", true, 'x', 0), 373 | "too many args: expected 3 args but got 5 args") 374 | // Not enough 375 | assert.EqualError(t, chain.Run(0, "hi"), 376 | "missing args of types: [bool]") 377 | assert.EqualError(t, chain.Run(0), 378 | "missing args of types: [string bool]") 379 | } 380 | 381 | func TestRunWithNilReservedInterface(t *testing.T) { 382 | var capturedStringer fmt.Stringer = time.Now() 383 | chain := New(). 384 | Arg((*fmt.Stringer)(nil)). 385 | Then(func(s fmt.Stringer) { capturedStringer = s }) 386 | 387 | require.NoError(t, chain.Run(nil)) 388 | assert.Nil(t, capturedStringer) 389 | } 390 | -------------------------------------------------------------------------------- /router.go: -------------------------------------------------------------------------------- 1 | package sandwich 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "strings" 8 | 9 | "github.com/augustoroman/sandwich/chain" 10 | ) 11 | 12 | // Router implements the sandwich middleware chaining and routing functionality. 13 | type Router interface { 14 | // Set a value that will be available to all handlers subsequent referenced. 15 | // This is typically used for concrete values. For interfaces to be correctly 16 | // provided to subsequent middleware, use SetAs. 17 | Set(vals ...any) 18 | // SetAs sets a value as the specified interface that will be available to all 19 | // handlers. 20 | // 21 | // Example: 22 | // type DB interface { ... } 23 | // var db DB = ... 24 | // mux.SetAs(db, (*DB)(nil)) 25 | // 26 | // That is functionally equivalent to using a middleware function that returns 27 | // the desired interface instance: 28 | // type DB interface { ... } 29 | // var db DB = ... 30 | // mux.Use(func() DB { return db }) 31 | SetAs(val, ifacePtr any) 32 | 33 | // Use adds middleware to be invoked for all routes registered by the 34 | // returned Router. The current router is not affected. This is equivalent to 35 | // adding the specified middelwareHandlers to each registered route. 36 | Use(middlewareHandlers ...any) 37 | 38 | // On will register a handler for the given method and path. 39 | On(method, path string, handlers ...any) 40 | 41 | // Get registers handlers for the specified path for the 'GET' HTTP method. 42 | // Get is shorthand for `On("GET", ...)`. 43 | Get(path string, handlers ...any) 44 | // Put registers handlers for the specified path for the 'PUT' HTTP method. 45 | // Put is shorthand for `On("PUT", ...)`. 46 | Put(path string, handlers ...any) 47 | // Post registers handlers for the specified path for the 'POST' HTTP method. 48 | // Post is shorthand for `On("POST", ...)`. 49 | Post(path string, handlers ...any) 50 | // Patch registers handlers for the specified path for the 'PATCH' HTTP 51 | // method. Patch is shorthand for `On("PATCH", ...)`. 52 | Patch(path string, handlers ...any) 53 | // Delete registers handlers for the specified path for the 'DELETE' HTTP 54 | // method. Delete is shorthand for `On("DELETE", ...)`. 55 | Delete(path string, handlers ...any) 56 | // Any registers a handlers for the specified path for any HTTP method. This 57 | // will always be superceded by dedicated method handlers. For example, if the 58 | // path '/users/:id/' is registered for Get, Put and Any, GET and PUT requests 59 | // will be handled by the Get(...) and Put(...) registrations, but DELETE, 60 | // CONNECT, or HEAD would be handled by the Any(...) registration. Any is a 61 | // shortcut for `On("*", ...)`. 62 | Any(path string, handlers ...any) 63 | 64 | // OnErr uses the specified error handler to handle any errors that occur on 65 | // any routes in this router. 66 | OnErr(handler any) 67 | 68 | // SubRouter derives a router that will called for all suffixes (and methods) 69 | // for the specified path. For example, `sub := root.SubRouter("/api")` will 70 | // create a router that will handle `/api/`, `/api/foo`. 71 | SubRouter(pathPrefix string) Router 72 | 73 | // ServeHTTP implements the http.Handler interface for the router. 74 | ServeHTTP(w http.ResponseWriter, r *http.Request) 75 | } 76 | 77 | // BuildYourOwn returns a minimal router that has no initial middleware 78 | // handling. 79 | func BuildYourOwn() Router { 80 | r := &router{} 81 | r.base = r.base.Arg((*http.ResponseWriter)(nil)) 82 | r.base = r.base.Arg((*http.Request)(nil)) 83 | r.base = r.base.Arg((Params)(nil)) 84 | return r 85 | } 86 | 87 | // TheUsual returns a router initialized with useful middleware. 88 | func TheUsual() Router { 89 | r := BuildYourOwn() 90 | r.Use(WrapResponseWriter, LogRequests) 91 | r.OnErr(HandleError) 92 | return r 93 | } 94 | 95 | type router struct { 96 | base chain.Func 97 | subRouters map[string]*router 98 | byMethod map[string]*mux 99 | anyMethod *mux 100 | notFound http.Handler 101 | } 102 | 103 | func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { 104 | params := Params{} 105 | h := r.match(req.Method, req.URL.Path, params) 106 | if h != nil { 107 | h.ServeHTTP(w, req, params) 108 | } else if r.notFound != nil { 109 | r.notFound.ServeHTTP(w, req) 110 | } else { 111 | http.Error(w, "Not found", http.StatusNotFound) 112 | } 113 | } 114 | 115 | func (r *router) SubRouter(prefix string) Router { 116 | if r.subRouters == nil { 117 | r.subRouters = map[string]*router{} 118 | } 119 | prefix = strings.TrimRight(prefix, "/") + "/" 120 | for existingPrefix := range r.subRouters { 121 | if existingPrefix == prefix || strings.HasPrefix(existingPrefix, prefix) || strings.HasPrefix(prefix, existingPrefix) { 122 | panic(fmt.Sprintf( 123 | "SubRouter with prefix %#q conflicts with existing SubRouter with prefix %#q", 124 | prefix, existingPrefix, 125 | )) 126 | } 127 | } 128 | r.subRouters[prefix] = &router{ 129 | base: r.base, 130 | notFound: r.notFound, 131 | } 132 | return r.subRouters[prefix] 133 | } 134 | 135 | func (r *router) match(method, uri string, params Params) httpHandlerWithParams { 136 | method = strings.ToUpper(method) 137 | for prefix, sub := range r.subRouters { 138 | if strings.HasPrefix(uri, prefix) { 139 | return sub.match(method, strings.TrimPrefix(uri, prefix), params) 140 | } 141 | } 142 | if h := r.byMethod[method].Match(uri, params); h != nil { 143 | return h 144 | } 145 | if h := r.anyMethod.Match(uri, params); h != nil { 146 | return h 147 | } 148 | return nil 149 | } 150 | 151 | func (r *router) Set(vals ...any) { 152 | for _, val := range vals { 153 | r.base = r.base.Set(val) 154 | } 155 | } 156 | 157 | func (r *router) SetAs(val, ifacePtr any) { 158 | r.base = r.base.SetAs(val, ifacePtr) 159 | } 160 | 161 | func (r *router) Use(middlewareHandlers ...any) { 162 | r.base = apply(r.base, middlewareHandlers...) 163 | } 164 | 165 | func (r *router) OnErr(errorHandler any) { 166 | r.base = r.base.OnErr(errorHandler) 167 | } 168 | 169 | func (r *router) On(method, path string, handlers ...any) { 170 | method = strings.ToUpper(method) 171 | m := r.getOrAllocateMux(method) 172 | if err := m.Register(path, handler{apply(r.base, handlers...)}); err != nil { 173 | panic(fmt.Errorf("Cannot register route: %v", err)) 174 | } 175 | } 176 | 177 | func (r *router) Any(path string, handlers ...any) { r.On("*", path, handlers...) } 178 | func (r *router) Get(path string, handlers ...any) { r.On("GET", path, handlers...) } 179 | func (r *router) Put(path string, handlers ...any) { r.On("PUT", path, handlers...) } 180 | func (r *router) Post(path string, handlers ...any) { r.On("POST", path, handlers...) } 181 | func (r *router) Patch(path string, handlers ...any) { r.On("PATCH", path, handlers...) } 182 | func (r *router) Delete(path string, handlers ...any) { r.On("DELETE", path, handlers...) } 183 | 184 | func (r *router) getOrAllocateMux(method string) *mux { 185 | if method == "*" { 186 | if r.anyMethod == nil { 187 | r.anyMethod = &mux{} 188 | } 189 | return r.anyMethod 190 | } 191 | if r.byMethod == nil { 192 | r.byMethod = map[string]*mux{} 193 | } 194 | m := r.byMethod[method] 195 | if m == nil { 196 | m = &mux{} 197 | r.byMethod[method] = m 198 | } 199 | return m 200 | } 201 | 202 | type handler struct{ chain.Func } 203 | 204 | func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request, p Params) { 205 | h.Func.MustRun(w, r, p) 206 | } 207 | 208 | type Params map[string]string 209 | 210 | type mux struct { 211 | static map[string]*mux 212 | params []muxParam 213 | handler httpHandlerWithParams 214 | } 215 | 216 | type muxParam struct { 217 | paramName string 218 | greedy bool 219 | mux *mux 220 | } 221 | 222 | type httpHandlerWithParams interface { 223 | ServeHTTP(w http.ResponseWriter, r *http.Request, p Params) 224 | } 225 | 226 | func (m *mux) Register(pattern string, h httpHandlerWithParams) error { 227 | if !strings.HasPrefix(pattern, "/") { 228 | return errors.New("patterns must begin with /") 229 | } 230 | segments := strings.Split(pattern[1:], "/") 231 | reg := registerInfo{ 232 | seenParams: map[string]bool{}, 233 | seenGreedy: false, 234 | } 235 | if m.static == nil { 236 | m.static = map[string]*mux{} 237 | } 238 | if err := reg.registerSegments(m, segments, h); err != nil { 239 | return fmt.Errorf("%#q: bad pattern: %w", pattern, err) 240 | } 241 | return nil 242 | } 243 | 244 | type registerInfo struct { 245 | seenParams map[string]bool 246 | seenGreedy bool 247 | } 248 | 249 | func (r *registerInfo) registerSegments(m *mux, segments []string, h httpHandlerWithParams) error { 250 | if len(segments) == 0 { 251 | if m.handler != nil { 252 | return fmt.Errorf("repeated entry") 253 | } 254 | m.handler = h 255 | return nil 256 | } 257 | next, remaining := segments[0], segments[1:] 258 | if strings.HasPrefix(next, "::") { 259 | return r.registerStatic(m, next[1:], remaining, h) 260 | } else if strings.HasPrefix(next, ":") { 261 | return r.registerParam(m, next[1:], remaining, h) 262 | } else { 263 | return r.registerStatic(m, next, remaining, h) 264 | } 265 | } 266 | 267 | func (r *registerInfo) registerStatic(m *mux, path string, remaining []string, h httpHandlerWithParams) error { 268 | sub := m.static[path] 269 | if sub == nil { 270 | sub = &mux{ 271 | static: map[string]*mux{}, 272 | } 273 | } 274 | err := r.registerSegments(sub, remaining, h) 275 | if err == nil { 276 | m.static[path] = sub 277 | } 278 | return err 279 | } 280 | 281 | func (r *registerInfo) registerParam(m *mux, param string, remaining []string, h httpHandlerWithParams) error { 282 | greedy := strings.HasSuffix(param, "*") 283 | name := strings.TrimSuffix(param, "*") 284 | if greedy && r.seenGreedy { 285 | return fmt.Errorf("only one greedy param allowed per pattern: %#q", name) 286 | } else if r.seenParams[name] { 287 | return fmt.Errorf("param used twice: %#q", name) 288 | } 289 | // Check to see if the param already exists. E.g. we've already registered 290 | // param at this level via: 291 | // /root/:param/path1 --> h1 292 | // and now we're registering: 293 | // /root/:param/path2 --> h2 294 | for _, p := range m.params { 295 | if p.paramName == name { 296 | if p.greedy != greedy { 297 | return fmt.Errorf("param %#q is sometimes greedy and sometimes not", name) 298 | } 299 | return r.registerSegments(p.mux, remaining, h) 300 | } 301 | // If we haven't registered this one yet, then we need to avoid ambiguous 302 | // path registrations. For example: 303 | // /root/:p1/path 304 | // /root/:p2/path 305 | // should not be allowed, nor should: 306 | // /root/:p1/:x/:y 307 | // /root/:p2/:a/:b 308 | if err := p.mux.checkAmbiguous(remaining); err != nil { 309 | return fmt.Errorf("ambiguous route: %w", err) 310 | } 311 | } 312 | sub := &mux{ 313 | static: map[string]*mux{}, 314 | } 315 | r.seenParams[name] = true 316 | r.seenGreedy = r.seenGreedy || greedy 317 | err := r.registerSegments(sub, remaining, h) 318 | if err == nil { 319 | m.params = append(m.params, muxParam{ 320 | paramName: name, 321 | greedy: greedy, 322 | mux: sub, 323 | }) 324 | } 325 | return err 326 | } 327 | 328 | func (m *mux) checkAmbiguous(segments []string) error { 329 | if len(segments) == 0 { 330 | if m.handler != nil { 331 | return fmt.Errorf("ambiguous route") 332 | } 333 | return nil 334 | } 335 | static, isStatic, _, _ := entryToInfo(segments[0]) 336 | if isStatic { 337 | if child := m.static[static]; child != nil { 338 | return child.checkAmbiguous(segments[1:]) 339 | } 340 | return nil 341 | } 342 | for _, p := range m.params { 343 | if err := p.mux.checkAmbiguous(segments[1:]); err != nil { 344 | return err 345 | } 346 | } 347 | return nil 348 | } 349 | 350 | func entryToInfo(entry string) (static string, isStatic bool, paramName string, greedy bool) { 351 | if strings.HasPrefix(entry, "::") { 352 | // double colon prefix escapes to single colon static path name. 353 | return entry[1:], true, "", false 354 | } else if !strings.HasPrefix(entry, ":") { 355 | return entry, true, "", false 356 | } 357 | paramName = strings.TrimSuffix(entry[1:], "*") 358 | greedy = strings.HasSuffix(entry, "*") 359 | return "", false, paramName, greedy 360 | } 361 | 362 | func (m *mux) Match(uri string, params Params) httpHandlerWithParams { 363 | uri = strings.TrimPrefix(uri, "/") 364 | segments := strings.Split(uri, "/") 365 | matched := m.matchPrefix(segments, params) 366 | if matched == nil { 367 | return nil 368 | } 369 | return matched 370 | } 371 | 372 | func (m *mux) matchPrefix(segments []string, params Params) httpHandlerWithParams { 373 | if m == nil { 374 | return nil 375 | } 376 | if len(segments) == 0 { 377 | return m.handler 378 | } 379 | path, remaining := segments[0], segments[1:] 380 | if sub := m.static[path]; sub != nil { 381 | match := sub.matchPrefix(remaining, params) 382 | if match != nil { 383 | return match 384 | } 385 | } 386 | for _, param := range m.params { 387 | if !param.greedy { 388 | matched := param.mux.matchPrefix(remaining, params) 389 | if matched != nil { 390 | params[param.paramName] = path 391 | return matched 392 | } 393 | } else { 394 | matched, used := param.mux.matchSuffix(remaining, params) 395 | if matched != nil { 396 | N := len(segments) 397 | params[param.paramName] = strings.Join(segments[:N-used], "/") 398 | return matched 399 | } 400 | } 401 | } 402 | return nil 403 | } 404 | 405 | func (m *mux) matchSuffix(segments []string, params Params) (h httpHandlerWithParams, depth int) { 406 | N := len(segments) 407 | if N == 0 { 408 | return m.handler, 0 409 | } 410 | for staticPath, sub := range m.static { 411 | match, d := sub.matchSuffix(segments, params) 412 | if match == nil { 413 | continue 414 | } 415 | depth = d + 1 416 | actualPath := segments[N-depth] 417 | if actualPath != staticPath { 418 | continue 419 | } 420 | return match, depth 421 | } 422 | for _, param := range m.params { 423 | match, d := param.mux.matchSuffix(segments, params) 424 | if match == nil { 425 | continue 426 | } 427 | depth = d + 1 428 | actualPath := segments[N-depth] 429 | params[param.paramName] = actualPath // TODO: might be rejected, might spam params 430 | return match, depth 431 | } 432 | return m.handler, 0 433 | } 434 | -------------------------------------------------------------------------------- /chain/chain.go: -------------------------------------------------------------------------------- 1 | // Package chain is a reflection-based dependency-injected chain of functions 2 | // that powers the sandwich middleware framework. 3 | // 4 | // A Func chain represents a sequence of functions to call along with an initial 5 | // input. The parameters to each function are automatically provided from either 6 | // the initial inputs or return values of earlier functions in the sequence. 7 | // 8 | // In contrast to other dependency-injection frameworks, chain does not 9 | // automatically determine how to provide dependencies -- it merely uses the 10 | // most recently-provided value. This enables chains to report errors 11 | // immediately during the chain construction and, if successfully constructed, a 12 | // chain can always be executed. 13 | // 14 | // # HTTP Middleware Example 15 | // 16 | // As a common example, chains in http handling middleware typically start with 17 | // the http.ResponseWriter and *http.Request provided by the http framework: 18 | // 19 | // base := chain.Func{}. 20 | // Arg((*http.ResponseWriter)(nil)). // declared as an arg when Run 21 | // Arg((*http.Request)(nil)). // declared as an arg when Run 22 | // 23 | // Given the following functions: 24 | // 25 | // func GetDB() (*UserDB, error) {...} 26 | // func GetUserFromRequest(db *UserDB, req *http.Request) (*User, error) {...} 27 | // func SendUserAsJSON(w http.ResponseWriter, u *User) error {...} 28 | // 29 | // func GetUserID(r *http.Request) (UserID, error) {...} 30 | // func (db *UserDB) Lookup(UserID) (*User, error) { ... } 31 | // 32 | // func SendProjectAsJSON(w http.ResponseWriter, p *Project) error {...} 33 | // 34 | // then these chains would work fine: 35 | // 36 | // base.Then( 37 | // GetDB, // takes no args ✅, provides *UserDB to later funcs 38 | // GetUserFromRequest, // takes *UserDB ✅ and *Request ✅, provides *User 39 | // SendUserAsJSON, // takes ResponseWriter ✅ and *User ✅ 40 | // ) 41 | // 42 | // base.Then( 43 | // GetDB, // takes no args ✅, provides *UserDB to later funcs 44 | // GetUserID, // takes *Request ✅, provides UserID 45 | // (*UserDB).Lookup, // takes *UserDB ✅ and UserID ✅, provides *User 46 | // SendUserAsJSON, // takes ResponseWriter ✅ and *User ✅ 47 | // ) 48 | // 49 | // but these chains would fail: 50 | // 51 | // base.Then( 52 | // GetUserFromRequest, // takes *UserDB ❌ and *Request ✅ 53 | // GetDB, // this *UserDB isn't available yet. 54 | // SendUserAsJSON, // 55 | // ) 56 | // 57 | // base.Then( 58 | // GetDB, // takes no args ✅, provides *UserDB to later funcs 59 | // GetUserID, // takes *Request ✅, provides UserID 60 | // (*UserDB).Lookup, // takes *UserDB ✅ and UserID ✅, provides *User 61 | // SendProjectAsJSON, // takes ResponseWriter ✅ and *Project ❌ 62 | // ) 63 | // 64 | // base.Then( 65 | // GetUserFromRequest, // takes *UserDB ❌ and *Request ✅ 66 | // SendUserAsJSON, // 67 | // ) 68 | package chain 69 | 70 | import ( 71 | "bytes" 72 | "fmt" 73 | "reflect" 74 | "runtime" 75 | "strings" 76 | "text/tabwriter" 77 | ) 78 | 79 | var errorType = reflect.TypeOf((*error)(nil)).Elem() 80 | 81 | // DefaultErrorHandler is called when an error in the chain occurs and no error 82 | // handler has been registered. Warning! The default error handler is not 83 | // checked to verify that it's arguments can be provided. It's STRONGLY 84 | // recommended to keep this as absolutely simple as possible. 85 | var DefaultErrorHandler interface{} = func(err error) { panic(err) } 86 | 87 | // Func defines the chain of functions to invoke when Run. Each Func is 88 | // immutable: all operations will return a new Func chain. 89 | type Func struct{ steps []step } 90 | 91 | // step is a single value or handler in the middleware stack. Each step has a 92 | // typ flag that indicates what kind of step it is. 93 | type step struct { 94 | typ stepType 95 | val reflect.Value 96 | // For tVALUE steps, this may optionally be non-nil to specific an 97 | // additional interface type that is provided. 98 | // For tRESERVE steps, this must be non-nil to declare the reserved type. 99 | // For t*_HANDLER steps, this is the function type. 100 | valTyp reflect.Type 101 | } 102 | 103 | type stepType uint8 104 | 105 | const ( 106 | tARG stepType = iota 107 | tVALUE 108 | tPRE_HANDLER // PRE handlers are the normal handlers 109 | tPOST_HANDLER // POST handlers are deferred handlers 110 | tERROR_HANDLER 111 | ) 112 | 113 | // Clone this chain and add the extra steps to the clone. 114 | func (c Func) with(steps ...step) Func { 115 | s := make([]step, 0, len(c.steps)+len(steps)) 116 | s = append(s, c.steps...) 117 | s = append(s, steps...) 118 | return Func{s} 119 | } 120 | 121 | // Arg indicates that a value with the specified type will be a parameter to Run 122 | // when the Func is invoked. This is typically necessary to start the chain for 123 | // a given middleware framework. Arg should not be exposed to users of sandwich 124 | // since it bypasses the causal checks and risks runtime errors. 125 | func (c Func) Arg(typeOrInterfacePtr interface{}) Func { 126 | typ := reflect.TypeOf(typeOrInterfacePtr) 127 | if typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Interface { 128 | typ = typ.Elem() 129 | } 130 | return c.with(step{tARG, reflect.Value{}, typ}) 131 | } 132 | 133 | // Set an immediate value. This cannot be used to provide an interface, instead 134 | // use SetAs(...) or With(...) with a function that returns the interface. 135 | func (c Func) Set(value interface{}) Func { 136 | if value == nil { 137 | panicf("Set(nil) is not allowed -- " + 138 | "did you mean to use SetAs(val, (*IFace)(nil))?") 139 | } 140 | return c.with(step{tVALUE, reflect.ValueOf(value), reflect.TypeOf(value)}) 141 | } 142 | 143 | // SetAs provides an immediate value as the specified interface type. 144 | func (c Func) SetAs(value, ifacePtr interface{}) Func { 145 | val := reflect.ValueOf(value) 146 | typ := reflect.TypeOf(ifacePtr) 147 | if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Interface { 148 | panicf("ifacePtr must be a pointer to an interface for "+ 149 | "SetAs, instead got %s", typ) 150 | } 151 | typ = typ.Elem() 152 | // It's ok to pass in a nil value here if you want the interface to actually 153 | // be nil. 154 | if !val.IsValid() { 155 | val = reflect.Zero(typ) 156 | } 157 | if !val.Type().Implements(typ) { 158 | panicf("%s doesn't implement %s", val.Type(), typ) 159 | } 160 | return c.with(step{tVALUE, val, typ}) 161 | } 162 | 163 | // Compute what types are available from the reserved values, provide values, 164 | // and function return values of the current handler chain. This excludes 165 | // error handlers and deferred handlers. 166 | func (c Func) typesAvailable() map[reflect.Type]bool { 167 | m := map[reflect.Type]bool{} 168 | for _, s := range c.steps { 169 | switch s.typ { 170 | case tARG: 171 | m[s.valTyp] = true 172 | case tVALUE: 173 | m[s.val.Type()] = true 174 | m[s.valTyp] = true 175 | case tPRE_HANDLER: 176 | for i := 0; i < s.valTyp.NumOut(); i++ { 177 | m[s.valTyp.Out(i)] = true 178 | } 179 | case tPOST_HANDLER, tERROR_HANDLER: 180 | // ignored, we don't allow any return values for these. 181 | } 182 | } 183 | return m 184 | } 185 | 186 | // Then adds one or more handlers to the middleware chain. It may only accept 187 | // args of types that have already been provided. 188 | func (c Func) Then(handlers ...interface{}) Func { 189 | steps := make([]step, len(handlers)) 190 | available := c.typesAvailable() 191 | for i, handler := range handlers { 192 | fn, err := valueOfFunction(handler) 193 | if err != nil { 194 | panicf("%s arg of With(...) %v", ordinalize(i+1), err) 195 | } 196 | if err := checkCanCall(available, fn); err != nil { 197 | panicf("%s arg of With(...) %v", ordinalize(i+1), err) 198 | } 199 | fnType := fn.Func.Type() 200 | steps[i] = step{tPRE_HANDLER, fn.Func, fnType} 201 | for i := 0; i < fnType.NumOut(); i++ { 202 | available[fnType.Out(i)] = true 203 | } 204 | } 205 | return c.with(steps...) 206 | } 207 | 208 | // OnErr registers an error handler to be called for failures of subsequent 209 | // handlers. It may only accept args of types that have already been provided. 210 | func (c Func) OnErr(errorHandler interface{}) Func { 211 | fn, err := valueOfFunction(errorHandler) 212 | if err != nil { 213 | panicf("Error handler %v", err) 214 | } 215 | available := c.typesAvailable() 216 | available[errorType] = true // Set internally by chain. 217 | if err := checkCanCall(available, fn); err != nil { 218 | panicf("Error handler %v", err) 219 | } 220 | if fn.Func.Type().NumOut() > 0 { 221 | panicf("Error handler %s may not have any return values, signature is %s", 222 | fn.Name, fn.Func.Type()) 223 | } 224 | return c.with(step{tERROR_HANDLER, fn.Func, fn.Func.Type()}) 225 | } 226 | 227 | // Defer adds a deferred handler to be executed after all normal handlers and 228 | // error handlers have been called. Deferred handlers are executed in reverse 229 | // order that they were registered (most recent first). Deferred handlers can 230 | // accept the error type even if it hasn't been explicitly provided yet. If no 231 | // error has occurred, it will be nil. 232 | func (c Func) Defer(handler interface{}) Func { 233 | fn, err := valueOfFunction(handler) 234 | if err != nil { 235 | panicf("Defer(...) arg %v", err) 236 | } 237 | available := c.typesAvailable() 238 | available[errorType] = true // Set internally by chain. 239 | if err := checkCanCall(available, fn); err != nil { 240 | panicf("Defer(...) arg %v", err) 241 | } 242 | if fn.Func.Type().NumOut() > 0 { 243 | panicf("Defer'd handler %s may not have any return values, signature is %s", 244 | fn.Name, fn.Func.Type()) 245 | } 246 | return c.with(step{tPOST_HANDLER, fn.Func, fn.Func.Type()}) 247 | } 248 | 249 | // MustRun will function chain with the provided args and panic if the args 250 | // don't match the expected arg values. 251 | func (c Func) MustRun(argValues ...interface{}) { 252 | // This will only ever return an error if the arguments to Run don't match. 253 | // Runtime failures of the functions in the chain are handled by the 254 | // registered error handlers (or the default error handler which may panic). 255 | if err := c.Run(argValues...); err != nil { 256 | panic(err) 257 | } 258 | } 259 | 260 | // Run executes the function chain. All declared args must be provided in the 261 | // order than they were declared. This will return an error only if the 262 | // arguments do not exactly correspond to the declared args. Interface values 263 | // must be passed as pointers to the interface. 264 | // 265 | // Important note: The returned error is NOT related to whether any the calls of 266 | // chain returns an error -- any errors returned by functions in the chain are 267 | // handled by the registered error handlers. 268 | func (c Func) Run(argValues ...interface{}) error { 269 | data := map[reflect.Type]reflect.Value{} 270 | postSteps := []step{} // collect post steps here 271 | errHandler := step{ // Initialize using the default error handler. 272 | tERROR_HANDLER, 273 | reflect.ValueOf(DefaultErrorHandler), 274 | reflect.TypeOf(DefaultErrorHandler), 275 | } 276 | stack := []step{} 277 | 278 | // 1: Apply all of the arguments to the available data. Make sure that the 279 | // provided arguments match the Arg calls, otherwise we bomb. 280 | if err := c.processRunArgs(data, argValues...); err != nil { 281 | return err 282 | } 283 | 284 | // Start executing the function chain. First pass through is the normal call 285 | // chain, so we skip execution of error handlers and deferred handlers, 286 | // although we keep track of them. 287 | execution: 288 | for _, step := range c.steps { 289 | switch step.typ { 290 | case tARG: 291 | // ignored now, already handled during initialization above. 292 | case tVALUE: 293 | data[step.val.Type()] = step.val 294 | data[step.valTyp] = step.val 295 | case tPRE_HANDLER: 296 | c.call(step, data, &stack) 297 | // Check to see if there's an error. If so, abort the chain. 298 | if errorVal := data[errorType]; errorVal.IsValid() && !errorVal.IsNil() { 299 | break execution 300 | } 301 | case tPOST_HANDLER: 302 | postSteps = append(postSteps, step) 303 | case tERROR_HANDLER: 304 | errHandler = step 305 | } 306 | } 307 | 308 | // Execute the error handler if there is any error. 309 | if errorVal := data[errorType]; errorVal.IsValid() && !errorVal.IsNil() { 310 | c.call(errHandler, data, &stack) 311 | } else { 312 | data[errorType] = reflect.Zero(errorType) 313 | } 314 | 315 | // Finally, call any deferred functions that we've gotten to. 316 | for i := len(postSteps) - 1; i >= 0; i-- { 317 | c.call(postSteps[i], data, &stack) 318 | } 319 | 320 | return nil 321 | } 322 | 323 | func (c Func) processRunArgs( 324 | data map[reflect.Type]reflect.Value, 325 | argValues ...interface{}, 326 | ) error { 327 | argIndex := 0 328 | expectedNumArgs := 0 329 | var missingArgs []string 330 | for _, step := range c.steps { 331 | if step.typ != tARG { 332 | continue 333 | } 334 | expectedNumArgs++ 335 | if argIndex >= len(argValues) { 336 | missingArgs = append(missingArgs, step.valTyp.String()) 337 | continue 338 | } 339 | val := argValues[argIndex] 340 | argIndex++ 341 | 342 | if val == nil { 343 | if step.valTyp.Kind() == reflect.Interface || step.valTyp.Kind() == reflect.Ptr { 344 | data[step.valTyp] = reflect.New(step.valTyp).Elem() 345 | continue 346 | } 347 | return fmt.Errorf("bad arg: %s arg of Run(...) should be a %s but is %v", 348 | ordinalize(argIndex), step.valTyp, val) 349 | } 350 | 351 | rv := reflect.ValueOf(val) 352 | if !rv.CanConvert(step.valTyp) { 353 | return fmt.Errorf("bad arg: %s arg of Run(...) should be a %s but is %s", 354 | ordinalize(argIndex), step.valTyp, rv.Type()) 355 | } 356 | data[step.valTyp] = rv.Convert(step.valTyp) 357 | } 358 | if len(missingArgs) > 0 { 359 | return fmt.Errorf("missing args of types: %s", missingArgs) 360 | } 361 | if argIndex != len(argValues) { 362 | return fmt.Errorf("too many args: expected %d args but got %d args", 363 | expectedNumArgs, len(argValues)) 364 | } 365 | return nil 366 | } 367 | 368 | func (c Func) call(s step, data map[reflect.Type]reflect.Value, stack *[]step) { 369 | t := s.valTyp 370 | in := make([]reflect.Value, t.NumIn()) 371 | for i := range in { 372 | in[i] = data[t.In(i)] 373 | // This isn't supposed to happen if we've done all our checks right. 374 | if !in[i].IsValid() { 375 | name := runtime.FuncForPC(s.val.Pointer()).Name() 376 | panicf("Cannot inject %s arg of type %s into %s (%s). Data: %v", 377 | ordinalize(i+1), t.In(i), name, t, data) 378 | } 379 | } 380 | defer func() { 381 | if err := c.wrapPanic(recover(), *stack); err != nil { 382 | data[errorType] = reflect.ValueOf((*error)(&err)).Elem() 383 | } 384 | }() 385 | *stack = append(*stack, s) 386 | out := s.val.Call(in) 387 | for _, val := range out { 388 | data[val.Type()] = val 389 | } 390 | } 391 | 392 | func (c Func) wrapPanic(x interface{}, steps []step) error { 393 | if x == nil { 394 | return nil 395 | } 396 | var stack [8192]byte 397 | n := runtime.Stack(stack[:], false) 398 | 399 | N := len(steps) 400 | mwStack := make([]FuncInfo, N) 401 | for i := range steps { 402 | step := steps[N-i-1] 403 | info := runtime.FuncForPC(step.val.Pointer()) 404 | file, line := info.FileLine(step.val.Pointer()) 405 | mwStack[i] = FuncInfo{info.Name(), file, line, step.val} 406 | } 407 | 408 | return PanicError{ 409 | Val: x, 410 | RawStack: string(stack[:n]), 411 | MiddlewareStack: mwStack, 412 | } 413 | } 414 | 415 | // PanicError is the error that is returned if a handler panics. It includes 416 | // the panic'd value (Val), the raw Go stack trace (RawStack), and the 417 | // middleware execution history (MiddlewareStack) that shows what middleware 418 | // functions have already been called. 419 | type PanicError struct { 420 | Val interface{} 421 | RawStack string 422 | MiddlewareStack []FuncInfo 423 | } 424 | 425 | // FuncInfo describes a registered middleware function. 426 | type FuncInfo struct { 427 | Name string // fully-qualified name, e.g.: github.com/foo/bar.FuncName 428 | File string 429 | Line int 430 | Func reflect.Value 431 | } 432 | 433 | // FilteredStack returns the stack trace without some internal chain.* functions 434 | // and without reflect.Value.call stack frames, since these are generally just 435 | // noise. The reflect.Value.call removal could affect user stack frames. 436 | // 437 | // TODO(aroman): Refine filtering so that it only removes reflect.Value.call 438 | // frames due to sandwich. 439 | func (p PanicError) FilteredStack() []string { 440 | lines := strings.Split(p.RawStack, "\n") 441 | var filtered []string 442 | for i := 0; i < len(lines); i++ { 443 | line := lines[i] 444 | if strings.HasPrefix(line, "github.com/augustoroman/sandwich/chain") && 445 | !strings.HasPrefix(line, "github.com/augustoroman/sandwich/chain.Func.Run(") && 446 | !strings.HasPrefix(line, "github.com/augustoroman/sandwich/chain.Test") { 447 | i++ 448 | continue 449 | } 450 | if strings.HasPrefix(line, "reflect.Value.call") || strings.HasPrefix(line, "reflect.Value.Call") { 451 | i++ 452 | continue 453 | } 454 | filtered = append(filtered, line) 455 | } 456 | return filtered 457 | } 458 | 459 | func (p PanicError) Error() string { 460 | var mwStack bytes.Buffer 461 | w := tabwriter.NewWriter(&mwStack, 5, 7, 2, ' ', 0) 462 | for _, fn := range p.MiddlewareStack { 463 | fmt.Fprintf(w, " %s\t%s\n", fn.Name, fn.Func.Type()) 464 | } 465 | w.Flush() 466 | return fmt.Sprintf( 467 | "Panic executing middleware %s: %v\n"+ 468 | " Middleware executed:\n%s"+ 469 | " Filtered call stack:\n %s", 470 | p.MiddlewareStack[0].Name, p.Val, 471 | mwStack.String(), 472 | strings.Join(p.FilteredStack(), "\n ")) 473 | } 474 | --------------------------------------------------------------------------------