├── .github ├── ISSUE_TEMPLATE.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ └── build.yml ├── .gitignore ├── .whitesource ├── LICENSE ├── README.md ├── codecov.yml ├── context.go ├── context_test.go ├── examples ├── authentication │ ├── .gitignore │ ├── README.md │ ├── main.go │ ├── middleware.go │ ├── resource.go │ └── state.go ├── params │ └── main.go └── simple │ └── main.go ├── go.mod ├── handler.go ├── handler_test.go ├── params.go ├── params_test.go ├── routeparams.go ├── service.go ├── service_test.go ├── tree.go └── tree_test.go /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Before you file an issue, please consider: 2 | 3 | We only accept pull requests for minor fixes or improvements. This includes: 4 | 5 | * Small bug fixes 6 | * Typos 7 | * Documentation or comments 8 | 9 | Please open issues to discuss new features. Pull requests for new features will be rejected, 10 | so we recommend forking the repository and making changes in your fork for your use case. 11 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Before you create a pull request, please consider: 2 | 3 | We only accept pull requests for minor fixes or improvements. This includes: 4 | 5 | * Small bug fixes 6 | * Typos 7 | * Documentation or comments 8 | 9 | Please open issues to discuss new features. Pull requests for new features will be rejected, 10 | so we recommend forking the repository and making changes in your fork for your use case. 11 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | paths-ignore: 8 | - .github/** 9 | - .gitignore 10 | - .whitesource 11 | - codecov.yml 12 | - README.md 13 | pull_request: 14 | paths-ignore: 15 | - .github/** 16 | - .gitignore 17 | - .whitesource 18 | - codecov.yml 19 | - README.md 20 | 21 | jobs: 22 | build: 23 | runs-on: ubuntu-latest 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | include: 28 | - go: 1.15 29 | build-with: true 30 | - go: 1.16 31 | build-with: false 32 | continue-on-error: ${{ matrix.build-with == false }} 33 | name: Build with ${{ matrix.go }} 34 | env: 35 | GO111MODULE: on 36 | 37 | steps: 38 | - name: Set up Go 39 | uses: actions/setup-go@v1 40 | with: 41 | go-version: ${{ matrix.go }} 42 | 43 | - name: Checkout code 44 | uses: actions/checkout@v2 45 | 46 | - name: Vet 47 | run: go vet ./... 48 | 49 | - name: Test 50 | run: go test -vet=off -race -coverprofile=coverage.txt -covermode=atomic ./... 51 | 52 | - name: Upload code coverage report 53 | if: matrix.build-with == true 54 | env: 55 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 56 | run: bash <(curl -s https://raw.githubusercontent.com/VividCortex/codecov-bash/master/codecov) 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /coverage.txt -------------------------------------------------------------------------------- /.whitesource: -------------------------------------------------------------------------------- 1 | { 2 | "settingsInheritedFrom": "VividCortex/whitesource-config@master" 3 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 VividCortex 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # siesta 2 | 3 | [![GoDoc](https://godoc.org/github.com/VividCortex/siesta?status.svg)](https://godoc.org/github.com/VividCortex/siesta) 4 | ![build](https://github.com/VividCortex/siesta/workflows/build/badge.svg) 5 | [![codecov](https://codecov.io/gh/VividCortex/siesta/branch/master/graph/badge.svg)](https://codecov.io/gh/VividCortex/siesta) 6 | 7 | Siesta is a framework for writing composable HTTP handlers in Go. It supports typed URL parameters, middleware chains, and context passing. 8 | 9 | ## Getting started 10 | 11 | Siesta offers a `Service` type, which is a collection of middleware chains and handlers rooted at a base URI. There is no distinction between a middleware function and a handler function; they are all considered to be handlers and have access to the same arguments. 12 | 13 | Siesta accepts many types of handlers. Refer to the [GoDoc](https://godoc.org/github.com/VividCortex/siesta#Service.Route) documentation for `Service.Route` for more information. 14 | 15 | Here is the `simple` program in the examples directory. It demonstrates the use of a `Service`, routing, middleware, and a `Context`. 16 | 17 | ```go 18 | package main 19 | 20 | import ( 21 | "fmt" 22 | "log" 23 | "net/http" 24 | "time" 25 | 26 | "github.com/VividCortex/siesta" 27 | ) 28 | 29 | func main() { 30 | // Create a new Service rooted at "/" 31 | service := siesta.NewService("/") 32 | 33 | // Route accepts normal http.Handlers. 34 | // The arguments are the method, path, description, 35 | // and the handler. 36 | service.Route("GET", "/", "Sends 'Hello, world!'", 37 | func(w http.ResponseWriter, r *http.Request) { 38 | fmt.Fprintln(w, "Hello, world!") 39 | }) 40 | 41 | // Let's create some simple "middleware." 42 | // This handler will accept a Context argument and will add the current 43 | // time to it. 44 | timestamper := func(c siesta.Context, w http.ResponseWriter, r *http.Request) { 45 | c.Set("start", time.Now()) 46 | } 47 | 48 | // This is the handler that will actually send data back to the client. 49 | // It also takes a Context argument so it can get the timestamp from the 50 | // previous handler. 51 | timeHandler := func(c siesta.Context, w http.ResponseWriter, r *http.Request) { 52 | start := c.Get("start").(time.Time) 53 | delta := time.Now().Sub(start) 54 | fmt.Fprintf(w, "That took %v.\n", delta) 55 | } 56 | 57 | // We can compose these handlers together. 58 | timeHandlers := siesta.Compose(timestamper, timeHandler) 59 | 60 | // Finally, we'll add the new handler we created using composition to a new route. 61 | service.Route("GET", "/time", "Sends how long it took to send a message", timeHandlers) 62 | 63 | // service is an http.Handler, so we can pass it directly to ListenAndServe. 64 | log.Fatal(http.ListenAndServe(":8080", service)) 65 | } 66 | ``` 67 | 68 | Siesta also provides utilities to manage URL parameters similar to the flag package. Refer to the `params` [example](https://github.com/VividCortex/siesta/blob/master/examples/params/main.go) for a demonstration. 69 | 70 | ## Contributing 71 | 72 | We only accept pull requests for minor fixes or improvements. This includes: 73 | 74 | * Small bug fixes 75 | * Typos 76 | * Documentation or comments 77 | 78 | Please open issues to discuss new features. Pull requests for new features will be rejected, 79 | so we recommend forking the repository and making changes in your fork for your use case. 80 | 81 | ## License 82 | 83 | Siesta is licensed under the MIT license. The router, which is adapted from [httprouter](https://github.com/julienschmidt/httprouter), is licensed [separately](https://github.com/VividCortex/siesta/blob/6ce42bf31875cc845310b1f4775129edfc8d9967/tree.go#L2-L24). 84 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | threshold: 15% 6 | patch: off 7 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package siesta 2 | 3 | // Prepending nullByteStr avoids accidental context key collisions. 4 | const nullByteStr = "\x00" 5 | 6 | // UsageContextKey is a special context key to get the route usage information 7 | // within a handler. 8 | const UsageContextKey = nullByteStr + "usage" 9 | 10 | // Context is a context interface that gets passed to each ContextHandler. 11 | type Context interface { 12 | Set(string, interface{}) 13 | Get(string) interface{} 14 | } 15 | 16 | // EmptyContext is a blank context. 17 | type EmptyContext struct{} 18 | 19 | func (c EmptyContext) Set(key string, value interface{}) { 20 | } 21 | 22 | func (c EmptyContext) Get(key string) interface{} { 23 | return nil 24 | } 25 | 26 | // SiestaContext is a concrete implementation of the siesta.Context 27 | // interface. Typically this will be created by the siesta framework 28 | // itself upon each request. However creating your own SiestaContext 29 | // might be useful for testing to isolate the behavior of a single 30 | // handler. 31 | type SiestaContext map[string]interface{} 32 | 33 | func NewSiestaContext() SiestaContext { 34 | return SiestaContext{} 35 | } 36 | 37 | func (c SiestaContext) Set(key string, value interface{}) { 38 | c[key] = value 39 | } 40 | 41 | func (c SiestaContext) Get(key string) interface{} { 42 | return c[key] 43 | } 44 | -------------------------------------------------------------------------------- /context_test.go: -------------------------------------------------------------------------------- 1 | package siesta 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestContext(t *testing.T) { 8 | var c Context = NewSiestaContext() 9 | c.Set("foo", "bar") 10 | v := c.Get("foo") 11 | if v == nil { 12 | t.Fatal("expected to see a value for key `foo`") 13 | } 14 | 15 | if v.(string) != "bar" { 16 | t.Errorf("expected value %v, got %v", "bar", v.(string)) 17 | } 18 | } 19 | 20 | func TestEmptyContext(t *testing.T) { 21 | var c Context = EmptyContext{} 22 | c.Set("foo", "bar") 23 | v := c.Get("foo") 24 | if v != nil { 25 | t.Fatal("expected to not see a value for key `foo`") 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /examples/authentication/.gitignore: -------------------------------------------------------------------------------- 1 | authentication 2 | -------------------------------------------------------------------------------- /examples/authentication/README.md: -------------------------------------------------------------------------------- 1 | Authentication example 2 | === 3 | This program demonstrates the use of Siesta's contexts and middleware chaining to handle authentication. In addition, there are also other features like request identification and logging that are extremely useful in practice. 4 | 5 | Suppose we have some state with the following data: 6 | 7 | | Token | User | 8 | | ----- | ---- | 9 | | abcde | alice | 10 | | 12345 | bob | 11 | 12 | | User | Resource ID | Value | 13 | | ---- | ----------- | ----- | 14 | | alice | 1 | foo | 15 | | alice | 2 | bar | 16 | | bob | 3 | baz | 17 | 18 | Users of the API have to supply a valid token to be able to access the secured resources that they are assigned to. 19 | 20 | There is a single endpoint: `GET /resources/:resourceID` 21 | 22 | The token will be provided by the user for every request as the HTTP basic authentication username. This is similar to [Stripe](https://stripe.com/docs/api#authentication)'s API authentication. 23 | 24 | Example requests 25 | --- 26 | ``` 27 | $ curl -i localhost:8080 28 | HTTP/1.1 401 Unauthorized 29 | X-Request-Id: 4d65822107fcfd52 30 | Date: Wed, 10 Jun 2015 13:03:36 GMT 31 | Content-Length: 27 32 | Content-Type: text/plain; charset=utf-8 33 | 34 | {"error":"token required"} 35 | ``` 36 | 37 | ``` 38 | $ curl -i localhost:8080/resources/1 -u abcde: 39 | HTTP/1.1 200 OK 40 | Content-Type: application/json 41 | X-Request-Id: 55104dc76695721d 42 | Date: Wed, 10 Jun 2015 13:04:23 GMT 43 | Content-Length: 15 44 | 45 | {"data":"foo"} 46 | ``` 47 | 48 | ``` 49 | $ curl -i localhost:8080/resources/3 -u 12345: 50 | HTTP/1.1 200 OK 51 | Content-Type: application/json 52 | X-Request-Id: 380704bb7b4d7c03 53 | Date: Wed, 10 Jun 2015 13:05:07 GMT 54 | Content-Length: 15 55 | 56 | {"data":"baz"} 57 | ``` 58 | 59 | ``` 60 | $ curl -i localhost:8080/resources/2 -u 12345: 61 | HTTP/1.1 404 Not Found 62 | X-Request-Id: 365a858149c6e2d1 63 | Date: Wed, 10 Jun 2015 13:05:28 GMT 64 | Content-Length: 22 65 | Content-Type: text/plain; charset=utf-8 66 | 67 | {"error":"not found"} 68 | ``` 69 | 70 | Logging 71 | --- 72 | You'll notice that the server supplies a `X-Request-Id` header. This ID is generated for every request and is provided in the log output. 73 | 74 | ``` 75 | $ ./authentication 76 | 2015/06/10 09:03:24 Listening on :8080 77 | 2015/06/10 09:03:36 [Req 4d65822107fcfd52] GET / 78 | 2015/06/10 09:03:36 [Req 4d65822107fcfd52] Did not provide a token 79 | 2015/06/10 09:04:19 [Req 78629a0f5f3f164f] GET /resources/1 80 | 2015/06/10 09:04:19 [Req 78629a0f5f3f164f] Provided a token for: bob 81 | 2015/06/10 09:04:23 [Req 55104dc76695721d] GET /resources/1 82 | 2015/06/10 09:04:23 [Req 55104dc76695721d] Provided a token for: alice 83 | 2015/06/10 09:05:07 [Req 380704bb7b4d7c03] GET /resources/3 84 | 2015/06/10 09:05:07 [Req 380704bb7b4d7c03] Provided a token for: bob 85 | 2015/06/10 09:05:28 [Req 365a858149c6e2d1] GET /resources/2 86 | 2015/06/10 09:05:28 [Req 365a858149c6e2d1] Provided a token for: bob 87 | ``` 88 | -------------------------------------------------------------------------------- /examples/authentication/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/VividCortex/siesta" 5 | 6 | "log" 7 | "net/http" 8 | ) 9 | 10 | func main() { 11 | // Create a new service rooted at /. 12 | service := siesta.NewService("/") 13 | 14 | // requestIdentifier assigns an ID to every request 15 | // and adds it to the context for that request. 16 | // This is useful for logging. 17 | service.AddPre(requestIdentifier) 18 | 19 | // Add access to the state via the context in every handler. 20 | service.AddPre(func(c siesta.Context, w http.ResponseWriter, r *http.Request) { 21 | c.Set("db", state) 22 | }) 23 | 24 | // We'll add the authenticator middleware to the "pre" chain. 25 | // It will ensure that every request has a valid token. 26 | service.AddPre(authenticator) 27 | 28 | // Response generation 29 | service.AddPost(responseGenerator) 30 | service.AddPost(responseWriter) 31 | 32 | // Custom 404 handler 33 | service.SetNotFound(func(c siesta.Context, w http.ResponseWriter, r *http.Request) { 34 | c.Set("status-code", http.StatusNotFound) 35 | c.Set("error", "not found") 36 | }) 37 | 38 | // Routes 39 | service.Route("GET", "/resources/:resourceID", "Retrieves a resource", 40 | getResource) 41 | 42 | log.Println("Listening on :8080") 43 | panic(http.ListenAndServe(":8080", service)) 44 | } 45 | -------------------------------------------------------------------------------- /examples/authentication/middleware.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/VividCortex/siesta" 5 | 6 | "encoding/json" 7 | "fmt" 8 | "log" 9 | "math/rand" 10 | "net/http" 11 | ) 12 | 13 | // apiResponse defines the structure of the responses. 14 | type apiResponse struct { 15 | Data interface{} `json:"data,omitempty"` 16 | Error string `json:"error,omitempty"` 17 | } 18 | 19 | // requestIdentifier generates a request ID and sets the "request-id" 20 | // key in the context. It also logs the request ID and the requested URL. 21 | func requestIdentifier(c siesta.Context, w http.ResponseWriter, r *http.Request) { 22 | requestID := fmt.Sprintf("%x", rand.Int()) 23 | c.Set("request-id", requestID) 24 | log.Printf("[Req %s] %s %s", requestID, r.Method, r.URL) 25 | } 26 | 27 | // authenticator reads the username from the HTTP basic authentication header 28 | // and validates the token. It sets the "user" key in the context to the 29 | // user associated with the token. 30 | func authenticator(c siesta.Context, w http.ResponseWriter, r *http.Request, 31 | quit func()) { 32 | // Context variables 33 | requestID := c.Get("request-id").(string) 34 | db := c.Get("db").(*DB) 35 | 36 | // Check for a token in the HTTP basic authentication username field. 37 | token, _, ok := r.BasicAuth() 38 | if ok { 39 | user, err := db.validateToken(token) 40 | if err != nil { 41 | log.Printf("[Req %s] Did not provide a valid token", requestID) 42 | c.Set("status-code", http.StatusUnauthorized) 43 | c.Set("error", "invalid token") 44 | quit() 45 | return 46 | } 47 | 48 | log.Printf("[Req %s] Provided a token for: %s", requestID, user) 49 | 50 | // Add the user to the context. 51 | c.Set("user", user) 52 | } else { 53 | log.Printf("[Req %s] Did not provide a token", requestID) 54 | 55 | c.Set("error", "token required") 56 | c.Set("status-code", http.StatusUnauthorized) 57 | 58 | // Exit the chain here. 59 | quit() 60 | return 61 | } 62 | } 63 | 64 | // responseGenerator converts response and/or error data passed through the 65 | // context into a structured response. 66 | func responseGenerator(c siesta.Context, w http.ResponseWriter, r *http.Request) { 67 | response := apiResponse{} 68 | 69 | if data := c.Get("data"); data != nil { 70 | response.Data = data 71 | } 72 | 73 | if err := c.Get("error"); err != nil { 74 | response.Error = err.(string) 75 | } 76 | 77 | c.Set("response", response) 78 | } 79 | 80 | // responseWriter sets the proper headers and status code, and 81 | // writes a JSON-encoded response to the client. 82 | func responseWriter(c siesta.Context, w http.ResponseWriter, r *http.Request, 83 | quit func()) { 84 | // Set the request ID header. 85 | if requestID := c.Get("request-id"); requestID != nil { 86 | w.Header().Set("X-Request-ID", requestID.(string)) 87 | } 88 | 89 | // Set the content type. 90 | w.Header().Set("Content-Type", "application/json") 91 | 92 | enc := json.NewEncoder(w) 93 | 94 | // If we have a status code set in the context, 95 | // send that in the header. 96 | // 97 | // Go defaults to 200 OK. 98 | statusCode := c.Get("status-code") 99 | if statusCode != nil { 100 | statusCodeInt := statusCode.(int) 101 | w.WriteHeader(statusCodeInt) 102 | } 103 | 104 | // Check to see if we have some sort of response. 105 | response := c.Get("response") 106 | if response != nil { 107 | // We'll encode it as JSON without knowing 108 | // what it exactly is. 109 | enc.Encode(response) 110 | } 111 | 112 | // We're at the end of the middleware chain, so quit. 113 | quit() 114 | } 115 | -------------------------------------------------------------------------------- /examples/authentication/resource.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/VividCortex/siesta" 5 | 6 | "log" 7 | "net/http" 8 | ) 9 | 10 | // getResource is the function that handles the GET /resources/:resourceID route. 11 | func getResource(c siesta.Context, w http.ResponseWriter, r *http.Request) { 12 | // Context variables 13 | requestID := c.Get("request-id").(string) 14 | db := c.Get("db").(*DB) 15 | user := c.Get("user").(string) 16 | 17 | // Check parameters 18 | var params siesta.Params 19 | resourceID := params.Int("resourceID", -1, "Resource identifier") 20 | err := params.Parse(r.Form) 21 | if err != nil { 22 | log.Printf("[Req %s] %v", requestID, err) 23 | c.Set("error", err.Error()) 24 | c.Set("status-code", http.StatusBadRequest) 25 | return 26 | } 27 | 28 | // Make sure we have a valid resource ID. 29 | if *resourceID == -1 { 30 | c.Set("error", "invalid or missing resource ID") 31 | c.Set("status-code", http.StatusBadRequest) 32 | return 33 | } 34 | 35 | resource, err := db.resource(user, *resourceID) 36 | if err != nil { 37 | c.Set("status-code", http.StatusNotFound) 38 | c.Set("error", "not found") 39 | return 40 | } 41 | 42 | c.Set("data", resource) 43 | } 44 | -------------------------------------------------------------------------------- /examples/authentication/state.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | var ( 8 | ErrInvalidToken = errors.New("invalid token") 9 | ErrResourceNotFound = errors.New("resource not found") 10 | ) 11 | 12 | // DB represents a handler for some sort of state. 13 | type DB struct { 14 | tokenUsers map[string]string 15 | userResources map[string]map[int]string 16 | } 17 | 18 | // state contains some actual state. 19 | var state = &DB{ 20 | tokenUsers: map[string]string{ 21 | "abcde": "alice", 22 | "12345": "bob", 23 | }, 24 | 25 | userResources: map[string]map[int]string{ 26 | "alice": map[int]string{ 27 | 1: "foo", 28 | 2: "bar", 29 | }, 30 | "bob": map[int]string{ 31 | 3: "baz", 32 | }, 33 | }, 34 | } 35 | 36 | // validateToken returns the user corresponding to the token. 37 | // An error is returned if the token is not recognized. 38 | func (db *DB) validateToken(token string) (string, error) { 39 | user, ok := db.tokenUsers[token] 40 | if !ok { 41 | return "", ErrInvalidToken 42 | } 43 | 44 | return user, nil 45 | } 46 | 47 | // resource returns the resource with id for user. 48 | // An error is returned if the resource is not found. 49 | func (db *DB) resource(user string, id int) (string, error) { 50 | resources, ok := db.userResources[user] 51 | if !ok { 52 | return "", ErrResourceNotFound 53 | } 54 | 55 | resource, ok := resources[id] 56 | if !ok { 57 | return "", ErrResourceNotFound 58 | } 59 | 60 | return resource, nil 61 | } 62 | -------------------------------------------------------------------------------- /examples/params/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "math" 7 | "net/http" 8 | 9 | "github.com/VividCortex/siesta" 10 | ) 11 | 12 | func main() { 13 | // Create a new Service rooted at "/" 14 | service := siesta.NewService("/") 15 | 16 | // Here's a handler that uses a URL parameter. 17 | // Example: GET /greet/Bob 18 | service.Route("GET", "/greet/:name", "Greets with a name.", 19 | func(w http.ResponseWriter, r *http.Request) { 20 | var params siesta.Params 21 | name := params.String("name", "", "Person's name") 22 | 23 | err := params.Parse(r.Form) 24 | if err != nil { 25 | log.Println("Error parsing parameters!", err) 26 | return 27 | } 28 | 29 | fmt.Fprintf(w, "Hello, %s!", *name) 30 | }, 31 | ) 32 | 33 | // Here's a handler that uses a query string parameter. 34 | // Example: GET /square?number=10 35 | service.Route("GET", "/square", "Prints the square of a number.", 36 | func(w http.ResponseWriter, r *http.Request) { 37 | var params siesta.Params 38 | number := params.Int("number", 0, "A number to square") 39 | 40 | err := params.Parse(r.Form) 41 | if err != nil { 42 | log.Println("Error parsing parameters!", err) 43 | return 44 | } 45 | 46 | fmt.Fprintf(w, "%d * %d = %d.", *number, *number, (*number)*(*number)) 47 | }, 48 | ) 49 | 50 | // We can also use both URL and query string parameters. 51 | // Example: GET /exponentiate/10?power=10 52 | service.Route("GET", "/exponentiate/:number", "Exponentiates a number.", 53 | func(w http.ResponseWriter, r *http.Request) { 54 | var params siesta.Params 55 | number := params.Float64("number", 0, "A number to exponentiate") 56 | power := params.Float64("power", 1, "Power") 57 | 58 | err := params.Parse(r.Form) 59 | if err != nil { 60 | log.Println("Error parsing parameters!", err) 61 | return 62 | } 63 | 64 | fmt.Fprintf(w, "%g ^ %g = %g.", *number, *power, math.Pow(*number, *power)) 65 | }, 66 | ) 67 | 68 | // service is an http.Handler, so we can pass it directly to ListenAndServe. 69 | log.Fatal(http.ListenAndServe(":8080", service)) 70 | } 71 | -------------------------------------------------------------------------------- /examples/simple/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/VividCortex/siesta" 10 | ) 11 | 12 | func main() { 13 | // Create a new Service rooted at "/" 14 | service := siesta.NewService("/") 15 | 16 | // Route accepts normal http.Handlers. 17 | service.Route("GET", "/", "Sends 'Hello, world!'", func(w http.ResponseWriter, r *http.Request) { 18 | fmt.Fprintln(w, "Hello, world!") 19 | }) 20 | 21 | // Let's create some simple "middleware." 22 | // This handler will accept a Context argument and will add the current 23 | // time to it. 24 | timestamper := func(c siesta.Context, w http.ResponseWriter, r *http.Request) { 25 | c.Set("start", time.Now()) 26 | } 27 | 28 | // This is the handler that will actually send data back to the client. 29 | // It also takes a Context argument so it can get the timestamp from the 30 | // previous handler. 31 | timeHandler := func(c siesta.Context, w http.ResponseWriter, r *http.Request) { 32 | start := c.Get("start").(time.Time) 33 | delta := time.Now().Sub(start) 34 | fmt.Fprintf(w, "That took %v.\n", delta) 35 | } 36 | 37 | // We can compose these handlers together. 38 | timeHandlers := siesta.Compose(timestamper, timeHandler) 39 | 40 | // Finally, we'll add the new handler we created using composition to a new route. 41 | service.Route("GET", "/time", "Sends how long it took to send a message", timeHandlers) 42 | 43 | // service is an http.Handler, so we can pass it directly to ListenAndServe. 44 | log.Fatal(http.ListenAndServe(":8080", service)) 45 | } 46 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/VividCortex/siesta 2 | 3 | go 1.12 4 | -------------------------------------------------------------------------------- /handler.go: -------------------------------------------------------------------------------- 1 | package siesta 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | ) 7 | 8 | var ErrUnsupportedHandler = errors.New("siesta: unsupported handler") 9 | 10 | // ContextHandler is a siesta handler. 11 | type ContextHandler func(Context, http.ResponseWriter, *http.Request, func()) 12 | 13 | func (h ContextHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 14 | h(EmptyContext{}, w, r, nil) 15 | } 16 | 17 | func (h ContextHandler) ServeHTTPInContext(c Context, w http.ResponseWriter, r *http.Request) { 18 | h(c, w, r, nil) 19 | } 20 | 21 | // ToContextHandler transforms f into a ContextHandler. 22 | // f must be a function with one of the following signatures: 23 | // func(http.ResponseWriter, *http.Request) 24 | // func(http.ResponseWriter, *http.Request, func()) 25 | // func(Context, http.ResponseWriter, *http.Request) 26 | // func(Context, http.ResponseWriter, *http.Request, func()) 27 | func ToContextHandler(f interface{}) ContextHandler { 28 | switch t := f.(type) { 29 | case func(Context, http.ResponseWriter, *http.Request, func()): 30 | return ContextHandler(t) 31 | case ContextHandler: 32 | return t 33 | case func(Context, http.ResponseWriter, *http.Request): 34 | return func(c Context, w http.ResponseWriter, r *http.Request, q func()) { 35 | t(c, w, r) 36 | } 37 | case func(http.ResponseWriter, *http.Request, func()): 38 | return func(c Context, w http.ResponseWriter, r *http.Request, q func()) { 39 | t(w, r, q) 40 | } 41 | case func(http.ResponseWriter, *http.Request): 42 | return func(c Context, w http.ResponseWriter, r *http.Request, q func()) { 43 | t(w, r) 44 | } 45 | case http.Handler: 46 | return func(c Context, w http.ResponseWriter, r *http.Request, q func()) { 47 | t.ServeHTTP(w, r) 48 | } 49 | default: 50 | panic(ErrUnsupportedHandler) 51 | } 52 | } 53 | 54 | // Compose composes multiple ContextHandlers into a single ContextHandler. 55 | func Compose(stack ...interface{}) ContextHandler { 56 | contextStack := make([]ContextHandler, 0, len(stack)) 57 | for i := range stack { 58 | m := ToContextHandler(stack[i]) 59 | 60 | contextStack = append(contextStack, m) 61 | } 62 | 63 | return func(c Context, w http.ResponseWriter, r *http.Request, quit func()) { 64 | quitStack := false 65 | 66 | for _, m := range contextStack { 67 | m(c, w, r, func() { 68 | quitStack = true 69 | }) 70 | 71 | if quitStack { 72 | quit() 73 | break 74 | } 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /handler_test.go: -------------------------------------------------------------------------------- 1 | package siesta 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | ) 8 | 9 | func TestCompose(t *testing.T) { 10 | key := "i" 11 | stack := Compose( 12 | func(c Context, w http.ResponseWriter, r *http.Request, quit func()) { 13 | r.Header.Set(key, r.Header.Get(key)+"a") 14 | i, _ := c.Get(key).(int) 15 | c.Set(key, i+2) 16 | }, 17 | func(c Context, w http.ResponseWriter, r *http.Request) { 18 | r.Header.Set(key, r.Header.Get(key)+"b") 19 | i, _ := c.Get(key).(int) 20 | c.Set(key, i+4) 21 | }, 22 | func(w http.ResponseWriter, r *http.Request, quit func()) { 23 | r.Header.Set(key, r.Header.Get(key)+"c") 24 | }, 25 | func(w http.ResponseWriter, r *http.Request) { 26 | r.Header.Set(key, r.Header.Get(key)+"d") 27 | }, 28 | http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 29 | r.Header.Set(key, r.Header.Get(key)+"e") 30 | }), 31 | ) 32 | 33 | c := NewSiestaContext() 34 | r := httptest.NewRequest(http.MethodGet, "/", nil) 35 | w := httptest.NewRecorder() 36 | stack.ServeHTTPInContext(c, w, r) 37 | 38 | i, _ := c.Get(key).(int) 39 | if want, got := 6, i; want != got { 40 | t.Errorf("expected %d got %d", want, got) 41 | } 42 | if want, got := "abcde", r.Header.Get(key); want != got { 43 | t.Errorf("expected %s got %s", want, got) 44 | } 45 | } 46 | 47 | func TestToContextHandlerUnsupportedHandler(t *testing.T) { 48 | defer func() { 49 | r := recover() 50 | if r == nil { 51 | t.Fatal("expected a panic") 52 | } 53 | err, _ := r.(error) 54 | if want, got := ErrUnsupportedHandler, err; want != got { 55 | t.Fatalf("expected %v got %v", want, got) 56 | } 57 | }() 58 | 59 | _ = ToContextHandler(func() {}) 60 | } 61 | -------------------------------------------------------------------------------- /params.go: -------------------------------------------------------------------------------- 1 | package siesta 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "net/url" 7 | "strconv" 8 | "strings" 9 | "time" 10 | ) 11 | 12 | // Params represents a set of URL parameters from a request's query string. 13 | // The interface is similar to a flag.FlagSet, but a) there is no usage string, 14 | // b) there are no custom Var()s, and c) there are SliceXXX types. Sliced types 15 | // support two ways of generating a multi-valued parameter: setting the parameter 16 | // multiple times, and using a comma-delimited string. This adds the limitation 17 | // that you can't have a value with a comma if in a Sliced type. 18 | // Under the covers, Params uses flag.FlagSet. 19 | type Params struct { 20 | fset *flag.FlagSet 21 | } 22 | 23 | // Parse parses URL parameters from a http.Request.URL.Query(), which is a 24 | // url.Values, which is just a map[string][string]. 25 | func (rp *Params) Parse(args url.Values) error { 26 | if rp.fset == nil { 27 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 28 | } 29 | 30 | // Parse items from URL query string 31 | FLAG_LOOP: 32 | for name, vals := range args { 33 | for _, v := range vals { 34 | 35 | f := rp.fset.Lookup(name) 36 | if f == nil { 37 | // Flag wasn't found. 38 | continue FLAG_LOOP 39 | } 40 | 41 | // Check if the value is empty 42 | if v == "" { 43 | if bv, ok := f.Value.(boolFlag); ok && bv.IsBoolFlag() { 44 | bv.Set("true") 45 | 46 | continue FLAG_LOOP 47 | } 48 | } 49 | 50 | err := rp.fset.Set(name, v) 51 | if err != nil { 52 | // Remove the "flag" error message and make a "params" one. 53 | // TODO: optionally allow undefined params to be given, but ignored? 54 | if !strings.Contains(err.Error(), "no such flag -") { 55 | // Give a helpful message about which param caused the error 56 | err = fmt.Errorf("bad param '%s': %s", name, err.Error()) 57 | return err 58 | } 59 | } 60 | } 61 | } 62 | 63 | return nil 64 | } 65 | 66 | // Usage returns a map keyed on parameter names. The map values are an array of 67 | // name, type, and usage information for each parameter. 68 | func (rp *Params) Usage() map[string][3]string { 69 | docs := make(map[string][3]string) 70 | var translations map[string]string = map[string]string{ 71 | "*flag.stringValue": "string", 72 | "*flag.durationValue": "duration", 73 | "*flag.intValue": "int", 74 | "*flag.boolValue": "bool", 75 | "*flag.float64Value": "float64", 76 | "*flag.int64Value": "int64", 77 | "*flag.uintValue": "uint", 78 | "*flag.uint64Value": "uint64", 79 | "*siesta.SString": "[]string", 80 | "*siesta.SDuration": "[]duration", 81 | "*siesta.SInt": "[]int", 82 | "*siesta.SBool": "[]bool", 83 | "*siesta.SFloat64": "[]float64", 84 | "*siesta.SInt64": "[]int64", 85 | "*siesta.SUint": "[]uint", 86 | "*siesta.SUint64": "[]uint64", 87 | } 88 | rp.fset.VisitAll(func(flag *flag.Flag) { 89 | niceName := translations[fmt.Sprintf("%T", flag.Value)] 90 | if niceName == "" { 91 | niceName = fmt.Sprintf("%T", flag.Value) 92 | } 93 | docs[flag.Name] = [...]string{flag.Name, niceName, flag.Usage} 94 | }) 95 | return docs 96 | } 97 | 98 | type boolFlag interface { 99 | flag.Value 100 | IsBoolFlag() bool 101 | } 102 | 103 | // Bool defines a bool param with specified name and default value. 104 | // The return value is the address of a bool variable that stores the value of the param. 105 | func (rp *Params) Bool(name string, value bool, usage string) *bool { 106 | if rp.fset == nil { 107 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 108 | } 109 | p := new(bool) 110 | rp.fset.BoolVar(p, name, value, usage) 111 | return p 112 | } 113 | 114 | // SBool is a slice of bool. 115 | type SBool []bool 116 | 117 | // String is the method to format the param's value, part of the flag.Value interface. 118 | // The String method's output will be used in diagnostics. 119 | func (s *SBool) String() string { 120 | return fmt.Sprint(*s) 121 | } 122 | 123 | // Set is the method to set the param value, part of the flag.Value interface. 124 | // Set's argument is a string to be parsed to set the param. 125 | // It's a comma-separated list, so we split it. 126 | func (s *SBool) Set(value string) error { 127 | for _, dt := range strings.Split(value, ",") { 128 | if len(dt) > 0 { 129 | parsed, err := strconv.ParseBool(dt) 130 | if err != nil { 131 | return err 132 | } 133 | *s = append(*s, parsed) 134 | } 135 | } 136 | return nil 137 | } 138 | 139 | // SliceBool defines a multi-value bool param with specified name and default value. 140 | // The return value is the address of a SBool variable that stores the values of the param. 141 | func (rp *Params) SliceBool(name string, value bool, usage string) *SBool { 142 | if rp.fset == nil { 143 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 144 | } 145 | p := new(SBool) 146 | rp.fset.Var(p, name, usage) 147 | return p 148 | } 149 | 150 | // Int defines an int param with specified name and default value. 151 | // The return value is the address of an int variable that stores the value of the param. 152 | func (rp *Params) Int(name string, value int, usage string) *int { 153 | if rp.fset == nil { 154 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 155 | } 156 | p := new(int) 157 | rp.fset.IntVar(p, name, value, usage) 158 | return p 159 | } 160 | 161 | // SInt is a slice of int. 162 | type SInt []int 163 | 164 | // String is the method to format the param's value, part of the flag.Value interface. 165 | // The String method's output will be used in diagnostics. 166 | func (s *SInt) String() string { 167 | return fmt.Sprint(*s) 168 | } 169 | 170 | // Set is the method to set the param value, part of the flag.Value interface. 171 | // Set's argument is a string to be parsed to set the param. 172 | // It's a comma-separated list, so we split it. 173 | func (s *SInt) Set(value string) error { 174 | for _, dt := range strings.Split(value, ",") { 175 | if len(dt) > 0 { 176 | parsed, err := strconv.ParseInt(dt, 0, 64) 177 | if err != nil { 178 | return err 179 | } 180 | *s = append(*s, int(parsed)) 181 | } 182 | } 183 | return nil 184 | } 185 | 186 | // SliceInt defines a multi-value int param with specified name and default value. 187 | // The return value is the address of a SInt variable that stores the values of the param. 188 | func (rp *Params) SliceInt(name string, value int, usage string) *SInt { 189 | if rp.fset == nil { 190 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 191 | } 192 | p := new(SInt) 193 | rp.fset.Var(p, name, usage) 194 | return p 195 | } 196 | 197 | // Int64 defines an int64 param with specified name and default value. 198 | // The return value is the address of an int64 variable that stores the value of the param. 199 | func (rp *Params) Int64(name string, value int64, usage string) *int64 { 200 | if rp.fset == nil { 201 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 202 | } 203 | p := new(int64) 204 | rp.fset.Int64Var(p, name, value, usage) 205 | return p 206 | } 207 | 208 | // SInt64 is a slice of int64. 209 | type SInt64 []int64 210 | 211 | // String is the method to format the param's value, part of the flag.Value interface. 212 | // The String method's output will be used in diagnostics. 213 | func (s *SInt64) String() string { 214 | return fmt.Sprint(*s) 215 | } 216 | 217 | // Set is the method to set the param value, part of the flag.Value interface. 218 | // Set's argument is a string to be parsed to set the param. 219 | // It's a comma-separated list, so we split it. 220 | func (s *SInt64) Set(value string) error { 221 | for _, dt := range strings.Split(value, ",") { 222 | if len(dt) > 0 { 223 | parsed, err := strconv.ParseInt(dt, 0, 64) 224 | if err != nil { 225 | return err 226 | } 227 | *s = append(*s, int64(parsed)) 228 | } 229 | } 230 | return nil 231 | } 232 | 233 | // SliceInt64 defines a multi-value int64 param with specified name and default value. 234 | // The return value is the address of a SInt64 variable that stores the values of the param. 235 | func (rp *Params) SliceInt64(name string, value int64, usage string) *SInt64 { 236 | if rp.fset == nil { 237 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 238 | } 239 | p := new(SInt64) 240 | rp.fset.Var(p, name, usage) 241 | return p 242 | } 243 | 244 | // Uint defines a uint param with specified name and default value. 245 | // The return value is the address of a uint variable that stores the value of the param. 246 | func (rp *Params) Uint(name string, value uint, usage string) *uint { 247 | if rp.fset == nil { 248 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 249 | } 250 | p := new(uint) 251 | rp.fset.UintVar(p, name, value, usage) 252 | return p 253 | } 254 | 255 | // SUint is a slice of uint. 256 | type SUint []uint 257 | 258 | // String is the method to format the param's value, part of the flag.Value interface. 259 | // The String method's output will be used in diagnostics. 260 | func (s *SUint) String() string { 261 | return fmt.Sprint(*s) 262 | } 263 | 264 | // Set is the method to set the param value, part of the flag.Value interface. 265 | // Set's argument is a string to be parsed to set the param. 266 | // It's a comma-separated list, so we split it. 267 | func (s *SUint) Set(value string) error { 268 | for _, dt := range strings.Split(value, ",") { 269 | if len(dt) > 0 { 270 | parsed, err := strconv.ParseUint(dt, 10, 64) 271 | if err != nil { 272 | return err 273 | } 274 | *s = append(*s, uint(parsed)) 275 | } 276 | } 277 | return nil 278 | } 279 | 280 | // SliceUint defines a multi-value uint param with specified name and default value. 281 | // The return value is the address of a SUint variable that stores the values of the param. 282 | func (rp *Params) SliceUint(name string, value uint, usage string) *SUint { 283 | if rp.fset == nil { 284 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 285 | } 286 | p := new(SUint) 287 | rp.fset.Var(p, name, usage) 288 | return p 289 | } 290 | 291 | // Uint64 defines a uint64 param with specified name and default value. 292 | // The return value is the address of a uint64 variable that stores the value of the param. 293 | func (rp *Params) Uint64(name string, value uint64, usage string) *uint64 { 294 | if rp.fset == nil { 295 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 296 | } 297 | p := new(uint64) 298 | rp.fset.Uint64Var(p, name, value, usage) 299 | return p 300 | } 301 | 302 | // SUint64 is a slice of uint64. 303 | type SUint64 []uint64 304 | 305 | // String is the method to format the param's value, part of the flag.Value interface. 306 | // The String method's output will be used in diagnostics. 307 | func (s *SUint64) String() string { 308 | return fmt.Sprint(*s) 309 | } 310 | 311 | // Set is the method to set the param value, part of the flag.Value interface. 312 | // Set's argument is a string to be parsed to set the param. 313 | // It's a comma-separated list, so we split it. 314 | func (s *SUint64) Set(value string) error { 315 | for _, dt := range strings.Split(value, ",") { 316 | if len(dt) > 0 { 317 | parsed, err := strconv.ParseUint(dt, 10, 64) 318 | if err != nil { 319 | return err 320 | } 321 | *s = append(*s, parsed) 322 | } 323 | } 324 | return nil 325 | } 326 | 327 | // SliceUint64 defines a multi-value uint64 param with specified name and default value. 328 | // The return value is the address of a SUint64 variable that stores the values of the param. 329 | func (rp *Params) SliceUint64(name string, value uint64, usage string) *SUint64 { 330 | if rp.fset == nil { 331 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 332 | } 333 | p := new(SUint64) 334 | rp.fset.Var(p, name, usage) 335 | return p 336 | } 337 | 338 | // String defines a string param with specified name and default value. 339 | // The return value is the address of a string variable that stores the value of the param. 340 | func (rp *Params) String(name string, value string, usage string) *string { 341 | if rp.fset == nil { 342 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 343 | } 344 | p := new(string) 345 | rp.fset.StringVar(p, name, value, usage) 346 | return p 347 | } 348 | 349 | // SString is a slice of string. 350 | type SString []string 351 | 352 | // String is the method to format the param's value, part of the flag.Value interface. 353 | // The String method's output will be used in diagnostics. 354 | func (s *SString) String() string { 355 | return strings.Join(*s, ",") 356 | } 357 | 358 | // Set is the method to set the param value, part of the flag.Value interface. 359 | // Set's argument is a string to be parsed to set the param. 360 | // It's a comma-separated list, so we split it. 361 | func (s *SString) Set(value string) error { 362 | for _, dt := range strings.Split(value, ",") { 363 | *s = append(*s, dt) 364 | } 365 | return nil 366 | } 367 | 368 | // SliceString defines a multi-value string param with specified name and default value. 369 | // The return value is the address of a SString variable that stores the values of the param. 370 | func (rp *Params) SliceString(name string, value string, usage string) *SString { 371 | if rp.fset == nil { 372 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 373 | } 374 | p := new(SString) 375 | rp.fset.Var(p, name, usage) 376 | return p 377 | } 378 | 379 | // Float64 defines a float64 param with specified name and default value. 380 | // The return value is the address of a float64 variable that stores the value of the param. 381 | func (rp *Params) Float64(name string, value float64, usage string) *float64 { 382 | if rp.fset == nil { 383 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 384 | } 385 | p := new(float64) 386 | rp.fset.Float64Var(p, name, value, usage) 387 | return p 388 | } 389 | 390 | // SFloat64 is a slice of float64. 391 | type SFloat64 []float64 392 | 393 | // String is the method to format the param's value, part of the flag.Value interface. 394 | // The String method's output will be used in diagnostics. 395 | func (s *SFloat64) String() string { 396 | return fmt.Sprintf("%f", *s) 397 | } 398 | 399 | // Set is the method to set the param value, part of the flag.Value interface. 400 | // Set's argument is a string to be parsed to set the param. 401 | // It's a comma-separated list, so we split it. 402 | func (s *SFloat64) Set(value string) error { 403 | for _, dt := range strings.Split(value, ",") { 404 | if len(dt) > 0 { 405 | parsed, err := strconv.ParseFloat(dt, 64) 406 | if err != nil { 407 | return err 408 | } 409 | *s = append(*s, parsed) 410 | } 411 | } 412 | return nil 413 | } 414 | 415 | // SliceFloat64 defines a multi-value float64 param with specified name and default value. 416 | // The return value is the address of a SFloat64 variable that stores the values of the param. 417 | func (rp *Params) SliceFloat64(name string, value float64, usage string) *SFloat64 { 418 | if rp.fset == nil { 419 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 420 | } 421 | p := new(SFloat64) 422 | rp.fset.Var(p, name, usage) 423 | return p 424 | } 425 | 426 | // Duration defines a time.Duration param with specified name and default value. 427 | // The return value is the address of a time.Duration variable that stores the value of the param. 428 | func (rp *Params) Duration(name string, value time.Duration, usage string) *time.Duration { 429 | if rp.fset == nil { 430 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 431 | } 432 | p := new(time.Duration) 433 | rp.fset.DurationVar(p, name, value, usage) 434 | return p 435 | } 436 | 437 | // SDuration is a slice of time.Duration. 438 | type SDuration []time.Duration 439 | 440 | // String is the method to format the param's value, part of the flag.Value interface. 441 | // The String method's output will be used in diagnostics. 442 | func (s *SDuration) String() string { 443 | return fmt.Sprint(*s) 444 | } 445 | 446 | // Set is the method to set the param value, part of the flag.Value interface. 447 | // Set's argument is a string to be parsed to set the param. 448 | // It's a comma-separated list, so we split it. 449 | func (s *SDuration) Set(value string) error { 450 | for _, dt := range strings.Split(value, ",") { 451 | if len(dt) > 0 { 452 | parsed, err := time.ParseDuration(dt) 453 | if err != nil { 454 | return err 455 | } 456 | *s = append(*s, parsed) 457 | } 458 | } 459 | return nil 460 | } 461 | 462 | // SliceDuration defines a multi-value time.Duration param with specified name and default value. 463 | // The return value is the address of a SDuration variable that stores the values of the param. 464 | func (rp *Params) SliceDuration(name string, value time.Duration, usage string) *SDuration { 465 | if rp.fset == nil { 466 | rp.fset = flag.NewFlagSet("anonymous", flag.ExitOnError) // both args are unused. 467 | } 468 | p := new(SDuration) 469 | rp.fset.Var(p, name, usage) 470 | return p 471 | } 472 | -------------------------------------------------------------------------------- /params_test.go: -------------------------------------------------------------------------------- 1 | package siesta 2 | 3 | import ( 4 | "net/url" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestParamsSimple(t *testing.T) { 10 | v := url.Values{} 11 | p := Params{} 12 | v.Set("company", "VividCortex") 13 | v.Set("founded", "2012") 14 | v.Set("startup", "true") 15 | v.Set("duration", "10ms") 16 | v.Set("float", "12.89") 17 | v.Set("uint64", "1234") 18 | v.Set("int64", "-9876") 19 | v.Set("uint", "2345") 20 | v.Set("nonexistent", "8765") 21 | v.Set("valueless", "") 22 | v.Set("falseBool", "f") 23 | company := p.String("company", "", "the company name") 24 | founded := p.Int("founded", 0, "when it was founded") 25 | startup := p.Bool("startup", false, "whether it's a startup") 26 | duration := p.Duration("duration", 0, "how long it's been") 27 | floatVar := p.Float64("float", 0, "some float64") 28 | uint64Var := p.Uint64("uint64", 0, "some uint64") 29 | int64Var := p.Int64("int64", 0, "some int64") 30 | uintVar := p.Uint("uint", 0, "some uint") 31 | valueless := p.Bool("valueless", false, "some bool") 32 | falseBool := p.Bool("falseBool", true, "a bool with value false") 33 | err := p.Parse(v) 34 | if err != nil { 35 | t.Error(err) 36 | } else if *company != "VividCortex" { 37 | t.Errorf("expected VividCortex, got %s", *company) 38 | } else if *founded != 2012 { 39 | t.Errorf("expected 2012, got %d", *founded) 40 | } else if !*startup { 41 | t.Errorf("expected true, got %t", *startup) 42 | } else if *duration != 10*time.Millisecond { 43 | t.Errorf("expected 10ms, got %s", *duration) 44 | } else if *floatVar != 12.89 { 45 | t.Errorf("expected 12.89, got %f", *floatVar) 46 | } else if *uint64Var != 1234 { 47 | t.Errorf("expected 1234, got %d", *uint64Var) 48 | } else if *int64Var != -9876 { 49 | t.Errorf("expected -9876, got %d", *int64Var) 50 | } else if *uintVar != 2345 { 51 | t.Errorf("expected 2345, got %d", *uintVar) 52 | } else if *valueless != true { 53 | t.Errorf("expected true, got %t", *valueless) 54 | } else if *falseBool != false { 55 | t.Errorf("expected false, got %t", *falseBool) 56 | } 57 | 58 | usage := p.Usage() 59 | var expected map[string][3]string = map[string][3]string{ 60 | "company": [3]string{"company", "string", "the company name"}, 61 | "founded": [3]string{"founded", "int", "when it was founded"}, 62 | "startup": [3]string{"startup", "bool", "whether it's a startup"}, 63 | "duration": [3]string{"duration", "duration", "how long it's been"}, 64 | "float": [3]string{"float", "float64", "some float64"}, 65 | "uint64": [3]string{"uint64", "uint64", "some uint64"}, 66 | "int64": [3]string{"int64", "int64", "some int64"}, 67 | "uint": [3]string{"uint", "uint", "some uint"}, 68 | "valueless": [3]string{"valueless", "bool", "some bool"}, 69 | "falseBool": [3]string{"falseBool", "bool", "a bool with value false"}, 70 | } 71 | compareUsageMaps(t, usage, expected) 72 | } 73 | 74 | func compareUsageMaps(t *testing.T, got, expected map[string][3]string) { 75 | seen := make(map[string]bool) 76 | for k, v := range got { 77 | seen[k] = true 78 | v2, ok := expected[k] 79 | if !ok { 80 | t.Errorf("%s doesn't exist in expected", k) 81 | } else if v2 != v { 82 | t.Errorf("%s: got '%s', expected '%s'", k, v, v2) 83 | } 84 | } 85 | for k, _ := range expected { 86 | if !seen[k] { 87 | _, ok := got[k] 88 | if !ok { 89 | t.Errorf("%s doesn't exist in got", k) 90 | } 91 | } 92 | } 93 | } 94 | 95 | func compareSlices(a, b []interface{}) bool { 96 | if len(a) != len(b) { 97 | return false 98 | } 99 | for i, aVal := range a { 100 | if aVal != b[i] { 101 | return false 102 | } 103 | } 104 | return true 105 | } 106 | 107 | func TestParamsSlices(t *testing.T) { 108 | v := url.Values{} 109 | p := Params{} 110 | v.Add("company", "VividCortex") 111 | v.Add("company", "Inc,,comma,") 112 | company := p.SliceString("company", "", "the company name") 113 | v.Add("founded", "2012") 114 | v.Add("founded", "2012,,2102,") 115 | founded := p.SliceInt("founded", 0, "when it was founded") 116 | v.Add("startup", "true") 117 | v.Add("startup", "false,,true,") 118 | startup := p.SliceBool("startup", false, "whether it's a startup") 119 | v.Add("float", "12.89") 120 | v.Add("float", "1.25,,12.625,") 121 | floatVar := p.SliceFloat64("float", 0, "some float64") 122 | v.Add("uint64", "1234") 123 | v.Add("uint64", "1234,,5678,") 124 | v.Add("uint64", "18446744073709551615") // 2^63-1 125 | uint64Var := p.SliceUint64("uint64", 0, "some uint64") 126 | v.Add("int64", "-9876") 127 | v.Add("int64", "-9876,,8765,") 128 | int64Var := p.SliceInt64("int64", 0, "some int64") 129 | v.Add("uint", "2345") 130 | v.Add("uint", "2345,,3456,") 131 | uintVar := p.SliceUint("uint", 0, "some uint") 132 | v.Add("duration", "10ms") 133 | v.Add("duration", "10s,,12ms,") 134 | duration := p.SliceDuration("duration", 0, "how long it's been") 135 | 136 | err := p.Parse(v) 137 | if err != nil { 138 | t.Error(err) 139 | } 140 | 141 | companies := []string{"VividCortex", "Inc", "", "comma", ""} 142 | for i, v := range *company { 143 | if v != companies[i] { 144 | t.Errorf("expected %s, got %s", companies[i], v) 145 | } 146 | } 147 | 148 | foundings := []int{2012, 2012, 2102} 149 | for i, v := range *founded { 150 | if v != foundings[i] { 151 | t.Errorf("expected %d, got %d", foundings[i], v) 152 | } 153 | } 154 | 155 | startups := []bool{true, false, true} 156 | for i, v := range *startup { 157 | if v != startups[i] { 158 | t.Errorf("expected %t, got %t", startups[i], v) 159 | } 160 | } 161 | 162 | floats := []float64{12.89, 1.25, 12.625} 163 | for i, v := range *floatVar { 164 | if v != floats[i] { 165 | t.Errorf("expected %f, got %f", floats[i], v) 166 | } 167 | } 168 | 169 | uint64s := []uint64{1234, 1234, 5678, 18446744073709551615} 170 | for i, v := range *uint64Var { 171 | if v != uint64s[i] { 172 | t.Errorf("expected %d, got %d", uint64s[i], v) 173 | } 174 | } 175 | 176 | int64s := []int64{-9876, -9876, 8765} 177 | for i, v := range *int64Var { 178 | if v != int64s[i] { 179 | t.Errorf("expected %d, got %d", int64s[i], v) 180 | } 181 | } 182 | 183 | uints := []uint{2345, 2345, 3456} 184 | for i, v := range *uintVar { 185 | if v != uints[i] { 186 | t.Errorf("expected %d, got %d", uints[i], v) 187 | } 188 | } 189 | 190 | durations := []time.Duration{10 * time.Millisecond, 10 * time.Second, 12 * time.Millisecond} 191 | for i, v := range *duration { 192 | if v != durations[i] { 193 | t.Errorf("expected %s, got %s", durations[i], v) 194 | } 195 | } 196 | 197 | usage := p.Usage() 198 | var expected map[string][3]string = map[string][3]string{ 199 | "company": [3]string{"company", "[]string", "the company name"}, 200 | "founded": [3]string{"founded", "[]int", "when it was founded"}, 201 | "startup": [3]string{"startup", "[]bool", "whether it's a startup"}, 202 | "duration": [3]string{"duration", "[]duration", "how long it's been"}, 203 | "float": [3]string{"float", "[]float64", "some float64"}, 204 | "uint64": [3]string{"uint64", "[]uint64", "some uint64"}, 205 | "int64": [3]string{"int64", "[]int64", "some int64"}, 206 | "uint": [3]string{"uint", "[]uint", "some uint"}, 207 | } 208 | compareUsageMaps(t, usage, expected) 209 | 210 | } 211 | -------------------------------------------------------------------------------- /routeparams.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2013 Julien Schmidt. All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * The names of the contributors may not be used to endorse or promote 12 | products derived from this software without specific prior written 13 | permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL JULIEN SCHMIDT BE LIABLE FOR ANY 19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | */ 26 | 27 | package siesta 28 | 29 | // Param is a single URL parameter, consisting of a key and a value. 30 | type routeParam struct { 31 | Key string 32 | Value string 33 | } 34 | 35 | // Params is a Param-slice, as returned by the router. 36 | // The slice is ordered, the first URL parameter is also the first slice value. 37 | // It is therefore safe to read values by the index. 38 | type routeParams []routeParam 39 | -------------------------------------------------------------------------------- /service.go: -------------------------------------------------------------------------------- 1 | package siesta 2 | 3 | import ( 4 | "io" 5 | "io/ioutil" 6 | "net/http" 7 | "path" 8 | "strings" 9 | ) 10 | 11 | // Registered services keyed by base URI. 12 | var services = map[string]*Service{} 13 | 14 | // A Service is a container for routes with a common base URI. 15 | // It also has two middleware chains, named "pre" and "post". 16 | // 17 | // The "pre" chain is run before the main handler. The first 18 | // handler in the "pre" chain is guaranteed to run, but execution 19 | // may quit anywhere else in the chain. 20 | // 21 | // If the "pre" chain executes completely, the main handler is executed. 22 | // It is skipped otherwise. 23 | // 24 | // The "post" chain runs after the main handler, whether it is skipped 25 | // or not. The first handler in the "post" chain is guaranteed to run, but 26 | // execution may quit anywhere else in the chain if the quit function 27 | // is called. 28 | type Service struct { 29 | baseURI string 30 | trimSlash bool 31 | 32 | pre []ContextHandler 33 | post []ContextHandler 34 | 35 | routes map[string]*node 36 | 37 | notFound ContextHandler 38 | 39 | // postExecutionFunc runs at the end of the request 40 | postExecutionFunc func(c Context, r *http.Request, panicValue interface{}) 41 | } 42 | 43 | // NewService returns a new Service with the given base URI 44 | // or panics if the base URI has already been registered. 45 | func NewService(baseURI string) *Service { 46 | if services[baseURI] != nil { 47 | panic("service already registered") 48 | } 49 | 50 | return &Service{ 51 | baseURI: path.Join("/", baseURI, "/"), 52 | routes: map[string]*node{}, 53 | trimSlash: true, 54 | } 55 | } 56 | 57 | // SetPostExecutionFunc sets a function that is executed at the end of every request. 58 | // panicValue will be non-nil if a value was recovered after a panic. 59 | func (s *Service) SetPostExecutionFunc(f func(c Context, r *http.Request, panicValue interface{})) { 60 | s.postExecutionFunc = f 61 | } 62 | 63 | // DisableTrimSlash disables the removal of trailing slashes 64 | // before route matching. 65 | func (s *Service) DisableTrimSlash() { 66 | s.trimSlash = false 67 | } 68 | 69 | func addToChain(f interface{}, chain []ContextHandler) []ContextHandler { 70 | m := ToContextHandler(f) 71 | return append(chain, m) 72 | } 73 | 74 | // AddPre adds f to the end of the "pre" chain. 75 | // It panics if f cannot be converted to a ContextHandler (see Service.Route). 76 | func (s *Service) AddPre(f interface{}) { 77 | s.pre = addToChain(f, s.pre) 78 | } 79 | 80 | // AddPost adds f to the end of the "post" chain. 81 | // It panics if f cannot be converted to a ContextHandler (see Service.Route). 82 | func (s *Service) AddPost(f interface{}) { 83 | s.post = addToChain(f, s.post) 84 | } 85 | 86 | // Service satisfies the http.Handler interface. 87 | func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { 88 | s.ServeHTTPInContext(NewSiestaContext(), w, r) 89 | } 90 | 91 | // ServeHTTPInContext serves an HTTP request within the Context c. 92 | // A Service will run through both of its internal chains, quitting 93 | // when requested. 94 | func (s *Service) ServeHTTPInContext(c Context, w http.ResponseWriter, r *http.Request) { 95 | defer func() { 96 | var e interface{} 97 | // Check if there was a panic 98 | e = recover() 99 | // Run the post execution func if we have one 100 | if s.postExecutionFunc != nil { 101 | s.postExecutionFunc(c, r, e) 102 | } 103 | if e != nil { 104 | // Re-panic if we recovered 105 | panic(e) 106 | } 107 | }() 108 | r.ParseForm() 109 | 110 | quit := false 111 | for _, m := range s.pre { 112 | m(c, w, r, func() { 113 | quit = true 114 | }) 115 | 116 | if quit { 117 | // Break out of the "pre" loop, but 118 | // continue on. 119 | break 120 | } 121 | } 122 | 123 | if !quit { 124 | // The main handler is only run if we have not 125 | // been signaled to quit. 126 | 127 | if r.URL.Path != "/" && s.trimSlash { 128 | r.URL.Path = strings.TrimRight(r.URL.Path, "/") 129 | } 130 | 131 | var ( 132 | handler ContextHandler 133 | usage string 134 | params routeParams 135 | ) 136 | 137 | // Lookup the tree for this method 138 | routeNode, ok := s.routes[r.Method] 139 | 140 | if ok { 141 | handler, usage, params, _ = routeNode.getValue(r.URL.Path) 142 | c.Set(UsageContextKey, usage) 143 | } 144 | 145 | if handler == nil { 146 | if s.notFound != nil { 147 | // Use user-defined handler. 148 | s.notFound(c, w, r, func() {}) 149 | } else { 150 | // Default to the net/http NotFoundHandler. 151 | http.NotFoundHandler().ServeHTTP(w, r) 152 | } 153 | } else { 154 | for _, p := range params { 155 | r.Form.Set(p.Key, p.Value) 156 | } 157 | 158 | handler(c, w, r, func() { 159 | quit = true 160 | }) 161 | 162 | if r.Body != nil { 163 | io.Copy(ioutil.Discard, r.Body) 164 | r.Body.Close() 165 | } 166 | } 167 | } 168 | 169 | quit = false 170 | for _, m := range s.post { 171 | m(c, w, r, func() { 172 | quit = true 173 | }) 174 | 175 | if quit { 176 | return 177 | } 178 | } 179 | } 180 | 181 | // Route adds a new route to the Service. 182 | // f must be a function with one of the following signatures: 183 | // 184 | // func(http.ResponseWriter, *http.Request) 185 | // func(http.ResponseWriter, *http.Request, func()) 186 | // func(Context, http.ResponseWriter, *http.Request) 187 | // func(Context, http.ResponseWriter, *http.Request, func()) 188 | // 189 | // Note that Context is an interface type defined in this package. 190 | // The last argument is a function which is called to signal the 191 | // quitting of the current execution sequence. 192 | func (s *Service) Route(verb, uriPath, usage string, f interface{}) { 193 | handler := ToContextHandler(f) 194 | 195 | if n := s.routes[verb]; n == nil { 196 | s.routes[verb] = &node{} 197 | } 198 | 199 | s.routes[verb].addRoute( 200 | path.Join(s.baseURI, strings.TrimRight(uriPath, "/")), 201 | usage, handler) 202 | } 203 | 204 | // SetNotFound sets the handler for all paths that do not 205 | // match any existing routes. It accepts the same function 206 | // signatures that Route does with the addition of `nil`. 207 | func (s *Service) SetNotFound(f interface{}) { 208 | if f == nil { 209 | s.notFound = nil 210 | return 211 | } 212 | 213 | handler := ToContextHandler(f) 214 | s.notFound = handler 215 | } 216 | 217 | // Register registers s by adding it as a handler to the 218 | // DefaultServeMux in the net/http package. 219 | func (s *Service) Register() { 220 | http.Handle(s.baseURI, s) 221 | } 222 | -------------------------------------------------------------------------------- /service_test.go: -------------------------------------------------------------------------------- 1 | package siesta 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "net/http/httptest" 7 | "reflect" 8 | "testing" 9 | ) 10 | 11 | func TestServiceRoute(t *testing.T) { 12 | s := NewService("foos") 13 | s.Route(http.MethodGet, "/bars/:id/bazs", "Handles bars' bazs", func(Context, http.ResponseWriter, *http.Request, func()) {}) 14 | 15 | srv := httptest.NewServer(s) 16 | defer srv.Close() 17 | 18 | resp, err := http.Get(srv.URL + "/foos/bars/1/bazs") 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | if want, got := http.StatusOK, resp.StatusCode; want != got { 23 | t.Fatalf("expected status %d got %d", want, got) 24 | } 25 | } 26 | 27 | func TestServiceDefaultNotFound(t *testing.T) { 28 | s := NewService("") 29 | 30 | srv := httptest.NewServer(s) 31 | defer srv.Close() 32 | 33 | resp, err := http.Get(srv.URL + "/no/where/to/be/found") 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | if want, got := http.StatusNotFound, resp.StatusCode; want != got { 38 | t.Fatalf("expected status %d got %d", want, got) 39 | } 40 | } 41 | 42 | func TestServiceCustomNotFound(t *testing.T) { 43 | type payload struct { 44 | Code int `json:"code"` 45 | Message string `json:"message"` 46 | } 47 | want := payload{Code: http.StatusNotFound, Message: http.StatusText(http.StatusNotFound)} 48 | 49 | s := NewService("") 50 | s.SetNotFound(func(c Context, w http.ResponseWriter, r *http.Request) { 51 | w.Header().Set("Content-Type", "application/json") 52 | w.WriteHeader(http.StatusNotFound) 53 | _ = json.NewEncoder(w).Encode(&want) 54 | }) 55 | 56 | srv := httptest.NewServer(s) 57 | defer srv.Close() 58 | 59 | resp, err := http.Get(srv.URL + "/no/where/to/be/found") 60 | if resp != nil { 61 | defer resp.Body.Close() 62 | } 63 | if err != nil { 64 | t.Fatal(err) 65 | } 66 | if want, got := http.StatusNotFound, resp.StatusCode; want != got { 67 | t.Fatalf("expected status %d got %d", want, got) 68 | } 69 | 70 | var got payload 71 | if err = json.NewDecoder(resp.Body).Decode(&got); err != nil { 72 | t.Fatal(err) 73 | } 74 | if !reflect.DeepEqual(want, got) { 75 | t.Errorf("expected payload %v got %v", want, got) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /tree.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2013 Julien Schmidt. All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * The names of the contributors may not be used to endorse or promote 12 | products derived from this software without specific prior written 13 | permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL JULIEN SCHMIDT BE LIABLE FOR ANY 19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | */ 26 | 27 | package siesta 28 | 29 | import ( 30 | "strings" 31 | "unicode" 32 | ) 33 | 34 | func min(a, b int) int { 35 | if a <= b { 36 | return a 37 | } 38 | return b 39 | } 40 | 41 | func countParams(path string) uint8 { 42 | var n uint 43 | for i := 0; i < len(path); i++ { 44 | if path[i] != ':' && path[i] != '*' { 45 | continue 46 | } 47 | n++ 48 | } 49 | if n >= 255 { 50 | return 255 51 | } 52 | return uint8(n) 53 | } 54 | 55 | type nodeType uint8 56 | 57 | const ( 58 | static nodeType = 0 59 | param nodeType = 1 60 | catchAll nodeType = 2 61 | ) 62 | 63 | type node struct { 64 | path string 65 | wildChild bool 66 | nType nodeType 67 | maxParams uint8 68 | indices []byte 69 | children []*node 70 | handle ContextHandler 71 | usage string 72 | priority uint32 73 | } 74 | 75 | // increments priority of the given child and reorders if necessary 76 | func (n *node) incrementChildPrio(i int) int { 77 | n.children[i].priority++ 78 | prio := n.children[i].priority 79 | 80 | // adjust position (move to front) 81 | for j := i - 1; j >= 0 && n.children[j].priority < prio; j-- { 82 | // swap node positions 83 | tmpN := n.children[j] 84 | n.children[j] = n.children[i] 85 | n.children[i] = tmpN 86 | tmpI := n.indices[j] 87 | n.indices[j] = n.indices[i] 88 | n.indices[i] = tmpI 89 | 90 | i-- 91 | } 92 | return i 93 | } 94 | 95 | // addRoute adds a node with the given handle to the path. 96 | // Not concurrency-safe! 97 | func (n *node) addRoute(path string, usage string, handle ContextHandler) { 98 | n.priority++ 99 | numParams := countParams(path) 100 | 101 | // non-empty tree 102 | if len(n.path) > 0 || len(n.children) > 0 { 103 | WALK: 104 | for { 105 | // Update maxParams of the current node 106 | if numParams > n.maxParams { 107 | n.maxParams = numParams 108 | } 109 | 110 | // Find the longest common prefix. 111 | // This also implies that the commom prefix contains no ':' or '*' 112 | // since the existing key can't contain this chars. 113 | i := 0 114 | for max := min(len(path), len(n.path)); i < max && path[i] == n.path[i]; i++ { 115 | } 116 | 117 | // Split edge 118 | if i < len(n.path) { 119 | child := node{ 120 | path: n.path[i:], 121 | wildChild: n.wildChild, 122 | indices: n.indices, 123 | children: n.children, 124 | handle: n.handle, 125 | usage: n.usage, 126 | priority: n.priority - 1, 127 | } 128 | 129 | // Update maxParams (max of all children) 130 | for i := range child.children { 131 | if child.children[i].maxParams > child.maxParams { 132 | child.maxParams = child.children[i].maxParams 133 | } 134 | } 135 | 136 | n.children = []*node{&child} 137 | n.indices = []byte{n.path[i]} 138 | n.path = path[:i] 139 | n.handle = nil 140 | n.usage = "" 141 | n.wildChild = false 142 | } 143 | 144 | // Make new node a child of this node 145 | if i < len(path) { 146 | path = path[i:] 147 | 148 | if n.wildChild { 149 | n = n.children[0] 150 | n.priority++ 151 | 152 | // Update maxParams of the child node 153 | if numParams > n.maxParams { 154 | n.maxParams = numParams 155 | } 156 | numParams-- 157 | 158 | // Check if the wildcard matches 159 | if len(path) >= len(n.path) && n.path == path[:len(n.path)] { 160 | // check for longer wildcard, e.g. :name and :names 161 | if len(n.path) >= len(path) || path[len(n.path)] == '/' { 162 | continue WALK 163 | } 164 | } 165 | 166 | panic("conflict with wildcard route") 167 | } 168 | 169 | c := path[0] 170 | 171 | // slash after param 172 | if n.nType == param && c == '/' && len(n.children) == 1 { 173 | n = n.children[0] 174 | n.priority++ 175 | continue WALK 176 | } 177 | 178 | // Check if a child with the next path byte exists 179 | for i, index := range n.indices { 180 | if c == index { 181 | i = n.incrementChildPrio(i) 182 | n = n.children[i] 183 | continue WALK 184 | } 185 | } 186 | 187 | // Otherwise insert it 188 | if c != ':' && c != '*' { 189 | n.indices = append(n.indices, c) 190 | child := &node{ 191 | maxParams: numParams, 192 | } 193 | n.children = append(n.children, child) 194 | n.incrementChildPrio(len(n.indices) - 1) 195 | n = child 196 | } 197 | n.insertChild(numParams, path, usage, handle) 198 | return 199 | 200 | } else if i == len(path) { // Make node a (in-path) leaf 201 | if n.handle != nil { 202 | panic("a Handle is already registered for this path") 203 | } 204 | n.handle = handle 205 | n.usage = usage 206 | } 207 | return 208 | } 209 | } else { // Empty tree 210 | n.insertChild(numParams, path, usage, handle) 211 | } 212 | } 213 | 214 | func (n *node) insertChild(numParams uint8, path string, usage string, handle ContextHandler) { 215 | var offset int // already handled bytes of the path 216 | 217 | // find prefix until first wildcard (beginning with ':'' or '*'') 218 | for i, max := 0, len(path); numParams > 0; i++ { 219 | c := path[i] 220 | if c != ':' && c != '*' { 221 | continue 222 | } 223 | 224 | // check if this Node existing children which would be 225 | // unreachable if we insert the wildcard here 226 | if len(n.children) > 0 { 227 | panic("wildcard route conflicts with existing children") 228 | } 229 | 230 | // find wildcard end (either '/' or path end) 231 | end := i + 1 232 | for end < max && path[end] != '/' { 233 | switch path[end] { 234 | // the wildcard name must not contain ':' and '*' 235 | case ':', '*': 236 | panic("only one wildcard per path segment is allowed") 237 | default: 238 | end++ 239 | } 240 | } 241 | 242 | // check if the wildcard has a name 243 | if end-i < 2 { 244 | panic("wildcards must be named with a non-empty name") 245 | } 246 | 247 | if c == ':' { // param 248 | // split path at the beginning of the wildcard 249 | if i > 0 { 250 | n.path = path[offset:i] 251 | offset = i 252 | } 253 | 254 | child := &node{ 255 | nType: param, 256 | maxParams: numParams, 257 | } 258 | n.children = []*node{child} 259 | n.wildChild = true 260 | n = child 261 | n.priority++ 262 | numParams-- 263 | 264 | // if the path doesn't end with the wildcard, then there 265 | // will be another non-wildcard subpath starting with '/' 266 | if end < max { 267 | n.path = path[offset:end] 268 | offset = end 269 | 270 | child := &node{ 271 | maxParams: numParams, 272 | priority: 1, 273 | } 274 | n.children = []*node{child} 275 | n = child 276 | } 277 | 278 | } else { // catchAll 279 | if end != max || numParams > 1 { 280 | panic("catch-all routes are only allowed at the end of the path") 281 | } 282 | 283 | if len(n.path) > 0 && n.path[len(n.path)-1] == '/' { 284 | panic("catch-all conflicts with existing handle for the path segment root") 285 | } 286 | 287 | // currently fixed width 1 for '/' 288 | i-- 289 | if path[i] != '/' { 290 | panic("no / before catch-all") 291 | } 292 | 293 | n.path = path[offset:i] 294 | 295 | // first node: catchAll node with empty path 296 | child := &node{ 297 | wildChild: true, 298 | nType: catchAll, 299 | maxParams: 1, 300 | } 301 | n.children = []*node{child} 302 | n.indices = []byte{path[i]} 303 | n = child 304 | n.priority++ 305 | 306 | // second node: node holding the variable 307 | child = &node{ 308 | path: path[i:], 309 | nType: catchAll, 310 | maxParams: 1, 311 | handle: handle, 312 | usage: usage, 313 | priority: 1, 314 | } 315 | n.children = []*node{child} 316 | 317 | return 318 | } 319 | } 320 | 321 | // insert remaining path part and handle to the leaf 322 | n.path = path[offset:] 323 | n.handle = handle 324 | n.usage = usage 325 | } 326 | 327 | // Returns the handle registered with the given path (key). The values of 328 | // wildcards are saved to a map. 329 | // If no handle can be found, a TSR (trailing slash redirect) recommendation is 330 | // made if a handle exists with an extra (without the) trailing slash for the 331 | // given path. 332 | func (n *node) getValue(path string) (handle ContextHandler, usage string, p routeParams, tsr bool) { 333 | walk: // Outer loop for walking the tree 334 | for { 335 | if len(path) > len(n.path) { 336 | if path[:len(n.path)] == n.path { 337 | path = path[len(n.path):] 338 | // If this node does not have a wildcard (param or catchAll) 339 | // child, we can just look up the next child node and continue 340 | // to walk down the tree 341 | if !n.wildChild { 342 | c := path[0] 343 | for i, index := range n.indices { 344 | if c == index { 345 | n = n.children[i] 346 | continue walk 347 | } 348 | } 349 | 350 | // Nothing found. 351 | // We can recommend to redirect to the same URL without a 352 | // trailing slash if a leaf exists for that path. 353 | tsr = (path == "/" && n.handle != nil) 354 | return 355 | 356 | } 357 | 358 | // handle wildcard child 359 | n = n.children[0] 360 | switch n.nType { 361 | case param: 362 | // find param end (either '/' or path end) 363 | end := 0 364 | for end < len(path) && path[end] != '/' { 365 | end++ 366 | } 367 | 368 | // save param value 369 | if p == nil { 370 | // lazy allocation 371 | p = make(routeParams, 0, n.maxParams) 372 | } 373 | i := len(p) 374 | p = p[:i+1] // expand slice within preallocated capacity 375 | p[i].Key = n.path[1:] 376 | p[i].Value = path[:end] 377 | 378 | // we need to go deeper! 379 | if end < len(path) { 380 | if len(n.children) > 0 { 381 | path = path[end:] 382 | n = n.children[0] 383 | continue walk 384 | } 385 | 386 | // ... but we can't 387 | tsr = (len(path) == end+1) 388 | return 389 | } 390 | 391 | if handle, usage = n.handle, n.usage; handle != nil { 392 | return 393 | } else if len(n.children) == 1 { 394 | // No handle found. Check if a handle for this path + a 395 | // trailing slash exists for TSR recommendation 396 | n = n.children[0] 397 | tsr = (n.path == "/" && n.handle != nil) 398 | } 399 | 400 | return 401 | 402 | case catchAll: 403 | // save param value 404 | if p == nil { 405 | // lazy allocation 406 | p = make(routeParams, 0, n.maxParams) 407 | } 408 | i := len(p) 409 | p = p[:i+1] // expand slice within preallocated capacity 410 | p[i].Key = n.path[2:] 411 | p[i].Value = path 412 | 413 | handle = n.handle 414 | usage = n.usage 415 | return 416 | 417 | default: 418 | panic("Invalid node type") 419 | } 420 | } 421 | } else if path == n.path { 422 | // We should have reached the node containing the handle. 423 | // Check if this node has a handle registered. 424 | if handle, usage = n.handle, n.usage; handle != nil { 425 | return 426 | } 427 | 428 | // No handle found. Check if a handle for this path + a 429 | // trailing slash exists for trailing slash recommendation 430 | for i, index := range n.indices { 431 | if index == '/' { 432 | n = n.children[i] 433 | tsr = (n.path == "/" && n.handle != nil) || 434 | (n.nType == catchAll && n.children[0].handle != nil) 435 | return 436 | } 437 | } 438 | 439 | return 440 | } 441 | 442 | // Nothing found. We can recommend to redirect to the same URL with an 443 | // extra trailing slash if a leaf exists for that path 444 | tsr = (path == "/") || 445 | (len(n.path) == len(path)+1 && n.path[len(path)] == '/' && 446 | path == n.path[:len(n.path)-1] && n.handle != nil) 447 | return 448 | } 449 | } 450 | 451 | // Makes a case-insensitive lookup of the given path and tries to find a handler. 452 | // It can optionally also fix trailing slashes. 453 | // It returns the case-corrected path and a bool indicating wether the lookup 454 | // was successful. 455 | func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPath []byte, found bool) { 456 | ciPath = make([]byte, 0, len(path)+1) // preallocate enough memory 457 | 458 | // Outer loop for walking the tree 459 | for len(path) >= len(n.path) && strings.ToLower(path[:len(n.path)]) == strings.ToLower(n.path) { 460 | path = path[len(n.path):] 461 | ciPath = append(ciPath, n.path...) 462 | 463 | if len(path) > 0 { 464 | // If this node does not have a wildcard (param or catchAll) child, 465 | // we can just look up the next child node and continue to walk down 466 | // the tree 467 | if !n.wildChild { 468 | r := unicode.ToLower(rune(path[0])) 469 | for i, index := range n.indices { 470 | // must use recursive approach since both index and 471 | // ToLower(index) could exist. We must check both. 472 | if r == unicode.ToLower(rune(index)) { 473 | out, found := n.children[i].findCaseInsensitivePath(path, fixTrailingSlash) 474 | if found { 475 | return append(ciPath, out...), true 476 | } 477 | } 478 | } 479 | 480 | // Nothing found. We can recommend to redirect to the same URL 481 | // without a trailing slash if a leaf exists for that path 482 | found = (fixTrailingSlash && path == "/" && n.handle != nil) 483 | return 484 | } 485 | 486 | n = n.children[0] 487 | switch n.nType { 488 | case param: 489 | // find param end (either '/' or path end) 490 | k := 0 491 | for k < len(path) && path[k] != '/' { 492 | k++ 493 | } 494 | 495 | // add param value to case insensitive path 496 | ciPath = append(ciPath, path[:k]...) 497 | 498 | // we need to go deeper! 499 | if k < len(path) { 500 | if len(n.children) > 0 { 501 | path = path[k:] 502 | n = n.children[0] 503 | continue 504 | } 505 | 506 | // ... but we can't 507 | if fixTrailingSlash && len(path) == k+1 { 508 | return ciPath, true 509 | } 510 | return 511 | } 512 | 513 | if n.handle != nil { 514 | return ciPath, true 515 | } else if fixTrailingSlash && len(n.children) == 1 { 516 | // No handle found. Check if a handle for this path + a 517 | // trailing slash exists 518 | n = n.children[0] 519 | if n.path == "/" && n.handle != nil { 520 | return append(ciPath, '/'), true 521 | } 522 | } 523 | return 524 | 525 | case catchAll: 526 | return append(ciPath, path...), true 527 | 528 | default: 529 | panic("Invalid node type") 530 | } 531 | } else { 532 | // We should have reached the node containing the handle. 533 | // Check if this node has a handle registered. 534 | if n.handle != nil { 535 | return ciPath, true 536 | } 537 | 538 | // No handle found. 539 | // Try to fix the path by adding a trailing slash 540 | if fixTrailingSlash { 541 | for i, index := range n.indices { 542 | if index == '/' { 543 | n = n.children[i] 544 | if (n.path == "/" && n.handle != nil) || 545 | (n.nType == catchAll && n.children[0].handle != nil) { 546 | return append(ciPath, '/'), true 547 | } 548 | return 549 | } 550 | } 551 | } 552 | return 553 | } 554 | } 555 | 556 | // Nothing found. 557 | // Try to fix the path by adding / removing a trailing slash 558 | if fixTrailingSlash { 559 | if path == "/" { 560 | return ciPath, true 561 | } 562 | if len(path)+1 == len(n.path) && n.path[len(path)] == '/' && 563 | strings.ToLower(path) == strings.ToLower(n.path[:len(path)]) && 564 | n.handle != nil { 565 | return append(ciPath, n.path...), true 566 | } 567 | } 568 | return 569 | } 570 | -------------------------------------------------------------------------------- /tree_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2013 Julien Schmidt. All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * The names of the contributors may not be used to endorse or promote 12 | products derived from this software without specific prior written 13 | permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL JULIEN SCHMIDT BE LIABLE FOR ANY 19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | */ 26 | 27 | package siesta 28 | 29 | import ( 30 | "fmt" 31 | "net/http" 32 | "reflect" 33 | "strings" 34 | "testing" 35 | ) 36 | 37 | func printChildren(n *node, prefix string) { 38 | fmt.Printf(" %02d:%02d %s%s[%d] %v %t %d \r\n", n.priority, n.maxParams, prefix, n.path, len(n.children), n.handle, n.wildChild, n.nType) 39 | for l := len(n.path); l > 0; l-- { 40 | prefix += " " 41 | } 42 | for _, child := range n.children { 43 | printChildren(child, prefix) 44 | } 45 | } 46 | 47 | // Used as a workaround since we can't compare functions or their adresses 48 | var fakeHandlerValue string 49 | 50 | func fakeHandler(val string) ContextHandler { 51 | return ToContextHandler(func(http.ResponseWriter, *http.Request) { 52 | fakeHandlerValue = val 53 | }) 54 | } 55 | 56 | type testRequests []struct { 57 | path string 58 | nilHandler bool 59 | route string 60 | ps routeParams 61 | } 62 | 63 | func checkRequests(t *testing.T, tree *node, requests testRequests) { 64 | for _, request := range requests { 65 | handler, _, ps, _ := tree.getValue(request.path) 66 | 67 | if handler == nil { 68 | if !request.nilHandler { 69 | t.Errorf("handle mismatch for route '%s': Expected non-nil handle", request.path) 70 | } 71 | } else if request.nilHandler { 72 | t.Errorf("handle mismatch for route '%s': Expected nil handle", request.path) 73 | } else { 74 | handler(nil, nil, nil, nil) 75 | if fakeHandlerValue != request.route { 76 | t.Errorf("handle mismatch for route '%s': Wrong handle (%s != %s)", request.path, fakeHandlerValue, request.route) 77 | } 78 | } 79 | 80 | if !reflect.DeepEqual(ps, request.ps) { 81 | t.Errorf("Params mismatch for route '%s'", request.path) 82 | } 83 | } 84 | } 85 | 86 | func checkPriorities(t *testing.T, n *node) uint32 { 87 | var prio uint32 88 | for i := range n.children { 89 | prio += checkPriorities(t, n.children[i]) 90 | } 91 | 92 | if n.handle != nil { 93 | prio++ 94 | } 95 | 96 | if n.priority != prio { 97 | t.Errorf( 98 | "priority mismatch for node '%s': is %d, should be %d", 99 | n.path, n.priority, prio, 100 | ) 101 | } 102 | 103 | return prio 104 | } 105 | 106 | func checkMaxParams(t *testing.T, n *node) uint8 { 107 | var maxParams uint8 108 | for i := range n.children { 109 | params := checkMaxParams(t, n.children[i]) 110 | if params > maxParams { 111 | maxParams = params 112 | } 113 | } 114 | if n.nType != static && !n.wildChild { 115 | maxParams++ 116 | } 117 | 118 | if n.maxParams != maxParams { 119 | t.Errorf( 120 | "maxParams mismatch for node '%s': is %d, should be %d", 121 | n.path, n.maxParams, maxParams, 122 | ) 123 | } 124 | 125 | return maxParams 126 | } 127 | 128 | func TestCountParams(t *testing.T) { 129 | if countParams("/path/:param1/static/*catch-all") != 2 { 130 | t.Fail() 131 | } 132 | if countParams(strings.Repeat("/:param", 256)) != 255 { 133 | t.Fail() 134 | } 135 | } 136 | 137 | func TestTreeAddAndGet(t *testing.T) { 138 | tree := &node{} 139 | 140 | routes := [...]string{ 141 | "/hi", 142 | "/contact", 143 | "/co", 144 | "/c", 145 | "/a", 146 | "/ab", 147 | "/doc/", 148 | "/doc/go_faq.html", 149 | "/doc/go1.html", 150 | } 151 | for _, route := range routes { 152 | tree.addRoute(route, "", fakeHandler(route)) 153 | } 154 | 155 | checkRequests(t, tree, testRequests{ 156 | {"/a", false, "/a", nil}, 157 | {"/", true, "", nil}, 158 | {"/hi", false, "/hi", nil}, 159 | {"/contact", false, "/contact", nil}, 160 | {"/co", false, "/co", nil}, 161 | {"/con", true, "", nil}, // key mismatch 162 | {"/cona", true, "", nil}, // key mismatch 163 | {"/no", true, "", nil}, // no matching child 164 | {"/ab", false, "/ab", nil}, 165 | }) 166 | 167 | checkPriorities(t, tree) 168 | checkMaxParams(t, tree) 169 | } 170 | 171 | func TestTreeWildcard(t *testing.T) { 172 | tree := &node{} 173 | 174 | routes := [...]string{ 175 | "/", 176 | "/cmd/:tool/:sub", 177 | "/cmd/:tool/", 178 | "/src/*filepath", 179 | "/search/", 180 | "/search/:query", 181 | "/user_:name", 182 | "/user_:name/about", 183 | "/files/:dir/*filepath", 184 | "/doc/", 185 | "/doc/go_faq.html", 186 | "/doc/go1.html", 187 | "/info/:user/public", 188 | "/info/:user/project/:project", 189 | } 190 | for _, route := range routes { 191 | tree.addRoute(route, "", fakeHandler(route)) 192 | } 193 | 194 | checkRequests(t, tree, testRequests{ 195 | {"/", false, "/", nil}, 196 | {"/cmd/test/", false, "/cmd/:tool/", routeParams{routeParam{"tool", "test"}}}, 197 | {"/cmd/test", true, "", routeParams{routeParam{"tool", "test"}}}, 198 | {"/cmd/test/3", false, "/cmd/:tool/:sub", routeParams{routeParam{"tool", "test"}, routeParam{"sub", "3"}}}, 199 | {"/src/", false, "/src/*filepath", routeParams{routeParam{"filepath", "/"}}}, 200 | {"/src/some/file.png", false, "/src/*filepath", routeParams{routeParam{"filepath", "/some/file.png"}}}, 201 | {"/search/", false, "/search/", nil}, 202 | {"/search/someth!ng+in+ünìcodé", false, "/search/:query", routeParams{routeParam{"query", "someth!ng+in+ünìcodé"}}}, 203 | {"/search/someth!ng+in+ünìcodé/", true, "", routeParams{routeParam{"query", "someth!ng+in+ünìcodé"}}}, 204 | {"/user_gopher", false, "/user_:name", routeParams{routeParam{"name", "gopher"}}}, 205 | {"/user_gopher/about", false, "/user_:name/about", routeParams{routeParam{"name", "gopher"}}}, 206 | {"/files/js/inc/framework.js", false, "/files/:dir/*filepath", routeParams{routeParam{"dir", "js"}, routeParam{"filepath", "/inc/framework.js"}}}, 207 | {"/info/gordon/public", false, "/info/:user/public", routeParams{routeParam{"user", "gordon"}}}, 208 | {"/info/gordon/project/go", false, "/info/:user/project/:project", routeParams{routeParam{"user", "gordon"}, routeParam{"project", "go"}}}, 209 | }) 210 | 211 | checkPriorities(t, tree) 212 | checkMaxParams(t, tree) 213 | } 214 | 215 | func catchPanic(testFunc func()) (recv interface{}) { 216 | defer func() { 217 | recv = recover() 218 | }() 219 | 220 | testFunc() 221 | return 222 | } 223 | 224 | type testRoute struct { 225 | path string 226 | conflict bool 227 | } 228 | 229 | func testRoutes(t *testing.T, routes []testRoute) { 230 | tree := &node{} 231 | 232 | for _, route := range routes { 233 | recv := catchPanic(func() { 234 | tree.addRoute(route.path, "", nil) 235 | }) 236 | 237 | if route.conflict { 238 | if recv == nil { 239 | t.Errorf("no panic for conflicting route '%s'", route.path) 240 | } 241 | } else if recv != nil { 242 | t.Errorf("unexpected panic for route '%s': %v", route.path, recv) 243 | } 244 | } 245 | } 246 | 247 | func TestTreeWildcardConflict(t *testing.T) { 248 | routes := []testRoute{ 249 | {"/cmd/:tool/:sub", false}, 250 | {"/cmd/vet", true}, 251 | {"/src/*filepath", false}, 252 | {"/src/*filepathx", true}, 253 | {"/src/", true}, 254 | {"/src1/", false}, 255 | {"/src1/*filepath", true}, 256 | {"/src2*filepath", true}, 257 | {"/search/:query", false}, 258 | {"/search/invalid", true}, 259 | {"/user_:name", false}, 260 | {"/user_x", true}, 261 | {"/user_:name", false}, 262 | {"/id:id", false}, 263 | {"/id/:id", true}, 264 | } 265 | testRoutes(t, routes) 266 | } 267 | 268 | func TestTreeChildConflict(t *testing.T) { 269 | routes := []testRoute{ 270 | {"/cmd/vet", false}, 271 | {"/cmd/:tool/:sub", true}, 272 | {"/src/AUTHORS", false}, 273 | {"/src/*filepath", true}, 274 | {"/user_x", false}, 275 | {"/user_:name", true}, 276 | {"/id/:id", false}, 277 | {"/id:id", true}, 278 | {"/:id", true}, 279 | {"/*filepath", true}, 280 | } 281 | testRoutes(t, routes) 282 | } 283 | 284 | func TestTreeDupliatePath(t *testing.T) { 285 | tree := &node{} 286 | 287 | routes := [...]string{ 288 | "/", 289 | "/doc/", 290 | "/src/*filepath", 291 | "/search/:query", 292 | "/user_:name", 293 | } 294 | for _, route := range routes { 295 | recv := catchPanic(func() { 296 | tree.addRoute(route, "", fakeHandler(route)) 297 | }) 298 | if recv != nil { 299 | t.Fatalf("panic inserting route '%s': %v", route, recv) 300 | } 301 | 302 | // Add again 303 | recv = catchPanic(func() { 304 | tree.addRoute(route, "", nil) 305 | }) 306 | if recv == nil { 307 | t.Fatalf("no panic while inserting duplicate route '%s", route) 308 | } 309 | } 310 | 311 | checkRequests(t, tree, testRequests{ 312 | {"/", false, "/", nil}, 313 | {"/doc/", false, "/doc/", nil}, 314 | {"/src/some/file.png", false, "/src/*filepath", routeParams{routeParam{"filepath", "/some/file.png"}}}, 315 | {"/search/someth!ng+in+ünìcodé", false, "/search/:query", routeParams{routeParam{"query", "someth!ng+in+ünìcodé"}}}, 316 | {"/user_gopher", false, "/user_:name", routeParams{routeParam{"name", "gopher"}}}, 317 | }) 318 | } 319 | 320 | func TestEmptyWildcardName(t *testing.T) { 321 | tree := &node{} 322 | 323 | routes := [...]string{ 324 | "/user:", 325 | "/user:/", 326 | "/cmd/:/", 327 | "/src/*", 328 | } 329 | for _, route := range routes { 330 | recv := catchPanic(func() { 331 | tree.addRoute(route, "", nil) 332 | }) 333 | if recv == nil { 334 | t.Fatalf("no panic while inserting route with empty wildcard name '%s", route) 335 | } 336 | } 337 | } 338 | 339 | func TestTreeCatchAllConflict(t *testing.T) { 340 | routes := []testRoute{ 341 | {"/src/*filepath/x", true}, 342 | {"/src2/", false}, 343 | {"/src2/*filepath/x", true}, 344 | } 345 | testRoutes(t, routes) 346 | } 347 | 348 | func TestTreeCatchAllConflictRoot(t *testing.T) { 349 | routes := []testRoute{ 350 | {"/", false}, 351 | {"/*filepath", true}, 352 | } 353 | testRoutes(t, routes) 354 | } 355 | 356 | func TestTreeDoubleWildcard(t *testing.T) { 357 | const panicMsg = "only one wildcard per path segment is allowed" 358 | 359 | routes := [...]string{ 360 | "/:foo:bar", 361 | "/:foo:bar/", 362 | "/:foo*bar", 363 | } 364 | 365 | for _, route := range routes { 366 | tree := &node{} 367 | recv := catchPanic(func() { 368 | tree.addRoute(route, "", nil) 369 | }) 370 | 371 | if rs, ok := recv.(string); !ok || rs != panicMsg { 372 | t.Fatalf(`"Expected panic "%s" for route '%s', got "%v"`, panicMsg, route, recv) 373 | } 374 | } 375 | } 376 | 377 | func TestTreeTrailingSlashRedirect(t *testing.T) { 378 | tree := &node{} 379 | 380 | routes := [...]string{ 381 | "/hi", 382 | "/b/", 383 | "/search/:query", 384 | "/cmd/:tool/", 385 | "/src/*filepath", 386 | "/x", 387 | "/x/y", 388 | "/y/", 389 | "/y/z", 390 | "/0/:id", 391 | "/0/:id/1", 392 | "/1/:id/", 393 | "/1/:id/2", 394 | "/aa", 395 | "/a/", 396 | "/doc", 397 | "/doc/go_faq.html", 398 | "/doc/go1.html", 399 | "/no/a", 400 | "/no/b", 401 | "/api/hello/:name", 402 | } 403 | for _, route := range routes { 404 | recv := catchPanic(func() { 405 | tree.addRoute(route, "", fakeHandler(route)) 406 | }) 407 | if recv != nil { 408 | t.Fatalf("panic inserting route '%s': %v", route, recv) 409 | } 410 | } 411 | 412 | tsrRoutes := [...]string{ 413 | "/hi/", 414 | "/b", 415 | "/search/gopher/", 416 | "/cmd/vet", 417 | "/src", 418 | "/x/", 419 | "/y", 420 | "/0/go/", 421 | "/1/go", 422 | "/a", 423 | "/doc/", 424 | } 425 | for _, route := range tsrRoutes { 426 | handler, _, _, tsr := tree.getValue(route) 427 | if handler != nil { 428 | t.Fatalf("non-nil handler for TSR route '%s", route) 429 | } else if !tsr { 430 | t.Errorf("expected TSR recommendation for route '%s'", route) 431 | } 432 | } 433 | 434 | noTsrRoutes := [...]string{ 435 | "/", 436 | "/no", 437 | "/no/", 438 | "/_", 439 | "/_/", 440 | "/api/world/abc", 441 | } 442 | for _, route := range noTsrRoutes { 443 | handler, _, _, tsr := tree.getValue(route) 444 | if handler != nil { 445 | t.Fatalf("non-nil handler for No-TSR route '%s", route) 446 | } else if tsr { 447 | t.Errorf("expected no TSR recommendation for route '%s'", route) 448 | } 449 | } 450 | } 451 | 452 | func TestTreeFindCaseInsensitivePath(t *testing.T) { 453 | tree := &node{} 454 | 455 | routes := [...]string{ 456 | "/hi", 457 | "/b/", 458 | "/ABC/", 459 | "/search/:query", 460 | "/cmd/:tool/", 461 | "/src/*filepath", 462 | "/x", 463 | "/x/y", 464 | "/y/", 465 | "/y/z", 466 | "/0/:id", 467 | "/0/:id/1", 468 | "/1/:id/", 469 | "/1/:id/2", 470 | "/aa", 471 | "/a/", 472 | "/doc", 473 | "/doc/go_faq.html", 474 | "/doc/go1.html", 475 | "/doc/go/away", 476 | "/no/a", 477 | "/no/b", 478 | } 479 | 480 | for _, route := range routes { 481 | recv := catchPanic(func() { 482 | tree.addRoute(route, "", fakeHandler(route)) 483 | }) 484 | if recv != nil { 485 | t.Fatalf("panic inserting route '%s': %v", route, recv) 486 | } 487 | } 488 | 489 | // Check out == in for all registered routes 490 | // With fixTrailingSlash = true 491 | for _, route := range routes { 492 | out, found := tree.findCaseInsensitivePath(route, true) 493 | if !found { 494 | t.Errorf("Route '%s' not found!", route) 495 | } else if string(out) != route { 496 | t.Errorf("Wrong result for route '%s': %s", route, string(out)) 497 | } 498 | } 499 | // With fixTrailingSlash = false 500 | for _, route := range routes { 501 | out, found := tree.findCaseInsensitivePath(route, false) 502 | if !found { 503 | t.Errorf("Route '%s' not found!", route) 504 | } else if string(out) != route { 505 | t.Errorf("Wrong result for route '%s': %s", route, string(out)) 506 | } 507 | } 508 | 509 | tests := []struct { 510 | in string 511 | out string 512 | found bool 513 | slash bool 514 | }{ 515 | {"/HI", "/hi", true, false}, 516 | {"/HI/", "/hi", true, true}, 517 | {"/B", "/b/", true, true}, 518 | {"/B/", "/b/", true, false}, 519 | {"/abc", "/ABC/", true, true}, 520 | {"/abc/", "/ABC/", true, false}, 521 | {"/aBc", "/ABC/", true, true}, 522 | {"/aBc/", "/ABC/", true, false}, 523 | {"/abC", "/ABC/", true, true}, 524 | {"/abC/", "/ABC/", true, false}, 525 | {"/SEARCH/QUERY", "/search/QUERY", true, false}, 526 | {"/SEARCH/QUERY/", "/search/QUERY", true, true}, 527 | {"/CMD/TOOL/", "/cmd/TOOL/", true, false}, 528 | {"/CMD/TOOL", "/cmd/TOOL/", true, true}, 529 | {"/SRC/FILE/PATH", "/src/FILE/PATH", true, false}, 530 | {"/x/Y", "/x/y", true, false}, 531 | {"/x/Y/", "/x/y", true, true}, 532 | {"/X/y", "/x/y", true, false}, 533 | {"/X/y/", "/x/y", true, true}, 534 | {"/X/Y", "/x/y", true, false}, 535 | {"/X/Y/", "/x/y", true, true}, 536 | {"/Y/", "/y/", true, false}, 537 | {"/Y", "/y/", true, true}, 538 | {"/Y/z", "/y/z", true, false}, 539 | {"/Y/z/", "/y/z", true, true}, 540 | {"/Y/Z", "/y/z", true, false}, 541 | {"/Y/Z/", "/y/z", true, true}, 542 | {"/y/Z", "/y/z", true, false}, 543 | {"/y/Z/", "/y/z", true, true}, 544 | {"/Aa", "/aa", true, false}, 545 | {"/Aa/", "/aa", true, true}, 546 | {"/AA", "/aa", true, false}, 547 | {"/AA/", "/aa", true, true}, 548 | {"/aA", "/aa", true, false}, 549 | {"/aA/", "/aa", true, true}, 550 | {"/A/", "/a/", true, false}, 551 | {"/A", "/a/", true, true}, 552 | {"/DOC", "/doc", true, false}, 553 | {"/DOC/", "/doc", true, true}, 554 | {"/NO", "", false, true}, 555 | {"/DOC/GO", "", false, true}, 556 | } 557 | // With fixTrailingSlash = true 558 | for _, test := range tests { 559 | out, found := tree.findCaseInsensitivePath(test.in, true) 560 | if found != test.found || (found && (string(out) != test.out)) { 561 | t.Errorf("Wrong result for '%s': got %s, %t; want %s, %t", 562 | test.in, string(out), found, test.out, test.found) 563 | return 564 | } 565 | } 566 | // With fixTrailingSlash = false 567 | for _, test := range tests { 568 | out, found := tree.findCaseInsensitivePath(test.in, false) 569 | if test.slash { 570 | if found { // test needs a trailingSlash fix. It must not be found! 571 | t.Errorf("Found without fixTrailingSlash: %s; got %s", test.in, string(out)) 572 | } 573 | } else { 574 | if found != test.found || (found && (string(out) != test.out)) { 575 | t.Errorf("Wrong result for '%s': got %s, %t; want %s, %t", 576 | test.in, string(out), found, test.out, test.found) 577 | return 578 | } 579 | } 580 | } 581 | } 582 | 583 | func TestTreeInvalidNodeType(t *testing.T) { 584 | tree := &node{} 585 | tree.addRoute("/", "", fakeHandler("/")) 586 | tree.addRoute("/:page", "", fakeHandler("/:page")) 587 | 588 | // set invalid node type 589 | tree.children[0].nType = 42 590 | 591 | // normal lookup 592 | recv := catchPanic(func() { 593 | tree.getValue("/test") 594 | }) 595 | if rs, ok := recv.(string); !ok || rs != "Invalid node type" { 596 | t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv) 597 | } 598 | 599 | // case-insensitive lookup 600 | recv = catchPanic(func() { 601 | tree.findCaseInsensitivePath("/test", true) 602 | }) 603 | if rs, ok := recv.(string); !ok || rs != "Invalid node type" { 604 | t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv) 605 | } 606 | } 607 | --------------------------------------------------------------------------------