├── .gitignore
├── .travis.yml
├── examples
├── 2-advanced
│ ├── oauth2-client-id.gif
│ ├── config.json
│ ├── static
│ │ ├── error.tpl.html
│ │ ├── landing-page.tpl.html
│ │ ├── home.tpl.html
│ │ └── main.js
│ ├── README.md
│ └── main.go
├── 0-helloworld
│ └── main.go
└── 1-simple
│ ├── README.md
│ └── main.go
├── go.mod
├── fs_test.go
├── fs.go
├── LICENSE
├── chain
├── benchmark_reflect_call_test.go
├── codegen_test.go
├── naming_test.go
├── ordinal.go
├── util.go
├── example_test.go
├── codegen.go
├── naming.go
├── chain_test.go
└── chain.go
├── gzip_test.go
├── response_writer.go
├── go.sum
├── gzip.go
├── example_test.go
├── wrap.go
├── errors.go
├── logger.go
├── benchmark_simple_handler_test.go
├── logger_test.go
├── doc.go
├── router_test.go
├── README.md
└── router.go
/.gitignore:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: go
2 | go:
3 | - 1.19
4 | - master
5 | notifications:
6 | email:
7 | on_failure: change
8 |
--------------------------------------------------------------------------------
/examples/2-advanced/oauth2-client-id.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/augustoroman/sandwich/HEAD/examples/2-advanced/oauth2-client-id.gif
--------------------------------------------------------------------------------
/examples/2-advanced/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "host": "localhost:5000",
3 | "port": 5000,
4 | "cookie-secret": "123456789012345678901234",
5 | "oauth2-client-id": "see README.md",
6 | "oauth2-client-secret": "see README.md"
7 | }
8 |
--------------------------------------------------------------------------------
/examples/2-advanced/static/error.tpl.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Failed
5 |
6 |
7 | OMG, something went horribly wrong.
8 | If this were real, this would be a beautifully-styled fail page.
9 | Your error is:
10 | ({{.Error.Code}}) - {{.Error.ClientMsg}}
11 |
12 |
13 |
--------------------------------------------------------------------------------
/examples/0-helloworld/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "log"
6 | "net/http"
7 |
8 | "github.com/augustoroman/sandwich"
9 | )
10 |
11 | func main() {
12 | mux := sandwich.TheUsual()
13 | mux.Get("/", func(w http.ResponseWriter) {
14 | fmt.Fprintf(w, "Hello world!")
15 | })
16 | if err := http.ListenAndServe(":8080", mux); err != nil {
17 | log.Fatal(err)
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/examples/1-simple/README.md:
--------------------------------------------------------------------------------
1 | # Basic sandwich usage
2 |
3 | This example demonstrates most of the basic features of sandwich, including:
4 |
5 | * Providing user types to the middleware chain
6 | * Adding middleware handlers to the stack.
7 | * Writing handlers that provide request-scoped values.
8 | * Writing handlers using injected values.
9 | * Using the default sandwich logging system.
10 | * Using the default sandwich error system.
11 |
12 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/augustoroman/sandwich
2 |
3 | go 1.19
4 |
5 | require (
6 | github.com/bradrydzewski/go.auth v0.0.0-20130828171325-d0051b5cc538
7 | github.com/stretchr/testify v1.7.0
8 | )
9 |
10 | require (
11 | github.com/davecgh/go-spew v1.1.1 // indirect
12 | github.com/dchest/authcookie v0.0.0-20190824115100-f900d2294c8e // indirect
13 | github.com/pmezard/go-difflib v1.0.0 // indirect
14 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
15 | )
16 |
--------------------------------------------------------------------------------
/fs_test.go:
--------------------------------------------------------------------------------
1 | package sandwich
2 |
3 | import (
4 | "embed"
5 | "io/fs"
6 | "net/http/httptest"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/assert"
10 | "github.com/stretchr/testify/require"
11 | )
12 |
13 | //go:embed examples
14 | var examples embed.FS
15 |
16 | func TestServeFS(t *testing.T) {
17 | helloworld := ServeFS(examples, "examples/0-helloworld", "path")
18 | w := httptest.NewRecorder()
19 | helloworld(w, httptest.NewRequest("", "/foo", nil), Params{"path": "main.go"})
20 | contents, err := fs.ReadFile(examples, "examples/0-helloworld/main.go")
21 | require.NoError(t, err)
22 | assert.Equal(t, string(contents), w.Body.String())
23 | }
24 |
--------------------------------------------------------------------------------
/fs.go:
--------------------------------------------------------------------------------
1 | package sandwich
2 |
3 | import (
4 | "io/fs"
5 | "net/http"
6 | )
7 |
8 | // ServeFS is a simple helper that will serve static files from an fs.FS
9 | // filesystem. It allows serving files identified by a sandwich path parameter
10 | // out of a subdirectory of the filesystem. This is especially useful when
11 | // embedding static files:
12 | //
13 | // //go:embed server_files
14 | // var all_files embed.FS
15 | //
16 | // mux.Get("/css/:path*", sandwich.ServeFS(all_files, "static/css", "path"))
17 | // mux.Get("/js/:path*", sandwich.ServeFS(all_files, "dist/js", "path"))
18 | // mux.Get("/i/:path*", sandwich.ServeFS(all_files, "static/images", "path"))
19 | func ServeFS(
20 | f fs.FS,
21 | fsRoot string,
22 | pathParam string,
23 | ) func(w http.ResponseWriter, r *http.Request, p Params) {
24 | sub, err := fs.Sub(f, fsRoot)
25 | if err != nil {
26 | panic(err)
27 | }
28 | handler := http.FileServer(http.FS(sub))
29 | return func(w http.ResponseWriter, r *http.Request, p Params) {
30 | r.URL.Path = p[pathParam]
31 | handler.ServeHTTP(w, r)
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/examples/2-advanced/static/landing-page.tpl.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Welcome to a sample todo app
5 |
13 |
14 |
15 | Please sign in to start making your todo list:
16 |
17 | - Use the fake sign-in to sign in as anyone:
18 |
24 |
25 | - Sign in with Google
For this, you must have
26 | configured the oauth client id and secret into the config.json file.
27 | See README.md for details.
28 |
.
29 |
30 |
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2016 Augusto Roman
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/chain/benchmark_reflect_call_test.go:
--------------------------------------------------------------------------------
1 | package chain
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 | )
7 |
8 | func twoArg(a, b string) string {
9 | return a
10 | }
11 |
12 | func manyArg(a, b, c, d, e, f, g string) (s, t, u, v, w, y, z string) {
13 | return "s", "t", "u", "v", "w", "y", "z"
14 | }
15 |
16 | func BenchmarkReflectCall_TwoArgs(b *testing.B) {
17 | arg1, arg2 := reflect.ValueOf("foo"), reflect.ValueOf("bar")
18 | fn := reflect.ValueOf(twoArg)
19 | args := []reflect.Value{arg1, arg2}
20 | for i := 0; i < b.N; i++ {
21 | fn.Call(args)
22 | }
23 | }
24 | func BenchmarkDirectCall_TwoArgs(b *testing.B) {
25 | arg1, arg2 := "foo", "bar"
26 | for i := 0; i < b.N; i++ {
27 | twoArg(arg1, arg2)
28 | }
29 | }
30 |
31 | func BenchmarkReflectCall_ManyArgs(b *testing.B) {
32 | arg := reflect.ValueOf("arg")
33 | fn := reflect.ValueOf(manyArg)
34 | args := []reflect.Value{arg, arg, arg, arg, arg, arg, arg}
35 | for i := 0; i < b.N; i++ {
36 | fn.Call(args)
37 | }
38 | }
39 | func BenchmarkDirectCall_ManyArgs(bb *testing.B) {
40 | a, b, c, d, e, f, g := "a", "b", "c", "d", "e", "f", "g"
41 | for i := 0; i < bb.N; i++ {
42 | manyArg(a, b, c, d, e, f, g)
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/examples/2-advanced/static/home.tpl.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | TODO List
5 |
19 |
20 |
21 | Hello {{.User.Name}}
22 |
23 | [
24 | Normal handlers |
25 | Sign out
26 | ]
27 |
28 |
29 |
30 | Here's your TODO list:
31 |
32 | {{ range .Tasks }}
33 | - {{.Desc}}
34 | {{ end }}
35 |
36 |
37 | Add a new task:
38 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/examples/2-advanced/README.md:
--------------------------------------------------------------------------------
1 | # Advanced sandwich usage: Simple TODO app.
2 |
3 | This example demonstrates more advanced features of sandwich, including:
4 |
5 | * Providing interface types to the middleware chain.
6 | TaskDb is the interface provided to the handlers, the actual value injected
7 | in main() is a taskDbImpl.
8 | * Using 3rd party middleware (go.auth, go.rice)
9 | * Using a 3rd party router (gorilla/mux)
10 | * Using multiple error handlers, and custom error handlers.
11 | Most web servers will want to server a custom HTML error page for user-facing
12 | error pages. An example of that is included here. For AJAX calls, however,
13 | we don't want to serve HTML. Instead, we always respond with JSON using the sandwich.. With
14 | sandwich, we the errors returned from handlers are agnostic. Instead, the
15 | error handler decides what format to respond in.
16 | * Early exit of the middleware chain via the sandwich.Done error
17 | See `CheckForFakeLogin()` for usage.
18 |
19 | ## Google Login (Oauth2 Authentication)
20 |
21 | In order to use the Google login, you need an Oauth2 client ID & client secret.
22 | See this animation for an example of getting these values:
23 |
24 | 
25 |
26 | More documentation is available at https://developers.google.com/identity/protocols/OAuth2WebServer
27 |
--------------------------------------------------------------------------------
/chain/codegen_test.go:
--------------------------------------------------------------------------------
1 | package chain
2 |
3 | import (
4 | "bytes"
5 | "net/http"
6 | "regexp"
7 | "strings"
8 | "testing"
9 | )
10 |
11 | func normalizeWhitespace(s string) string {
12 | return regexp.MustCompile(`\s+`).ReplaceAllLiteralString(strings.TrimSpace(s), " ")
13 | }
14 |
15 | type TestDb struct{}
16 |
17 | func (t *TestDb) Validate(s string) {}
18 |
19 | func TestCodeGen(t *testing.T) {
20 | var buf bytes.Buffer
21 |
22 | type User struct{}
23 | type Http struct{}
24 |
25 | New().
26 | Arg((*http.ResponseWriter)(nil)).
27 | Set("").
28 | Set(int64(0)).
29 | Set(int(1)).
30 | Set((*User)(nil)).
31 | Set(User{}).
32 | Set(Http{}).
33 | Set(&TestDb{}).
34 | Then((*TestDb).Validate).
35 | Then(a, b, c).
36 | Code("foo", "chain", &buf)
37 |
38 | const expected = `func foo(
39 | str string,
40 | i64 int64,
41 | i int,
42 | pUser *User,
43 | user User,
44 | chain_Http Http,
45 | pTestDb *TestDb,
46 | ) func(
47 | rw http.ResponseWriter,
48 | ) {
49 | return func(
50 | rw http.ResponseWriter,
51 | ) {
52 | (*TestDb).Validate(pTestDb, str)
53 |
54 | str = a()
55 |
56 | str, i = b(str)
57 |
58 | c(str, i)
59 |
60 | }
61 | }`
62 | if normalizeWhitespace(buf.String()) != normalizeWhitespace(expected) {
63 | t.Errorf("Wrong code generated: %s\nExp: %q\nGot: %q", buf.String(),
64 | normalizeWhitespace(expected), normalizeWhitespace(buf.String()))
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/gzip_test.go:
--------------------------------------------------------------------------------
1 | package sandwich
2 |
3 | import (
4 | "compress/gzip"
5 | "fmt"
6 | "io"
7 | "net/http"
8 | "net/http/httptest"
9 | "testing"
10 | )
11 |
12 | func TestGzip(t *testing.T) {
13 | greet := func(w http.ResponseWriter, r *http.Request) {
14 | fmt.Fprintf(w, "Hi there!")
15 | }
16 | handler := BuildYourOwn()
17 | handler.Use(Gzip)
18 | handler.Any("/", greet)
19 |
20 | resp := httptest.NewRecorder()
21 | req, _ := http.NewRequest("GET", "/", nil)
22 | req.Header.Add(headerAcceptEncoding, "gzip")
23 | handler.ServeHTTP(resp, req)
24 |
25 | if resp.Header().Get(headerContentEncoding) != "gzip" {
26 | t.Errorf("Not gzip'd? Content-encoding: %q", resp.Header())
27 | }
28 |
29 | if resp.Header().Get(headerContentLength) != "" {
30 | t.Errorf("Not supposed to include content-length: %q", resp.Header())
31 | }
32 |
33 | r, err := gzip.NewReader(resp.Body)
34 | if err != nil {
35 | t.Fatal(err)
36 | }
37 | defer r.Close()
38 | if body, err := io.ReadAll(r); err != nil {
39 | t.Fatal(err)
40 | } else if string(body) != "Hi there!" {
41 | t.Errorf("Wrong response: %q", string(body))
42 | }
43 |
44 | // Also, test without the accept header and make sure it's NOT gzip'd.
45 | resp = httptest.NewRecorder()
46 | req, _ = http.NewRequest("GET", "/", nil)
47 | handler.ServeHTTP(resp, req)
48 | if resp.Header().Get(headerContentEncoding) == "gzip" {
49 | t.Errorf("Unexpectedly gzip'd: Content-encoding: %q", resp.Header())
50 | }
51 | if resp.Body.String() != "Hi there!" {
52 | t.Errorf("Wrong response: %q", resp.Body.String())
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/response_writer.go:
--------------------------------------------------------------------------------
1 | package sandwich
2 |
3 | import (
4 | "bufio"
5 | "fmt"
6 | "net"
7 | "net/http"
8 | )
9 |
10 | // WrapResponseWriter creates a ResponseWriter and returns it as both an
11 | // http.ResponseWriter and a *ResponseWriter. The double return is redundant
12 | // for native Go code, but is a necessary hint to the dependency injection.
13 | func WrapResponseWriter(w http.ResponseWriter) (http.ResponseWriter, *ResponseWriter) {
14 | rw := &ResponseWriter{w, 0, 0}
15 | return rw, rw
16 | }
17 |
18 | // ResponseWriter wraps http.ResponseWriter to add tracking of the response size
19 | // and response code.
20 | type ResponseWriter struct {
21 | http.ResponseWriter
22 | Size int // The size of the response written so far, in bytes.
23 | Code int // The status code of the response, or 0 if not written yet.
24 | }
25 |
26 | func (w *ResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
27 | hijacker, ok := w.ResponseWriter.(http.Hijacker)
28 | if !ok {
29 | return nil, nil, fmt.Errorf("the ResponseWriter doesn't support the Hijacker interface")
30 | }
31 | return hijacker.Hijack()
32 | }
33 |
34 | func (w *ResponseWriter) Flush() {
35 | flusher, ok := w.ResponseWriter.(http.Flusher)
36 | if ok {
37 | flusher.Flush()
38 | }
39 | }
40 |
41 | func (w *ResponseWriter) WriteHeader(code int) {
42 | if w.Code == 0 {
43 | w.Code = code
44 | }
45 | w.ResponseWriter.WriteHeader(code)
46 | }
47 |
48 | func (w *ResponseWriter) Write(p []byte) (int, error) {
49 | if w.Code == 0 {
50 | w.Code = 200
51 | }
52 | n, err := w.ResponseWriter.Write(p)
53 | w.Size += n
54 | return n, err
55 | }
56 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/bradrydzewski/go.auth v0.0.0-20130828171325-d0051b5cc538 h1:xdNrK3humN0dHUoArlWMrkDUBcxtDHlJ7196exYTbsI=
2 | github.com/bradrydzewski/go.auth v0.0.0-20130828171325-d0051b5cc538/go.mod h1:uPPs9JS166ZU50k/tsHQ5Dux4qOBjsJyD2zrU0WaUbk=
3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
6 | github.com/dchest/authcookie v0.0.0-20190824115100-f900d2294c8e h1:xizeG5ksKSdyNaom2//2Bow4hLWqXkCql36nrL9iEUI=
7 | github.com/dchest/authcookie v0.0.0-20190824115100-f900d2294c8e/go.mod h1:x7AK2h2QzaXVEFi1tbMYMDuvHcCEr1QdMDrg3hkW24Q=
8 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
9 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
10 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
11 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
12 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
13 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
14 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
15 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
16 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
17 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
18 |
--------------------------------------------------------------------------------
/gzip.go:
--------------------------------------------------------------------------------
1 | package sandwich
2 |
3 | import (
4 | "compress/gzip"
5 | "net/http"
6 | "strings"
7 | )
8 |
9 | const (
10 | headerAcceptEncoding = "Accept-Encoding"
11 | headerContentEncoding = "Content-Encoding"
12 | headerContentLength = "Content-Length"
13 | headerContentType = "Content-Type"
14 | headerVary = "Vary"
15 | )
16 |
17 | // Gzip wraps a sandwich.Middleware to add gzip compression to the output for
18 | // all subsequent handlers.
19 | //
20 | // For example, to gzip everything you could use:
21 | //
22 | // router.Use(sandwich.Gzip)
23 | // ...use as normal...
24 | //
25 | // Or, to gzip just a particular route you could do:
26 | //
27 | // router.Get("/foo/bar", sandwich.Gzip, MyHandleFooBar)
28 | //
29 | // Note that this does NOT auto-detect the content and disable compression for
30 | // already-compressed data (e.g. jpg images).
31 | var Gzip = Wrap{provideGZipWriter, (*gZipWriter).Flush}
32 |
33 | func provideGZipWriter(w http.ResponseWriter, r *http.Request) (http.ResponseWriter, *gZipWriter) {
34 | if !strings.Contains(r.Header.Get(headerAcceptEncoding), "gzip") {
35 | return w, nil
36 | }
37 | headers := w.Header()
38 | headers.Set(headerContentEncoding, "gzip")
39 | headers.Set(headerVary, headerAcceptEncoding)
40 |
41 | wr := &gZipWriter{w, gzip.NewWriter(w)}
42 | return wr, wr
43 | }
44 |
45 | type gZipWriter struct {
46 | http.ResponseWriter
47 | w *gzip.Writer
48 | }
49 |
50 | func (g *gZipWriter) Write(p []byte) (int, error) {
51 | if len(g.Header().Get(headerContentType)) == 0 {
52 | g.Header().Set(headerContentType, http.DetectContentType(p))
53 | }
54 | return g.w.Write(p)
55 | }
56 |
57 | func (g *gZipWriter) Flush() {
58 | g.Header().Del(headerContentLength)
59 | g.w.Close()
60 | }
61 |
--------------------------------------------------------------------------------
/chain/naming_test.go:
--------------------------------------------------------------------------------
1 | package chain
2 |
3 | import (
4 | "net/http"
5 | "reflect"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestNameMapper(t *testing.T) {
12 | var n nameMapper
13 |
14 | assert.Equal(t, "i64", n.For(reflect.TypeOf(int64(0))), "int64")
15 | assert.Equal(t, "i", n.For(reflect.TypeOf(int(0))), "int")
16 | assert.Equal(t, "u", n.For(reflect.TypeOf(uint(0))), "uint")
17 | assert.Equal(t, "f32", n.For(reflect.TypeOf(float32(0))), "float32")
18 | assert.Equal(t, "f64", n.For(reflect.TypeOf(float64(0))), "float64")
19 | assert.Equal(t, "flag", n.For(reflect.TypeOf(false)), "bool")
20 |
21 | assert.Equal(t, "rw", n.For(reflect.TypeOf((*http.ResponseWriter)(nil)).Elem()), "http.ResponseWriter")
22 | assert.Equal(t, "req", n.For(reflect.TypeOf((*http.Request)(nil))), "*http.Request")
23 |
24 | type Req struct{}
25 | assert.Equal(t, "chain_Req", n.For(reflect.TypeOf(Req{})), "chain.Req (could've been req, but that's taken)")
26 | assert.Equal(t, "pReq", n.For(reflect.TypeOf(&Req{})), "*chain.Req")
27 | assert.Equal(t, "pppReq", n.For(reflect.TypeOf((***Req)(nil))), "***chain.Req")
28 |
29 | var c map[string]struct {
30 | A []byte
31 | B chan bool
32 | }
33 | assert.Equal(t, "map_string_struct_A_uint8_B_chan_bool", n.For(reflect.TypeOf(c)),
34 | "crazy inlined struct")
35 |
36 | var d map[string]struct {
37 | A_uint8_B chan bool
38 | }
39 | assert.Equal(t, "__var12__", n.For(reflect.TypeOf(d)),
40 | "inlined struct with var name that conflicts")
41 |
42 | assert.Equal(t, "u8", n.For(reflect.TypeOf(byte(0))), "single byte")
43 | assert.Equal(t, "sliceOfUint8", n.For(reflect.TypeOf([]byte{})), "byte slice")
44 | assert.Equal(t, "sliceOfInt", n.For(reflect.TypeOf([]int{})), "int slice")
45 | }
46 |
--------------------------------------------------------------------------------
/examples/2-advanced/static/main.js:
--------------------------------------------------------------------------------
1 | function addTask() {
2 | var el = document.querySelector('input[name=desc]');
3 | var data = { desc: el.value };
4 | el.value = "";
5 | var xhr = {
6 | url: 'api/task',
7 | method: 'POST',
8 | credentials: 'include',
9 | body: JSON.stringify(data),
10 | cache: 'no-cache',
11 | };
12 | fetch(xhr.url, xhr).then(response => response.json()).then(data => {
13 | if (data.error) {
14 | setNotice('#FEE', 'Failed: ' + data.error);
15 | } else if (data.task) {
16 | var li = html('');
17 | li.id = 'task' + data.task.id;
18 | li.innerText = data.task.desc; // use innerText so it's escaped!
19 | document.querySelector('#tasks').appendChild(li);
20 | setNotice('#EFE', 'Added task ' + data.task.desc);
21 | } else {
22 | setNotice('#FEE', 'Server error');
23 | console.error('Bad response: ', data);
24 | }
25 | })
26 |
27 | return false;
28 | }
29 |
30 | function toggle(el) {
31 | var id = el.id.substr(4); // skip the "task" prefix
32 | var xhr = {
33 | url: 'api/task/' + id,
34 | method: 'POST',
35 | credentials: 'include',
36 | body: '{"toggle":true,"id":"' + id + '"}',
37 | cache: 'no-cache',
38 | };
39 | fetch(xhr.url, xhr).then(response => response.json()).then(data => {
40 | if (data.error) {
41 | setNotice('#FEE', 'Failed: ' + data.error);
42 | } else if (data.task) {
43 | el.classList.toggle('done', data.task.done);
44 | setNotice('#EFE', 'Updated!')
45 | } else {
46 | setNotice('#FEE', 'Server error');
47 | console.error('Bad response: ', data);
48 | }
49 | });
50 | }
51 |
52 | // Updates the notice div.
53 | function setNotice(col, msg) {
54 | var el = document.querySelector('#notice');
55 | el.innerText = msg;
56 | el.style.background = col;
57 | }
58 |
59 | // Creates an html element from a string.
60 | function html(content) {
61 | var div = document.createElement('div');
62 | div.innerHTML = content;
63 | return div.firstChild;
64 | }
65 |
--------------------------------------------------------------------------------
/chain/ordinal.go:
--------------------------------------------------------------------------------
1 | package chain
2 |
3 | // adapted from https://github.com/martinusso/inflect/blob/master/ordinal.go#L38
4 |
5 | // The MIT License (MIT)
6 | //
7 | // Copyright (c) 2016 Breno Martinusso
8 | //
9 | // Permission is hereby granted, free of charge, to any person obtaining a copy
10 | // of this software and associated documentation files (the "Software"), to deal
11 | // in the Software without restriction, including without limitation the rights
12 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | // copies of the Software, and to permit persons to whom the Software is
14 | // furnished to do so, subject to the following conditions:
15 | //
16 | // The above copyright notice and this permission notice shall be included in all
17 | // copies or substantial portions of the Software.
18 | //
19 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | // SOFTWARE.
26 |
27 | import (
28 | "fmt"
29 | "math"
30 | )
31 |
32 | const (
33 | st = "st"
34 | nd = "nd"
35 | rd = "rd"
36 | th = "th"
37 | )
38 |
39 | // Ordinal returns the ordinal suffix that should be added to a number to denote the position in an ordered sequence such as 1st, 2nd, 3rd, 4th...
40 | func ordinal(number int) string {
41 | switch abs(number) % 100 {
42 | case 11, 12, 13:
43 | return th
44 | default:
45 | switch abs(number) % 10 {
46 | case 1:
47 | return st
48 | case 2:
49 | return nd
50 | case 3:
51 | return rd
52 | }
53 | }
54 | return th
55 | }
56 |
57 | func abs(number int) int {
58 | return int(math.Abs(float64(number)))
59 | }
60 |
61 | // Ordinalize turns a number into an ordinal string
62 | func ordinalize(number int) string {
63 | ordinal := ordinal(number)
64 | return fmt.Sprintf("%d%s", number, ordinal)
65 | }
66 |
--------------------------------------------------------------------------------
/example_test.go:
--------------------------------------------------------------------------------
1 | package sandwich_test
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "io/fs"
7 | "net/http"
8 |
9 | "github.com/augustoroman/sandwich"
10 | )
11 |
12 | type UserID string
13 |
14 | type User struct{}
15 | type UserDB interface {
16 | Get(UserID) (*User, error)
17 | New(*User) (UserID, error)
18 | Del(UserID) error
19 | List() ([]*User, error)
20 | }
21 |
22 | func ExampleRouter() {
23 | var db UserDB // = NewUserDB
24 |
25 | root := sandwich.TheUsual()
26 | root.SetAs(db, &db)
27 |
28 | api := root.SubRouter("/api")
29 | api.OnErr(sandwich.HandleErrorJson)
30 |
31 | apiUsers := api.SubRouter("/users")
32 | apiUsers.Get("/:uid", UserIDFromParam, UserDB.Get, SendUser)
33 | apiUsers.Delete("/:uid", UserIDFromParam, UserDB.Del)
34 | apiUsers.Get("/", UserDB.List, SendUserList)
35 | apiUsers.Post("/", UserFromCreateRequest, UserDB.New, SendUserID)
36 |
37 | var staticFS fs.FS
38 | root.Get("/home/", GetLoggedInUser, UserDB.Get, Home)
39 | root.Get("/:path", http.FileServer(http.FS(staticFS)))
40 |
41 | // Output:
42 | }
43 |
44 | func GetLoggedInUser(r *http.Request) (UserID, error) {
45 | token := r.Header.Get("user-token")
46 | uid := UserID(token) // decode the token to get the user info
47 | if uid == "" {
48 | return "", sandwich.Error{Code: http.StatusUnauthorized, ClientMsg: "invalid user token"}
49 | }
50 | return uid, nil
51 | }
52 |
53 | func Home(w http.ResponseWriter, u *User) {
54 | fmt.Fprintf(w, "Hello %v", u)
55 | }
56 |
57 | func UserIDFromParam(p sandwich.Params) (UserID, error) {
58 | uid := UserID(p["id"])
59 | if uid == "" {
60 | return "", sandwich.Error{Code: 400, ClientMsg: "Missing UID param", LogMsg: "Request missing UID param"}
61 | }
62 | return "", nil
63 | }
64 |
65 | func UserFromCreateRequest(r *http.Request) (*User, error) {
66 | u := &User{}
67 | return u, json.NewDecoder(r.Body).Decode(u)
68 | }
69 |
70 | func SendUserList(w http.ResponseWriter, users []*User) error { return SendJson(w, users) }
71 | func SendUser(w http.ResponseWriter, user *User) error { return SendJson(w, user) }
72 | func SendUserID(w http.ResponseWriter, id UserID) error { return SendJson(w, id) }
73 |
74 | func SendJson(w http.ResponseWriter, val interface{}) error {
75 | w.Header().Set("Content-Type", "application/json")
76 | return json.NewEncoder(w).Encode(val)
77 | }
78 |
--------------------------------------------------------------------------------
/wrap.go:
--------------------------------------------------------------------------------
1 | package sandwich
2 |
3 | import (
4 | "net/http"
5 |
6 | "github.com/augustoroman/sandwich/chain"
7 | )
8 |
9 | // ChainMutation is a special type that allows modifying the chain directly when
10 | // added to a router. This allows advanced usage and should generally not be
11 | // used unless you know what you're doing. In particular, don't add `Arg`s to
12 | // the chain, that will break the router.
13 | type ChainMutation interface {
14 | // Modify the provided chain and return the modified chain.
15 | Apply(c chain.Func) chain.Func
16 | }
17 |
18 | // Wrap provides a mechanism to add two handlers: one that runs during the
19 | // normal course of middleware handling (Before) and one that is defer'd and
20 | // runs after the main set of middleware has executed (After). The defer'd
21 | // handler may accept the `error` type and handle or ignore errors as desired.
22 | //
23 | // This is generally useful for specifying operations that need to run before
24 | // and after subsequent middleware, such as timing, logging, or
25 | // allocation/cleanup operations.
26 | type Wrap struct {
27 | // `Before` is run in the normal course of middleware evaluation. Any returned
28 | // types from this will be available to the defer'd After handler. If Before
29 | // itself returns an error, After will not run.
30 | Before any
31 | // `After` is defer`d and run after the normal course of middleware has
32 | // completed, in reverse order of any registered `defer` handlers. Defer`d
33 | // handlers will always be executed if `Before` was executed, even in the case
34 | // of errors. The `After` handler may accept the `error` type -- that will be
35 | // nil unless a subsequent handler has returned an error.
36 | After any
37 | }
38 |
39 | // Apply modifies the chain to add Before and After.
40 | func (w Wrap) Apply(c chain.Func) chain.Func {
41 | return c.Then(toHandlerFunc(w.Before)).Defer(toHandlerFunc(w.After))
42 | }
43 |
44 | func apply(c chain.Func, handlers ...any) chain.Func {
45 | for _, h := range handlers {
46 | if mod, ok := h.(ChainMutation); ok {
47 | c = mod.Apply(c)
48 | } else {
49 | c = c.Then(toHandlerFunc(h))
50 | }
51 | }
52 | return c
53 | }
54 |
55 | func toHandlerFunc(h any) any {
56 | if handlerInterface, ok := h.(http.Handler); ok {
57 | return handlerInterface.ServeHTTP
58 | }
59 | return h
60 | }
61 |
--------------------------------------------------------------------------------
/chain/util.go:
--------------------------------------------------------------------------------
1 | package chain
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | "runtime"
7 | "sort"
8 | )
9 |
10 | // TODO(aroman) Replace calls with an explicit error type
11 | func panicf(msgfmt string, args ...interface{}) {
12 | panic(fmt.Errorf(msgfmt, args...))
13 | }
14 |
15 | func valueOfFunction(handler interface{}) (FuncInfo, error) {
16 | if handler == nil {
17 | return FuncInfo{}, fmt.Errorf("should be a function, handler is ")
18 | }
19 | val := reflect.ValueOf(handler)
20 | if !val.IsValid() || val.Kind() != reflect.Func {
21 | return FuncInfo{}, fmt.Errorf("should be a function, handler is %s", val.Type())
22 | }
23 | info := runtime.FuncForPC(val.Pointer())
24 | file, line := info.FileLine(val.Pointer())
25 | return FuncInfo{info.Name(), file, line, val}, nil
26 | }
27 |
28 | func checkCanCall(available map[reflect.Type]bool, fn FuncInfo) error {
29 | fn_typ := fn.Func.Type()
30 | for i := 0; i < fn_typ.NumIn(); i++ {
31 | t := fn_typ.In(i)
32 | if available[t] {
33 | continue
34 | }
35 |
36 | // Un-oh, not available. Let's see what we can do to make a helpful
37 | // error message.
38 | provided := []string{}
39 | candidates := []string{}
40 | for typ := range available {
41 | provided = append(provided, typ.String())
42 | if t.Kind() == reflect.Interface && typ.Implements(t) {
43 | candidates = append(candidates, typ.String())
44 | }
45 | }
46 | sort.Strings(provided)
47 |
48 | suggestion := ""
49 | if len(candidates) == 0 && t.Kind() == reflect.Interface {
50 | suggestion = fmt.Sprintf(" Type %s is an interface, but not "+
51 | "implemented by any of the provided types.", t)
52 | } else if len(candidates) == 1 {
53 | suggestion = fmt.Sprintf(" Type %s is an interface that is "+
54 | "implemented by the provided type %s. Did you mean to use "+
55 | "'.SetAs(val, (*%s)(nil))' instead of '.Set(val)'?",
56 | t, candidates[0], strip("main", t))
57 | } else if len(candidates) > 1 {
58 | suggestion = fmt.Sprintf(" Type %s is an interface that is implemented "+
59 | "by %d provided types: %s. If you meant to use one of those, use "+
60 | "'.SetAs(val, (*someInterface)(nil))' to explicitly assign "+
61 | "to that type.",
62 | t, len(candidates), candidates)
63 | }
64 |
65 | return fmt.Errorf("can't be called: type %s required for %s arg "+
66 | "of %s (%s) has not been provided. Types that have been provided: %s. %s",
67 | t, ordinalize(i+1), fn.Name, fn_typ, provided, suggestion)
68 | }
69 | return nil
70 | }
71 |
--------------------------------------------------------------------------------
/chain/example_test.go:
--------------------------------------------------------------------------------
1 | package chain_test
2 |
3 | import (
4 | "fmt"
5 | "os"
6 | "time"
7 |
8 | "github.com/augustoroman/sandwich/chain"
9 | )
10 |
11 | func ExampleFunc() {
12 | example := chain.Func{}.
13 | // Indicate that the chain will will receive a time.Duration as the first
14 | // arg when it's executed.
15 | Arg(time.Duration(0)).
16 | // When the chain is executed, it will first call time.Now which takes no
17 | // arguments but will return a time.Time value that will be available to
18 | // later calls.
19 | Then(time.Now).
20 | // Next, time.Sleep will be invoked, which requires a time.Duration
21 | // parameter. That's available since it's provided as an input to the chain.
22 | Then(time.Sleep).
23 | // Next, time.Since will be invoked, which requires a time.Time value that
24 | // was provided by the earlier time.Now call. It will return a time.Duration
25 | // value that will overwrite the input the chain.
26 | Then(time.Since).
27 | // Finally, we'll print out the stored time.Duration value.
28 | Then(func(dt time.Duration) {
29 | // Round to the nearest 10ms -- this makes the test not-flaky since the
30 | // sleep duration will not have been exact.
31 | dt = dt.Truncate(10 * time.Millisecond)
32 | fmt.Println("elapsed:", dt)
33 | })
34 |
35 | example.MustRun(time.Duration(30 * time.Millisecond))
36 |
37 | // Print the equivalent code:
38 | fmt.Println("Generated code is:")
39 | example.Code("example", "main", os.Stdout)
40 |
41 | // Output:
42 | // elapsed: 30ms
43 | // Generated code is:
44 | // func example(
45 | // ) func(
46 | // duration time.Duration,
47 | // ) {
48 | // return func(
49 | // duration time.Duration,
50 | // ) {
51 | // var time_Time time.Time
52 | // time_Time = time.Now()
53 | //
54 | // time.Sleep(duration)
55 | //
56 | // duration = time.Since(time_Time)
57 | //
58 | // chain_test.ExampleFunc.func1(duration)
59 | //
60 | // }
61 | // }
62 | }
63 |
64 | func ExampleFunc_file() {
65 | // Chains can be used to do file operations!
66 |
67 | writeToFile := chain.Func{}.
68 | Arg(""). // filename
69 | Arg([]byte(nil)). // data
70 | Then(os.Create).
71 | Then((*os.File).Write).
72 | Then((*os.File).Close)
73 |
74 | // This never fails -- any errors in creating the file or writing to it will
75 | // be handled by the default error handler that logs a message, but the `Run`
76 | // itself doesn't fail unless the args are incorrect.
77 | writeToFile.MustRun("test.txt", []byte("the data"))
78 |
79 | content, err := os.ReadFile("test.txt")
80 | panicOnErr(err)
81 | fmt.Printf("test.txt: %s\n", content)
82 |
83 | panicOnErr(os.Remove("test.txt"))
84 |
85 | // Output:
86 | // test.txt: the data
87 | }
88 |
89 | func panicOnErr(err error) {
90 | if err != nil {
91 | panic(err)
92 | }
93 | }
94 |
--------------------------------------------------------------------------------
/chain/codegen.go:
--------------------------------------------------------------------------------
1 | package chain
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "path/filepath"
7 | "reflect"
8 | "runtime"
9 | "strings"
10 | )
11 |
12 | // Code writes the Go code for the current chain out to w assuming it lives in
13 | // package "pkg" with the specified handler function name.
14 | func (c Func) Code(name, pkg string, w io.Writer) {
15 | vars := &nameMapper{}
16 |
17 | for _, s := range c.steps {
18 | vars.Reserve(s.valTyp.Name())
19 | vars.Reserve(filepath.Base(s.valTyp.PkgPath()))
20 | }
21 |
22 | fmt.Fprintf(w, "func %s(\n", name)
23 | for _, s := range c.steps {
24 | switch s.typ {
25 | case tVALUE:
26 | fmt.Fprintf(w, "\t%s %s,\n", vars.For(s.valTyp), strip(pkg, s.valTyp))
27 | }
28 | }
29 | fmt.Fprintf(w, ") func(\n")
30 | for _, s := range c.steps {
31 | if s.typ == tARG {
32 | fmt.Fprintf(w, "\t%s %s,\n", vars.For(s.valTyp), strip(pkg, s.valTyp))
33 | }
34 | }
35 | fmt.Fprintf(w, ") {\n")
36 |
37 | fmt.Fprintf(w, "\treturn func(\n")
38 | for _, s := range c.steps {
39 | if s.typ == tARG {
40 | fmt.Fprintf(w, "\t\t%s %s,\n", vars.For(s.valTyp), strip(pkg, s.valTyp))
41 | }
42 | }
43 | fmt.Fprintf(w, "\t) {\n")
44 |
45 | errHandler := step{tERROR_HANDLER, reflect.ValueOf(DefaultErrorHandler), nil}
46 | for _, s := range c.steps {
47 | if s.typ == tARG || s.typ == tVALUE {
48 | continue
49 | }
50 |
51 | if s.typ == tERROR_HANDLER {
52 | errHandler = s
53 | continue
54 | }
55 |
56 | for i := 0; i < s.valTyp.NumOut(); i++ {
57 | t := s.valTyp.Out(i)
58 | if !vars.Has(t) {
59 | fmt.Fprintf(w, "\t\tvar %s %s\n", vars.For(t), strip(pkg, t))
60 | }
61 | }
62 |
63 | if s.typ == tPOST_HANDLER {
64 | fmt.Fprintf(w, "\t\tdefer func() {\n\t")
65 | }
66 |
67 | name, inVars, outVars, returnsError := getArgNames(pkg, vars, s.val)
68 |
69 | fmt.Fprintf(w, "\t\t")
70 | if len(outVars) > 0 {
71 | fmt.Fprintf(w, "%s = ", strings.Join(outVars, ", "))
72 | }
73 | fmt.Fprintf(w, "%s(%s)\n", name, strings.Join(inVars, ", "))
74 |
75 | if returnsError {
76 | name, inVars, _, _ := getArgNames(pkg, vars, errHandler.val)
77 | fmt.Fprintf(w, "\t\tif err != nil {\n")
78 | fmt.Fprintf(w, "\t\t\t%s(%s)\n", name, strings.Join(inVars, ", "))
79 | fmt.Fprintf(w, "\t\t\treturn\n")
80 | fmt.Fprintf(w, "\t\t}\n")
81 | }
82 |
83 | if s.typ == tPOST_HANDLER {
84 | fmt.Fprintf(w, "\t\t}()\n")
85 | }
86 | fmt.Fprintf(w, "\n")
87 | }
88 | fmt.Fprintf(w, "\t}\n")
89 | fmt.Fprintf(w, "}\n")
90 | }
91 |
92 | func strip(pkg string, t reflect.Type) string {
93 | return stripStr(pkg, t.String())
94 | }
95 | func stripStr(pkg, s string) string {
96 | pos := strings.IndexFunc(s, func(r rune) bool { return r != '*' })
97 | s = s[:pos] + strings.TrimPrefix(s[pos:], pkg+".")
98 | return s
99 | }
100 |
101 | func getArgNames(pkg string, vars *nameMapper, v reflect.Value) (name string, in, out []string, returnsError bool) {
102 | name = runtime.FuncForPC(v.Pointer()).Name()
103 | name = filepath.Base(name)
104 | name = strings.TrimPrefix(name, pkg+".")
105 |
106 | if pos := strings.Index(name, ".(*"); pos > 0 {
107 | pkgName := name[:pos+1]
108 | pkgName = strings.TrimPrefix(pkgName, pkg+".")
109 | name = "(*" + pkgName + name[pos+3:]
110 | }
111 |
112 | t := v.Type()
113 | out = make([]string, t.NumOut())
114 | for i := 0; i < t.NumOut(); i++ {
115 | out[i] = vars.For(t.Out(i))
116 | if t.Out(i) == errorType {
117 | returnsError = true
118 | }
119 | }
120 | in = make([]string, t.NumIn())
121 | for i := 0; i < t.NumIn(); i++ {
122 | in[i] = vars.For(t.In(i))
123 | }
124 | return name, in, out, returnsError
125 | }
126 |
--------------------------------------------------------------------------------
/errors.go:
--------------------------------------------------------------------------------
1 | package sandwich
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "net/http"
7 | )
8 |
9 | // Error is an error implementation that provides the ability to specify three
10 | // things to the sandwich error handler:
11 | // - The HTTP status code that should be used in the response.
12 | // - The client-facing message that should be sent. Typically this is a
13 | // sanitized error message, such as "Internal Server Error".
14 | // - Internal debugging detail including a log message and the underlying
15 | // error that should be included in the server logs.
16 | //
17 | // Note that Cause may be nil.
18 | //
19 | // The sandwich standard Error handlers (HandleError and HandleErrorJson) will
20 | // respect these Errors and respond with the appropriate status code and client
21 | // message. Additionally, the sandwich standard log handling will log LogMsg.
22 | type Error struct {
23 | Code int
24 | ClientMsg string
25 | LogMsg string
26 | Cause error
27 | }
28 |
29 | func (e Error) Error() string {
30 | msg := fmt.Sprintf("(%d) %s", e.Code, e.LogMsg)
31 | if e.LogMsg == "" {
32 | msg += e.ClientMsg
33 | }
34 | if e.Cause != nil {
35 | msg += ": " + e.Cause.Error()
36 | }
37 | return msg
38 | }
39 |
40 | // LogIfMsg will set the Error field on the LogEntry if the Error's LogMsg
41 | // field has something.
42 | func (e Error) LogIfMsg(l *LogEntry) {
43 | if e.LogMsg != "" {
44 | l.Error = e
45 | }
46 | }
47 |
48 | // Done is a sentinel error value that can be used to interrupt the middleware
49 | // chain without triggering the default error handling. HandleError will not
50 | // attempt to write any status code or client message, nor will it add the error
51 | // to the log.
52 | var Done = errors.New("")
53 |
54 | // ToError will convert a generic non-nil error to an explicit sandwich.Error
55 | // type. If err is already a sandwich.Error, it will be returned. Otherwise, a
56 | // generic 500 Error (internal server error) will be initialized and returned.
57 | // Note that if err is nil, it will still return a generic 500 Error.
58 | func ToError(err error) Error {
59 | var e Error
60 | if errors.As(err, &e) {
61 | if e.Code == 0 {
62 | e.Code = 500
63 | }
64 | if e.ClientMsg == "" {
65 | e.ClientMsg = http.StatusText(e.Code)
66 | }
67 | return e
68 | }
69 | return Error{
70 | Code: http.StatusInternalServerError,
71 | LogMsg: "Failure",
72 | Cause: err,
73 | ClientMsg: http.StatusText(http.StatusInternalServerError),
74 | }
75 | }
76 |
77 | // HandleError is the default error handler included in sandwich.TheUsual.
78 | // If the error is a sandwich.Error, it responds with the specified status code
79 | // and client message. Otherwise, it responds with a 500. In both cases, the
80 | // underlying error is added to the request log.
81 | //
82 | // If the error is sandwich.Done, HandleError does nothing.
83 | func HandleError(w http.ResponseWriter, r *http.Request, l *LogEntry, err error) {
84 | if err == Done {
85 | return
86 | }
87 | e := ToError(err)
88 | e.LogIfMsg(l)
89 | http.Error(w, e.ClientMsg, e.Code)
90 | }
91 |
92 | // HandleErrorJson is identical to HandleError except that it responds to the
93 | // client as JSON instead of plain text. Again, detailed error info is added
94 | // to the request log.
95 | //
96 | // If the error is sandwich.Done, HandleErrorJson does nothing.
97 | func HandleErrorJson(w http.ResponseWriter, r *http.Request, l *LogEntry, err error) {
98 | if err == Done {
99 | return
100 | }
101 | e := ToError(err)
102 | e.LogIfMsg(l)
103 | w.Header().Set("Content-Type", "application/json")
104 | w.WriteHeader(e.Code)
105 | fmt.Fprintf(w, `{"error":%q}`, e.ClientMsg)
106 | }
107 |
--------------------------------------------------------------------------------
/logger.go:
--------------------------------------------------------------------------------
1 | package sandwich
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "net/http"
7 | "os"
8 | "sort"
9 | "strings"
10 | "time"
11 | )
12 |
13 | // Injected for testing
14 | var time_Now = time.Now
15 | var os_Stderr io.Writer = os.Stderr
16 |
17 | // LogEntry is the information tracked on a per-request basis for the sandwich
18 | // Logger. All fields other than Note are automatically filled in. The Note
19 | // field is a generic key-value string map for adding additional per-request
20 | // metadata to the logs. You can take *sandwich.LogEntry to your functions to
21 | // add fields to Note.
22 | //
23 | // For example:
24 | //
25 | // func MyAuthCheck(r *http.Request, e *sandwich.LogEntry) (User, error) {
26 | // user, err := decodeAuthCookie(r)
27 | // if user != nil {
28 | // e.Note["user"] = user.Id() // indicate which user is auth'd
29 | // }
30 | // return user, err
31 | // }
32 | type LogEntry struct {
33 | RemoteIp string
34 | Start time.Time
35 | Request *http.Request
36 | StatusCode int
37 | ResponseSize int
38 | Elapsed time.Duration
39 | Error error
40 | Note map[string]string
41 | // set to true to suppress logging this request
42 | Quiet bool
43 | }
44 |
45 | // NoLog is a middleware function that suppresses log output for this request.
46 | // For example:
47 | //
48 | // // suppress logging of the favicon request to reduce log spam.
49 | // router.Get("/favicon.ico", sandwich.NoLog, staticHandler)
50 | //
51 | // This depends on WriteLog respecting the Quiet flag, which the default
52 | // implementation does.
53 | func NoLog(e *LogEntry) { e.Quiet = true }
54 |
55 | // LogRequests is a middleware wrap that creates a log entry during middleware
56 | // processing and then commits the log entry after the middleware has executed.
57 | var LogRequests = Wrap{NewLogEntry, (*LogEntry).Commit}
58 |
59 | // NewLogEntry creates a *LogEntry and initializes it with basic request
60 | // information.
61 | func NewLogEntry(r *http.Request) *LogEntry {
62 | return &LogEntry{
63 | RemoteIp: remoteIp(r),
64 | Start: time_Now(),
65 | Request: r,
66 | Note: map[string]string{},
67 | }
68 | }
69 |
70 | // Commit fills in the remaining *LogEntry fields and writes the entry out.
71 | func (entry *LogEntry) Commit(w *ResponseWriter) {
72 | entry.Elapsed = time_Now().Sub(entry.Start)
73 | entry.ResponseSize = w.Size
74 | entry.StatusCode = w.Code
75 | WriteLog(*entry)
76 | }
77 |
78 | // Some nice escape codes
79 | const (
80 | _GREEN = "\033[32m"
81 | _YELLOW = "\033[33m"
82 | _RESET = "\033[0m"
83 | _RED = "\033[91m"
84 | )
85 |
86 | // WriteLog is called to actually write a LogEntry out to the log. By default,
87 | // it writes to stderr and colors normal requests green, slow requests yellow,
88 | // and errors red. You can replace the function to adjust the formatting or use
89 | // whatever logging library you like.
90 | var WriteLog = func(e LogEntry) {
91 | if e.Quiet {
92 | return
93 | }
94 | col, reset := logColors(e)
95 | fmt.Fprintf(os_Stderr, "%s%s %s \"%s %s\" (%d %dB %s) %s%s\n",
96 | col,
97 | e.Start.Format(time.RFC3339), e.RemoteIp,
98 | e.Request.Method, e.Request.RequestURI,
99 | e.StatusCode, e.ResponseSize, e.Elapsed,
100 | e.NotesAndError(),
101 | reset)
102 | }
103 |
104 | // NotesAndError formats the Note values and error (if any) for logging.
105 | func (l LogEntry) NotesAndError() string {
106 | pairs := make([]string, len(l.Note))
107 | for k, v := range l.Note {
108 | pairs = append(pairs, fmt.Sprintf("%s=%q", k, v))
109 | }
110 | sort.Strings(pairs)
111 | msg := strings.Join(pairs, " ")
112 | if l.Error != nil {
113 | msg += "\n ERROR: " + l.Error.Error()
114 | }
115 | return msg
116 | }
117 |
118 | func logColors(e LogEntry) (start, reset string) {
119 | col, reset := _GREEN, _RESET
120 | if e.Elapsed > 30*time.Millisecond {
121 | col = _YELLOW
122 | }
123 | if e.StatusCode >= 400 || e.Error != nil {
124 | col, reset = _RED, _RESET // high-intensity red + reset
125 | }
126 | return col, reset
127 | }
128 |
129 | // remoteIp extracts the remote IP from the request. Adapted from code in
130 | // Martini:
131 | //
132 | // https://github.com/go-martini/martini/blob/1d33529c15f19/logger.go#L14..L20
133 | func remoteIp(r *http.Request) string {
134 | if addr := r.Header.Get("X-Real-IP"); addr != "" {
135 | return addr
136 | } else if addr := r.Header.Get("X-Forwarded-For"); addr != "" {
137 | return addr
138 | }
139 | return r.RemoteAddr
140 | }
141 |
--------------------------------------------------------------------------------
/benchmark_simple_handler_test.go:
--------------------------------------------------------------------------------
1 | package sandwich
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 | )
10 |
11 | // Sample data to JSON-encode for benchmarking.
12 | var userInfo = struct {
13 | Id uint64 `json:"id"`
14 | Name string `json:"name"`
15 | Email string `json:"email"`
16 | AvatarUrl string `json:"avatar_url"`
17 | }{
18 | 12345467, "John Doe", "john.doe@example.com", "https://www.example.com/users/john.doe/image",
19 | }
20 |
21 | func write204(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) }
22 | func hello(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("Hello there!")) }
23 | func sendjson(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(userInfo) }
24 |
25 | func addTestRoutes(mux Router) {
26 | mux.Get("/204", write204)
27 | mux.Get("/hello", hello)
28 | mux.Get("/jsonuser", sendjson)
29 |
30 | mux.Get("/long/1/2/3/4/5/6/7/8/9/xyz/204", write204)
31 | mux.Get("/long/1/2/3/4/5/6/7/8/9/xyz/hello", hello)
32 | mux.Get("/long/1/2/3/4/5/6/7/8/9/xyz/jsonuser", sendjson)
33 |
34 | mux.Get("/1param/:var/204", write204)
35 | mux.Get("/1param/:var/hello", hello)
36 | mux.Get("/1param/:var/jsonuser", sendjson)
37 |
38 | mux.Get("/manyparams/:var/:x/:y/:z/:a/:b/:c/204", write204)
39 | mux.Get("/manyparams/:var/:x/:y/:z/:a/:b/:c/hello", hello)
40 | mux.Get("/manyparams/:var/:x/:y/:z/:a/:b/:c/jsonuser", sendjson)
41 |
42 | mux.Get("/greedy/:var*/204", write204)
43 | mux.Get("/greedy/:var*/hello", hello)
44 | mux.Get("/greedy/:var*/jsonuser", sendjson)
45 | }
46 |
47 | var usualRouter = func() Router {
48 | mux := TheUsual()
49 | mux.Use(NoLog)
50 | addTestRoutes(mux)
51 | return mux
52 | }()
53 | var bareRouter = func() Router {
54 | mux := BuildYourOwn()
55 | addTestRoutes(mux)
56 | return mux
57 | }()
58 |
59 | func bench(N int, route string, mux Router) {
60 | req := httptest.NewRequest("GET", route, nil)
61 | for i := 0; i < N; i++ {
62 | w := httptest.NewRecorder()
63 | mux.ServeHTTP(w, req)
64 | }
65 | }
66 |
67 | func S(s ...string) []string { return s }
68 |
69 | func runBenches(b *testing.B, mux Router, routes, endpoints []string) {
70 | for _, ep := range endpoints {
71 | for _, route := range routes {
72 | path := route + "/" + ep
73 | b.Run(ep+"::"+path, func(b *testing.B) {
74 | bench(b.N, path, mux)
75 | })
76 | }
77 | }
78 | }
79 |
80 | // Just to shorten the following benchmark functions:
81 | type Handler = http.HandlerFunc
82 |
83 | func BenchmarkUsual(b *testing.B) {
84 | runBenches(b, usualRouter,
85 | S("", "/long/1/2/3/4/5/6/7/8/9/xyz", "/1param/foo", "/manyparams/foo/x/y/z/a/b/c", "/greedy/short", "/greedy/x/y/z/a/b/c"),
86 | S("204", "hello", "jsonuser"),
87 | )
88 | }
89 | func BenchmarkBare(b *testing.B) {
90 | runBenches(b, bareRouter,
91 | S("", "/long/1/2/3/4/5/6/7/8/9/xyz", "/1param/foo", "/manyparams/foo/x/y/z/a/b/c", "/greedy/short", "/greedy/x/y/z/a/b/c"),
92 | S("204", "hello", "jsonuser"),
93 | )
94 | }
95 |
96 | func BenchmarkCalls(b *testing.B) {
97 | for i := 1; i < 20; i += 2 {
98 | b.Run(fmt.Sprintf("%02d", i), func(b *testing.B) {
99 | var calls []any
100 | for j := 0; j < i; j++ {
101 | calls = append(calls, hello)
102 | }
103 | mux := BuildYourOwn()
104 | mux.Get("/", calls...)
105 | b.ResetTimer()
106 | bench(b.N, "/", mux)
107 | })
108 | }
109 | }
110 |
111 | // func Benchmark_Usual_Short_Write204(b *testing.B) { bench(b.N, "/204", usualRouter) }
112 | // func Benchmark_Usual_Short_Hello(b *testing.B) { bench(b.N, "/hello", usualRouter) }
113 | // func Benchmark_Usual_Short_SnedJson(b *testing.B) { bench(b.N, "/204", usualRouter) }
114 |
115 | // // func Benchmark_Hello_RawHTTP(b *testing.B) { bench(b.N, Handler(hello)) }
116 | // // func Benchmark_Hello_Dynamic_Bare(b *testing.B) { bench(b.N, makeHello_Bare()) }
117 | // // func Benchmark_Hello_Dynamic_TheUsual(b *testing.B) { bench(b.N, makeHello_TheUsual()) }
118 |
119 | // // func Benchmark_Write204_RawHTTP(b *testing.B) { bench(b.N, Handler(write204)) }
120 |
121 | // // func Benchmark_Write204_Dynamic_Bare(b *testing.B) { bench(b.N, makeWrite204_Bare()) }
122 | // // func Benchmark_Write204_Dynamic_TheUsual(b *testing.B) { bench(b.N, makeWrite204_TheUsual()) }
123 |
124 | // // func Benchmark_SendJson_RawHTTP(b *testing.B) { bench(b.N, Handler(sendjson)) }
125 | // // func Benchmark_SendJson_Dynamic_Bare(b *testing.B) { bench(b.N, makeSendJson_Bare()) }
126 | // // func Benchmark_SendJson_Dynamic_TheUsual(b *testing.B) { bench(b.N, makeSendJson_TheUsual()) }
127 |
--------------------------------------------------------------------------------
/logger_test.go:
--------------------------------------------------------------------------------
1 | package sandwich
2 |
3 | import (
4 | "bytes"
5 | "errors"
6 | "net/http"
7 | "net/http/httptest"
8 | "os"
9 | "strings"
10 | "testing"
11 | "time"
12 |
13 | "github.com/augustoroman/sandwich/chain"
14 | )
15 |
16 | type fakeClock struct {
17 | now time.Time
18 | advance time.Duration
19 | }
20 |
21 | func (f *fakeClock) Now() time.Time {
22 | now := f.now
23 | f.now = now.Add(f.advance)
24 | return now
25 | }
26 |
27 | func (f *fakeClock) Sleep(dt time.Duration) {
28 | f.now = f.now.Add(dt)
29 | }
30 |
31 | func validateLogMessage(t *testing.T, logs, expectedColor, expectedMsg string) {
32 | logs = strings.TrimSpace(logs)
33 |
34 | if !strings.HasPrefix(logs, expectedColor) {
35 | t.Errorf("Expected color prefix of %q: %q", expectedColor, logs)
36 | } else {
37 | logs = strings.TrimPrefix(logs, expectedColor)
38 | }
39 | if !strings.HasSuffix(logs, _RESET) {
40 | t.Errorf("Expected reset suffix: %q", logs)
41 | } else {
42 | logs = strings.TrimSuffix(logs, _RESET)
43 | }
44 | logs = strings.TrimSpace(logs)
45 | expectedMsg = strings.TrimSpace(expectedMsg)
46 | if logs != expectedMsg {
47 | t.Errorf("Wrong log message:\nExp: %q\nGot: %q", expectedMsg, logs)
48 | }
49 | }
50 |
51 | func TestLogger(t *testing.T) {
52 | // Restore the world from insanity when we're done:
53 | orig := WriteLog
54 | defer func() { time_Now = time.Now; os_Stderr = os.Stderr; WriteLog = orig }()
55 |
56 | // Setup our fake world.
57 | var logBuf bytes.Buffer
58 | os_Stderr = &logBuf
59 | clk := &fakeClock{time.Date(2001, 2, 3, 4, 5, 6, 7, time.UTC), 13 * time.Millisecond}
60 | time_Now = clk.Now
61 |
62 | // Useful handlers:
63 | sendMsg := func(w http.ResponseWriter) { _, _ = w.Write([]byte("Hi there")) }
64 | slowSendMsg := func(w http.ResponseWriter) { clk.Sleep(100 * time.Millisecond); sendMsg(w) }
65 | fail := func() error { return errors.New("It went horribly wrong") }
66 | slowFail := func() error { clk.Sleep(time.Second); return fail() }
67 | panics := func(w http.ResponseWriter) { sendMsg(w); panic("oops") }
68 | addsNote := func(w http.ResponseWriter, e *LogEntry) { e.Note["a"] = "x"; e.Note["b"] = "y"; sendMsg(w) }
69 |
70 | var resp *httptest.ResponseRecorder
71 | var req *http.Request
72 |
73 | mux := TheUsual()
74 |
75 | // Test a normal response:
76 | logBuf.Reset()
77 | resp = httptest.NewRecorder()
78 | req, _ = http.NewRequest("GET", "/", nil)
79 | req.RequestURI = req.URL.String()
80 | req.Header.Add("X-Real-IP", "123.456.789.0")
81 | mux.Get("/", addsNote)
82 | mux.ServeHTTP(resp, req)
83 | validateLogMessage(t, logBuf.String(), _GREEN,
84 | `2001-02-03T04:05:06Z 123.456.789.0 "GET /" (200 8B 13ms) a="x" b="y"`)
85 |
86 | // Test a slow response:
87 | logBuf.Reset()
88 | resp = httptest.NewRecorder()
89 | req, _ = http.NewRequest("POST", "/slow", nil)
90 | req.RequestURI = req.URL.String()
91 | req.Header.Add("X-Forwarded-For", "")
92 | mux.Post("/slow", slowSendMsg)
93 | mux.ServeHTTP(resp, req)
94 | validateLogMessage(t, logBuf.String(), _YELLOW,
95 | `2001-02-03T04:05:06Z "POST /slow" (200 8B 113ms)`)
96 |
97 | // Test a failed response:
98 | logBuf.Reset()
99 | resp = httptest.NewRecorder()
100 | req, _ = http.NewRequest("BOO!", "/fail", nil)
101 | req.RequestURI = req.URL.String()
102 | req.RemoteAddr = "[::1]:56596"
103 | mux.On("BOO!", "/fail", fail)
104 | mux.ServeHTTP(resp, req)
105 | validateLogMessage(t, logBuf.String(), _RED,
106 | `2001-02-03T04:05:06Z [::1]:56596 "BOO! /fail" (500 22B 13ms) `+"\n"+
107 | ` ERROR: (500) Failure: It went horribly wrong`)
108 |
109 | // Test a slow failed response (should still be red):
110 | logBuf.Reset()
111 | resp = httptest.NewRecorder()
112 | req, _ = http.NewRequest("PUT", "/slowfail", nil)
113 | req.RequestURI = req.URL.String()
114 | req.RemoteAddr = "[::1]:56596"
115 | req.Header.Add("X-Forwarded-For", "")
116 | req.Header.Add("X-Real-IP", "123.456.789.0") // takes precedence
117 | mux.Put("/slowfail", slowFail)
118 | mux.ServeHTTP(resp, req)
119 | validateLogMessage(t, logBuf.String(), _RED,
120 | `2001-02-03T04:05:06Z 123.456.789.0 "PUT /slowfail" (500 22B 1.013s) `+"\n"+
121 | ` ERROR: (500) Failure: It went horribly wrong`)
122 |
123 | // Test a suppressed log.
124 | logBuf.Reset()
125 | resp = httptest.NewRecorder()
126 | req, _ = http.NewRequest("GET", "/nolog", nil)
127 | mux.Get("/nolog", NoLog, addsNote)
128 | mux.ServeHTTP(resp, req)
129 | if logBuf.String() != "" {
130 | t.Errorf("Expected no log output, but got [%s]", logBuf.String())
131 | }
132 |
133 | // Test that a panic should be recorded.
134 | var log LogEntry
135 | WriteLog = func(e LogEntry) { log = e }
136 | resp = httptest.NewRecorder()
137 | req, _ = http.NewRequest("PUT", "/panic", nil)
138 | req.RequestURI = req.URL.String()
139 | req.RemoteAddr = ""
140 | mux.Put("/panic", panics)
141 | mux.ServeHTTP(resp, req)
142 |
143 | if err, ok := ToError(log.Error).Cause.(chain.PanicError); !ok {
144 | t.Errorf("log error should be a panic, but is: %#v", log.Error)
145 | } else if msg := err.Error(); !strings.Contains(msg, `Panic executing middleware`) {
146 | t.Errorf("Bad err message: %s", err)
147 | } else if !strings.Contains(msg, `oops`) {
148 | t.Errorf("Bad err message: %s", err)
149 | }
150 |
151 | if resp.Body.String() != "Hi thereInternal Server Error\n" {
152 | t.Errorf("Incorrect client response: %q", resp.Body.String())
153 | }
154 | }
155 |
--------------------------------------------------------------------------------
/examples/1-simple/main.go:
--------------------------------------------------------------------------------
1 | // 1-simple is a demo webserver for the sandwich middleware package
2 | // demonstrating basic usage.
3 | //
4 | // This example demonstrates most of the basic features of sandwich, including:
5 | //
6 | // * Providing user types to the middleware chain
7 | // * Adding middleware handlers to the stack.
8 | // * Writing handlers that provide request-scoped values.
9 | // * Writing handlers using injected values.
10 | // * Using the default sandwich logging system.
11 | // * Using the default sandwich error system.
12 | package main
13 |
14 | import (
15 | "fmt"
16 | "log"
17 | "net/http"
18 | "time"
19 |
20 | "github.com/augustoroman/sandwich"
21 | )
22 |
23 | // Interface for abstracting out the user database.
24 | type UserDb interface {
25 | Lookup(id string) (User, error)
26 | }
27 | type User struct{ Id, Name, Email string }
28 |
29 | func main() {
30 | // To reduce log spam, we'll just put this here, not using any framework.
31 | http.Handle("/favicon.ico", http.NotFoundHandler())
32 |
33 | // Setup connections to the databases.
34 | udb := userDb{{"bob", "Bob", "bob@example.com"}, {"alice", "Alice", "alice@example.com"}}
35 |
36 | // Create a typical sandwich middleware with logging and error-handling.
37 | mux := sandwich.TheUsual()
38 | // Inject config and user database; now available to all handlers.
39 | mux.SetAs(udb, (*UserDb)(nil))
40 | // In this example, we'll always check to see if the user is logged in.
41 | // If so, we'll add the user ID to the log entries.
42 | mux.Use(ParseUserIfLoggedIn)
43 |
44 | // If the user is logged in, they'll get a personalized landing page.
45 | // Otherwise, they'll get a generic landing page.
46 | mux.Get("/", ShowLandingPage)
47 | mux.Post("/login", Login)
48 |
49 | // Some pages are only allowed if the user is logged in.
50 | mux.Get("/user/profile", FailIfNotAuthenticated, ShowUserProfile)
51 | // If you have multiple pages that require authentication, you could do:
52 | // authed := mw.Then(FailIfNotAuthenticated)
53 | // http.Handle("/user/profile", authed.Then(ShowUserProfile))
54 | // http.Handle("/user/...", authed.Then(...))
55 | // http.Handle("/user/...", authed.Then(...))
56 |
57 | log.Println("Serving on http://localhost:8080/")
58 | if err := http.ListenAndServe(":8080", mux); err != nil {
59 | log.Fatal("Can't start webserver:", err)
60 | }
61 | }
62 |
63 | // The actual user DB implementation.
64 | type userDb []User
65 |
66 | func (udb userDb) Lookup(id string) (User, error) {
67 | for _, u := range udb {
68 | if id == u.Id {
69 | return u, nil
70 | }
71 | }
72 | return User{}, fmt.Errorf("no such user %q", id)
73 | }
74 |
75 | func ShowLandingPage(w http.ResponseWriter, u *User) {
76 | fmt.Fprintln(w, "")
77 | if u == nil {
78 | fmt.Fprint(w, "Hello unknown person!")
79 | fmt.Fprintf(w, " [profile will fail and log an error]")
80 | } else {
81 | fmt.Fprintf(w, "Welcome back, %s!", u.Name)
82 | fmt.Fprintf(w, " [profile]")
83 | }
84 | fmt.Fprintln(w, `
85 | Login
86 |
90 |
91 | Try logging in with:
92 | - "alice" will authenticate to Alice
93 |
- "bob" will authenticate to Bob but panic during request handling
94 |
- any other string for a non-authenticated user.
95 |
96 | `)
97 | }
98 |
99 | func ShowUserProfile(w http.ResponseWriter, u User) {
100 | fmt.Fprintln(w, "")
101 | fmt.Fprintf(w, "Id: %s
Name: %s
Email:%s", u.Id, u.Name, u.Email)
102 |
103 | // Show an example of user-code panicking in a handler.
104 | if u.Id == "bob" {
105 | panic("oops")
106 | }
107 | }
108 |
109 | func Login(w http.ResponseWriter, r *http.Request, udb UserDb, e *sandwich.LogEntry) {
110 | u, err := udb.Lookup(r.FormValue("id"))
111 | if err != nil {
112 | log.Printf("No such user id: %q", r.FormValue("id"))
113 | http.SetCookie(w, &http.Cookie{Name: "auth", Value: "", Expires: time.Now()})
114 | // Redirect to /
115 | fmt.Fprintf(w, `
116 | `)
117 | return
118 | }
119 |
120 | e.Note["userId"] = u.Id
121 | http.SetCookie(w, &http.Cookie{
122 | Name: "auth",
123 | Value: u.Id, // Encrypt cookie here, maybe include the whole user struct.
124 | Expires: time.Now().Add(time.Hour),
125 | MaxAge: int(time.Hour / time.Second),
126 | HttpOnly: true,
127 | })
128 | // Redirect to /user/profile
129 | fmt.Fprintf(w, `
130 | `)
131 | }
132 |
133 | func FailIfNotAuthenticated(u *User) (User, error) {
134 | if u == nil {
135 | return User{}, sandwich.Error{
136 | Code: http.StatusUnauthorized,
137 | ClientMsg: "Not logged in",
138 | LogMsg: "Unauthorized access attempt",
139 | }
140 | }
141 | return *u, nil
142 | }
143 |
144 | func getAndParseCookie(r *http.Request) (string, error) {
145 | c, err := r.Cookie("auth")
146 | if err != nil {
147 | return "", err
148 | }
149 | userid := c.Value // Decrypt cookie here, maybe getting a whole user struct.
150 | return userid, nil
151 | }
152 |
153 | func ParseUserIfLoggedIn(r *http.Request, udb UserDb, e *sandwich.LogEntry) (*User, error) {
154 | if user_id, err := getAndParseCookie(r); err != nil {
155 | return nil, nil // not logged in or expired or corrupt. Ignore cookie.
156 | } else if user, err := udb.Lookup(user_id); err != nil {
157 | log.Printf("No such user: %q", user_id)
158 | return nil, nil // no such user
159 | } else {
160 | e.Note["userId"] = user.Id
161 | return &user, nil // Hello logged-in user!
162 | }
163 | }
164 |
--------------------------------------------------------------------------------
/chain/naming.go:
--------------------------------------------------------------------------------
1 | package chain
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "path/filepath"
7 | "reflect"
8 | "regexp"
9 | "strings"
10 | "unicode"
11 | )
12 |
13 | type nameMapper struct {
14 | typToName map[reflect.Type]string
15 | used map[string]bool
16 | }
17 |
18 | func (n *nameMapper) Has(t reflect.Type) bool {
19 | _, exists := n.typToName[t]
20 | return exists
21 | }
22 |
23 | // Don't ever use these as variable names: keywords + primitive type names var
24 | var disallowed = map[string]bool{
25 | // keywords
26 | "break": true, "default": true, "func": true, "interface": true,
27 | "select": true, "case": true, "defer": true, "go": true, "map": true,
28 | "struct": true, "chan": true, "else": true, "goto": true, "package": true,
29 | "switch": true, "const": true, "fallthrough": true, "if": true,
30 | "range": true, "type": true, "continue": true, "for": true, "import": true,
31 | "return": true, "var": true,
32 | // pre-declared identifiers:
33 | "bool": true, "byte": true, "complex64": true, "complex128": true,
34 | "error": true, "float32": true, "float64": true, "int": true, "int8": true,
35 | "int16": true, "int32": true, "int64": true, "rune": true, "string": true,
36 | "uint": true, "uint8": true, "uint16": true, "uint32": true, "uint64": true,
37 | "uintptr": true, "true": true, "false": true, "iota": true, "nil": true,
38 | "append": true, "cap": true, "close": true, "complex": true, "copy": true,
39 | "delete": true, "imag": true, "len": true, "make": true, "new": true,
40 | "panic": true, "print": true, "println": true, "real": true,
41 | "recover": true,
42 | }
43 |
44 | func (n *nameMapper) Reserve(names ...string) {
45 | if n.typToName == nil {
46 | n.typToName = map[reflect.Type]string{}
47 | n.used = map[string]bool{}
48 | }
49 | for _, name := range names {
50 | n.used[name] = true
51 | }
52 | }
53 |
54 | func (n *nameMapper) For(t reflect.Type) string {
55 | if name, exists := n.typToName[t]; exists {
56 | return name
57 | }
58 | if n.typToName == nil {
59 | n.Reserve()
60 | }
61 | for _, name := range n.options(t) {
62 | if !disallowed[name] && !n.used[name] {
63 | n.used[name] = true
64 | n.typToName[t] = name
65 | return name
66 | }
67 | }
68 | // This should never happen: The final option should be a completely unique
69 | // name using the full package name and type name.
70 | panic(fmt.Errorf("Could not come up with a unique variable name for %s. "+
71 | "Used names are %v.\nTyp2Name: %q", t, n.used, n.typToName))
72 | }
73 |
74 | func extractCaps(s string) string {
75 | caps := ""
76 | for i, r := range s {
77 | if r == '_' {
78 | return "" // If there are any underscores in the name, give up on extracting caps.
79 | }
80 | if !(unicode.IsLetter(r) || unicode.IsNumber(r)) {
81 | continue
82 | }
83 | if i == 0 || unicode.IsUpper(r) || unicode.IsNumber(r) {
84 | caps += string(r)
85 | }
86 | }
87 | if len(caps) == 0 {
88 | caps += string(s[0])
89 | }
90 | return strings.ToLower(caps)
91 | }
92 |
93 | func upperFirstLetter(s string) string {
94 | for i, r := range s {
95 | return string(unicode.ToUpper(r)) + s[i+1:]
96 | }
97 | return ""
98 | }
99 |
100 | func lowerFirstLetter(s string) string {
101 | for i, r := range s {
102 | return string(unicode.ToLower(r)) + s[i+1:]
103 | }
104 | return ""
105 | }
106 |
107 | func ptrPrefix(t reflect.Type) (string, reflect.Type) {
108 | var s = ""
109 | for t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice {
110 | if t.Kind() == reflect.Ptr {
111 | s = "p" + s
112 | } else if t.Kind() == reflect.Slice {
113 | s = "sliceOf" + s
114 | }
115 | t = t.Elem()
116 | }
117 | return s, t
118 | }
119 | func assemble(prefix, name string) string {
120 | if prefix == "" {
121 | return name
122 | }
123 | return prefix + upperFirstLetter(name)
124 | }
125 |
126 | func pkgNamePrefix(pkg string) string {
127 | if pkg == "" || pkg == "." {
128 | return ""
129 | }
130 | pkg = strings.ToLower(strings.Replace(pkg, ".", "_", -1))
131 | pkg = strings.Trim(pkg, "_")
132 | return pkg + "_"
133 | }
134 |
135 | func cleanTypeName(name string) string {
136 | name = strings.Replace(name, "chan ", "chan_", -1)
137 | name = strings.Replace(name, " ", "_", -1)
138 | name = strings.Map(func(r rune) rune {
139 | if !unicode.IsLetter(r) && !unicode.IsNumber(r) {
140 | return '_'
141 | }
142 | return r
143 | }, name)
144 | name = regexp.MustCompile("_{2,}").ReplaceAllLiteralString(name, "_")
145 | name = strings.Trim(name, "_")
146 | return name
147 | }
148 |
149 | func (n nameMapper) options(t reflect.Type) []string {
150 | options := wellKnownTypesAndCommonNames[t]
151 | prefix, t := ptrPrefix(t)
152 | short_pkg := pkgNamePrefix(filepath.Base(t.PkgPath()))
153 | full_pkg := pkgNamePrefix(t.PkgPath())
154 |
155 | name := cleanTypeName(t.Name())
156 | if name == "" {
157 | name = cleanTypeName(t.String())
158 | short_pkg = ""
159 | full_pkg = ""
160 | }
161 | if name != "" {
162 | lname := lowerFirstLetter(name)
163 | options = append(options, assemble(prefix, lname))
164 | options = append(options, assemble(short_pkg+prefix, lname))
165 | options = append(options, assemble(prefix, lname))
166 | caps := extractCaps(name)
167 | if caps != "" {
168 | options = append(options, assemble(prefix, caps))
169 | options = append(options, assemble(prefix, string(caps[0])))
170 | }
171 |
172 | options = append(options, assemble(full_pkg+prefix, lname))
173 | }
174 | // uniqueVarNameForType should always make something completely unique, but
175 | // it's a bit verbose.
176 | // options = append(options, uniqueVarNameForType(t))
177 | // As a completely paranoid option, this should absolutely positively make
178 | // a unique var name:
179 | options = append(options, fmt.Sprintf("__var%d__", len(n.used)))
180 | return options
181 | }
182 |
183 | var wellKnownTypesAndCommonNames = map[reflect.Type][]string{
184 | reflect.TypeOf((*http.ResponseWriter)(nil)).Elem(): {"rw", "w"},
185 | reflect.TypeOf((*http.Request)(nil)): {"req", "r"},
186 | reflect.TypeOf(""): {"str"},
187 | reflect.TypeOf(false): {"flag"},
188 | errorType: {"err"},
189 | }
190 |
--------------------------------------------------------------------------------
/doc.go:
--------------------------------------------------------------------------------
1 | // Package sandwich is a middleware framework for go that lets you write
2 | // testable web servers.
3 | //
4 | // Sandwich allows writing robust middleware handlers that are easily tested:
5 | // - Avoid globals, instead propagate per-request state automatically from
6 | // one handler to the next.
7 | // - Write your handlers to accept the parameters they need rather than
8 | // type-asserting from an untyped per-request context.
9 | // - Abort request handling by returning an error.
10 | //
11 | // Sandwich is provides a basic PAT-style router.
12 | //
13 | // # Example
14 | //
15 | // Here's a simple complete program using sandwich:
16 | //
17 | // package main
18 | //
19 | // import (
20 | // "fmt"
21 | // "log"
22 | // "net/http"
23 | //
24 | // "github.com/augustoroman/sandwich"
25 | // )
26 | //
27 | // func main() {
28 | // mux := sandwich.TheUsual()
29 | // mux.Get("/", func(w http.ResponseWriter) {
30 | // fmt.Fprintf(w, "Hello world!")
31 | // })
32 | // if err := http.ListenAndServe(":6060", mux); err != nil {
33 | // log.Fatal(err)
34 | // }
35 | // }
36 | //
37 | // # Providing
38 | //
39 | // Sandwich automatically calls your middleware with the necessary arguments to
40 | // run them based on the types they require. These types can be provided by
41 | // previous middleware or directly during the initial setup.
42 | //
43 | // For example, you can use this to provide your database to all handlers:
44 | //
45 | // func main() {
46 | // db_conn := ConnectToDatabase(...)
47 | // mux := sandwich.TheUsual()
48 | // mux.Set(db_conn)
49 | // mux.Get("/", Home)
50 | // }
51 | //
52 | // func Home(w http.ResponseWriter, r *http.Request, db_conn *Database) {
53 | // // process the request here, using the provided db_conn
54 | // }
55 | //
56 | // Set(...) and SetAs(...) are excellent alternatives to using global values,
57 | // plus they keep your functions easy to test!
58 | //
59 | // # Handlers
60 | //
61 | // In many cases you want to initialize a value based on the request, for
62 | // example extracting the user login:
63 | //
64 | // func main() {
65 | // mux := sandwich.TheUsual()
66 | // mux.Get("/", ParseUserCookie, SayHi)
67 | // }
68 | // // You can write & test exactly this signature:
69 | // func ParseUserCookie(r *http.Request) (User, error) { ... }
70 | // // Then write your handler assuming User is available:
71 | // func SayHi(w http.ResponseWriter, u User) {
72 | // fmt.Fprintf(w, "Hello %s", u.Name)
73 | // }
74 | //
75 | // This starts to show off the real power of sandwich. For each request, the
76 | // following occurs:
77 | // - First ParseUserCookie is called. If it returns a non-nil error,
78 | // sandwich's HandleError is called the request is aborted. If the error
79 | // is nil, processing continues.
80 | // - Next SayHi is called with the User value returned from ParseUserCookie.
81 | //
82 | // This allows you to write small, independently testable functions and let
83 | // sandwich chain them together for you. Sandwich works hard to ensure that you
84 | // don't get annoying run-time errors: it's structured such that it must always
85 | // be possible to call your functions when the middleware is initialized rather
86 | // than when the http handler is being executed, so you don't get surprised
87 | // while your server is running.
88 | //
89 | // # Error Handlers
90 | //
91 | // When a handler returns an error, sandwich aborts the middleware chain and
92 | // looks for the most recently registered error handler and calls that. Error
93 | // handlers may accept any types that have been provided so far in the
94 | // middleware stack as well as the error type. They must not have any return
95 | // values.
96 | //
97 | // # Wrapping Handlers
98 | //
99 | // Sandwich also allows registering handlers to run during AND after the
100 | // middleware (and error handling) stack has completed. This is especially
101 | // useful for handles such as logging or gzip wrappers. Once the before handle
102 | // is run, the 'after' handlers are queued to run and will be run regardless of
103 | // whether an error aborts any subsequent middleware handlers.
104 | //
105 | // Typically this is done with the first function creating and initializing some
106 | // state to pass to the deferred handler. For example, the logging handlers are:
107 | //
108 | // // StartLog creates a *LogEntry and initializes it with basic request
109 | // // information.
110 | // func NewLogEntry(r *http.Request) *LogEntry {
111 | // return &LogEntry{Start: time.Now(), ...}
112 | // }
113 | //
114 | // // Commit fills in the remaining *LogEntry fields and writes the entry out.
115 | // func (entry *LogEntry) Commit(w *ResponseWriter) {
116 | // entry.Elapsed = time.Since(entry.Start)
117 | // ...
118 | // WriteLog(*entry)
119 | // }
120 | //
121 | // and are added to the chain using:
122 | //
123 | // var LogRequests = Wrap{NewLogEntry, (*LogEntry).Commit}
124 | //
125 | // In this case, the `Wrap` executes NewLogEntry during middleware processing
126 | // that returns a *LogEntry which is provided to downstream handlers, including
127 | // the deferred Commit handler -- in this case a method expression
128 | // (https://golang.org/ref/spec#Method_expressions) that takes the *LogEntry as
129 | // its value receiver.
130 | //
131 | // # Providing Interfaces
132 | //
133 | // Unfortunately, providing interfaces is a little tricky. Since interfaces in
134 | // Go are only used for static typing, the encapsulation isn't passed to
135 | // functions that accept interface{}, like Set().
136 | //
137 | // This means that if you have an interface and a concrete implementation, such
138 | // as:
139 | //
140 | // type UserDatabase interface{
141 | // GetUserProfile(u User) (Profile, error)
142 | // }
143 | // type userDbImpl struct { ... }
144 | // func (u *userDbImpl) GetUserProfile(u User) (Profile, error) { ... }
145 | //
146 | // You cannot provide this to handlers directly via the Set() call.
147 | //
148 | // udb := &userDbImpl{...}
149 | // // DOESN'T WORK: this will provide *userDbImpl, not UserDatabase
150 | // mux.Set(udb)
151 | // // STILL DOESN'T WORK
152 | // mux.Set((UserDatabase)(udb))
153 | // // *STILL* DOESN'T WORK
154 | // udb_iface := UserDatabase(udb)
155 | // mux.Set(&udb_iface)
156 | //
157 | // Instead, you have to either use SetAs() or a dedicated middleware function:
158 | //
159 | // udb := &userDbImpl{...}
160 | // mux.SetAs(udb, (*UserDatabase)(nil)) // either use SetAs() with a pointer to the interface
161 | // mux.Use(func() UserDatabase { return udb }) // or add a handler that returns the interface
162 | //
163 | // It's a bit silly, but that's how it is.
164 | package sandwich
165 |
--------------------------------------------------------------------------------
/router_test.go:
--------------------------------------------------------------------------------
1 | package sandwich
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "net/http"
7 | "net/http/httptest"
8 | "strings"
9 | "testing"
10 |
11 | "github.com/stretchr/testify/assert"
12 | "github.com/stretchr/testify/require"
13 | )
14 |
15 | // Shorthand for keeping tests concise below
16 | type M = Params
17 |
18 | func TestMuxRegisterAndMatch(t *testing.T) {
19 | const REGISTRATION_ERROR = "•ERR:"
20 | fail := func(reason string) string {
21 | return REGISTRATION_ERROR + reason
22 | }
23 | split := func(combined_pattern string) (pattern, errmsg string) {
24 | pos := strings.Index(combined_pattern, REGISTRATION_ERROR)
25 | if pos == -1 {
26 | return combined_pattern, ""
27 | }
28 | return combined_pattern[:pos], combined_pattern[pos+len(REGISTRATION_ERROR):]
29 | }
30 | patterns := []string{
31 | "/",
32 | "/a",
33 | "/a" + fail("repeated entry"),
34 | "/a/:x/:x" + fail("repeated param name"),
35 | "/a/",
36 | "/a/b",
37 | "/a/b/c",
38 | "/a/b/c" + fail("repeated entry"),
39 | "/a/b/c/d/e", // NOTE: /a/b/c/d not registered
40 | "/a/:x/c",
41 | "/a/:x/c" + fail("repeated entry"),
42 | "/a/:y/c" + fail("ambiguous param var"),
43 | "/a/:y/c2",
44 | "/a/:m*",
45 | "/a/:m*/",
46 | "/b/:a*/x",
47 | "/b/:b*/y",
48 | "/b/:b*/x" + fail("ambiguous greedy pattern"),
49 | "/c/:x/y",
50 | "/c/:x*/y" + fail("ambiguous param (greedy or not)"),
51 | "/:m*/b/c",
52 | "/:m*/:x/c",
53 | "/:m*/:x*/c" + fail("multiple greedy patterns"),
54 | "/:x*/b/c" + fail("ambiguous greedy var"),
55 | "/x/:x*/y/:y/z/:z*/blah" + fail("multiple greedy patterns"),
56 |
57 | // literal colon in static URL
58 | "/a/::x",
59 | "/a/::x/c",
60 | }
61 |
62 | var m mux
63 |
64 | for _, combo_pattern := range patterns {
65 | pattern, errmsg := split(combo_pattern)
66 | err := m.Register(pattern, noopHandler(pattern))
67 | if errmsg == "" {
68 | require.NoError(t, err)
69 | } else {
70 | require.Error(t, err, "Pattern %#q should have failed: %s", pattern, errmsg)
71 | }
72 | }
73 |
74 | // priority:
75 | // - static routes
76 | // - explicit parameter
77 | // - greedy parameter
78 |
79 | testCases := []struct {
80 | uri string
81 | expectedHandler noopHandler
82 | expectedParams M
83 | }{
84 | {"/", "/", M{}},
85 | {"/a", "/a", M{}},
86 | {"/a/", "/a/", M{}},
87 | {"/a/b", "/a/b", M{}},
88 | {"/a/b/c", "/a/b/c", M{}},
89 | {"/a/b/c/d/e", "/a/b/c/d/e", M{}},
90 | {"/a/b/c/d", "/a/:m*", M{"m": "b/c/d"}},
91 |
92 | {"/a/foobar/c", "/a/:x/c", M{"x": "foobar"}},
93 | {"/a/foobar/c2", "/a/:y/c2", M{"y": "foobar"}},
94 |
95 | {"/a/foobar/blah", "/a/:m*", M{"m": "foobar/blah"}},
96 | {"/a/foobar/blah/", "/a/:m*/", M{"m": "foobar/blah"}},
97 |
98 | {"/b/mm/nn/", "", nil},
99 | {"/b/mm/nn/x", "/b/:a*/x", M{"a": "mm/nn"}},
100 | {"/b/mm/nn/y", "/b/:b*/y", M{"b": "mm/nn"}},
101 |
102 | {"/c/x/y", "/c/:x/y", M{"x": "x"}},
103 |
104 | {"/b/x/y/b/c", "/:m*/b/c", M{"m": "b/x/y"}},
105 | {"/b/x/y/bo/c", "/:m*/:x/c", M{"m": "b/x/y", "x": "bo"}},
106 | }
107 |
108 | for _, test := range testCases {
109 | t.Run(fmt.Sprintf("%s -> %s", test.uri, test.expectedHandler), func(t *testing.T) {
110 | if test.expectedHandler == "" {
111 | t.Logf("Testing input uri %#q --> should not match any pattern",
112 | test.uri)
113 | } else {
114 | t.Logf("Testing input uri %#q --> should match pattern %#q",
115 | test.uri, test.expectedHandler)
116 | }
117 | params := Params{}
118 | selected := m.Match(test.uri, params)
119 | if test.expectedHandler == "" {
120 | assert.Nil(t, selected, "should not match any pattern")
121 | assert.Empty(t, params)
122 | } else {
123 | require.NotNil(t, selected)
124 | assert.Equal(t, test.expectedHandler, selected)
125 | assert.Equal(t, test.expectedParams, params)
126 | }
127 | })
128 | }
129 | }
130 |
131 | type noopHandler string
132 |
133 | func (h noopHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, p Params) {}
134 |
135 | func TestRouter(t *testing.T) {
136 | r := TheUsual()
137 |
138 | type UserID string
139 | type User string
140 | type UserDB map[UserID]User
141 |
142 | theUserDB := UserDB{"1": "bob", "2": "alice"}
143 | r.Set(theUserDB)
144 |
145 | loadUser := func(db UserDB, p Params) (User, UserID, error) {
146 | if uid := UserID(p["userID"]); uid == "" {
147 | return "", "", Error{Code: 400, ClientMsg: "Must specify user ID"}
148 | } else if u := db[uid]; u == "" {
149 | return "", uid, Error{Code: 404, ClientMsg: "No such user"}
150 | } else {
151 | return u, uid, nil
152 | }
153 | }
154 | newUserFromRequest := func(r *http.Request) (UserID, User, error) {
155 | uid := UserID(r.FormValue("uid"))
156 | user := User(r.FormValue("name"))
157 | if uid == "" {
158 | return "", "", errors.New("missing user id")
159 | } else if user == "" {
160 | return "", "", errors.New("missing user info")
161 | }
162 | return uid, user, nil
163 | }
164 |
165 | r.Get("/user/:userID", loadUser,
166 | func(w http.ResponseWriter, u User) {
167 | fmt.Fprintf(w, "Hi user %#q", u)
168 | },
169 | )
170 | r.Post("/user/", newUserFromRequest,
171 | func(db UserDB, uid UserID, u User) { db[uid] = u },
172 | func(w http.ResponseWriter, uid UserID, u User) {
173 | fmt.Fprintf(w, "Made user %#q = %#q", uid, u)
174 | },
175 | )
176 | r.Any("/user/:userID/:cmd*", loadUser,
177 | func(w http.ResponseWriter, r *http.Request, p Params, u User) {
178 | fmt.Fprintf(w, "Doing %#q (%s) to user %#q", r.Method, p["cmd"], u)
179 | },
180 | )
181 |
182 | w := httptest.NewRecorder()
183 | r.ServeHTTP(w, httptest.NewRequest("GET", "/user/1", nil))
184 | assert.Equal(t, http.StatusOK, w.Result().StatusCode)
185 | assert.Equal(t, "Hi user `bob`", w.Body.String())
186 |
187 | w = httptest.NewRecorder()
188 | r.ServeHTTP(w, httptest.NewRequest("GET", "/user/2", nil))
189 | assert.Equal(t, http.StatusOK, w.Result().StatusCode)
190 | assert.Equal(t, "Hi user `alice`", w.Body.String())
191 |
192 | w = httptest.NewRecorder()
193 | r.ServeHTTP(w, httptest.NewRequest("GET", "/user/3", nil))
194 | assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
195 | assert.Equal(t, "No such user\n", w.Body.String())
196 |
197 | w = httptest.NewRecorder()
198 | r.ServeHTTP(w, httptest.NewRequest("POST", "/user/?uid=3&name=sid", nil))
199 | assert.Equal(t, http.StatusOK, w.Result().StatusCode, "Response: %s", w.Body.String())
200 |
201 | w = httptest.NewRecorder()
202 | r.ServeHTTP(w, httptest.NewRequest("GET", "/user/3", nil))
203 | assert.Equal(t, http.StatusOK, w.Result().StatusCode)
204 | assert.Equal(t, "Hi user `sid`", w.Body.String())
205 |
206 | w = httptest.NewRecorder()
207 | r.ServeHTTP(w, httptest.NewRequest("EXPLODE", "/user/3/boom", nil))
208 | assert.Equal(t, http.StatusOK, w.Result().StatusCode)
209 | assert.Equal(t, "Doing `EXPLODE` (boom) to user `sid`", w.Body.String())
210 | }
211 |
212 | // func TestNodeMatch(t *testing.T) {
213 | // testCases := []struct {
214 | // path, pattern string
215 | // matches bool
216 | // expectedParams M
217 | // }{
218 | // // static paths
219 | // {"/a/b/c", "/a/b/c", true, M{}},
220 | // {"/a/b/c", "/x/b/c", false, nil},
221 | // {"/a/b/c", "/a/b", false, nil},
222 | // {"/a/b/c", "/a/b/c/d", false, nil},
223 |
224 | // {"/a/b/c", "/a/b/:last", true, M{"last": "c"}},
225 | // {"/a/b/c", "/a/:mid/c", true, M{"mid": "b"}},
226 | // {"/a/b/c", "/:first/:mid/c", true, M{"first": "a", "mid": "b"}},
227 | // {"/a/b/c", "/:first/:mid/:last", true, M{"first": "a", "mid": "b", "last": "c"}},
228 | // {"/a/b/c", "/:first/:mid/:last/:missing", false, nil},
229 | // {"/a/b/c", "/:first/:mid/:last/x", false, nil},
230 |
231 | // {"/a/b/c/d/e/f/g", "/:path*", true, M{"path": "a/b/c/d/e/f/g"}},
232 | // {"/a/b/c/d/e/f/g", "/a/:path*", true, M{"path": "b/c/d/e/f/g"}},
233 | // {"/a/b/c/d/e/f/g", "/a/:path*/g", true, M{"path": "b/c/d/e/f"}},
234 | // {"/a/b/c/d/e/f/g", "/a/:path*/f/g", true, M{"path": "b/c/d/e"}},
235 | // {"/a/b/c/d/e/f/g", "/a/:path*/:last", true, M{"path": "b/c/d/e/f", "last": "g"}},
236 | // {"/a/b/c/d/e/f/g", "/a/:first/:mid*/:last", true, M{"first": "b", "mid": "c/d/e/f", "last": "g"}},
237 | // {"/a/b/c/d/e/f/g", "/a/:first*/:mid/:last", true, M{"first": "b/c/d/e", "mid": "f", "last": "g"}},
238 | // {"/a/b/c/d/e/f/g", "/a/:first/:mid/:last*", true, M{"first": "b", "mid": "c", "last": "d/e/f/g"}},
239 | // }
240 |
241 | // for i, test := range testCases {
242 | // t.Run(fmt.Sprintf("%d:%s", i, test.pattern), func(t *testing.T) {
243 | // t.Logf("Test %d: Path %#q Pattern %#q Should match: %v",
244 | // i, test.path, test.pattern, test.matches)
245 | // root, err := makeMatchNodes(test.pattern)
246 | // require.NoError(t, err)
247 | // pathSegments := strings.Split(test.path, "/")
248 | // params := Params{}
249 | // match := root.match(pathSegments[1:], params)
250 | // if test.matches {
251 | // require.NotNil(t, match)
252 | // assert.Equal(t, test.expectedParams, params)
253 | // } else {
254 | // assert.Nil(t, match)
255 | // }
256 | // })
257 | // }
258 | // }
259 |
--------------------------------------------------------------------------------
/examples/2-advanced/main.go:
--------------------------------------------------------------------------------
1 | // 2-advanced is a demo webserver for the sandwich middleware package
2 | // demonstrating advanced usage.
3 | //
4 | // This provides a sample, multi-user TODO list application that allows users
5 | // to sign in via a Google account or sign in using fake credentials.
6 | // This example demonstrates more advanced features of sandwich, including:
7 | //
8 | // - Providing interface types to the middleware chain.
9 | // TaskDb is the interface provided to the handlers, the actual value injected
10 | // in main() is a taskDbImpl.
11 | // - Using 3rd party middleware (go.auth, go.rice)
12 | // - Using a 3rd party router (gorilla/mux)
13 | // - Using multiple error handlers, and custom error handlers.
14 | // Most web servers will want to server a custom HTML error page for user-facing
15 | // error pages. An example of that is included here. For AJAX calls, however,
16 | // ....
17 | // - Early exit of the middleware chain via the sandwich.Done error
18 | // - Auto-generating handler code
19 | package main
20 |
21 | import (
22 | "embed"
23 | "encoding/json"
24 | "fmt"
25 | "html/template"
26 | "io/fs"
27 | "log"
28 | "net/http"
29 | "os"
30 | "time"
31 |
32 | "github.com/augustoroman/sandwich"
33 | auth "github.com/bradrydzewski/go.auth"
34 | )
35 |
36 | //go:embed static
37 | var static embed.FS
38 |
39 | func main() {
40 | // Read in configuration:
41 | var config struct {
42 | Host string `json:"host"`
43 | Port int `json:"port"`
44 | CookieSecret string `json:"cookie-secret"`
45 | ClientId string `json:"oauth2-client-id"`
46 | ClientSecret string `json:"oauth2-client-secret"`
47 | }
48 | failOnError(readJsonFile("config.json", &config))
49 |
50 | // Setup Oauth login framework:
51 | auth.Config.LoginRedirect = "/auth/login" // send user here to login
52 | auth.Config.LoginSuccessRedirect = "/" // send user here post-login
53 | auth.Config.CookieSecure = false // for local-testing only
54 | auth.Config.CookieSecret = []byte(config.CookieSecret)
55 | // This must match the authorized URLs entered in the google cloud api console.
56 | redirectUrl := fmt.Sprintf("http://%s/auth/google/callback", config.Host)
57 | authHandler := auth.Google(config.ClientId, config.ClientSecret, redirectUrl)
58 |
59 | // Setup task database:
60 | taskDb := taskDbImpl{}
61 |
62 | // Load our templates.
63 | tpl := template.Must(template.ParseFS(static, "static/*.tpl.html"))
64 |
65 | // Start setting up our server:
66 | mux := sandwich.TheUsual()
67 | mux.Use(ParseUserCookie, LogUser)
68 | mux.SetAs(taskDb, (*TaskDb)(nil))
69 | mux.Set(tpl)
70 | mux.OnErr(CustomErrorPage)
71 |
72 | // Don't log these requests since we don't have a favicon, it's just a
73 | // bunch of 404 spam.
74 | mux.Get("/favicon.ico", sandwich.NoLog, NotFound)
75 |
76 | // When login is called, we'll FIRST call our very own CheckForFakeLogin
77 | // handler. If we detect the fake login form params, we'll process that
78 | // and then abort the middleware chain.
79 | // However, if we don't have fake parameters, we'll continue on and let
80 | // the authHandler take care of things.
81 | mux.Any("/auth/login", CheckForFakeLogin, authHandler)
82 | mux.Any("/auth/google/callback", authHandler)
83 | // Note that we can use auth.DeleteUserCookie directly.
84 | mux.Any("/auth/logout", auth.DeleteUserCookie,
85 | http.RedirectHandler("/", http.StatusTemporaryRedirect))
86 |
87 | // Static file handling. The s.Then(...) wrapper isn't strictly necessary,
88 | // but it gives us logging (and potentially gzip or other middleware).
89 | // static := http.StripPrefix("/static", http.FileServer(http.FS(
90 | // mustSubFS(static, "static"))))
91 | mux.Get("/static/:path*", sandwich.ServeFS(static, "static", "path"))
92 |
93 | // OK, here are the core handlers:
94 | mux.Get("/", Home)
95 | // All API calls will use the api middleware that responds with JSON for
96 | // errors and requires users to be logged in.
97 | api := mux.SubRouter("/api/")
98 | api.OnErr(sandwich.HandleErrorJson)
99 | api.Use(RequireLoggedIn)
100 | api.Post("/task", TaskFromAddRequest, TaskDb.Add, SendTaskAsJson)
101 | api.Post("/task/:id", TaskOpFromUpdateRequest, UpdateTask)
102 |
103 | // Catch all remaining URLs and respond with not-found errors. We
104 | // explicitly use the error-return mechanism so that we get the JSON
105 | // response under /api/ and normal HTML responses elsewhere.
106 | api.Any("/:*", NotFound)
107 | mux.Any("/:*", NotFound)
108 |
109 | // Otherwise, start serving!
110 | addr := fmt.Sprintf("localhost:%d", config.Port)
111 | log.Printf("Server listening on http://%s", addr)
112 | failOnError(http.ListenAndServe(addr, mux))
113 | }
114 |
115 | // ============================================================================
116 | // Database
117 |
118 | type UserId string
119 | type TaskDb interface {
120 | List(UserId) ([]Task, error)
121 | Add(UserId, *Task) error
122 | Update(UserId, Task) error
123 | }
124 | type Task struct {
125 | Id string `json:"id"`
126 | Desc string `json:"desc"`
127 | Done bool `json:"done"`
128 | }
129 |
130 | type taskDbImpl map[UserId][]Task
131 |
132 | func (db taskDbImpl) List(u UserId) ([]Task, error) { return db[u], nil }
133 | func (db taskDbImpl) Add(u UserId, t *Task) error {
134 | t.Id = fmt.Sprint(time.Now().UnixNano())
135 | db[u] = append(db[u], *t)
136 | return nil
137 | }
138 | func (db taskDbImpl) Update(u UserId, t Task) error {
139 | tasks := db[u]
140 | for i, task := range tasks {
141 | if task.Id == t.Id {
142 | tasks[i] = t
143 | return nil
144 | }
145 | }
146 | return sandwich.Error{
147 | Code: http.StatusBadRequest,
148 | ClientMsg: "No such task",
149 | Cause: fmt.Errorf("No such task: %q", t.Id),
150 | }
151 | }
152 |
153 | // ============================================================================
154 | // Core Handlers
155 |
156 | func Home(
157 | w http.ResponseWriter,
158 | r *http.Request,
159 | uid UserId,
160 | u auth.User,
161 | db TaskDb,
162 | tpl *template.Template,
163 | ) error {
164 | if u == nil {
165 | // tpl := template.Must(template.New("").Parse(
166 | // rice.MustFindBox("static").MustString("landing-page.tpl.html")))
167 | return tpl.ExecuteTemplate(w, "landing-page.tpl.html", nil)
168 | }
169 | tasks, err := db.List(uid)
170 | if err != nil {
171 | return err
172 | }
173 | // tpl := template.Must(template.New("").Parse(
174 | // rice.MustFindBox("static").MustString("home.tpl.html")))
175 | return tpl.ExecuteTemplate(w, "home.tpl.html", map[string]interface{}{
176 | "User": u,
177 | "Tasks": tasks,
178 | })
179 | }
180 |
181 | func TaskFromAddRequest(r *http.Request) (*Task, error) {
182 | var t Task
183 | if err := json.NewDecoder(r.Body).Decode(&t); err != nil {
184 | return nil, sandwich.Error{Code: http.StatusBadRequest, Cause: err}
185 | }
186 | if t.Desc == "" {
187 | return nil, sandwich.Error{Code: http.StatusBadRequest,
188 | ClientMsg: "Please include a task description"}
189 | }
190 | return &t, nil
191 | }
192 |
193 | func SendTaskAsJson(w http.ResponseWriter, t *Task) error {
194 | return json.NewEncoder(w).Encode(map[string]interface{}{"task": t})
195 | }
196 |
197 | type TaskOp struct {
198 | Toggle bool
199 | Id string
200 | }
201 |
202 | func TaskOpFromUpdateRequest(r *http.Request) (TaskOp, error) {
203 | var op TaskOp
204 | if err := json.NewDecoder(r.Body).Decode(&op); err != nil {
205 | return op, sandwich.Error{Code: http.StatusBadRequest, Cause: err}
206 | }
207 | if op.Id == "" {
208 | return op, sandwich.Error{
209 | Code: http.StatusBadRequest,
210 | ClientMsg: "Invalid op: missing task id",
211 | }
212 | }
213 | return op, nil
214 | }
215 |
216 | func UpdateTask(w http.ResponseWriter, r *http.Request, uid UserId, op TaskOp, db TaskDb) error {
217 | tasks, err := db.List(uid)
218 | if err != nil {
219 | return err
220 | }
221 | var t Task
222 | for i := range tasks {
223 | if tasks[i].Id == op.Id {
224 | t = tasks[i]
225 | break
226 | }
227 | }
228 | if t.Id == "" {
229 | return sandwich.Error{
230 | Code: http.StatusBadRequest,
231 | ClientMsg: "No such task id: " + op.Id,
232 | }
233 | }
234 |
235 | if op.Toggle {
236 | t.Done = !t.Done
237 | }
238 |
239 | if err := db.Update(uid, t); err != nil {
240 | return err
241 | }
242 | return json.NewEncoder(w).Encode(map[string]interface{}{"task": t})
243 | }
244 |
245 | // This will get called for any error that occurs outside of the API calls.
246 | func CustomErrorPage(
247 | w http.ResponseWriter,
248 | r *http.Request,
249 | err error,
250 | tpl *template.Template,
251 | l *sandwich.LogEntry,
252 | ) {
253 | // Make sure we actually have a real error:
254 | if err == sandwich.Done {
255 | return
256 | }
257 | // Convert the error to a sandwich.Error that has an error code.
258 | e := sandwich.ToError(err)
259 | // Always log the error and error details.
260 | l.Error = e
261 |
262 | w.WriteHeader(e.Code)
263 | err = tpl.ExecuteTemplate(w, "error.tpl.html", map[string]interface{}{
264 | "Error": e,
265 | })
266 |
267 | // But... what if our fancy template rendering fails? At this point, we
268 | // fall back to the simplest possible thing: http.Error(...). Maybe it'll
269 | // work, but we'll also log the error so it doesn't disappear.
270 | if err != nil {
271 | // Try putting a typo in the template name above, and you'll see this:
272 | l.Error = fmt.Errorf("Failed to render error page: %v\nTriggering error: %v",
273 | err, e)
274 | http.Error(w, "Internal server error", http.StatusInternalServerError)
275 | }
276 | }
277 |
278 | func CheckForFakeLogin(w http.ResponseWriter, r *http.Request) error {
279 | if r.FormValue("id") == "" {
280 | return nil
281 | }
282 |
283 | user := &auth.GoogleUser{
284 | UserId: r.FormValue("id"),
285 | UserEmail: r.FormValue("email"),
286 | UserName: r.FormValue("name"),
287 | }
288 | auth.SetUserCookie(w, r, user)
289 | http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
290 |
291 | // Great, everything is handled, so don't continue with the Google auth.
292 | return sandwich.Done
293 | }
294 |
295 | // ============================================================================
296 | // Basic Handlers
297 |
298 | // You could also use .Then(http.NotFound), but that wouldn't go through the
299 | // error-handlers. The advantage of using the error handlers is that you
300 | // automatically get JSON vs HTML handling.
301 | func NotFound() error { return sandwich.Error{Code: http.StatusNotFound} }
302 |
303 | func RequireLoggedIn(u auth.User) error {
304 | if u == nil {
305 | return sandwich.Error{Code: http.StatusUnauthorized}
306 | }
307 | return nil
308 | }
309 | func ParseUserCookie(r *http.Request) (auth.User, UserId) {
310 | // Ignore errors. If the cookie is invalid or expired or corrupt or
311 | // missing, just consider the user not-logged-in.
312 | u, _ := auth.GetUserCookie(r)
313 | var uid UserId
314 | if u != nil {
315 | uid = UserId(u.Id())
316 | }
317 | return u, uid
318 | }
319 |
320 | // Adds the current user to the per-request log notes, if logged in.
321 | func LogUser(u auth.User, e *sandwich.LogEntry) {
322 | if u != nil {
323 | e.Note["user"] = u.Email()
324 | e.Note["userId"] = u.Id()
325 | } else {
326 | e.Note["user"] = ""
327 | }
328 | }
329 |
330 | // ============================================================================
331 | // Simple utilities
332 |
333 | func readJsonFile(filename string, dst interface{}) error {
334 | f, err := os.Open(filename)
335 | if err != nil {
336 | return err
337 | }
338 | defer f.Close()
339 | return json.NewDecoder(f).Decode(dst)
340 | }
341 |
342 | func failOnError(err error) {
343 | if err != nil {
344 | panic(err)
345 | }
346 | }
347 |
348 | func mustSubFS(base fs.FS, dir string) fs.FS {
349 | sub, err := fs.Sub(base, dir)
350 | if err != nil {
351 | panic(err)
352 | }
353 | return sub
354 | }
355 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Sandwich: Delicious HTTP Middleware [](https://travis-ci.org/augustoroman/sandwich) [](https://gocover.io/github.com/augustoroman/sandwich) [](https://goreportcard.com/report/github.com/augustoroman/sandwich) [](https://pkg.go.dev/github.com/augustoroman/sandwich)
4 |
5 | *Keep pilin' it on!*
6 |
7 | Sandwich is a middleware & routing framework that lets you write your handlers
8 | and middleware the way you want to and it takes care of tracking & validating
9 | dependencies.
10 |
11 | ## Features
12 |
13 | * Keeps middleware and handlers simple and testable.
14 | * Consolidates error handling.
15 | * Ensures that middleware dependencies are safely provided -- avoids unsafe
16 | casting from generic context objects.
17 | * Detects missing dependencies *during route construction* (before the server
18 | starts listening!), not when the route is actually called.
19 | * Provides clear and helpful error messages.
20 | * Compatible with the [http.Handler](https://pkg.go.dev/net/http#Handler)
21 | interface and lots of existing middleware.
22 | * Provides just a touch of magic: enough to make things easier, but not enough
23 | to induce a debugging nightmare.
24 |
25 | ## Getting started
26 |
27 | Here's a very simple example of using sandwich with the standard HTTP stack:
28 |
29 | ```go
30 | package main
31 |
32 | import (
33 | "fmt"
34 | "log"
35 | "net/http"
36 |
37 | "github.com/augustoroman/sandwich"
38 | )
39 |
40 | func main() {
41 | // Create a default sandwich middlware stack that includes logging and
42 | // a simple error handler.
43 | mux := sandwich.TheUsual()
44 | mux.Get("/", func(w http.ResponseWriter) {
45 | fmt.Fprintf(w, "Hello world!")
46 | })
47 | if err := http.ListenAndServe(":6060", mux); err != nil {
48 | log.Fatal(err)
49 | }
50 | }
51 | ```
52 |
53 | See the [examples directory](examples/) for:
54 | * A [hello-world sample](examples/0-helloworld)
55 | * A [basic usage](examples/1-simple) demo
56 | * A [TODO app](examples/2-advanced) showing advanced usage including custom
57 | error handling, embedded files, code generation, & login and authentication
58 | via oauth.
59 |
60 | ## Usage
61 |
62 | ### Providing
63 |
64 | Sandwich automatically calls your middleware with the necessary arguments to
65 | run them based on the types they require. These types can be provided by
66 | previous middleware or directly during the initial setup.
67 |
68 | For example, you can use this to provide your database to all handlers:
69 |
70 | ```go
71 | func main() {
72 | db_conn := ConnectToDatabase(...)
73 | mux := sandwich.TheUsual()
74 | mux.Set(db_conn)
75 | mux.Get("/", Home)
76 | }
77 |
78 | func Home(w http.ResponseWriter, r *http.Request, db_conn *Database) {
79 | // process the request here, using the provided db_conn
80 | }
81 | ```
82 |
83 | Set(...) and SetAs(...) are excellent alternatives to using global
84 | values, plus they keep your functions easy to test!
85 |
86 |
87 | ### Handlers
88 |
89 | In many cases you want to initialize a value based on the request, for
90 | example extracting the user login:
91 |
92 | ```go
93 | func main() {
94 | mux := sandwich.TheUsual()
95 | mux.Get("/", ParseUserCookie, SayHi)
96 | }
97 | // You can write & test exactly this signature:
98 | func ParseUserCookie(r *http.Request) (User, error) { ... }
99 | // Then write your handler assuming User is available:
100 | func SayHi(w http.ResponseWriter, u User) {
101 | fmt.Fprintf(w, "Hello %s", u.Name)
102 | }
103 | ```
104 |
105 | This starts to show off the real power of sandwich. For each request, the
106 | following occurs:
107 |
108 | * First `ParseUserCookie` is called. If it returns a non-nil error,
109 | sandwich's `HandleError` is called and the request is aborted. If the error
110 | is nil, processing continues.
111 | * Next `SayHi` is called with `User` returned from `ParseUserCookie`.
112 |
113 | This allows you to write small, independently testable functions and let
114 | sandwich chain them together for you. Sandwich works hard to ensure that
115 | you don't get annoying run-time errors: it's structured such that it must
116 | always be possible to call your functions when the middleware is initialized
117 | rather than when the http handler is being executed, so you don't get
118 | surprised while your server is running.
119 |
120 |
121 | ### Error Handlers
122 |
123 | When a handler returns an error, sandwich aborts the middleware chain and
124 | looks for the most recently registered error handler and calls that.
125 | Error handlers may accept any types that have been provided so far in the
126 | middleware stack as well as the error type. They must not have any return
127 | values.
128 |
129 | Here's an example of rendering errors with a custom error page:
130 |
131 | ```go
132 | type ErrorPageTemplate *template.Template
133 | func main() {
134 | tpl := template.Must(template.ParseFiles("path/to/my/error_page.tpl"))
135 | mux := sandwich.TheUsual()
136 | mux.Set(ErrorPageTemplate(tpl))
137 | mux.OnErr(MyErrorHandler)
138 | ...
139 | }
140 | func MyErrorHandler(w http.ResponseWriter, t ErrorPageTemplate, l *sandwich.LogEntry, err error) {
141 | if err == sandwich.Done { // sandwich.Done can be returned to abort middleware.
142 | return // It indicates there was no actual error, so just return.
143 | }
144 | // Unwrap to a sandwich.Error that has Code, ClientMsg, and internal LogMsg.
145 | e := sandwich.ToError(err)
146 | // If there's an internal log message, add it to the request log.
147 | e.LogIfMsg(l)
148 | // Respond with my custom html error page, including the client-facing msg.
149 | w.WriteHeader(e.Code)
150 | t.Execute(w, map[string]string{Msg: e.ClientMsg})
151 | }
152 | ```
153 |
154 | Error handlers allow you consolidate the error handling of your web app. You
155 | can customize the error page, assign user-facing error codes, detect and fire
156 | alerts for certain errors, and control which errors get logged -- all in one
157 | place.
158 |
159 | By default, sandwich never sends internal error details to the client and
160 | insteads logs the details.
161 |
162 | ### Wrapping Handlers
163 |
164 | Sandwich also allows registering handlers to run during AND after the middleware
165 | (and error handling) stack has completed. This is especially useful for handles
166 | such as logging or gzip wrappers. Once the before handle is run, the 'after'
167 | handlers are queued to run and will be run regardless of whether an error aborts
168 | any subsequent middleware handlers.
169 |
170 | Typically this is done with the first function creating and initializing some
171 | state to pass to the deferred handler. For example, the logging handlers
172 | are:
173 |
174 | ```go
175 | // NewLogEntry creates a *LogEntry and initializes it with basic request
176 | // information.
177 | func NewLogEntry(r *http.Request) *LogEntry {
178 | return &LogEntry{Start: time.Now(), ...}
179 | }
180 |
181 | // Commit fills in the remaining *LogEntry fields and writes the entry out.
182 | func (entry *LogEntry) Commit(w *ResponseWriter) {
183 | entry.Elapsed = time.Since(entry.Start)
184 | ...
185 | WriteLog(*entry)
186 | }
187 | ```
188 |
189 | and are added to the chain using:
190 |
191 | ```go
192 | var LogRequests = Wrap{NewLogEntry, (*LogEntry).Commit}
193 | ```
194 |
195 | In this case, `NewLogEntry` returns a `*LogEntry` that is then provided to
196 | downstream handlers, including the deferred Commit handler -- in this case a
197 | [method expression](https://golang.org/ref/spec#Method_expressions) that takes
198 | the `*LogEntry` as its value receiver.
199 |
200 |
201 | ### Providing Interfaces
202 |
203 | Unfortunately, set interface values is a little tricky. Since interfaces in Go
204 | are only used for static typing, the encapsulation isn't passed to functions
205 | that accept interface{}, like Set().
206 |
207 | This means that if you have an interface and a concrete implementation, such
208 | as:
209 |
210 | ```go
211 | type UserDatabase interface{
212 | GetUserProfile(u User) (Profile, error)
213 | }
214 | type userDbImpl struct { ... }
215 | func (u *userDbImpl) GetUserProfile(u User) (Profile, error) { ... }
216 | ```
217 |
218 | You cannot provide this to handlers directly via the Set() call.
219 |
220 | ```go
221 | udb := &userDbImpl{...}
222 | // DOESN'T WORK: this will provide *userDbImpl, not UserDatabase
223 | mux.Set(udb)
224 | mux.Set((UserDatabase)(udb)) // DOESN'T WORK EITHER
225 | udb_iface := UserDatabase(udb)
226 | mux.Set(&udb_iface) // STILL DOESN'T WORK!
227 | ```
228 |
229 | Instead, you have to either use SetAs() or a dedicated middleware function that
230 | returns the interface:
231 |
232 | ```go
233 | udb := &userDbImpl{...}
234 | // either use SetAs() with a pointer to the interface
235 | mux.SetAs(udb, (*UserDatabase)(nil))
236 | // or add a handler that returns the interface
237 | mux.Use(func() UserDatabase { return udb })
238 | ```
239 |
240 | It's a bit silly, but there you are.
241 |
242 |
243 | ## FAQ
244 |
245 | Sandwich uses reflection-based dependency-injection to call the middleware
246 | functions with the parameters they need.
247 |
248 | **Q: OMG reflection and dependency-injection, isn't that terrible and slow and
249 | non-idiomatic go?!**
250 |
251 | Whoa, nelly. Let's deal with those one at time, m'kay?
252 |
253 | **Q: Isn't reflection slow?**
254 |
255 | Not compared to everything else a webserver needs to do.
256 |
257 | Yes, sandwich's reflection-based dependency-injection code is slower than
258 | middleware code that directly calls functions, however **the vast majority of
259 | server code (especially during development) is not impacted by time spent
260 | calling a few functions, but rather by HTTP network I/O, request parsing,
261 | database I/O, response marshalling, etc.**
262 |
263 | **Q: Ok, but aren't both reflection and dependency-injection non-idiomatic Go?**
264 |
265 | Sorta. The use of reflection in and of itself isn't non-idiomatic, but the use
266 | of magical dependency injection is: Go eschews magic.
267 |
268 | However, one of the major goals of this library is to allow the HTTP handler
269 | code (and all middleware) to be really clean, idiomatic go functions that are
270 | testable by themselves. The idea is that the magic is small, contained, doesn't
271 | leak, and provides substantial benefit.
272 |
273 | **Q: But wait, don't you get annoying run-time "dependency-not-found" errors
274 | with dependency-injection?**
275 |
276 | While it's true that you can't get the same compile-time checking that you do
277 | with direct-call-based middleware, sandwich works really hard to ensure that you
278 | don't get surprises while running your server.
279 |
280 | At the time each middleware function is added to the stack, the library ensures
281 | that it's dependencies have been explicitly provided. One of the *features* of
282 | sandwich is that you can't arbitrary inject values -- they need to have an
283 | explicit provisioning source.
284 |
285 | **Q: Doesn't the http.Request.Context in go 1.7 solve the middleware dependency
286 | problem?**
287 |
288 | Have a request-scoped context allows you to pass values between middleware
289 | handlers, it's true. However, there's no guarantee that the values are
290 | available, so you get the same run-time bugs that you might get with a naive
291 | dependency-injection framework. In addition, you have to do type-assertions to
292 | get your values, so there's another possible source of bugs. One of the goals
293 | of sandwich is to avoid these two types of bugs.
294 |
295 | **Q: Why do I have to use _two_ functions (before & after) to wrap a request.
296 | Why can't I just have one with a next() function?**
297 |
298 | Many middleware frameworks provide the capability to wrap a request via a next()
299 | function. Sometimes it's part of a context object
300 | ([martini's Context.Next()](https://pkg.go.dev/github.com/go-martini/martini#Context),
301 | [gin's Context.Next()](https://pkg.go.dev/github.com/gin-gonic/gin#Context.Next))
302 | and sometimes it's directly provided
303 | ([negroni's third handler arg](https://pkg.go.dev/github.com/urfave/negroni#HandlerFunc)).
304 |
305 | While implementing sandwich, I initially included a `next()` function until I
306 | realized it was impossible to validate the dependencies with such a function.
307 | Sandwich guarantees that dependencies can be supplied, and therefore `next()`
308 | had to go.
309 |
310 | Instead, I took a tip from go and instead implemented
311 | [defer](https://pkg.go.dev/github.com/augustoroman/sandwich/chain#Func.Defer).
312 | The wrap interface simply makes it obvious that there's a before and after.
313 | This allows me to keep my dependency guarantee.
314 |
315 | **Q: I don't know, it's still scary and terrible!**
316 |
317 | Don't get scared off. Take a look at the library, try it out, and I hope you
318 | enjoy it. If you don't, there are lots of great alternatives.
319 |
--------------------------------------------------------------------------------
/chain/chain_test.go:
--------------------------------------------------------------------------------
1 | package chain
2 |
3 | import (
4 | "bytes"
5 | "errors"
6 | "fmt"
7 | "testing"
8 | "time"
9 |
10 | "github.com/stretchr/testify/assert"
11 | "github.com/stretchr/testify/require"
12 | )
13 |
14 | func New() Func { return Func{} }
15 |
16 | func TestInitialInjection(t *testing.T) {
17 | var args []interface{}
18 | recordArgs := func(a int, b string) { args = append(args, a, b) }
19 |
20 | err := New().
21 | Arg(0).
22 | Arg("").
23 | Then(recordArgs).
24 | Set(3).
25 | Set("four").
26 | Then(recordArgs).
27 | Run(1, "two")
28 | assert.NoError(t, err)
29 |
30 | assert.EqualValues(t, []interface{}{1, "two", 3, "four"}, args)
31 | }
32 |
33 | func TestInitialDeferredInjection(t *testing.T) {
34 | var args []interface{}
35 | recordArgs := func(a int, b string) { args = append(args, a, b) }
36 |
37 | err := New().Arg(0).Arg("").Then(recordArgs).Run(2, "xyz")
38 | assert.NoError(t, err)
39 |
40 | assert.EqualValues(t, []interface{}{2, "xyz"}, args)
41 | }
42 |
43 | func TestDeferredExecutionOrder(t *testing.T) {
44 | var buf bytes.Buffer
45 | say := func(s string) func() { return func() { buf.WriteString(s + ":") } }
46 | err := New().
47 | Then(say("a"), say("b")).
48 | Defer(say("f")).
49 | Defer(say("e")).
50 | Then(say("c")).
51 | Defer(say("d")).
52 | Run()
53 | assert.NoError(t, err)
54 | assert.Equal(t, "a:b:c:d:e:f:", buf.String())
55 | }
56 |
57 | func TestDeferredExecutionOrderWithErrors(t *testing.T) {
58 | var buf bytes.Buffer
59 | say := func(s string) func() { return func() { buf.WriteString(s + ":") } }
60 | onErr := func(e error) { buf.WriteString("err[" + e.Error() + "]:") }
61 | fail := func() error { return errors.New("failed") }
62 | err := New().
63 | Then(say("a"), say("b")).
64 | Defer(say("f")).
65 | OnErr(onErr).
66 | Defer(say("e")).
67 | Then(fail).
68 | Then(say("c")).
69 | Defer(say("d")).
70 | Run()
71 | assert.NoError(t, err)
72 | assert.Equal(t, "a:b:err[failed]:e:f:", buf.String())
73 | }
74 |
75 | func TestBasicFuncExecution(t *testing.T) {
76 | a, b := 5, "hi"
77 | provide_initial := func() (*int, *string) { return &a, &b }
78 | verify_injected := func(x *int, y *string) {
79 | if *x != 5 {
80 | t.Errorf("Expected *int to be 6, got %d", *x)
81 | }
82 | if *y != "hi" {
83 | t.Errorf("Expected *string to be 'hi', got %s", *y)
84 | }
85 | }
86 | modify_injected := func(x *int, y *string) { *x = 6; *y = "bye" }
87 |
88 | err := New().Then(provide_initial, verify_injected, modify_injected).Run()
89 | assert.NoError(t, err)
90 |
91 | assert.Equal(t, 6, a)
92 | assert.Equal(t, "bye", b)
93 | }
94 |
95 | func TestSimpleErrorHandling(t *testing.T) {
96 | var out string
97 | chain := New().Arg("").Arg(0).
98 | OnErr(func(err error) { out += "First error handler: " + err.Error() }).
99 | Then(
100 | func(val string) (string, error) {
101 | if val == "foo" {
102 | return "bar", nil
103 | } else {
104 | return "", fmt.Errorf("%q is not foo", val)
105 | }
106 | }).
107 | OnErr(func(err error) { out += "Second error handler: " + err.Error() }).
108 | Then(
109 | func(num int) error {
110 | if num != 3 {
111 | return fmt.Errorf("%d is not 3", num)
112 | }
113 | return nil
114 | })
115 |
116 | out = ""
117 | assert.NoError(t, chain.Run("", 0))
118 | assert.Equal(t, `First error handler: "" is not foo`, out)
119 |
120 | out = ""
121 | assert.NoError(t, chain.Run("foo", 7))
122 | assert.Equal(t, `Second error handler: 7 is not 3`, out)
123 |
124 | out = ""
125 | assert.NoError(t, chain.Run("foo", 3))
126 | assert.Equal(t, ``, out)
127 | }
128 |
129 | func TestMustProvideTypes(t *testing.T) {
130 | assert.Panics(t, func() { New().Then(func(string) {}) })
131 | assert.NotPanics(t, func() { New().Set("").Then(func(string) {}) })
132 |
133 | assert.Panics(t, func() { New().Then(func(int, string) {}) })
134 | assert.NotPanics(t, func() { New().Set("").Set(3).Then(func(int, string) {}) })
135 |
136 | assert.NotPanics(t, func() {
137 | New().Then(
138 | func() string { return "" },
139 | func(string) int { return 3 },
140 | func(int) {},
141 | )
142 | }, "Should be OK: Everything is provided by earlier functions.")
143 | assert.NotPanics(t, func() {
144 | New().Then(
145 | func() string { return "" },
146 | func(string) int { return 3 },
147 | ).OnErr(func(int, error) {})
148 | }, "Should be OK: Everything is provided by earlier functions")
149 |
150 | assert.Panics(t, func() {
151 | New().Then(
152 | func() string { return "" },
153 | func(string) int { return 3 },
154 | func(bool) {},
155 | func(int) {},
156 | )
157 | }, "Should FAIL: bool isn't provided anywhere")
158 | assert.Panics(t, func() {
159 | New().Then(func() string { return "" }, func(string) int { return 3 }).
160 | OnErr(func(bool, error) {}).
161 | Then(func(int) {})
162 | }, "Should FAIL: bool isn't provided anywhere (even error handlers need proper provisioning)")
163 | }
164 |
165 | func TestErrorAbortsHandling(t *testing.T) {
166 | var out string
167 | err := New().OnErr(func(err error) { out += "Failed @ " + err.Error() }).Then(
168 | func() error { out += "1 "; return nil },
169 | func() error { out += "2 "; return fmt.Errorf("2") },
170 | func() error { out += "3 "; return nil },
171 | ).Run()
172 | assert.NoError(t, err)
173 | assert.Equal(t, "1 2 Failed @ 2", out)
174 | }
175 |
176 | func a() string { return "hello " }
177 | func b(s string) (string, int) { return s + "world", 42 }
178 | func c(s string, n int) {}
179 |
180 | func TestCatchesPanics(t *testing.T) {
181 | var err error
182 | captureError := func(e error) { err = e }
183 | panics := func() { panic("ahhhh! 🔥") }
184 |
185 | assert.NoError(t,
186 | New().OnErr(captureError).Then(a, b, c).Defer(c).Then(panics).Run())
187 |
188 | assert.NotNil(t, err)
189 |
190 | e := err.(PanicError)
191 | assert.Equal(t, e.Val, "ahhhh! 🔥")
192 | assert.Equal(t, len(e.MiddlewareStack), 4) // defers haven't run yet.
193 | assert.Contains(t, e.MiddlewareStack[0].Name, "chain.TestCatchesPanics.func2")
194 | assert.Contains(t, e.MiddlewareStack[1].Name, "chain.c")
195 | assert.Contains(t, e.MiddlewareStack[2].Name, "chain.b")
196 | assert.Contains(t, e.MiddlewareStack[3].Name, "chain.a")
197 |
198 | assert.Contains(t, err.Error(), "Panic executing middleware")
199 | assert.Contains(t, err.Error(), "ahhhh! 🔥")
200 | // This is where the panic actually occurred. This will need to be updated if
201 | // this file changes, sadly.
202 | assert.Contains(t, err.Error(), "/home/aroman/code/sandwich/chain/chain_test.go:183")
203 | assert.Contains(t, err.Error(), "func() string")
204 | assert.Contains(t, err.Error(), "func(string) (string, int)")
205 | assert.Contains(t, err.Error(), "func(string, int)")
206 | assert.Contains(t, err.Error(), "chain.a")
207 | assert.Contains(t, err.Error(), "chain.b")
208 | assert.Contains(t, err.Error(), "chain.c")
209 | }
210 |
211 | func TestDefersCanAcceptErrors(t *testing.T) {
212 | var buf bytes.Buffer
213 | onerr := func(err error) { fmt.Fprintf(&buf, "onerr[%v]:", err) }
214 | deferred := func(err error) { fmt.Fprintf(&buf, "defer[%v]:", err) }
215 | fails := func() error { return errors.New("💣") }
216 |
217 | assert.NoError(t, New().
218 | OnErr(onerr).
219 | Then(a, b, c).
220 | Defer(deferred).
221 | Then(fails).
222 | Run())
223 |
224 | assert.Equal(t, "onerr[💣]:defer[💣]:", buf.String())
225 |
226 | // But what if nothing actually fails? Defer's can still accept errors.
227 | buf.Reset()
228 | assert.NoError(t, New().
229 | OnErr(onerr).
230 | Then(a, b, c).
231 | Defer(deferred).
232 | // With(fails). // no failure!
233 | Run())
234 |
235 | assert.Equal(t, "defer[]:", buf.String())
236 | }
237 |
238 | func TestDefaultErrorHandler(t *testing.T) {
239 | var buf bytes.Buffer
240 | onerr := func(err error) { fmt.Fprintf(&buf, "onerr[%v]:", err) }
241 | fails := func() error { return errors.New("☠") }
242 |
243 | // Restore the default error handler when we're done with the test.
244 | defer func(orig interface{}) { DefaultErrorHandler = orig }(DefaultErrorHandler)
245 | DefaultErrorHandler = onerr
246 |
247 | assert.NoError(t, New().Then(fails).Run())
248 |
249 | assert.Equal(t, "onerr[☠]:", buf.String())
250 | }
251 |
252 | func TestSetAs_Nil(t *testing.T) {
253 | worked := false
254 | check := func(s fmt.Stringer) {
255 | require.Nil(t, s)
256 | worked = true
257 | }
258 | err := New().SetAs(nil, (*fmt.Stringer)(nil)).Then(check).Run()
259 | require.NoError(t, err)
260 | require.True(t, worked)
261 | }
262 |
263 | func TestProvidingBadValues(t *testing.T) {
264 | assert.Panics(t, func() { New().Set(nil) })
265 |
266 | // ifacePtr must be a pointer to an interface
267 | assert.Panics(t, func() { New().SetAs(nil, 5) })
268 | type Struct struct{}
269 | assert.Panics(t, func() { New().SetAs(nil, Struct{}) })
270 | assert.Panics(t, func() { New().SetAs(nil, &Struct{}) })
271 |
272 | // SetAs value must actually implement the specified interface
273 | assert.Panics(t, func() { New().SetAs(5, (*fmt.Stringer)(nil)) })
274 | assert.Panics(t, func() { New().SetAs(Struct{}, (*fmt.Stringer)(nil)) })
275 | }
276 |
277 | func TestWithBadValues(t *testing.T) {
278 | type Struct struct{}
279 | assert.Panics(t, func() { New().Then(nil) })
280 | assert.Panics(t, func() { New().Then(5) })
281 | assert.Panics(t, func() { New().Then(Struct{}) })
282 | }
283 |
284 | func TestBadErrorHandler(t *testing.T) {
285 | // The error handler must actually be a function
286 | assert.Panics(t, func() { New().OnErr(true) })
287 | // The error handler may not return any values.
288 | returnsSomething := func(err error) bool { return true }
289 | assert.Panics(t, func() { New().OnErr(returnsSomething) })
290 | // The error handler can't take args of types that have not yet been
291 | // provided.
292 | takesAString := func(str string, err error) {}
293 | assert.Panics(t, func() { New().OnErr(takesAString) })
294 | }
295 |
296 | func TestBadDefer(t *testing.T) {
297 | assert.Panics(t, func() { New().Defer(true) },
298 | "deferred func must actually be a function")
299 |
300 | returnsSomething := func(err error) bool { return true }
301 | assert.Panics(t, func() { New().Defer(returnsSomething) },
302 | "deferred func may not return any values")
303 |
304 | takesAString := func(str string) {}
305 | assert.Panics(t, func() { New().Defer(takesAString) },
306 | "deferred func arg types must have already been provided")
307 | }
308 |
309 | func TestInterfaceConversionOnRun(t *testing.T) {
310 | chain := New().Arg((*fmt.Stringer)(nil))
311 |
312 | assert.Error(t, chain.Run(), "missing arg")
313 | assert.NoError(t, chain.Run(nil), "nil value is ok")
314 |
315 | var stringer Stringer
316 | assert.NoError(t, chain.Run(stringer), "implements stringer")
317 | assert.NoError(t, chain.Run(&stringer), "implements stringer")
318 |
319 | var ptrStringer PtrStringer
320 | assert.Error(t, chain.Run(ptrStringer), "does not implement stringer")
321 | assert.NoError(t, chain.Run(&ptrStringer), "implements stringer")
322 |
323 | var ptrToPtrStringer *PtrStringer
324 | assert.NoError(t, chain.Run(ptrToPtrStringer), "nil implements stringer")
325 |
326 | var nilStringer fmt.Stringer
327 | assert.NoError(t, chain.Run(nilStringer), "nil values are ok")
328 |
329 | chain = New().Arg(0)
330 | assert.Error(t, chain.Run(nil),
331 | "nil values are not ok for non-pointers and non-interfaces")
332 |
333 | type Struct struct{}
334 | chain = New().Arg(&Struct{})
335 | assert.NoError(t, chain.Run(nil), "nil values are ok for pointers to structs")
336 | assert.NoError(t, chain.Run(&Struct{}), "nil values are ok for pointers to structs")
337 | assert.Error(t, chain.Run(1), "ints don't match struct pointers")
338 | }
339 |
340 | type Stringer struct{}
341 | type PtrStringer struct{}
342 |
343 | func (Stringer) String() string { return "yup" }
344 | func (*PtrStringer) String() string { return "yup" }
345 |
346 | func TestBadRunArgs(t *testing.T) {
347 | chain := New().
348 | Arg(int(0)).
349 | Arg((*fmt.Stringer)(nil))
350 |
351 | assert.Error(t, chain.Run(), "not all args are specified")
352 | assert.Error(t, chain.Run(0), "not all args are specified")
353 | assert.NoError(t, chain.Run(0, nil), "all args are specified")
354 | }
355 |
356 | func TestRunArgsMustExactlyMatchSpecifiedArgs(t *testing.T) {
357 | chain := New().
358 | Arg(int(0)).
359 | Arg("").
360 | Arg(true)
361 |
362 | // OK
363 | assert.NoError(t, chain.Run(0, "hi", true))
364 |
365 | // Wrong ordering
366 | assert.EqualError(t, chain.Run(true, "hi", 0),
367 | "bad arg: 1st arg of Run(...) should be a int but is bool")
368 | assert.EqualError(t, chain.Run(0, true, "hi"),
369 | "bad arg: 2nd arg of Run(...) should be a string but is bool")
370 |
371 | // Too many
372 | assert.EqualError(t, chain.Run(0, "hi", true, 'x', 0),
373 | "too many args: expected 3 args but got 5 args")
374 | // Not enough
375 | assert.EqualError(t, chain.Run(0, "hi"),
376 | "missing args of types: [bool]")
377 | assert.EqualError(t, chain.Run(0),
378 | "missing args of types: [string bool]")
379 | }
380 |
381 | func TestRunWithNilReservedInterface(t *testing.T) {
382 | var capturedStringer fmt.Stringer = time.Now()
383 | chain := New().
384 | Arg((*fmt.Stringer)(nil)).
385 | Then(func(s fmt.Stringer) { capturedStringer = s })
386 |
387 | require.NoError(t, chain.Run(nil))
388 | assert.Nil(t, capturedStringer)
389 | }
390 |
--------------------------------------------------------------------------------
/router.go:
--------------------------------------------------------------------------------
1 | package sandwich
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "net/http"
7 | "strings"
8 |
9 | "github.com/augustoroman/sandwich/chain"
10 | )
11 |
12 | // Router implements the sandwich middleware chaining and routing functionality.
13 | type Router interface {
14 | // Set a value that will be available to all handlers subsequent referenced.
15 | // This is typically used for concrete values. For interfaces to be correctly
16 | // provided to subsequent middleware, use SetAs.
17 | Set(vals ...any)
18 | // SetAs sets a value as the specified interface that will be available to all
19 | // handlers.
20 | //
21 | // Example:
22 | // type DB interface { ... }
23 | // var db DB = ...
24 | // mux.SetAs(db, (*DB)(nil))
25 | //
26 | // That is functionally equivalent to using a middleware function that returns
27 | // the desired interface instance:
28 | // type DB interface { ... }
29 | // var db DB = ...
30 | // mux.Use(func() DB { return db })
31 | SetAs(val, ifacePtr any)
32 |
33 | // Use adds middleware to be invoked for all routes registered by the
34 | // returned Router. The current router is not affected. This is equivalent to
35 | // adding the specified middelwareHandlers to each registered route.
36 | Use(middlewareHandlers ...any)
37 |
38 | // On will register a handler for the given method and path.
39 | On(method, path string, handlers ...any)
40 |
41 | // Get registers handlers for the specified path for the 'GET' HTTP method.
42 | // Get is shorthand for `On("GET", ...)`.
43 | Get(path string, handlers ...any)
44 | // Put registers handlers for the specified path for the 'PUT' HTTP method.
45 | // Put is shorthand for `On("PUT", ...)`.
46 | Put(path string, handlers ...any)
47 | // Post registers handlers for the specified path for the 'POST' HTTP method.
48 | // Post is shorthand for `On("POST", ...)`.
49 | Post(path string, handlers ...any)
50 | // Patch registers handlers for the specified path for the 'PATCH' HTTP
51 | // method. Patch is shorthand for `On("PATCH", ...)`.
52 | Patch(path string, handlers ...any)
53 | // Delete registers handlers for the specified path for the 'DELETE' HTTP
54 | // method. Delete is shorthand for `On("DELETE", ...)`.
55 | Delete(path string, handlers ...any)
56 | // Any registers a handlers for the specified path for any HTTP method. This
57 | // will always be superceded by dedicated method handlers. For example, if the
58 | // path '/users/:id/' is registered for Get, Put and Any, GET and PUT requests
59 | // will be handled by the Get(...) and Put(...) registrations, but DELETE,
60 | // CONNECT, or HEAD would be handled by the Any(...) registration. Any is a
61 | // shortcut for `On("*", ...)`.
62 | Any(path string, handlers ...any)
63 |
64 | // OnErr uses the specified error handler to handle any errors that occur on
65 | // any routes in this router.
66 | OnErr(handler any)
67 |
68 | // SubRouter derives a router that will called for all suffixes (and methods)
69 | // for the specified path. For example, `sub := root.SubRouter("/api")` will
70 | // create a router that will handle `/api/`, `/api/foo`.
71 | SubRouter(pathPrefix string) Router
72 |
73 | // ServeHTTP implements the http.Handler interface for the router.
74 | ServeHTTP(w http.ResponseWriter, r *http.Request)
75 | }
76 |
77 | // BuildYourOwn returns a minimal router that has no initial middleware
78 | // handling.
79 | func BuildYourOwn() Router {
80 | r := &router{}
81 | r.base = r.base.Arg((*http.ResponseWriter)(nil))
82 | r.base = r.base.Arg((*http.Request)(nil))
83 | r.base = r.base.Arg((Params)(nil))
84 | return r
85 | }
86 |
87 | // TheUsual returns a router initialized with useful middleware.
88 | func TheUsual() Router {
89 | r := BuildYourOwn()
90 | r.Use(WrapResponseWriter, LogRequests)
91 | r.OnErr(HandleError)
92 | return r
93 | }
94 |
95 | type router struct {
96 | base chain.Func
97 | subRouters map[string]*router
98 | byMethod map[string]*mux
99 | anyMethod *mux
100 | notFound http.Handler
101 | }
102 |
103 | func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
104 | params := Params{}
105 | h := r.match(req.Method, req.URL.Path, params)
106 | if h != nil {
107 | h.ServeHTTP(w, req, params)
108 | } else if r.notFound != nil {
109 | r.notFound.ServeHTTP(w, req)
110 | } else {
111 | http.Error(w, "Not found", http.StatusNotFound)
112 | }
113 | }
114 |
115 | func (r *router) SubRouter(prefix string) Router {
116 | if r.subRouters == nil {
117 | r.subRouters = map[string]*router{}
118 | }
119 | prefix = strings.TrimRight(prefix, "/") + "/"
120 | for existingPrefix := range r.subRouters {
121 | if existingPrefix == prefix || strings.HasPrefix(existingPrefix, prefix) || strings.HasPrefix(prefix, existingPrefix) {
122 | panic(fmt.Sprintf(
123 | "SubRouter with prefix %#q conflicts with existing SubRouter with prefix %#q",
124 | prefix, existingPrefix,
125 | ))
126 | }
127 | }
128 | r.subRouters[prefix] = &router{
129 | base: r.base,
130 | notFound: r.notFound,
131 | }
132 | return r.subRouters[prefix]
133 | }
134 |
135 | func (r *router) match(method, uri string, params Params) httpHandlerWithParams {
136 | method = strings.ToUpper(method)
137 | for prefix, sub := range r.subRouters {
138 | if strings.HasPrefix(uri, prefix) {
139 | return sub.match(method, strings.TrimPrefix(uri, prefix), params)
140 | }
141 | }
142 | if h := r.byMethod[method].Match(uri, params); h != nil {
143 | return h
144 | }
145 | if h := r.anyMethod.Match(uri, params); h != nil {
146 | return h
147 | }
148 | return nil
149 | }
150 |
151 | func (r *router) Set(vals ...any) {
152 | for _, val := range vals {
153 | r.base = r.base.Set(val)
154 | }
155 | }
156 |
157 | func (r *router) SetAs(val, ifacePtr any) {
158 | r.base = r.base.SetAs(val, ifacePtr)
159 | }
160 |
161 | func (r *router) Use(middlewareHandlers ...any) {
162 | r.base = apply(r.base, middlewareHandlers...)
163 | }
164 |
165 | func (r *router) OnErr(errorHandler any) {
166 | r.base = r.base.OnErr(errorHandler)
167 | }
168 |
169 | func (r *router) On(method, path string, handlers ...any) {
170 | method = strings.ToUpper(method)
171 | m := r.getOrAllocateMux(method)
172 | if err := m.Register(path, handler{apply(r.base, handlers...)}); err != nil {
173 | panic(fmt.Errorf("Cannot register route: %v", err))
174 | }
175 | }
176 |
177 | func (r *router) Any(path string, handlers ...any) { r.On("*", path, handlers...) }
178 | func (r *router) Get(path string, handlers ...any) { r.On("GET", path, handlers...) }
179 | func (r *router) Put(path string, handlers ...any) { r.On("PUT", path, handlers...) }
180 | func (r *router) Post(path string, handlers ...any) { r.On("POST", path, handlers...) }
181 | func (r *router) Patch(path string, handlers ...any) { r.On("PATCH", path, handlers...) }
182 | func (r *router) Delete(path string, handlers ...any) { r.On("DELETE", path, handlers...) }
183 |
184 | func (r *router) getOrAllocateMux(method string) *mux {
185 | if method == "*" {
186 | if r.anyMethod == nil {
187 | r.anyMethod = &mux{}
188 | }
189 | return r.anyMethod
190 | }
191 | if r.byMethod == nil {
192 | r.byMethod = map[string]*mux{}
193 | }
194 | m := r.byMethod[method]
195 | if m == nil {
196 | m = &mux{}
197 | r.byMethod[method] = m
198 | }
199 | return m
200 | }
201 |
202 | type handler struct{ chain.Func }
203 |
204 | func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request, p Params) {
205 | h.Func.MustRun(w, r, p)
206 | }
207 |
208 | type Params map[string]string
209 |
210 | type mux struct {
211 | static map[string]*mux
212 | params []muxParam
213 | handler httpHandlerWithParams
214 | }
215 |
216 | type muxParam struct {
217 | paramName string
218 | greedy bool
219 | mux *mux
220 | }
221 |
222 | type httpHandlerWithParams interface {
223 | ServeHTTP(w http.ResponseWriter, r *http.Request, p Params)
224 | }
225 |
226 | func (m *mux) Register(pattern string, h httpHandlerWithParams) error {
227 | if !strings.HasPrefix(pattern, "/") {
228 | return errors.New("patterns must begin with /")
229 | }
230 | segments := strings.Split(pattern[1:], "/")
231 | reg := registerInfo{
232 | seenParams: map[string]bool{},
233 | seenGreedy: false,
234 | }
235 | if m.static == nil {
236 | m.static = map[string]*mux{}
237 | }
238 | if err := reg.registerSegments(m, segments, h); err != nil {
239 | return fmt.Errorf("%#q: bad pattern: %w", pattern, err)
240 | }
241 | return nil
242 | }
243 |
244 | type registerInfo struct {
245 | seenParams map[string]bool
246 | seenGreedy bool
247 | }
248 |
249 | func (r *registerInfo) registerSegments(m *mux, segments []string, h httpHandlerWithParams) error {
250 | if len(segments) == 0 {
251 | if m.handler != nil {
252 | return fmt.Errorf("repeated entry")
253 | }
254 | m.handler = h
255 | return nil
256 | }
257 | next, remaining := segments[0], segments[1:]
258 | if strings.HasPrefix(next, "::") {
259 | return r.registerStatic(m, next[1:], remaining, h)
260 | } else if strings.HasPrefix(next, ":") {
261 | return r.registerParam(m, next[1:], remaining, h)
262 | } else {
263 | return r.registerStatic(m, next, remaining, h)
264 | }
265 | }
266 |
267 | func (r *registerInfo) registerStatic(m *mux, path string, remaining []string, h httpHandlerWithParams) error {
268 | sub := m.static[path]
269 | if sub == nil {
270 | sub = &mux{
271 | static: map[string]*mux{},
272 | }
273 | }
274 | err := r.registerSegments(sub, remaining, h)
275 | if err == nil {
276 | m.static[path] = sub
277 | }
278 | return err
279 | }
280 |
281 | func (r *registerInfo) registerParam(m *mux, param string, remaining []string, h httpHandlerWithParams) error {
282 | greedy := strings.HasSuffix(param, "*")
283 | name := strings.TrimSuffix(param, "*")
284 | if greedy && r.seenGreedy {
285 | return fmt.Errorf("only one greedy param allowed per pattern: %#q", name)
286 | } else if r.seenParams[name] {
287 | return fmt.Errorf("param used twice: %#q", name)
288 | }
289 | // Check to see if the param already exists. E.g. we've already registered
290 | // param at this level via:
291 | // /root/:param/path1 --> h1
292 | // and now we're registering:
293 | // /root/:param/path2 --> h2
294 | for _, p := range m.params {
295 | if p.paramName == name {
296 | if p.greedy != greedy {
297 | return fmt.Errorf("param %#q is sometimes greedy and sometimes not", name)
298 | }
299 | return r.registerSegments(p.mux, remaining, h)
300 | }
301 | // If we haven't registered this one yet, then we need to avoid ambiguous
302 | // path registrations. For example:
303 | // /root/:p1/path
304 | // /root/:p2/path
305 | // should not be allowed, nor should:
306 | // /root/:p1/:x/:y
307 | // /root/:p2/:a/:b
308 | if err := p.mux.checkAmbiguous(remaining); err != nil {
309 | return fmt.Errorf("ambiguous route: %w", err)
310 | }
311 | }
312 | sub := &mux{
313 | static: map[string]*mux{},
314 | }
315 | r.seenParams[name] = true
316 | r.seenGreedy = r.seenGreedy || greedy
317 | err := r.registerSegments(sub, remaining, h)
318 | if err == nil {
319 | m.params = append(m.params, muxParam{
320 | paramName: name,
321 | greedy: greedy,
322 | mux: sub,
323 | })
324 | }
325 | return err
326 | }
327 |
328 | func (m *mux) checkAmbiguous(segments []string) error {
329 | if len(segments) == 0 {
330 | if m.handler != nil {
331 | return fmt.Errorf("ambiguous route")
332 | }
333 | return nil
334 | }
335 | static, isStatic, _, _ := entryToInfo(segments[0])
336 | if isStatic {
337 | if child := m.static[static]; child != nil {
338 | return child.checkAmbiguous(segments[1:])
339 | }
340 | return nil
341 | }
342 | for _, p := range m.params {
343 | if err := p.mux.checkAmbiguous(segments[1:]); err != nil {
344 | return err
345 | }
346 | }
347 | return nil
348 | }
349 |
350 | func entryToInfo(entry string) (static string, isStatic bool, paramName string, greedy bool) {
351 | if strings.HasPrefix(entry, "::") {
352 | // double colon prefix escapes to single colon static path name.
353 | return entry[1:], true, "", false
354 | } else if !strings.HasPrefix(entry, ":") {
355 | return entry, true, "", false
356 | }
357 | paramName = strings.TrimSuffix(entry[1:], "*")
358 | greedy = strings.HasSuffix(entry, "*")
359 | return "", false, paramName, greedy
360 | }
361 |
362 | func (m *mux) Match(uri string, params Params) httpHandlerWithParams {
363 | uri = strings.TrimPrefix(uri, "/")
364 | segments := strings.Split(uri, "/")
365 | matched := m.matchPrefix(segments, params)
366 | if matched == nil {
367 | return nil
368 | }
369 | return matched
370 | }
371 |
372 | func (m *mux) matchPrefix(segments []string, params Params) httpHandlerWithParams {
373 | if m == nil {
374 | return nil
375 | }
376 | if len(segments) == 0 {
377 | return m.handler
378 | }
379 | path, remaining := segments[0], segments[1:]
380 | if sub := m.static[path]; sub != nil {
381 | match := sub.matchPrefix(remaining, params)
382 | if match != nil {
383 | return match
384 | }
385 | }
386 | for _, param := range m.params {
387 | if !param.greedy {
388 | matched := param.mux.matchPrefix(remaining, params)
389 | if matched != nil {
390 | params[param.paramName] = path
391 | return matched
392 | }
393 | } else {
394 | matched, used := param.mux.matchSuffix(remaining, params)
395 | if matched != nil {
396 | N := len(segments)
397 | params[param.paramName] = strings.Join(segments[:N-used], "/")
398 | return matched
399 | }
400 | }
401 | }
402 | return nil
403 | }
404 |
405 | func (m *mux) matchSuffix(segments []string, params Params) (h httpHandlerWithParams, depth int) {
406 | N := len(segments)
407 | if N == 0 {
408 | return m.handler, 0
409 | }
410 | for staticPath, sub := range m.static {
411 | match, d := sub.matchSuffix(segments, params)
412 | if match == nil {
413 | continue
414 | }
415 | depth = d + 1
416 | actualPath := segments[N-depth]
417 | if actualPath != staticPath {
418 | continue
419 | }
420 | return match, depth
421 | }
422 | for _, param := range m.params {
423 | match, d := param.mux.matchSuffix(segments, params)
424 | if match == nil {
425 | continue
426 | }
427 | depth = d + 1
428 | actualPath := segments[N-depth]
429 | params[param.paramName] = actualPath // TODO: might be rejected, might spam params
430 | return match, depth
431 | }
432 | return m.handler, 0
433 | }
434 |
--------------------------------------------------------------------------------
/chain/chain.go:
--------------------------------------------------------------------------------
1 | // Package chain is a reflection-based dependency-injected chain of functions
2 | // that powers the sandwich middleware framework.
3 | //
4 | // A Func chain represents a sequence of functions to call along with an initial
5 | // input. The parameters to each function are automatically provided from either
6 | // the initial inputs or return values of earlier functions in the sequence.
7 | //
8 | // In contrast to other dependency-injection frameworks, chain does not
9 | // automatically determine how to provide dependencies -- it merely uses the
10 | // most recently-provided value. This enables chains to report errors
11 | // immediately during the chain construction and, if successfully constructed, a
12 | // chain can always be executed.
13 | //
14 | // # HTTP Middleware Example
15 | //
16 | // As a common example, chains in http handling middleware typically start with
17 | // the http.ResponseWriter and *http.Request provided by the http framework:
18 | //
19 | // base := chain.Func{}.
20 | // Arg((*http.ResponseWriter)(nil)). // declared as an arg when Run
21 | // Arg((*http.Request)(nil)). // declared as an arg when Run
22 | //
23 | // Given the following functions:
24 | //
25 | // func GetDB() (*UserDB, error) {...}
26 | // func GetUserFromRequest(db *UserDB, req *http.Request) (*User, error) {...}
27 | // func SendUserAsJSON(w http.ResponseWriter, u *User) error {...}
28 | //
29 | // func GetUserID(r *http.Request) (UserID, error) {...}
30 | // func (db *UserDB) Lookup(UserID) (*User, error) { ... }
31 | //
32 | // func SendProjectAsJSON(w http.ResponseWriter, p *Project) error {...}
33 | //
34 | // then these chains would work fine:
35 | //
36 | // base.Then(
37 | // GetDB, // takes no args ✅, provides *UserDB to later funcs
38 | // GetUserFromRequest, // takes *UserDB ✅ and *Request ✅, provides *User
39 | // SendUserAsJSON, // takes ResponseWriter ✅ and *User ✅
40 | // )
41 | //
42 | // base.Then(
43 | // GetDB, // takes no args ✅, provides *UserDB to later funcs
44 | // GetUserID, // takes *Request ✅, provides UserID
45 | // (*UserDB).Lookup, // takes *UserDB ✅ and UserID ✅, provides *User
46 | // SendUserAsJSON, // takes ResponseWriter ✅ and *User ✅
47 | // )
48 | //
49 | // but these chains would fail:
50 | //
51 | // base.Then(
52 | // GetUserFromRequest, // takes *UserDB ❌ and *Request ✅
53 | // GetDB, // this *UserDB isn't available yet.
54 | // SendUserAsJSON, //
55 | // )
56 | //
57 | // base.Then(
58 | // GetDB, // takes no args ✅, provides *UserDB to later funcs
59 | // GetUserID, // takes *Request ✅, provides UserID
60 | // (*UserDB).Lookup, // takes *UserDB ✅ and UserID ✅, provides *User
61 | // SendProjectAsJSON, // takes ResponseWriter ✅ and *Project ❌
62 | // )
63 | //
64 | // base.Then(
65 | // GetUserFromRequest, // takes *UserDB ❌ and *Request ✅
66 | // SendUserAsJSON, //
67 | // )
68 | package chain
69 |
70 | import (
71 | "bytes"
72 | "fmt"
73 | "reflect"
74 | "runtime"
75 | "strings"
76 | "text/tabwriter"
77 | )
78 |
79 | var errorType = reflect.TypeOf((*error)(nil)).Elem()
80 |
81 | // DefaultErrorHandler is called when an error in the chain occurs and no error
82 | // handler has been registered. Warning! The default error handler is not
83 | // checked to verify that it's arguments can be provided. It's STRONGLY
84 | // recommended to keep this as absolutely simple as possible.
85 | var DefaultErrorHandler interface{} = func(err error) { panic(err) }
86 |
87 | // Func defines the chain of functions to invoke when Run. Each Func is
88 | // immutable: all operations will return a new Func chain.
89 | type Func struct{ steps []step }
90 |
91 | // step is a single value or handler in the middleware stack. Each step has a
92 | // typ flag that indicates what kind of step it is.
93 | type step struct {
94 | typ stepType
95 | val reflect.Value
96 | // For tVALUE steps, this may optionally be non-nil to specific an
97 | // additional interface type that is provided.
98 | // For tRESERVE steps, this must be non-nil to declare the reserved type.
99 | // For t*_HANDLER steps, this is the function type.
100 | valTyp reflect.Type
101 | }
102 |
103 | type stepType uint8
104 |
105 | const (
106 | tARG stepType = iota
107 | tVALUE
108 | tPRE_HANDLER // PRE handlers are the normal handlers
109 | tPOST_HANDLER // POST handlers are deferred handlers
110 | tERROR_HANDLER
111 | )
112 |
113 | // Clone this chain and add the extra steps to the clone.
114 | func (c Func) with(steps ...step) Func {
115 | s := make([]step, 0, len(c.steps)+len(steps))
116 | s = append(s, c.steps...)
117 | s = append(s, steps...)
118 | return Func{s}
119 | }
120 |
121 | // Arg indicates that a value with the specified type will be a parameter to Run
122 | // when the Func is invoked. This is typically necessary to start the chain for
123 | // a given middleware framework. Arg should not be exposed to users of sandwich
124 | // since it bypasses the causal checks and risks runtime errors.
125 | func (c Func) Arg(typeOrInterfacePtr interface{}) Func {
126 | typ := reflect.TypeOf(typeOrInterfacePtr)
127 | if typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Interface {
128 | typ = typ.Elem()
129 | }
130 | return c.with(step{tARG, reflect.Value{}, typ})
131 | }
132 |
133 | // Set an immediate value. This cannot be used to provide an interface, instead
134 | // use SetAs(...) or With(...) with a function that returns the interface.
135 | func (c Func) Set(value interface{}) Func {
136 | if value == nil {
137 | panicf("Set(nil) is not allowed -- " +
138 | "did you mean to use SetAs(val, (*IFace)(nil))?")
139 | }
140 | return c.with(step{tVALUE, reflect.ValueOf(value), reflect.TypeOf(value)})
141 | }
142 |
143 | // SetAs provides an immediate value as the specified interface type.
144 | func (c Func) SetAs(value, ifacePtr interface{}) Func {
145 | val := reflect.ValueOf(value)
146 | typ := reflect.TypeOf(ifacePtr)
147 | if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Interface {
148 | panicf("ifacePtr must be a pointer to an interface for "+
149 | "SetAs, instead got %s", typ)
150 | }
151 | typ = typ.Elem()
152 | // It's ok to pass in a nil value here if you want the interface to actually
153 | // be nil.
154 | if !val.IsValid() {
155 | val = reflect.Zero(typ)
156 | }
157 | if !val.Type().Implements(typ) {
158 | panicf("%s doesn't implement %s", val.Type(), typ)
159 | }
160 | return c.with(step{tVALUE, val, typ})
161 | }
162 |
163 | // Compute what types are available from the reserved values, provide values,
164 | // and function return values of the current handler chain. This excludes
165 | // error handlers and deferred handlers.
166 | func (c Func) typesAvailable() map[reflect.Type]bool {
167 | m := map[reflect.Type]bool{}
168 | for _, s := range c.steps {
169 | switch s.typ {
170 | case tARG:
171 | m[s.valTyp] = true
172 | case tVALUE:
173 | m[s.val.Type()] = true
174 | m[s.valTyp] = true
175 | case tPRE_HANDLER:
176 | for i := 0; i < s.valTyp.NumOut(); i++ {
177 | m[s.valTyp.Out(i)] = true
178 | }
179 | case tPOST_HANDLER, tERROR_HANDLER:
180 | // ignored, we don't allow any return values for these.
181 | }
182 | }
183 | return m
184 | }
185 |
186 | // Then adds one or more handlers to the middleware chain. It may only accept
187 | // args of types that have already been provided.
188 | func (c Func) Then(handlers ...interface{}) Func {
189 | steps := make([]step, len(handlers))
190 | available := c.typesAvailable()
191 | for i, handler := range handlers {
192 | fn, err := valueOfFunction(handler)
193 | if err != nil {
194 | panicf("%s arg of With(...) %v", ordinalize(i+1), err)
195 | }
196 | if err := checkCanCall(available, fn); err != nil {
197 | panicf("%s arg of With(...) %v", ordinalize(i+1), err)
198 | }
199 | fnType := fn.Func.Type()
200 | steps[i] = step{tPRE_HANDLER, fn.Func, fnType}
201 | for i := 0; i < fnType.NumOut(); i++ {
202 | available[fnType.Out(i)] = true
203 | }
204 | }
205 | return c.with(steps...)
206 | }
207 |
208 | // OnErr registers an error handler to be called for failures of subsequent
209 | // handlers. It may only accept args of types that have already been provided.
210 | func (c Func) OnErr(errorHandler interface{}) Func {
211 | fn, err := valueOfFunction(errorHandler)
212 | if err != nil {
213 | panicf("Error handler %v", err)
214 | }
215 | available := c.typesAvailable()
216 | available[errorType] = true // Set internally by chain.
217 | if err := checkCanCall(available, fn); err != nil {
218 | panicf("Error handler %v", err)
219 | }
220 | if fn.Func.Type().NumOut() > 0 {
221 | panicf("Error handler %s may not have any return values, signature is %s",
222 | fn.Name, fn.Func.Type())
223 | }
224 | return c.with(step{tERROR_HANDLER, fn.Func, fn.Func.Type()})
225 | }
226 |
227 | // Defer adds a deferred handler to be executed after all normal handlers and
228 | // error handlers have been called. Deferred handlers are executed in reverse
229 | // order that they were registered (most recent first). Deferred handlers can
230 | // accept the error type even if it hasn't been explicitly provided yet. If no
231 | // error has occurred, it will be nil.
232 | func (c Func) Defer(handler interface{}) Func {
233 | fn, err := valueOfFunction(handler)
234 | if err != nil {
235 | panicf("Defer(...) arg %v", err)
236 | }
237 | available := c.typesAvailable()
238 | available[errorType] = true // Set internally by chain.
239 | if err := checkCanCall(available, fn); err != nil {
240 | panicf("Defer(...) arg %v", err)
241 | }
242 | if fn.Func.Type().NumOut() > 0 {
243 | panicf("Defer'd handler %s may not have any return values, signature is %s",
244 | fn.Name, fn.Func.Type())
245 | }
246 | return c.with(step{tPOST_HANDLER, fn.Func, fn.Func.Type()})
247 | }
248 |
249 | // MustRun will function chain with the provided args and panic if the args
250 | // don't match the expected arg values.
251 | func (c Func) MustRun(argValues ...interface{}) {
252 | // This will only ever return an error if the arguments to Run don't match.
253 | // Runtime failures of the functions in the chain are handled by the
254 | // registered error handlers (or the default error handler which may panic).
255 | if err := c.Run(argValues...); err != nil {
256 | panic(err)
257 | }
258 | }
259 |
260 | // Run executes the function chain. All declared args must be provided in the
261 | // order than they were declared. This will return an error only if the
262 | // arguments do not exactly correspond to the declared args. Interface values
263 | // must be passed as pointers to the interface.
264 | //
265 | // Important note: The returned error is NOT related to whether any the calls of
266 | // chain returns an error -- any errors returned by functions in the chain are
267 | // handled by the registered error handlers.
268 | func (c Func) Run(argValues ...interface{}) error {
269 | data := map[reflect.Type]reflect.Value{}
270 | postSteps := []step{} // collect post steps here
271 | errHandler := step{ // Initialize using the default error handler.
272 | tERROR_HANDLER,
273 | reflect.ValueOf(DefaultErrorHandler),
274 | reflect.TypeOf(DefaultErrorHandler),
275 | }
276 | stack := []step{}
277 |
278 | // 1: Apply all of the arguments to the available data. Make sure that the
279 | // provided arguments match the Arg calls, otherwise we bomb.
280 | if err := c.processRunArgs(data, argValues...); err != nil {
281 | return err
282 | }
283 |
284 | // Start executing the function chain. First pass through is the normal call
285 | // chain, so we skip execution of error handlers and deferred handlers,
286 | // although we keep track of them.
287 | execution:
288 | for _, step := range c.steps {
289 | switch step.typ {
290 | case tARG:
291 | // ignored now, already handled during initialization above.
292 | case tVALUE:
293 | data[step.val.Type()] = step.val
294 | data[step.valTyp] = step.val
295 | case tPRE_HANDLER:
296 | c.call(step, data, &stack)
297 | // Check to see if there's an error. If so, abort the chain.
298 | if errorVal := data[errorType]; errorVal.IsValid() && !errorVal.IsNil() {
299 | break execution
300 | }
301 | case tPOST_HANDLER:
302 | postSteps = append(postSteps, step)
303 | case tERROR_HANDLER:
304 | errHandler = step
305 | }
306 | }
307 |
308 | // Execute the error handler if there is any error.
309 | if errorVal := data[errorType]; errorVal.IsValid() && !errorVal.IsNil() {
310 | c.call(errHandler, data, &stack)
311 | } else {
312 | data[errorType] = reflect.Zero(errorType)
313 | }
314 |
315 | // Finally, call any deferred functions that we've gotten to.
316 | for i := len(postSteps) - 1; i >= 0; i-- {
317 | c.call(postSteps[i], data, &stack)
318 | }
319 |
320 | return nil
321 | }
322 |
323 | func (c Func) processRunArgs(
324 | data map[reflect.Type]reflect.Value,
325 | argValues ...interface{},
326 | ) error {
327 | argIndex := 0
328 | expectedNumArgs := 0
329 | var missingArgs []string
330 | for _, step := range c.steps {
331 | if step.typ != tARG {
332 | continue
333 | }
334 | expectedNumArgs++
335 | if argIndex >= len(argValues) {
336 | missingArgs = append(missingArgs, step.valTyp.String())
337 | continue
338 | }
339 | val := argValues[argIndex]
340 | argIndex++
341 |
342 | if val == nil {
343 | if step.valTyp.Kind() == reflect.Interface || step.valTyp.Kind() == reflect.Ptr {
344 | data[step.valTyp] = reflect.New(step.valTyp).Elem()
345 | continue
346 | }
347 | return fmt.Errorf("bad arg: %s arg of Run(...) should be a %s but is %v",
348 | ordinalize(argIndex), step.valTyp, val)
349 | }
350 |
351 | rv := reflect.ValueOf(val)
352 | if !rv.CanConvert(step.valTyp) {
353 | return fmt.Errorf("bad arg: %s arg of Run(...) should be a %s but is %s",
354 | ordinalize(argIndex), step.valTyp, rv.Type())
355 | }
356 | data[step.valTyp] = rv.Convert(step.valTyp)
357 | }
358 | if len(missingArgs) > 0 {
359 | return fmt.Errorf("missing args of types: %s", missingArgs)
360 | }
361 | if argIndex != len(argValues) {
362 | return fmt.Errorf("too many args: expected %d args but got %d args",
363 | expectedNumArgs, len(argValues))
364 | }
365 | return nil
366 | }
367 |
368 | func (c Func) call(s step, data map[reflect.Type]reflect.Value, stack *[]step) {
369 | t := s.valTyp
370 | in := make([]reflect.Value, t.NumIn())
371 | for i := range in {
372 | in[i] = data[t.In(i)]
373 | // This isn't supposed to happen if we've done all our checks right.
374 | if !in[i].IsValid() {
375 | name := runtime.FuncForPC(s.val.Pointer()).Name()
376 | panicf("Cannot inject %s arg of type %s into %s (%s). Data: %v",
377 | ordinalize(i+1), t.In(i), name, t, data)
378 | }
379 | }
380 | defer func() {
381 | if err := c.wrapPanic(recover(), *stack); err != nil {
382 | data[errorType] = reflect.ValueOf((*error)(&err)).Elem()
383 | }
384 | }()
385 | *stack = append(*stack, s)
386 | out := s.val.Call(in)
387 | for _, val := range out {
388 | data[val.Type()] = val
389 | }
390 | }
391 |
392 | func (c Func) wrapPanic(x interface{}, steps []step) error {
393 | if x == nil {
394 | return nil
395 | }
396 | var stack [8192]byte
397 | n := runtime.Stack(stack[:], false)
398 |
399 | N := len(steps)
400 | mwStack := make([]FuncInfo, N)
401 | for i := range steps {
402 | step := steps[N-i-1]
403 | info := runtime.FuncForPC(step.val.Pointer())
404 | file, line := info.FileLine(step.val.Pointer())
405 | mwStack[i] = FuncInfo{info.Name(), file, line, step.val}
406 | }
407 |
408 | return PanicError{
409 | Val: x,
410 | RawStack: string(stack[:n]),
411 | MiddlewareStack: mwStack,
412 | }
413 | }
414 |
415 | // PanicError is the error that is returned if a handler panics. It includes
416 | // the panic'd value (Val), the raw Go stack trace (RawStack), and the
417 | // middleware execution history (MiddlewareStack) that shows what middleware
418 | // functions have already been called.
419 | type PanicError struct {
420 | Val interface{}
421 | RawStack string
422 | MiddlewareStack []FuncInfo
423 | }
424 |
425 | // FuncInfo describes a registered middleware function.
426 | type FuncInfo struct {
427 | Name string // fully-qualified name, e.g.: github.com/foo/bar.FuncName
428 | File string
429 | Line int
430 | Func reflect.Value
431 | }
432 |
433 | // FilteredStack returns the stack trace without some internal chain.* functions
434 | // and without reflect.Value.call stack frames, since these are generally just
435 | // noise. The reflect.Value.call removal could affect user stack frames.
436 | //
437 | // TODO(aroman): Refine filtering so that it only removes reflect.Value.call
438 | // frames due to sandwich.
439 | func (p PanicError) FilteredStack() []string {
440 | lines := strings.Split(p.RawStack, "\n")
441 | var filtered []string
442 | for i := 0; i < len(lines); i++ {
443 | line := lines[i]
444 | if strings.HasPrefix(line, "github.com/augustoroman/sandwich/chain") &&
445 | !strings.HasPrefix(line, "github.com/augustoroman/sandwich/chain.Func.Run(") &&
446 | !strings.HasPrefix(line, "github.com/augustoroman/sandwich/chain.Test") {
447 | i++
448 | continue
449 | }
450 | if strings.HasPrefix(line, "reflect.Value.call") || strings.HasPrefix(line, "reflect.Value.Call") {
451 | i++
452 | continue
453 | }
454 | filtered = append(filtered, line)
455 | }
456 | return filtered
457 | }
458 |
459 | func (p PanicError) Error() string {
460 | var mwStack bytes.Buffer
461 | w := tabwriter.NewWriter(&mwStack, 5, 7, 2, ' ', 0)
462 | for _, fn := range p.MiddlewareStack {
463 | fmt.Fprintf(w, " %s\t%s\n", fn.Name, fn.Func.Type())
464 | }
465 | w.Flush()
466 | return fmt.Sprintf(
467 | "Panic executing middleware %s: %v\n"+
468 | " Middleware executed:\n%s"+
469 | " Filtered call stack:\n %s",
470 | p.MiddlewareStack[0].Name, p.Val,
471 | mwStack.String(),
472 | strings.Join(p.FilteredStack(), "\n "))
473 | }
474 |
--------------------------------------------------------------------------------