├── .gitignore ├── .travis.yml ├── CODEOWNERS ├── LICENSE ├── Makefile ├── README.md ├── codecov.yml ├── debug.test ├── doc.go ├── example └── rye_example.go ├── example_overall_test.go ├── fakes └── statsdfakes │ └── fake_statter.go ├── images ├── Rye Logo.sketch ├── rye-gopher.svg ├── rye_logo.svg └── rye_logo_invision_pink.png ├── in-repo.yaml ├── middleware_accesstoken.go ├── middleware_accesstoken_test.go ├── middleware_auth.go ├── middleware_auth_test.go ├── middleware_cidr.go ├── middleware_cidr_test.go ├── middleware_cors.go ├── middleware_cors_test.go ├── middleware_getheader.go ├── middleware_getheader_test.go ├── middleware_jwt.go ├── middleware_jwt_test.go ├── middleware_routelogger.go ├── middleware_routelogger_test.go ├── middleware_static_file.go ├── middleware_static_file_test.go ├── middleware_static_filesystem.go ├── middleware_static_filesystem_test.go ├── rye.go ├── rye_suite_test.go ├── rye_test.go └── static-examples ├── dist ├── index.html ├── styles │ ├── index.css │ └── test.css └── test.html └── static_example.go /.gitignore: -------------------------------------------------------------------------------- 1 | launch.sh 2 | env-api 3 | main 4 | .tmp 5 | codeship.aes 6 | .vscode/ 7 | .idea/ 8 | debug 9 | debug.test 10 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.7 5 | 6 | before_install: 7 | - go get -t -v ./... 8 | 9 | script: 10 | - make test/codecov 11 | 12 | after_success: 13 | - bash <(curl -s https://codecov.io/bash) -t 646f78c9-a2ed-4fec-8779-c94358d08b24 14 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @InVisionApp/architecture -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 InVision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Some things this makefile could make use of: 2 | # 3 | # - test coverage target(s) 4 | # - profiler target(s) 5 | # 6 | 7 | BIN = rye 8 | OUTPUT_DIR = build 9 | TMP_DIR := .tmp 10 | RELEASE_VER := $(shell git rev-parse --short HEAD) 11 | NAME = default 12 | COVERMODE = atomic 13 | 14 | TEST_PACKAGES := $(shell go list ./... | grep -v vendor | grep -v fakes) 15 | 16 | .PHONY: help 17 | .DEFAULT_GOAL := help 18 | 19 | all: test build docker ## Test, build and docker image build 20 | 21 | setup: installtools ## Install and setup tools and local DB 22 | 23 | # Under the hood, `go test -tags ...` also runs the "default" (unit) test case 24 | # in addition to the specified tags 25 | test: installdeps test/integration ## Perform both unit and integration tests 26 | 27 | testv: installdeps testv/integration ## Perform both unit and integration tests (with verbose flags) 28 | 29 | test/unit: ## Perform unit tests 30 | go test -cover $(TEST_PACKAGES) 31 | 32 | testv/unit: ## Perform unit tests (with verbose flag) 33 | go test -v -cover $(TEST_PACKAGES) 34 | 35 | test/integration: ## Perform integration tests 36 | go test -cover -tags integration $(TEST_PACKAGES) 37 | 38 | testv/integration: ## Perform integration tests 39 | go test -v -cover -tags integration $(TEST_PACKAGES) 40 | 41 | test/race: ## Perform unit tests and enable the race detector 42 | go test -race -cover $(TEST_PACKAGES) 43 | 44 | test/cover: ## Run all tests + open coverage report for all packages 45 | echo 'mode: $(COVERMODE)' > .coverage 46 | for PKG in $(TEST_PACKAGES); do \ 47 | go test -coverprofile=.coverage.tmp -tags "integration" $$PKG; \ 48 | grep -v -E '^mode:' .coverage.tmp >> .coverage; \ 49 | done 50 | go tool cover -html=.coverage 51 | $(RM) .coverage .coverage.tmp 52 | 53 | test/codecov: ## Run all tests + open coverage report for all packages 54 | for PKG in $(TEST_PACKAGES); do \ 55 | go test -covermode=$(COVERMODE) -coverprofile=profile.out $$PKG; \ 56 | if [ -f profile.out ]; then\ 57 | cat profile.out >> coverage.txt;\ 58 | rm profile.out;\ 59 | fi;\ 60 | done 61 | $(RM) profile.out 62 | 63 | installdeps: ## Install needed dependencies for various middlewares 64 | go get github.com/dgrijalva/jwt-go 65 | 66 | installtools: ## Install development related tools 67 | go get github.com/kardianos/govendor 68 | go get github.com/maxbrunsfeld/counterfeiter 69 | 70 | generate: ## Run generate for non-vendor packages only 71 | go list ./... | xargs go generate 72 | 73 | help: ## Display this help message 74 | @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_\/-]+:.*?## / {printf "\033[34m%-30s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) | \ 75 | sort | \ 76 | grep -v '#' 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | | :warning: This project is no longer actively supported. 2 | | --- 3 | 4 | [![LICENSE](https://img.shields.io/badge/license-MIT-orange.svg)](LICENSE) 5 | [![Golang](https://img.shields.io/badge/Golang-v1.7-blue.svg)](https://golang.org/dl/) 6 | [![Godocs](https://img.shields.io/badge/golang-documentation-blue.svg)](https://godoc.org/github.com/InVisionApp/rye) 7 | [![Go Report Card](https://goreportcard.com/badge/github.com/InVisionApp/rye)](https://goreportcard.com/report/github.com/InVisionApp/rye) 8 | [![Travis Build Status](https://travis-ci.com/InVisionApp/rye.svg?token=qgpSBc6cjHgbnjqC45af&branch=master)](https://travis-ci.com/InVisionApp/rye) 9 | [![codecov](https://codecov.io/gh/InVisionApp/rye/branch/master/graph/badge.svg?token=hhqA1l88kx)](https://codecov.io/gh/InVisionApp/rye) 10 | 11 | 12 | 13 | 14 | # rye 15 | A simple library to support http services. Currently, **rye** provides a middleware handler which can be used to chain http handlers together while providing statsd metrics for use with DataDog or other logging aggregators. In addition, **rye** comes with various pre-built middleware handlers for enabling functionality such as CORS and rate/CIDR limiting. 16 | 17 | ## Setup 18 | In order to use **rye**, you should vendor it and the **statsd** client within your project. 19 | 20 | ```sh 21 | govendor fetch github.com/InVisionApp/rye 22 | govendor fetch github.com/cactus/go-statsd-client/statsd 23 | ``` 24 | 25 | ## Why another middleware lib? 26 | 27 | * `rye` is *tiny* - the core lib is ~143 lines of code (including comments)! 28 | * Each middleware gets statsd metrics tracking for free including an overall error counter 29 | * We wanted to have an easy way to say “run these two middlewares on this endpoint, but only one middleware on this endpoint” 30 | * Of course, this is doable with negroni and gorilla-mux, but you’d have to use a subrouter with gorilla, which tends to end up in more code 31 | * Bundled helper methods for standardising JSON response messages 32 | * Unified way for handlers and middlewares to return more detailed responses via the `rye.Response` struct (if they chose to do so). 33 | * Pre-built middlewares for things like CORS support 34 | 35 | ## Example 36 | 37 | You can run an example locally to give it a try. The code for the example is [here](example/rye_example.go)! 38 | 39 | ```sh 40 | cd example 41 | go run rye_example.go 42 | ``` 43 | 44 | ## Writing custom middleware handlers 45 | 46 | Begin by importing the required libraries: 47 | 48 | ```go 49 | import ( 50 | "github.com/cactus/go-statsd-client/statsd" 51 | "github.com/InVisionApp/rye" 52 | ) 53 | ``` 54 | 55 | Create a statsd client (if desired) and create a rye Config in order to pass in optional dependencies: 56 | 57 | ```go 58 | config := &rye.Config{ 59 | Statter: statsdClient, 60 | StatRate: DEFAULT_STATSD_RATE, 61 | } 62 | ``` 63 | 64 | Create a middleware handler. The purpose of the Handler is to keep Config and to provide an interface for chaining http handlers. 65 | 66 | ```go 67 | middlewareHandler := rye.NewMWHandler(config) 68 | ``` 69 | 70 | Set up any global handlers by using the `Use()` method. Global handlers get pre-pended to the list of your handlers for EVERY endpoint. 71 | They are bound to the MWHandler struct. Therefore, you could set up multiple MWHandler structs if you want to have different collections 72 | of global handlers. 73 | 74 | ```go 75 | middlewareHandler.Use(middleware_routelogger) 76 | ``` 77 | 78 | Build your http handlers using the Handler type from the **rye** package. 79 | 80 | ```go 81 | type Handler func(w http.ResponseWriter, r *http.Request) *rye.Response 82 | ``` 83 | 84 | Here are some example (custom) handlers: 85 | 86 | ```go 87 | func homeHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 88 | fmt.Fprint(rw, "Refer to README.md for auth-api API usage") 89 | return nil 90 | } 91 | 92 | func middlewareFirstHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 93 | fmt.Fprint(rw, "This handler fires first.") 94 | return nil 95 | } 96 | 97 | func errorHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 98 | return &rye.Response{ 99 | StatusCode: http.StatusInternalServerError, 100 | Err: errors.New(message), 101 | } 102 | } 103 | ``` 104 | 105 | Finally, to setup your handlers in your API (Example shown using [Gorilla](https://github.com/gorilla/mux)): 106 | ```go 107 | routes := mux.NewRouter().StrictSlash(true) 108 | 109 | routes.Handle("/", middlewareHandler.Handle([]rye.Handler{ 110 | a.middlewareFirstHandler, 111 | a.homeHandler, 112 | })).Methods("GET") 113 | 114 | log.Infof("API server listening on %v", ListenAddress) 115 | 116 | srv := &http.Server{ 117 | Addr: ListenAddress, 118 | Handler: routes, 119 | } 120 | 121 | srv.ListenAndServe() 122 | ``` 123 | 124 | ## Statsd Generated by Rye 125 | 126 | Rye comes with built-in configurable `statsd` statistics that you could record to your favorite monitoring system. To configure that, you'll need to set up a `Statter` based on the `github.com/cactus/go-statsd-client` and set it in your instantiation of `MWHandler` through the `rye.Config`. 127 | 128 | When a middleware is called, it's timing is recorded and a counter is recorded associated directly with the http status code returned during the call. Additionally, an `errors` counter is also sent to the statter which allows you to count any errors that occur with a code equaling or above 500. 129 | 130 | Example: If you have a middleware handler you've created with a method named `loginHandler`, successful calls to that will be recorded to `handlers.loginHandler.2xx`. Additionally you'll receive stats such as `handlers.loginHandler.400` or `handlers.loginHandler.500`. You also will receive an increase in the `errors` count. 131 | 132 | _If you're sending your logs into a system such as DataDog, be aware that your stats from Rye can have prefixes such as `statsd.my-service.my-k8s-cluster.handlers.loginHandler.2xx` or even `statsd.my-service.my-k8s-cluster.errors`. Just keep in mind your stats could end up in the destination sink system with prefixes._ 133 | 134 | ## Using with Golang 1.7 Context 135 | 136 | With Golang 1.7, a new feature has been added that supports a request specific context. This is a great feature that Rye supports out-of-the-box. The tricky part of this is how the context is modified on the request. In Golang, the Context is always available on a Request through `http.Request.Context()`. Great! However, if you want to add key/value pairs to the context, you will have to add the context to the request before it gets passed to the next Middleware. To support this, the `rye.Response` has a property called `Context`. This property takes a properly created context (pulled from the `request.Context()` function. When you return a `rye.Response` which has `Context`, the **rye** library will craft a new Request and make sure that the next middleware receives that request. 137 | 138 | Here's the details of creating a middleware with a proper `Context`. You must first pull from the current request `Context`. In the example below, you see `ctx := r.Context()`. That pulls the current context. Then, you create a NEW context with your additional context key/value. Finally, you return `&rye.Response{Context:ctx}` 139 | 140 | ```go 141 | func addContextVar(rw http.ResponseWriter, r *http.Request) *rye.Response { 142 | // Retrieve the request's context 143 | ctx := r.Context() 144 | 145 | // Create a NEW context 146 | ctx = context.WithValue(ctx,"CONTEXT_KEY","my context value") 147 | 148 | // Return that in the Rye response 149 | // Rye will add it to the Request to 150 | // pass to the next middleware 151 | return &rye.Response{Context:ctx} 152 | } 153 | ``` 154 | Now in a later middleware, you can easily retrieve the value you set! 155 | ```go 156 | func getContextVar(rw http.ResponseWriter, r *http.Request) *rye.Response { 157 | // Retrieving the value is easy! 158 | myVal := r.Context().Value("CONTEXT_KEY") 159 | 160 | // Log it to the server log? 161 | log.Infof("Context Value: %v", myVal) 162 | 163 | return nil 164 | } 165 | ``` 166 | For another simple example, look in the [JWT middleware](middleware_jwt.go) - it adds the JWT into the context for use by other middlewares. It uses the `CONTEXT_JWT` key to push the JWT token into the `Context`. 167 | 168 | ## Using built-in middleware handlers 169 | 170 | Rye comes with various pre-built middleware handlers. Pre-built middlewares source (and docs) can be found in the package dir following the pattern `middleware_*.go`. 171 | 172 | To use them, specify the constructor of the middleware as one of the middleware handlers when you define your routes: 173 | 174 | ```go 175 | // example 176 | routes.Handle("/", middlewareHandler.Handle([]rye.Handler{ 177 | rye.MiddlewareCORS(), // to use the CORS middleware (with defaults) 178 | a.homeHandler, 179 | })).Methods("GET") 180 | 181 | OR 182 | 183 | routes.Handle("/", middlewareHandler.Handle([]rye.Handler{ 184 | rye.NewMiddlewareCORS("*", "GET, POST", "X-Access-Token"), // to use specific config when instantiating the middleware handler 185 | a.homeHandler, 186 | })).Methods("GET") 187 | 188 | ``` 189 | 190 | ## Serving Static Files 191 | 192 | Rye has the ability to add serving static files in the chain. Two handlers 193 | have been provided: `StaticFilesystem` and `StaticFile`. These middlewares 194 | should always be used at the end of the chain. Their configuration is 195 | simply based on an absolute path on the server and possibly a skipped 196 | path prefix. 197 | 198 | The use case here could be a powerful one. Rye allows you to serve a filesystem 199 | just as a whole or a single file. Used together you could facilitate an application 200 | which does both -> fulfilling the capability to provide a single page application. 201 | For example, if you had a webpack application which served static resources and 202 | artifacts, you would use the `StaticFilesystem` to serve those. Then you'd use 203 | `StaticFile` to serve the single page which refers to the single-page application 204 | through `index.html`. 205 | 206 | A full sample is provided in the `static-examples` folder. Here's a snippet from 207 | the example using Gorilla: 208 | 209 | ```go 210 | pwd, err := os.Getwd() 211 | if err != nil { 212 | log.Fatalf("NewStaticFile: Could not get working directory.") 213 | } 214 | 215 | routes.PathPrefix("/dist/").Handler(middlewareHandler.Handle([]rye.Handler{ 216 | rye.MiddlewareRouteLogger(), 217 | rye.NewStaticFilesystem(pwd+"/dist/", "/dist/"), 218 | })) 219 | 220 | routes.PathPrefix("/ui/").Handler(middlewareHandler.Handle([]rye.Handler{ 221 | rye.MiddlewareRouteLogger(), 222 | rye.NewStaticFile(pwd + "/dist/index.html"), 223 | })) 224 | ``` 225 | 226 | ### Middleware list 227 | 228 | | Name | Description | 229 | |----------------------------|---------------------------------------| 230 | | [Access Token](middleware_accesstoken.go) | Provide Access Token validation | 231 | | [CIDR](middleware_cidr.go) | Provide request IP whitelisting | 232 | | [CORS](middleware_cors.go) | Provide CORS functionality for routes | 233 | | [Auth](middleware_auth.go) | Provide Authorization header validation (basic auth, JWT) | 234 | | [Route Logger](middleware_routelogger.go) | Provide basic logging for a specific route | 235 | | [Static File](middleware_static_file.go) | Provides serving a single file | 236 | | [Static Filesystem](middleware_static_filesystem.go) | Provides serving a single file | 237 | 238 | 239 | ### A Note on the JWT Middleware 240 | 241 | The [JWT Middleware](middleware_auth.go) pushes the JWT token onto the Context for use by other middlewares in the chain. This is a convenience that allows any part of your middleware chain quick access to the JWT. Example usage might include a middleware that needs access to your user id or email address stored in the JWT. To access this `Context` variable, the code is very simple: 242 | ```go 243 | func getJWTfromContext(rw http.ResponseWriter, r *http.Request) *rye.Response { 244 | // Retrieving the value is easy! 245 | // Just reference the rye.CONTEXT_JWT const as a key 246 | myVal := r.Context().Value(rye.CONTEXT_JWT) 247 | 248 | // Log it to the server log? 249 | log.Infof("Context Value: %v", myVal) 250 | 251 | return nil 252 | } 253 | ``` 254 | 255 | ## API 256 | 257 | ### Config 258 | This struct is configuration for the MWHandler. It holds references and config to dependencies such as the statsdClient. 259 | ```go 260 | type Config struct { 261 | Statter statsd.Statter 262 | StatRate float32 263 | } 264 | ``` 265 | 266 | ### MWHandler 267 | This struct is the primary handler container. It holds references to the statsd client. 268 | ```go 269 | type MWHandler struct { 270 | Config Config 271 | } 272 | ``` 273 | 274 | #### Constructor 275 | ```go 276 | func NewMWHandler(statter statsd.Statter, statrate float32) *MWHandler 277 | ``` 278 | 279 | #### Use 280 | This method prepends a global handler for every Handle method you call. 281 | Use this multiple times to setup global handlers for every endpoint. 282 | Call `Use()` for each global handler before setting up additional routes. 283 | ```go 284 | func (m *MWHandler) Use(handlers Handler) 285 | ``` 286 | 287 | #### Handle 288 | This method chains middleware handlers in order and returns a complete `http.Handler`. 289 | ```go 290 | func (m *MWHandler) Handle(handlers []Handler) http.Handler 291 | ``` 292 | 293 | ### rye.Response 294 | This struct is utilized by middlewares as a way to share state; ie. a middleware can return a `*rye.Response` as a way to indicate that further middleware execution should stop (without an error) or return a hard error by setting `Err` + `StatusCode` or add to the request `Context` by returning a non-nil `Context`. 295 | ```go 296 | type Response struct { 297 | Err error 298 | StatusCode int 299 | StopExecution bool 300 | Context context.Context 301 | } 302 | ``` 303 | 304 | ### Handler 305 | This type is used to define an http handler that can be chained using the MWHandler.Handle method. The `rye.Response` is from the **rye** package and has facilities to emit StatusCode, bubble up errors and/or stop further middleware execution chain. 306 | ```go 307 | type Handler func(w http.ResponseWriter, r *http.Request) *rye.Response 308 | ``` 309 | 310 | ## Test stuff 311 | All interfacing with the project is done via `make`. Targets exist for all primary tasks such as: 312 | 313 | - Testing: `make test` or `make testv` (for verbosity) 314 | - Generate: `make generate` - this generates based on vendored libraries (from $GOPATH) 315 | - All (test, build): `make all` 316 | - .. and a few others. Run `make help` to see all available targets. 317 | - You can also test the project in Docker (and Codeship) by running `jet steps` 318 | 319 | ## Contributing 320 | Fork the repository, write a PR and we'll consider it! 321 | 322 | ## Special Thanks 323 | Thanks go out to Justin Reyna (InVisionApp.com) for the awesome logo! 324 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | notify: 3 | require_ci_to_pass: true 4 | comment: 5 | layout: "header, diff, tree" 6 | require_changes: false 7 | branches: null 8 | behavior: default 9 | flags: null 10 | paths: null 11 | coverage: 12 | precision: 2 13 | range: 14 | - 70.0 15 | - 100.0 16 | round: down 17 | status: 18 | changes: false 19 | patch: true 20 | project: true 21 | -------------------------------------------------------------------------------- /debug.test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InVisionApp/rye/c260759a2358155e882164fa8cc9f3b074dcef4a/debug.test -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package Rye is a simple library to support http services. 3 | Rye provides a middleware handler which can be used to chain http handlers together while providing 4 | simple statsd metrics for use with a monitoring solution such as DataDog or other logging aggregators. 5 | Rye also provides some additional middleware handlers that are entirely optional but easily consumed using Rye. 6 | 7 | Setup 8 | 9 | In order to use rye, you should vendor it and the statsd client within your project. 10 | 11 | govendor fetch github.com/cactus/go-statsd-client/statsd 12 | 13 | # Rye is a private repo, so we should clone it first 14 | mkdir -p $GOPATH/github.com/InVisionApp 15 | cd $GOPATH/github.com/InVisionApp 16 | git clone git@github.com:InVisionApp/rye.git 17 | 18 | govendor add github.com/InVisionApp/rye 19 | 20 | Writing custom middleware handlers 21 | 22 | Begin by importing the required libraries: 23 | import ( 24 | "github.com/cactus/go-statsd-client/statsd" 25 | "github.com/InVisionApp/rye" 26 | ) 27 | Create a statsd client (if desired) and create a rye Config in order to pass in optional dependencies: 28 | config := &rye.Config{ 29 | Statter: statsdClient, 30 | StatRate: DEFAULT_STATSD_RATE, 31 | } 32 | Create a middleware handler. The purpose of the Handler is to keep Config and to provide an interface for chaining http handlers. 33 | middlewareHandler := rye.NewMWHandler(config) 34 | Build your http handlers using the Handler type from the **rye** package. 35 | type Handler func(w http.ResponseWriter, r *http.Request) *rye.Response 36 | Here are some example (custom) handlers: 37 | func homeHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 38 | fmt.Fprint(rw, "Refer to README.md for auth-api API usage") 39 | return nil 40 | } 41 | 42 | func middlewareFirstHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 43 | fmt.Fprint(rw, "This handler fires first.") 44 | return nil 45 | } 46 | 47 | func errorHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 48 | return &rye.Response { 49 | StatusCode: http.StatusInternalServerError, 50 | Err: errors.New(message), 51 | } 52 | } 53 | Finally, to setup your handlers in your API 54 | routes := mux.NewRouter().StrictSlash(true) 55 | 56 | routes.Handle("/", middlewareHandler.Handle( 57 | []rye.Handler{ 58 | a.middlewareFirstHandler, 59 | a.homeHandler, 60 | })).Methods("GET") 61 | 62 | log.Infof("API server listening on %v", ListenAddress) 63 | 64 | srv := &http.Server{ 65 | Addr: ListenAddress, 66 | Handler: routes, 67 | } 68 | 69 | srv.ListenAndServe() 70 | 71 | Statsd Generated by Rye 72 | 73 | Rye comes with built-in configurable `statsd` statistics that you could record to your favorite monitoring system. To configure that, you'll need to set up a `Statter` based on the `github.com/cactus/go-statsd-client` and set it in your instantiation of `MWHandler` through the `rye.Config`. 74 | 75 | When a middleware is called, it's timing is recorded and a counter is recorded associated directly with the http status code returned during the call. Additionally, an `errors` counter is also sent to the statter which allows you to count any errors that occur with a code equaling or above 500. 76 | 77 | Example: If you have a middleware handler you've created with a method named `loginHandler`, successful calls to that will be recorded to `handlers.loginHandler.2xx`. Additionally you'll receive stats such as `handlers.loginHandler.400` or `handlers.loginHandler.500`. You also will receive an increase in the `errors` count. 78 | 79 | If you're sending your logs into a system such as DataDog, be aware that your stats from Rye can have prefixes such as `statsd.my-service.my-k8s-cluster.handlers.loginHandler.2xx` or even `statsd.my-service.my-k8s-cluster.errors`. Just keep in mind your stats could end up in the destination sink system with prefixes. 80 | 81 | 82 | Using With Golang Context 83 | 84 | With Golang 1.7, a new feature has been added that supports a request specific context. 85 | This is a great feature that Rye supports out-of-the-box. The tricky part of this is how the context 86 | is modified on the request. In Golang, the Context is always available on a 87 | Request through `http.Request.Context()`. Great! However, if you want to add key/value pairs to the 88 | context, you will have to add the context to the request before it gets passed to the next Middleware. 89 | To support this, the `rye.Response` has a property called `Context`. This property takes a properly 90 | created context (pulled from the `request.Context()` function. When you return a `rye.Response` 91 | which has `Context`, the **rye** library will craft a new Request and make sure that the next middleware 92 | receives that request. 93 | 94 | Here's the details of creating a middleware with a proper `Context`. You must first pull from the 95 | current request `Context`. In the example below, you see `ctx := r.Context()`. That pulls the current 96 | context. Then, you create a NEW context with your additional context key/value. Finally, you 97 | return `&rye.Response{Context:ctx}` 98 | 99 | func addContextVar(rw http.ResponseWriter, r *http.Request) *rye.Response { 100 | // Retrieve the request's context 101 | ctx := r.Context() 102 | 103 | // Create a NEW context 104 | ctx = context.WithValue(ctx,"CONTEXT_KEY","my context value") 105 | 106 | // Return that in the Rye response 107 | // Rye will add it to the Request to 108 | // pass to the next middleware 109 | return &rye.Response{Context:ctx} 110 | } 111 | 112 | Now in a later middleware, you can easily retrieve the value you set! 113 | 114 | func getContextVar(rw http.ResponseWriter, r *http.Request) *rye.Response { 115 | // Retrieving the value is easy! 116 | myVal := r.Context().Value("CONTEXT_KEY") 117 | 118 | // Log it to the server log? 119 | log.Infof("Context Value: %v", myVal) 120 | 121 | return nil 122 | } 123 | 124 | For another simple example, look in the JWT middleware - it adds the JWT into the 125 | context for use by other middlewares. It uses the `CONTEXT_JWT` key to push the 126 | JWT token into the `Context`. 127 | 128 | 129 | Using built-in middleware handlers 130 | 131 | Rye comes with various pre-built middleware handlers. Pre-built middlewares source (and docs) can be found in the package dir following the pattern `middleware_*.go`. 132 | 133 | To use them, specify the constructor of the middleware as one of the middleware handlers when you define your routes: 134 | // example 135 | routes.Handle("/", middlewareHandler.Handle( 136 | []rye.Handler{ 137 | rye.MiddlewareCORS(), // to use the CORS middleware (with defaults) 138 | a.homeHandler, 139 | })).Methods("GET") 140 | OR 141 | routes.Handle("/", middlewareHandler.Handle( 142 | []rye.Handler{ 143 | rye.NewMiddlewareCORS("*", "GET, POST", "X-Access-Token"), // to use specific config when instantiating the middleware handler 144 | a.homeHandler, 145 | })).Methods("GET") 146 | 147 | A Note on the JWT Middleware 148 | 149 | The JWT Middleware pushes the JWT token onto the Context for use by other middlewares in the chain. 150 | This is a convenience that allows any part of your middleware chain quick access to the JWT. 151 | Example usage might include a middleware that needs access to your user id or email address stored in the JWT. 152 | To access this `Context` variable, the code is very simple: 153 | 154 | func getJWTfromContext(rw http.ResponseWriter, r *http.Request) *rye.Response { 155 | // Retrieving the value is easy! 156 | // Just reference the rye.CONTEXT_JWT const as a key 157 | myVal := r.Context().Value(rye.CONTEXT_JWT) 158 | 159 | // Log it to the server log? 160 | log.Infof("Context Value: %v", myVal) 161 | 162 | return nil 163 | } 164 | */ 165 | package rye 166 | -------------------------------------------------------------------------------- /example/rye_example.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | 9 | "github.com/InVisionApp/rye" 10 | "github.com/cactus/go-statsd-client/statsd" 11 | "github.com/gorilla/mux" 12 | log "github.com/sirupsen/logrus" 13 | ) 14 | 15 | func main() { 16 | statsdClient, err := statsd.NewBufferedClient("localhost:12345", "my_service", 1.0, 0) 17 | if err != nil { 18 | log.Fatalf("Unable to instantiate statsd client: %v", err.Error()) 19 | } 20 | 21 | config := rye.Config{ 22 | Statter: statsdClient, 23 | StatRate: 1.0, 24 | } 25 | 26 | middlewareHandler := rye.NewMWHandler(config) 27 | 28 | middlewareHandler.Use(beforeAllHandler) 29 | 30 | routes := mux.NewRouter().StrictSlash(true) 31 | 32 | routes.Handle("/", middlewareHandler.Handle([]rye.Handler{ 33 | middlewareFirstHandler, 34 | homeHandler, 35 | })).Methods("GET") 36 | 37 | // If you perform a `curl -i http://localhost:8181/cors -H "Origin: *.foo.com"` 38 | // you will see that the CORS middleware is adding required headers 39 | routes.Handle("/cors", middlewareHandler.Handle([]rye.Handler{ 40 | rye.MiddlewareCORS(), 41 | homeHandler, 42 | })).Methods("GET", "OPTIONS") 43 | 44 | // If you perform an `curl -i http://localhost:8181/jwt \ 45 | // -H "Authorization: Basic dXNlcjE6cGFzczEK" 46 | // you will see that we are allowed through to the handler, if the header is changed, you will get a 401 47 | routes.Handle("/basic-auth", middlewareHandler.Handle([]rye.Handler{ 48 | rye.NewMiddlewareAuth(rye.NewBasicAuthFunc(map[string]string{ 49 | "user1": "pass1", 50 | "user2": "pass2", 51 | })), 52 | getJwtFromContextHandler, 53 | })).Methods("GET") 54 | 55 | // If you perform an `curl -i http://localhost:8181/jwt \ 56 | // -H "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" 57 | // you will see that we are allowed through to the handler, if the sample token is changed, we will get a 401 58 | routes.Handle("/jwt", middlewareHandler.Handle([]rye.Handler{ 59 | rye.NewMiddlewareAuth(rye.NewJWTAuthFunc("secret")), 60 | getJwtFromContextHandler, 61 | })).Methods("GET") 62 | 63 | routes.Handle("/error", middlewareHandler.Handle([]rye.Handler{ 64 | middlewareFirstHandler, 65 | errorHandler, 66 | homeHandler, 67 | })).Methods("GET") 68 | 69 | // In order to pass in a context variable, this set of 70 | // handlers works with "ctx" on the query string 71 | routes.Handle("/context", middlewareHandler.Handle( 72 | []rye.Handler{ 73 | stashContextHandler, 74 | logContextHandler, 75 | })).Methods("GET") 76 | 77 | log.Infof("API server listening on %v", "localhost:8181") 78 | 79 | srv := &http.Server{ 80 | Addr: "localhost:8181", 81 | Handler: routes, 82 | } 83 | 84 | srv.ListenAndServe() 85 | } 86 | 87 | func beforeAllHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 88 | log.Infof("This handler is called before every endpoint: %+v", r) 89 | return nil 90 | } 91 | 92 | func homeHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 93 | log.Infof("Home handler has fired!") 94 | 95 | fmt.Fprint(rw, "This is the home handler") 96 | return nil 97 | } 98 | 99 | func middlewareFirstHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 100 | log.Infof("Middleware handler has fired!") 101 | return nil 102 | } 103 | 104 | func errorHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 105 | log.Infof("Error handler has fired!") 106 | 107 | message := "This is the error handler" 108 | 109 | return &rye.Response{ 110 | StatusCode: http.StatusInternalServerError, 111 | Err: errors.New(message), 112 | } 113 | } 114 | 115 | func stashContextHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 116 | log.Infof("Stash Context handler has fired!") 117 | 118 | // Retrieve the request's context 119 | ctx := r.Context() 120 | 121 | // A query string value to add to the context 122 | toContext := r.URL.Query().Get("ctx") 123 | 124 | if toContext != "" { 125 | log.Infof("Adding `query-string-ctx` to request.Context(). Val: %v", toContext) 126 | } else { 127 | log.Infof("Adding default `query-string-ctx` value to context") 128 | toContext = "No value added. Add querystring param `ctx` with a value to get it mirrored through context." 129 | } 130 | 131 | // Create a NEW context 132 | ctx = context.WithValue(ctx, "query-string-ctx", toContext) 133 | 134 | // Return that in the Rye response 135 | // Rye will add it to the Request to 136 | // pass to the next middleware 137 | return &rye.Response{Context: ctx} 138 | } 139 | 140 | func logContextHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 141 | log.Infof("Log Context handler has fired!") 142 | 143 | // Retrieving a context value is EASY in subsequent middlewares 144 | fromContext := r.Context().Value("query-string-ctx") 145 | 146 | // Reflect that on the http response 147 | fmt.Fprintf(rw, "Here's the `ctx` query string value you passed. Pulled from context: %v", fromContext) 148 | return nil 149 | } 150 | 151 | // This handler pulls the JWT from the Context and echoes it through the request 152 | func getJwtFromContextHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 153 | log.Infof("Log Context handler has fired!") 154 | 155 | jwt := r.Context().Value(rye.CONTEXT_JWT) 156 | if jwt != nil { 157 | fmt.Fprintf(rw, "JWT found in Context: %v", jwt) 158 | } 159 | return nil 160 | } 161 | -------------------------------------------------------------------------------- /example_overall_test.go: -------------------------------------------------------------------------------- 1 | package rye_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | 9 | "github.com/InVisionApp/rye" 10 | log "github.com/sirupsen/logrus" 11 | "github.com/cactus/go-statsd-client/statsd" 12 | "github.com/gorilla/mux" 13 | ) 14 | 15 | func Example_basic() { 16 | statsdClient, err := statsd.NewBufferedClient("localhost:12345", "my_service", 1.0, 0) 17 | if err != nil { 18 | log.Fatalf("Unable to instantiate statsd client: %v", err.Error()) 19 | } 20 | 21 | config := rye.Config{ 22 | Statter: statsdClient, 23 | StatRate: 1.0, 24 | } 25 | 26 | middlewareHandler := rye.NewMWHandler(config) 27 | 28 | middlewareHandler.Use(beforeAllHandler) 29 | 30 | routes := mux.NewRouter().StrictSlash(true) 31 | 32 | routes.Handle("/", middlewareHandler.Handle([]rye.Handler{ 33 | middlewareFirstHandler, 34 | homeHandler, 35 | })).Methods("GET") 36 | 37 | // If you perform a `curl -i http://localhost:8181/cors -H "Origin: *.foo.com"` 38 | // you will see that the CORS middleware is adding required headers 39 | routes.Handle("/cors", middlewareHandler.Handle([]rye.Handler{ 40 | rye.MiddlewareCORS(), 41 | homeHandler, 42 | })).Methods("GET", "OPTIONS") 43 | 44 | // If you perform an `curl -i http://localhost:8181/jwt \ 45 | // -H "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" 46 | // you will see that we are allowed through to the handler, if the sample token is changed, we will get a 401 47 | routes.Handle("/jwt", middlewareHandler.Handle([]rye.Handler{ 48 | rye.NewMiddlewareJWT("secret"), 49 | getJwtFromContextHandler, 50 | })).Methods("GET") 51 | 52 | routes.Handle("/error", middlewareHandler.Handle([]rye.Handler{ 53 | middlewareFirstHandler, 54 | errorHandler, 55 | homeHandler, 56 | })).Methods("GET") 57 | 58 | // In order to pass in a context variable, this set of 59 | // handlers works with "ctx" on the query string 60 | routes.Handle("/context", middlewareHandler.Handle( 61 | []rye.Handler{ 62 | stashContextHandler, 63 | logContextHandler, 64 | })).Methods("GET") 65 | 66 | log.Infof("API server listening on %v", "localhost:8181") 67 | 68 | srv := &http.Server{ 69 | Addr: "localhost:8181", 70 | Handler: routes, 71 | } 72 | 73 | srv.ListenAndServe() 74 | } 75 | 76 | func beforeAllHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 77 | log.Infof("This handler is called before every endpoint: %+v", r) 78 | return nil 79 | } 80 | 81 | func homeHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 82 | log.Infof("Home handler has fired!") 83 | 84 | fmt.Fprint(rw, "This is the home handler") 85 | return nil 86 | } 87 | 88 | func middlewareFirstHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 89 | log.Infof("Middleware handler has fired!") 90 | return nil 91 | } 92 | 93 | func errorHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 94 | log.Infof("Error handler has fired!") 95 | 96 | message := "This is the error handler" 97 | 98 | return &rye.Response{ 99 | StatusCode: http.StatusInternalServerError, 100 | Err: errors.New(message), 101 | } 102 | } 103 | 104 | func stashContextHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 105 | log.Infof("Stash Context handler has fired!") 106 | 107 | // Retrieve the request's context 108 | ctx := r.Context() 109 | 110 | // A query string value to add to the context 111 | toContext := r.URL.Query().Get("ctx") 112 | 113 | if toContext != "" { 114 | log.Infof("Adding `query-string-ctx` to request.Context(). Val: %v", toContext) 115 | } else { 116 | log.Infof("Adding default `query-string-ctx` value to context") 117 | toContext = "No value added. Add querystring param `ctx` with a value to get it mirrored through context." 118 | } 119 | 120 | // Create a NEW context 121 | ctx = context.WithValue(ctx, "query-string-ctx", toContext) 122 | 123 | // Return that in the Rye response 124 | // Rye will add it to the Request to 125 | // pass to the next middleware 126 | return &rye.Response{Context: ctx} 127 | } 128 | 129 | func logContextHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 130 | log.Infof("Log Context handler has fired!") 131 | 132 | // Retrieving a context value is EASY in subsequent middlewares 133 | fromContext := r.Context().Value("query-string-ctx") 134 | 135 | // Reflect that on the http response 136 | fmt.Fprintf(rw, "Here's the `ctx` query string value you passed. Pulled from context: %v", fromContext) 137 | return nil 138 | } 139 | 140 | // This handler pulls the JWT from the Context and echoes it through the request 141 | func getJwtFromContextHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 142 | log.Infof("Log Context handler has fired!") 143 | 144 | jwt := r.Context().Value(rye.CONTEXT_JWT) 145 | if jwt != nil { 146 | fmt.Fprintf(rw, "JWT found in Context: %v", jwt) 147 | } 148 | return nil 149 | } 150 | -------------------------------------------------------------------------------- /fakes/statsdfakes/fake_statter.go: -------------------------------------------------------------------------------- 1 | // This file was generated by counterfeiter 2 | package statsdfakes 3 | 4 | import ( 5 | "sync" 6 | "time" 7 | 8 | "github.com/cactus/go-statsd-client/statsd" 9 | ) 10 | 11 | type FakeStatter struct { 12 | IncStub func(string, int64, float32) error 13 | incMutex sync.RWMutex 14 | incArgsForCall []struct { 15 | arg1 string 16 | arg2 int64 17 | arg3 float32 18 | } 19 | incReturns struct { 20 | result1 error 21 | } 22 | DecStub func(string, int64, float32) error 23 | decMutex sync.RWMutex 24 | decArgsForCall []struct { 25 | arg1 string 26 | arg2 int64 27 | arg3 float32 28 | } 29 | decReturns struct { 30 | result1 error 31 | } 32 | GaugeStub func(string, int64, float32) error 33 | gaugeMutex sync.RWMutex 34 | gaugeArgsForCall []struct { 35 | arg1 string 36 | arg2 int64 37 | arg3 float32 38 | } 39 | gaugeReturns struct { 40 | result1 error 41 | } 42 | GaugeDeltaStub func(string, int64, float32) error 43 | gaugeDeltaMutex sync.RWMutex 44 | gaugeDeltaArgsForCall []struct { 45 | arg1 string 46 | arg2 int64 47 | arg3 float32 48 | } 49 | gaugeDeltaReturns struct { 50 | result1 error 51 | } 52 | TimingStub func(string, int64, float32) error 53 | timingMutex sync.RWMutex 54 | timingArgsForCall []struct { 55 | arg1 string 56 | arg2 int64 57 | arg3 float32 58 | } 59 | timingReturns struct { 60 | result1 error 61 | } 62 | TimingDurationStub func(string, time.Duration, float32) error 63 | timingDurationMutex sync.RWMutex 64 | timingDurationArgsForCall []struct { 65 | arg1 string 66 | arg2 time.Duration 67 | arg3 float32 68 | } 69 | timingDurationReturns struct { 70 | result1 error 71 | } 72 | SetStub func(string, string, float32) error 73 | setMutex sync.RWMutex 74 | setArgsForCall []struct { 75 | arg1 string 76 | arg2 string 77 | arg3 float32 78 | } 79 | setReturns struct { 80 | result1 error 81 | } 82 | SetIntStub func(string, int64, float32) error 83 | setIntMutex sync.RWMutex 84 | setIntArgsForCall []struct { 85 | arg1 string 86 | arg2 int64 87 | arg3 float32 88 | } 89 | setIntReturns struct { 90 | result1 error 91 | } 92 | RawStub func(string, string, float32) error 93 | rawMutex sync.RWMutex 94 | rawArgsForCall []struct { 95 | arg1 string 96 | arg2 string 97 | arg3 float32 98 | } 99 | rawReturns struct { 100 | result1 error 101 | } 102 | NewSubStatterStub func(string) statsd.SubStatter 103 | newSubStatterMutex sync.RWMutex 104 | newSubStatterArgsForCall []struct { 105 | arg1 string 106 | } 107 | newSubStatterReturns struct { 108 | result1 statsd.SubStatter 109 | } 110 | SetPrefixStub func(string) 111 | setPrefixMutex sync.RWMutex 112 | setPrefixArgsForCall []struct { 113 | arg1 string 114 | } 115 | CloseStub func() error 116 | closeMutex sync.RWMutex 117 | closeArgsForCall []struct{} 118 | closeReturns struct { 119 | result1 error 120 | } 121 | invocations map[string][][]interface{} 122 | invocationsMutex sync.RWMutex 123 | } 124 | 125 | func (fake *FakeStatter) Inc(arg1 string, arg2 int64, arg3 float32) error { 126 | fake.incMutex.Lock() 127 | fake.incArgsForCall = append(fake.incArgsForCall, struct { 128 | arg1 string 129 | arg2 int64 130 | arg3 float32 131 | }{arg1, arg2, arg3}) 132 | fake.recordInvocation("Inc", []interface{}{arg1, arg2, arg3}) 133 | fake.incMutex.Unlock() 134 | if fake.IncStub != nil { 135 | return fake.IncStub(arg1, arg2, arg3) 136 | } else { 137 | return fake.incReturns.result1 138 | } 139 | } 140 | 141 | func (fake *FakeStatter) IncCallCount() int { 142 | fake.incMutex.RLock() 143 | defer fake.incMutex.RUnlock() 144 | return len(fake.incArgsForCall) 145 | } 146 | 147 | func (fake *FakeStatter) IncArgsForCall(i int) (string, int64, float32) { 148 | fake.incMutex.RLock() 149 | defer fake.incMutex.RUnlock() 150 | return fake.incArgsForCall[i].arg1, fake.incArgsForCall[i].arg2, fake.incArgsForCall[i].arg3 151 | } 152 | 153 | func (fake *FakeStatter) IncReturns(result1 error) { 154 | fake.IncStub = nil 155 | fake.incReturns = struct { 156 | result1 error 157 | }{result1} 158 | } 159 | 160 | func (fake *FakeStatter) Dec(arg1 string, arg2 int64, arg3 float32) error { 161 | fake.decMutex.Lock() 162 | fake.decArgsForCall = append(fake.decArgsForCall, struct { 163 | arg1 string 164 | arg2 int64 165 | arg3 float32 166 | }{arg1, arg2, arg3}) 167 | fake.recordInvocation("Dec", []interface{}{arg1, arg2, arg3}) 168 | fake.decMutex.Unlock() 169 | if fake.DecStub != nil { 170 | return fake.DecStub(arg1, arg2, arg3) 171 | } else { 172 | return fake.decReturns.result1 173 | } 174 | } 175 | 176 | func (fake *FakeStatter) DecCallCount() int { 177 | fake.decMutex.RLock() 178 | defer fake.decMutex.RUnlock() 179 | return len(fake.decArgsForCall) 180 | } 181 | 182 | func (fake *FakeStatter) DecArgsForCall(i int) (string, int64, float32) { 183 | fake.decMutex.RLock() 184 | defer fake.decMutex.RUnlock() 185 | return fake.decArgsForCall[i].arg1, fake.decArgsForCall[i].arg2, fake.decArgsForCall[i].arg3 186 | } 187 | 188 | func (fake *FakeStatter) DecReturns(result1 error) { 189 | fake.DecStub = nil 190 | fake.decReturns = struct { 191 | result1 error 192 | }{result1} 193 | } 194 | 195 | func (fake *FakeStatter) Gauge(arg1 string, arg2 int64, arg3 float32) error { 196 | fake.gaugeMutex.Lock() 197 | fake.gaugeArgsForCall = append(fake.gaugeArgsForCall, struct { 198 | arg1 string 199 | arg2 int64 200 | arg3 float32 201 | }{arg1, arg2, arg3}) 202 | fake.recordInvocation("Gauge", []interface{}{arg1, arg2, arg3}) 203 | fake.gaugeMutex.Unlock() 204 | if fake.GaugeStub != nil { 205 | return fake.GaugeStub(arg1, arg2, arg3) 206 | } else { 207 | return fake.gaugeReturns.result1 208 | } 209 | } 210 | 211 | func (fake *FakeStatter) GaugeCallCount() int { 212 | fake.gaugeMutex.RLock() 213 | defer fake.gaugeMutex.RUnlock() 214 | return len(fake.gaugeArgsForCall) 215 | } 216 | 217 | func (fake *FakeStatter) GaugeArgsForCall(i int) (string, int64, float32) { 218 | fake.gaugeMutex.RLock() 219 | defer fake.gaugeMutex.RUnlock() 220 | return fake.gaugeArgsForCall[i].arg1, fake.gaugeArgsForCall[i].arg2, fake.gaugeArgsForCall[i].arg3 221 | } 222 | 223 | func (fake *FakeStatter) GaugeReturns(result1 error) { 224 | fake.GaugeStub = nil 225 | fake.gaugeReturns = struct { 226 | result1 error 227 | }{result1} 228 | } 229 | 230 | func (fake *FakeStatter) GaugeDelta(arg1 string, arg2 int64, arg3 float32) error { 231 | fake.gaugeDeltaMutex.Lock() 232 | fake.gaugeDeltaArgsForCall = append(fake.gaugeDeltaArgsForCall, struct { 233 | arg1 string 234 | arg2 int64 235 | arg3 float32 236 | }{arg1, arg2, arg3}) 237 | fake.recordInvocation("GaugeDelta", []interface{}{arg1, arg2, arg3}) 238 | fake.gaugeDeltaMutex.Unlock() 239 | if fake.GaugeDeltaStub != nil { 240 | return fake.GaugeDeltaStub(arg1, arg2, arg3) 241 | } else { 242 | return fake.gaugeDeltaReturns.result1 243 | } 244 | } 245 | 246 | func (fake *FakeStatter) GaugeDeltaCallCount() int { 247 | fake.gaugeDeltaMutex.RLock() 248 | defer fake.gaugeDeltaMutex.RUnlock() 249 | return len(fake.gaugeDeltaArgsForCall) 250 | } 251 | 252 | func (fake *FakeStatter) GaugeDeltaArgsForCall(i int) (string, int64, float32) { 253 | fake.gaugeDeltaMutex.RLock() 254 | defer fake.gaugeDeltaMutex.RUnlock() 255 | return fake.gaugeDeltaArgsForCall[i].arg1, fake.gaugeDeltaArgsForCall[i].arg2, fake.gaugeDeltaArgsForCall[i].arg3 256 | } 257 | 258 | func (fake *FakeStatter) GaugeDeltaReturns(result1 error) { 259 | fake.GaugeDeltaStub = nil 260 | fake.gaugeDeltaReturns = struct { 261 | result1 error 262 | }{result1} 263 | } 264 | 265 | func (fake *FakeStatter) Timing(arg1 string, arg2 int64, arg3 float32) error { 266 | fake.timingMutex.Lock() 267 | fake.timingArgsForCall = append(fake.timingArgsForCall, struct { 268 | arg1 string 269 | arg2 int64 270 | arg3 float32 271 | }{arg1, arg2, arg3}) 272 | fake.recordInvocation("Timing", []interface{}{arg1, arg2, arg3}) 273 | fake.timingMutex.Unlock() 274 | if fake.TimingStub != nil { 275 | return fake.TimingStub(arg1, arg2, arg3) 276 | } else { 277 | return fake.timingReturns.result1 278 | } 279 | } 280 | 281 | func (fake *FakeStatter) TimingCallCount() int { 282 | fake.timingMutex.RLock() 283 | defer fake.timingMutex.RUnlock() 284 | return len(fake.timingArgsForCall) 285 | } 286 | 287 | func (fake *FakeStatter) TimingArgsForCall(i int) (string, int64, float32) { 288 | fake.timingMutex.RLock() 289 | defer fake.timingMutex.RUnlock() 290 | return fake.timingArgsForCall[i].arg1, fake.timingArgsForCall[i].arg2, fake.timingArgsForCall[i].arg3 291 | } 292 | 293 | func (fake *FakeStatter) TimingReturns(result1 error) { 294 | fake.TimingStub = nil 295 | fake.timingReturns = struct { 296 | result1 error 297 | }{result1} 298 | } 299 | 300 | func (fake *FakeStatter) TimingDuration(arg1 string, arg2 time.Duration, arg3 float32) error { 301 | fake.timingDurationMutex.Lock() 302 | fake.timingDurationArgsForCall = append(fake.timingDurationArgsForCall, struct { 303 | arg1 string 304 | arg2 time.Duration 305 | arg3 float32 306 | }{arg1, arg2, arg3}) 307 | fake.recordInvocation("TimingDuration", []interface{}{arg1, arg2, arg3}) 308 | fake.timingDurationMutex.Unlock() 309 | if fake.TimingDurationStub != nil { 310 | return fake.TimingDurationStub(arg1, arg2, arg3) 311 | } else { 312 | return fake.timingDurationReturns.result1 313 | } 314 | } 315 | 316 | func (fake *FakeStatter) TimingDurationCallCount() int { 317 | fake.timingDurationMutex.RLock() 318 | defer fake.timingDurationMutex.RUnlock() 319 | return len(fake.timingDurationArgsForCall) 320 | } 321 | 322 | func (fake *FakeStatter) TimingDurationArgsForCall(i int) (string, time.Duration, float32) { 323 | fake.timingDurationMutex.RLock() 324 | defer fake.timingDurationMutex.RUnlock() 325 | return fake.timingDurationArgsForCall[i].arg1, fake.timingDurationArgsForCall[i].arg2, fake.timingDurationArgsForCall[i].arg3 326 | } 327 | 328 | func (fake *FakeStatter) TimingDurationReturns(result1 error) { 329 | fake.TimingDurationStub = nil 330 | fake.timingDurationReturns = struct { 331 | result1 error 332 | }{result1} 333 | } 334 | 335 | func (fake *FakeStatter) Set(arg1 string, arg2 string, arg3 float32) error { 336 | fake.setMutex.Lock() 337 | fake.setArgsForCall = append(fake.setArgsForCall, struct { 338 | arg1 string 339 | arg2 string 340 | arg3 float32 341 | }{arg1, arg2, arg3}) 342 | fake.recordInvocation("Set", []interface{}{arg1, arg2, arg3}) 343 | fake.setMutex.Unlock() 344 | if fake.SetStub != nil { 345 | return fake.SetStub(arg1, arg2, arg3) 346 | } else { 347 | return fake.setReturns.result1 348 | } 349 | } 350 | 351 | func (fake *FakeStatter) SetCallCount() int { 352 | fake.setMutex.RLock() 353 | defer fake.setMutex.RUnlock() 354 | return len(fake.setArgsForCall) 355 | } 356 | 357 | func (fake *FakeStatter) SetArgsForCall(i int) (string, string, float32) { 358 | fake.setMutex.RLock() 359 | defer fake.setMutex.RUnlock() 360 | return fake.setArgsForCall[i].arg1, fake.setArgsForCall[i].arg2, fake.setArgsForCall[i].arg3 361 | } 362 | 363 | func (fake *FakeStatter) SetReturns(result1 error) { 364 | fake.SetStub = nil 365 | fake.setReturns = struct { 366 | result1 error 367 | }{result1} 368 | } 369 | 370 | func (fake *FakeStatter) SetInt(arg1 string, arg2 int64, arg3 float32) error { 371 | fake.setIntMutex.Lock() 372 | fake.setIntArgsForCall = append(fake.setIntArgsForCall, struct { 373 | arg1 string 374 | arg2 int64 375 | arg3 float32 376 | }{arg1, arg2, arg3}) 377 | fake.recordInvocation("SetInt", []interface{}{arg1, arg2, arg3}) 378 | fake.setIntMutex.Unlock() 379 | if fake.SetIntStub != nil { 380 | return fake.SetIntStub(arg1, arg2, arg3) 381 | } else { 382 | return fake.setIntReturns.result1 383 | } 384 | } 385 | 386 | func (fake *FakeStatter) SetIntCallCount() int { 387 | fake.setIntMutex.RLock() 388 | defer fake.setIntMutex.RUnlock() 389 | return len(fake.setIntArgsForCall) 390 | } 391 | 392 | func (fake *FakeStatter) SetIntArgsForCall(i int) (string, int64, float32) { 393 | fake.setIntMutex.RLock() 394 | defer fake.setIntMutex.RUnlock() 395 | return fake.setIntArgsForCall[i].arg1, fake.setIntArgsForCall[i].arg2, fake.setIntArgsForCall[i].arg3 396 | } 397 | 398 | func (fake *FakeStatter) SetIntReturns(result1 error) { 399 | fake.SetIntStub = nil 400 | fake.setIntReturns = struct { 401 | result1 error 402 | }{result1} 403 | } 404 | 405 | func (fake *FakeStatter) Raw(arg1 string, arg2 string, arg3 float32) error { 406 | fake.rawMutex.Lock() 407 | fake.rawArgsForCall = append(fake.rawArgsForCall, struct { 408 | arg1 string 409 | arg2 string 410 | arg3 float32 411 | }{arg1, arg2, arg3}) 412 | fake.recordInvocation("Raw", []interface{}{arg1, arg2, arg3}) 413 | fake.rawMutex.Unlock() 414 | if fake.RawStub != nil { 415 | return fake.RawStub(arg1, arg2, arg3) 416 | } else { 417 | return fake.rawReturns.result1 418 | } 419 | } 420 | 421 | func (fake *FakeStatter) RawCallCount() int { 422 | fake.rawMutex.RLock() 423 | defer fake.rawMutex.RUnlock() 424 | return len(fake.rawArgsForCall) 425 | } 426 | 427 | func (fake *FakeStatter) RawArgsForCall(i int) (string, string, float32) { 428 | fake.rawMutex.RLock() 429 | defer fake.rawMutex.RUnlock() 430 | return fake.rawArgsForCall[i].arg1, fake.rawArgsForCall[i].arg2, fake.rawArgsForCall[i].arg3 431 | } 432 | 433 | func (fake *FakeStatter) RawReturns(result1 error) { 434 | fake.RawStub = nil 435 | fake.rawReturns = struct { 436 | result1 error 437 | }{result1} 438 | } 439 | 440 | func (fake *FakeStatter) NewSubStatter(arg1 string) statsd.SubStatter { 441 | fake.newSubStatterMutex.Lock() 442 | fake.newSubStatterArgsForCall = append(fake.newSubStatterArgsForCall, struct { 443 | arg1 string 444 | }{arg1}) 445 | fake.recordInvocation("NewSubStatter", []interface{}{arg1}) 446 | fake.newSubStatterMutex.Unlock() 447 | if fake.NewSubStatterStub != nil { 448 | return fake.NewSubStatterStub(arg1) 449 | } else { 450 | return fake.newSubStatterReturns.result1 451 | } 452 | } 453 | 454 | func (fake *FakeStatter) NewSubStatterCallCount() int { 455 | fake.newSubStatterMutex.RLock() 456 | defer fake.newSubStatterMutex.RUnlock() 457 | return len(fake.newSubStatterArgsForCall) 458 | } 459 | 460 | func (fake *FakeStatter) NewSubStatterArgsForCall(i int) string { 461 | fake.newSubStatterMutex.RLock() 462 | defer fake.newSubStatterMutex.RUnlock() 463 | return fake.newSubStatterArgsForCall[i].arg1 464 | } 465 | 466 | func (fake *FakeStatter) NewSubStatterReturns(result1 statsd.SubStatter) { 467 | fake.NewSubStatterStub = nil 468 | fake.newSubStatterReturns = struct { 469 | result1 statsd.SubStatter 470 | }{result1} 471 | } 472 | 473 | func (fake *FakeStatter) SetPrefix(arg1 string) { 474 | fake.setPrefixMutex.Lock() 475 | fake.setPrefixArgsForCall = append(fake.setPrefixArgsForCall, struct { 476 | arg1 string 477 | }{arg1}) 478 | fake.recordInvocation("SetPrefix", []interface{}{arg1}) 479 | fake.setPrefixMutex.Unlock() 480 | if fake.SetPrefixStub != nil { 481 | fake.SetPrefixStub(arg1) 482 | } 483 | } 484 | 485 | func (fake *FakeStatter) SetPrefixCallCount() int { 486 | fake.setPrefixMutex.RLock() 487 | defer fake.setPrefixMutex.RUnlock() 488 | return len(fake.setPrefixArgsForCall) 489 | } 490 | 491 | func (fake *FakeStatter) SetPrefixArgsForCall(i int) string { 492 | fake.setPrefixMutex.RLock() 493 | defer fake.setPrefixMutex.RUnlock() 494 | return fake.setPrefixArgsForCall[i].arg1 495 | } 496 | 497 | func (fake *FakeStatter) Close() error { 498 | fake.closeMutex.Lock() 499 | fake.closeArgsForCall = append(fake.closeArgsForCall, struct{}{}) 500 | fake.recordInvocation("Close", []interface{}{}) 501 | fake.closeMutex.Unlock() 502 | if fake.CloseStub != nil { 503 | return fake.CloseStub() 504 | } else { 505 | return fake.closeReturns.result1 506 | } 507 | } 508 | 509 | func (fake *FakeStatter) CloseCallCount() int { 510 | fake.closeMutex.RLock() 511 | defer fake.closeMutex.RUnlock() 512 | return len(fake.closeArgsForCall) 513 | } 514 | 515 | func (fake *FakeStatter) CloseReturns(result1 error) { 516 | fake.CloseStub = nil 517 | fake.closeReturns = struct { 518 | result1 error 519 | }{result1} 520 | } 521 | 522 | func (fake *FakeStatter) Invocations() map[string][][]interface{} { 523 | fake.invocationsMutex.RLock() 524 | defer fake.invocationsMutex.RUnlock() 525 | fake.incMutex.RLock() 526 | defer fake.incMutex.RUnlock() 527 | fake.decMutex.RLock() 528 | defer fake.decMutex.RUnlock() 529 | fake.gaugeMutex.RLock() 530 | defer fake.gaugeMutex.RUnlock() 531 | fake.gaugeDeltaMutex.RLock() 532 | defer fake.gaugeDeltaMutex.RUnlock() 533 | fake.timingMutex.RLock() 534 | defer fake.timingMutex.RUnlock() 535 | fake.timingDurationMutex.RLock() 536 | defer fake.timingDurationMutex.RUnlock() 537 | fake.setMutex.RLock() 538 | defer fake.setMutex.RUnlock() 539 | fake.setIntMutex.RLock() 540 | defer fake.setIntMutex.RUnlock() 541 | fake.rawMutex.RLock() 542 | defer fake.rawMutex.RUnlock() 543 | fake.newSubStatterMutex.RLock() 544 | defer fake.newSubStatterMutex.RUnlock() 545 | fake.setPrefixMutex.RLock() 546 | defer fake.setPrefixMutex.RUnlock() 547 | fake.closeMutex.RLock() 548 | defer fake.closeMutex.RUnlock() 549 | return fake.invocations 550 | } 551 | 552 | func (fake *FakeStatter) recordInvocation(key string, args []interface{}) { 553 | fake.invocationsMutex.Lock() 554 | defer fake.invocationsMutex.Unlock() 555 | if fake.invocations == nil { 556 | fake.invocations = map[string][][]interface{}{} 557 | } 558 | if fake.invocations[key] == nil { 559 | fake.invocations[key] = [][]interface{}{} 560 | } 561 | fake.invocations[key] = append(fake.invocations[key], args) 562 | } 563 | 564 | var _ statsd.Statter = new(FakeStatter) 565 | -------------------------------------------------------------------------------- /images/Rye Logo.sketch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InVisionApp/rye/c260759a2358155e882164fa8cc9f3b074dcef4a/images/Rye Logo.sketch -------------------------------------------------------------------------------- /images/rye-gopher.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Gopher flip 5 | Created with Sketch. 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /images/rye_logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | rye logo 5 | Created with Sketch. 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 62 | 63 | -------------------------------------------------------------------------------- /images/rye_logo_invision_pink.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InVisionApp/rye/c260759a2358155e882164fa8cc9f3b074dcef4a/images/rye_logo_invision_pink.png -------------------------------------------------------------------------------- /in-repo.yaml: -------------------------------------------------------------------------------- 1 | owner: 2 | active: 3 | team: core 4 | since: 2017-03-01 5 | -------------------------------------------------------------------------------- /middleware_accesstoken.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | ) 8 | 9 | type accessTokens struct { 10 | paramName string 11 | tokens []string 12 | getFunc func(string, *http.Request) string 13 | missingMessage string 14 | } 15 | 16 | /* 17 | NewMiddlewareAccessToken creates a new handler to verify access tokens passed as a header. 18 | 19 | Example usage: 20 | 21 | routes.Handle("/some/route", a.Dependencies.MWHandler.Handle( 22 | []rye.Handler{ 23 | rye.NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2}), 24 | yourHandler, 25 | })).Methods("POST") 26 | */ 27 | func NewMiddlewareAccessToken(headerName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response { 28 | return newAccessTokenHandler(headerName, tokens, "header") 29 | } 30 | 31 | /* 32 | NewMiddlewareAccessQueryToken creates a new handler to verify access tokens passed as a query parameter. 33 | 34 | Example usage: 35 | 36 | routes.Handle("/some/route", a.Dependencies.MWHandler.Handle( 37 | []rye.Handler{ 38 | rye.NewMiddlewareAccessQueryToken(queryParamName, []string{token1, token2}), 39 | yourHandler, 40 | })).Methods("POST") 41 | */ 42 | func NewMiddlewareAccessQueryToken(queryParamName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response { 43 | return newAccessTokenHandler(queryParamName, tokens, "query") 44 | } 45 | 46 | func newAccessTokenHandler(name string, tokens []string, tokenType string) func(rw http.ResponseWriter, req *http.Request) *Response { 47 | a := &accessTokens{ 48 | paramName: name, 49 | tokens: tokens, 50 | } 51 | 52 | switch tokenType { 53 | 54 | case "query": 55 | a.getFunc = func(s string, r *http.Request) string { 56 | q, ok := r.URL.Query()[s] 57 | if !ok { 58 | return "" 59 | } 60 | 61 | return q[0] 62 | } 63 | a.missingMessage = fmt.Sprintf("No access token found; ensure you pass the '%s' parameter", name) 64 | 65 | default: 66 | // default to using the header 67 | a.getFunc = func(s string, r *http.Request) string { 68 | return r.Header.Get(s) 69 | } 70 | a.missingMessage = fmt.Sprintf("No access token found; ensure you pass '%s' in header", name) 71 | } 72 | 73 | return a.handle 74 | } 75 | 76 | func (a *accessTokens) handle(rw http.ResponseWriter, r *http.Request) *Response { 77 | token := a.getFunc(a.paramName, r) 78 | 79 | if token == "" { 80 | return &Response{ 81 | Err: errors.New(a.missingMessage), 82 | StatusCode: http.StatusUnauthorized, 83 | } 84 | } 85 | 86 | if ok := stringListContains(a.tokens, token); !ok { 87 | return &Response{ 88 | Err: errors.New("Unauthorized request: invalid access token"), 89 | StatusCode: http.StatusUnauthorized, 90 | } 91 | } 92 | 93 | return nil 94 | } 95 | 96 | func stringListContains(stringList []string, element string) bool { 97 | for _, v := range stringList { 98 | if v == element { 99 | return true 100 | } 101 | } 102 | 103 | return false 104 | } 105 | -------------------------------------------------------------------------------- /middleware_accesstoken_test.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "net/url" 8 | 9 | . "github.com/onsi/ginkgo" 10 | . "github.com/onsi/gomega" 11 | ) 12 | 13 | var _ = Describe("AccessToken Middleware", func() { 14 | 15 | var ( 16 | request *http.Request 17 | response *httptest.ResponseRecorder 18 | 19 | testHandler func(http.ResponseWriter, *http.Request) *Response 20 | 21 | token1, token2 string 22 | ) 23 | 24 | BeforeEach(func() { 25 | response = httptest.NewRecorder() 26 | 27 | token1 = "test1" 28 | token2 = "test2" 29 | }) 30 | 31 | Context("header token", func() { 32 | var ( 33 | tokenHeaderName = "at-hname" 34 | ) 35 | 36 | BeforeEach(func() { 37 | testHandler = NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2}) 38 | request = &http.Request{ 39 | Header: map[string][]string{}, 40 | } 41 | }) 42 | 43 | Context("when a valid token is used", func() { 44 | It("should return nil", func() { 45 | request.Header.Add(tokenHeaderName, token1) 46 | resp := testHandler(response, request) 47 | Expect(resp).To(BeNil()) 48 | }) 49 | 50 | It("should return nil", func() { 51 | request.Header.Add(tokenHeaderName, token2) 52 | resp := testHandler(response, request) 53 | Expect(resp).To(BeNil()) 54 | }) 55 | }) 56 | 57 | Context("when an invalid token is used", func() { 58 | It("should return an error", func() { 59 | request.Header.Add(tokenHeaderName, "blah") 60 | resp := testHandler(response, request) 61 | Expect(resp).ToNot(BeNil()) 62 | Expect(resp.Err).To(HaveOccurred()) 63 | Expect(resp.Error()).To(ContainSubstring("invalid access token")) 64 | Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) 65 | }) 66 | }) 67 | 68 | Context("when no token header exists", func() { 69 | It("should return an error", func() { 70 | resp := testHandler(response, request) 71 | Expect(resp).ToNot(BeNil()) 72 | Expect(resp.Err).To(HaveOccurred()) 73 | Expect(resp.Error()).To(ContainSubstring("No access token found")) 74 | Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) 75 | }) 76 | }) 77 | 78 | Context("when token header is blank", func() { 79 | It("should return an error", func() { 80 | request.Header.Add(tokenHeaderName, "") 81 | resp := testHandler(response, request) 82 | Expect(resp).ToNot(BeNil()) 83 | Expect(resp.Err).To(HaveOccurred()) 84 | Expect(resp.Error()).To(ContainSubstring("No access token found")) 85 | Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) 86 | }) 87 | }) 88 | }) 89 | 90 | Context("query param token", func() { 91 | var ( 92 | qParamName string 93 | qParams string 94 | ) 95 | 96 | BeforeEach(func() { 97 | qParamName = "token" 98 | testHandler = NewMiddlewareAccessQueryToken(qParamName, []string{token1, token2}) 99 | }) 100 | 101 | JustBeforeEach(func() { 102 | u, err := url.Parse(fmt.Sprintf("http://doesntmatter.io/blah?%s", qParams)) 103 | Expect(err).ToNot(HaveOccurred()) 104 | 105 | request = &http.Request{ 106 | URL: u, 107 | } 108 | }) 109 | 110 | Context("when a valid token is used", func() { 111 | BeforeEach(func() { 112 | qParams = fmt.Sprintf("%s=%s", qParamName, token1) 113 | }) 114 | 115 | It("should return nil", func() { 116 | resp := testHandler(response, request) 117 | Expect(resp).To(BeNil()) 118 | }) 119 | }) 120 | 121 | Context("when the other valid token is used", func() { 122 | BeforeEach(func() { 123 | qParams = fmt.Sprintf("%s=%s", qParamName, token2) 124 | }) 125 | 126 | It("should return nil", func() { 127 | resp := testHandler(response, request) 128 | Expect(resp).To(BeNil()) 129 | }) 130 | }) 131 | 132 | Context("when an invalid token is used", func() { 133 | BeforeEach(func() { 134 | qParams = fmt.Sprintf("%s=blah", qParamName) 135 | }) 136 | 137 | It("should return an error", func() { 138 | resp := testHandler(response, request) 139 | Expect(resp).ToNot(BeNil()) 140 | Expect(resp.Err).To(HaveOccurred()) 141 | Expect(resp.Error()).To(ContainSubstring("invalid access token")) 142 | Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) 143 | }) 144 | }) 145 | 146 | Context("when no token param exists", func() { 147 | BeforeEach(func() { 148 | qParams = "something=else" 149 | }) 150 | 151 | It("should return an error", func() { 152 | resp := testHandler(response, request) 153 | Expect(resp).ToNot(BeNil()) 154 | Expect(resp.Err).To(HaveOccurred()) 155 | Expect(resp.Error()).To(ContainSubstring("No access token found")) 156 | Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) 157 | }) 158 | }) 159 | 160 | Context("when token param is blank", func() { 161 | BeforeEach(func() { 162 | qParams = fmt.Sprintf("%s=''", qParamName) 163 | }) 164 | 165 | It("should return an error", func() { 166 | resp := testHandler(response, request) 167 | Expect(resp).ToNot(BeNil()) 168 | Expect(resp.Err).To(HaveOccurred()) 169 | Expect(resp.Error()).To(ContainSubstring("invalid access token")) 170 | Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) 171 | }) 172 | }) 173 | 174 | Context("when no query params", func() { 175 | JustBeforeEach(func() { 176 | u, err := url.Parse("http://doesntmatter.io/blah") 177 | Expect(err).ToNot(HaveOccurred()) 178 | 179 | request = &http.Request{ 180 | URL: u, 181 | } 182 | }) 183 | 184 | It("should return an error", func() { 185 | resp := testHandler(response, request) 186 | Expect(resp).ToNot(BeNil()) 187 | Expect(resp.Err).To(HaveOccurred()) 188 | Expect(resp.Error()).To(ContainSubstring("No access token found")) 189 | Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) 190 | }) 191 | }) 192 | 193 | }) 194 | }) 195 | -------------------------------------------------------------------------------- /middleware_auth.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "context" 5 | "encoding/base64" 6 | "errors" 7 | "fmt" 8 | "net/http" 9 | "strings" 10 | 11 | jwt "github.com/dgrijalva/jwt-go" 12 | ) 13 | 14 | /* 15 | NewMiddlewareAuth creates a new middleware to extract the Authorization header 16 | from a request and validate it. It accepts a func of type AuthFunc which is 17 | used to do the credential validation. 18 | An AuthFuncs for Basic auth and JWT are provided here. 19 | 20 | Example usage: 21 | 22 | routes.Handle("/some/route", myMWHandler.Handle( 23 | []rye.Handler{ 24 | rye.NewMiddlewareAuth(rye.NewBasicAuthFunc(map[string]string{ 25 | "user1": "my_password", 26 | })), 27 | yourHandler, 28 | })).Methods("POST") 29 | */ 30 | 31 | type AuthFunc func(context.Context, string) *Response 32 | 33 | func NewMiddlewareAuth(authFunc AuthFunc) func(rw http.ResponseWriter, req *http.Request) *Response { 34 | return func(rw http.ResponseWriter, r *http.Request) *Response { 35 | auth := r.Header.Get("Authorization") 36 | if auth == "" { 37 | return &Response{ 38 | Err: errors.New("unauthorized: no authentication provided"), 39 | StatusCode: http.StatusUnauthorized, 40 | } 41 | } 42 | 43 | return authFunc(r.Context(), auth) 44 | } 45 | } 46 | 47 | /*********** 48 | Basic Auth 49 | ***********/ 50 | 51 | func NewBasicAuthFunc(userPass map[string]string) AuthFunc { 52 | return basicAuth(userPass).authenticate 53 | } 54 | 55 | type basicAuth map[string]string 56 | 57 | const AUTH_USERNAME_KEY = "request-username" 58 | 59 | // basicAuth.authenticate meets the AuthFunc type 60 | func (b basicAuth) authenticate(ctx context.Context, auth string) *Response { 61 | errResp := &Response{ 62 | Err: errors.New("unauthorized: invalid authentication provided"), 63 | StatusCode: http.StatusUnauthorized, 64 | } 65 | 66 | // parse the Authorization header 67 | u, p, ok := parseBasicAuth(auth) 68 | if !ok { 69 | return errResp 70 | } 71 | 72 | // get the password 73 | pass, ok := b[u] 74 | if !ok { 75 | return errResp 76 | } 77 | 78 | // compare the password 79 | if pass != p { 80 | return errResp 81 | } 82 | 83 | // add username to the context 84 | return &Response{ 85 | Context: context.WithValue(ctx, AUTH_USERNAME_KEY, u), 86 | } 87 | } 88 | 89 | const basicPrefix = "Basic " 90 | 91 | // parseBasicAuth parses an HTTP Basic Authentication string. 92 | // taken from net/http/request.go 93 | func parseBasicAuth(auth string) (username, password string, ok bool) { 94 | if !strings.HasPrefix(auth, basicPrefix) { 95 | return 96 | } 97 | c, err := base64.StdEncoding.DecodeString(auth[len(basicPrefix):]) 98 | if err != nil { 99 | return 100 | } 101 | cs := string(c) 102 | s := strings.IndexByte(cs, ':') 103 | if s < 0 { 104 | return 105 | } 106 | return cs[:s], cs[s+1:], true 107 | } 108 | 109 | /**** 110 | JWT 111 | ****/ 112 | 113 | type jwtAuth struct { 114 | secret string 115 | } 116 | 117 | func NewJWTAuthFunc(secret string) AuthFunc { 118 | j := &jwtAuth{secret: secret} 119 | return j.authenticate 120 | } 121 | 122 | const bearerPrefix = "Bearer " 123 | 124 | func (j *jwtAuth) authenticate(ctx context.Context, auth string) *Response { 125 | // Remove 'Bearer' prefix 126 | if !strings.HasPrefix(auth, bearerPrefix) && !strings.HasPrefix(auth, strings.ToLower(bearerPrefix)) { 127 | return &Response{ 128 | Err: errors.New("unauthorized: invalid authentication provided"), 129 | StatusCode: http.StatusUnauthorized, 130 | } 131 | } 132 | 133 | token := auth[len(bearerPrefix):] 134 | 135 | _, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) { 136 | if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { 137 | return nil, fmt.Errorf("Unexpected signing method") 138 | } 139 | return []byte(j.secret), nil 140 | }) 141 | if err != nil { 142 | return &Response{ 143 | Err: err, 144 | StatusCode: http.StatusUnauthorized, 145 | } 146 | } 147 | 148 | return &Response{ 149 | Context: context.WithValue(ctx, CONTEXT_JWT, token), 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /middleware_auth_test.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | 7 | "context" 8 | 9 | . "github.com/onsi/ginkgo" 10 | . "github.com/onsi/gomega" 11 | ) 12 | 13 | const AUTH_HEADER_NAME = "Authorization" 14 | 15 | var _ = Describe("Auth Middleware", func() { 16 | var ( 17 | request *http.Request 18 | response *httptest.ResponseRecorder 19 | 20 | testHandler func(http.ResponseWriter, *http.Request) *Response 21 | ) 22 | 23 | BeforeEach(func() { 24 | response = httptest.NewRecorder() 25 | }) 26 | 27 | Context("auth", func() { 28 | var ( 29 | fakeAuth *recorder 30 | ) 31 | 32 | BeforeEach(func() { 33 | fakeAuth = &recorder{} 34 | 35 | testHandler = NewMiddlewareAuth(fakeAuth.authFunc) 36 | request = &http.Request{ 37 | Header: map[string][]string{}, 38 | } 39 | }) 40 | 41 | It("passes the header to the auth func", func() { 42 | testAuth := "foobar" 43 | request.Header.Add(AUTH_HEADER_NAME, testAuth) 44 | resp := testHandler(response, request) 45 | 46 | Expect(resp).To(BeNil()) 47 | Expect(fakeAuth.header).To(Equal(testAuth)) 48 | }) 49 | 50 | Context("when no header is found", func() { 51 | It("errors", func() { 52 | resp := testHandler(response, request) 53 | 54 | Expect(resp).ToNot(BeNil()) 55 | Expect(resp.Err).ToNot(BeNil()) 56 | Expect(resp.Err.Error()).To(ContainSubstring("no authentication")) 57 | }) 58 | }) 59 | }) 60 | 61 | Context("Basic Auth", func() { 62 | var ( 63 | username = "user1" 64 | pass = "mypass" 65 | ) 66 | 67 | BeforeEach(func() { 68 | testHandler = NewMiddlewareAuth(NewBasicAuthFunc(map[string]string{ 69 | username: pass, 70 | })) 71 | 72 | request = &http.Request{ 73 | Header: map[string][]string{}, 74 | } 75 | }) 76 | 77 | It("validates the password", func() { 78 | request.SetBasicAuth(username, pass) 79 | resp := testHandler(response, request) 80 | 81 | Expect(resp.Err).To(BeNil()) 82 | }) 83 | 84 | It("adds the username to context", func() { 85 | request.SetBasicAuth(username, pass) 86 | resp := testHandler(response, request) 87 | 88 | Expect(resp.Err).To(BeNil()) 89 | 90 | ctxUname := resp.Context.Value(AUTH_USERNAME_KEY) 91 | uname, ok := ctxUname.(string) 92 | Expect(ok).To(BeTrue()) 93 | Expect(uname).To(Equal(username)) 94 | }) 95 | 96 | It("preserves the request context", func() { 97 | 98 | }) 99 | 100 | It("errors if username unknown", func() { 101 | request.SetBasicAuth("noname", pass) 102 | resp := testHandler(response, request) 103 | 104 | Expect(resp.Err).ToNot(BeNil()) 105 | Expect(resp.Err.Error()).To(ContainSubstring("invalid auth")) 106 | }) 107 | 108 | It("errors if password wrong", func() { 109 | request.SetBasicAuth(username, "wrong") 110 | resp := testHandler(response, request) 111 | 112 | Expect(resp.Err).ToNot(BeNil()) 113 | Expect(resp.Err.Error()).To(ContainSubstring("invalid auth")) 114 | }) 115 | 116 | Context("parseBasicAuth", func() { 117 | It("errors if header not basic", func() { 118 | request.Header.Add(AUTH_HEADER_NAME, "wrong") 119 | resp := testHandler(response, request) 120 | 121 | Expect(resp.Err).ToNot(BeNil()) 122 | Expect(resp.Err.Error()).To(ContainSubstring("invalid auth")) 123 | }) 124 | 125 | It("errors if header not base64", func() { 126 | request.Header.Add(AUTH_HEADER_NAME, "Basic ------") 127 | resp := testHandler(response, request) 128 | 129 | Expect(resp.Err).ToNot(BeNil()) 130 | Expect(resp.Err.Error()).To(ContainSubstring("invalid auth")) 131 | }) 132 | 133 | It("errors if header wrong format", func() { 134 | request.Header.Add(AUTH_HEADER_NAME, "Basic YXNkZgo=") // asdf no `:` 135 | resp := testHandler(response, request) 136 | 137 | Expect(resp.Err).ToNot(BeNil()) 138 | Expect(resp.Err.Error()).To(ContainSubstring("invalid auth")) 139 | }) 140 | }) 141 | }) 142 | }) 143 | 144 | type recorder struct { 145 | header string 146 | } 147 | 148 | func (r *recorder) authFunc(ctx context.Context, s string) *Response { 149 | r.header = s 150 | return nil 151 | } 152 | -------------------------------------------------------------------------------- /middleware_cidr.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/http" 7 | ) 8 | 9 | type cidr struct { 10 | cidrs []string 11 | } 12 | 13 | /* 14 | NewMiddlewareCIDR creates a new handler to verify incoming IPs against a set of CIDR Notation strings in a rye chain. 15 | For reference on CIDR notation see https://en.wikipedia.org/wiki/Classless_Inter-Domain_Routing 16 | 17 | Example usage: 18 | 19 | routes.Handle("/some/route", a.Dependencies.MWHandler.Handle( 20 | []rye.Handler{ 21 | rye.NewMiddlewareCIDR(CIDRs), // []string of allowed CIDRs 22 | yourHandler, 23 | })).Methods("POST") 24 | */ 25 | func NewMiddlewareCIDR(CIDRs []string) func(rw http.ResponseWriter, req *http.Request) *Response { 26 | c := &cidr{cidrs: CIDRs} 27 | return c.handle 28 | } 29 | 30 | // Verify if incoming request comes from a valid CIDR 31 | func (c *cidr) handle(rw http.ResponseWriter, r *http.Request) *Response { 32 | // Validate the incoming IP 33 | host, _, err := net.SplitHostPort(r.RemoteAddr) 34 | if err != nil { 35 | return &Response{ 36 | Err: fmt.Errorf("Remote address error: %v", err.Error()), 37 | StatusCode: http.StatusUnauthorized, 38 | } 39 | } 40 | 41 | included, err := inCIDRs(host, c.cidrs) 42 | if err != nil { 43 | return &Response{ 44 | Err: fmt.Errorf("Error validating IP address: %v", err.Error()), 45 | StatusCode: http.StatusUnauthorized, 46 | } 47 | } 48 | 49 | if !included { 50 | return &Response{ 51 | Err: fmt.Errorf("%v is not authorized", host), 52 | StatusCode: http.StatusUnauthorized, 53 | } 54 | } 55 | 56 | return nil 57 | } 58 | 59 | // Verify that a given IP is a part of at least one CIDR in given CIDR list 60 | func inCIDRs(ipAddr string, cidrList []string) (bool, error) { 61 | for _, v := range cidrList { 62 | state, err := inCIDR(ipAddr, v) 63 | if err != nil { 64 | return false, err 65 | } 66 | 67 | if state { 68 | return true, nil 69 | } 70 | } 71 | 72 | return false, nil 73 | } 74 | 75 | // Verify whether a given IP is in a CIDR 76 | func inCIDR(ipAddr, cidrAddr string) (bool, error) { 77 | _, cidrnet, err := net.ParseCIDR(cidrAddr) 78 | if err != nil { 79 | return false, err 80 | } 81 | 82 | ip := net.ParseIP(ipAddr) 83 | 84 | if ip == nil { 85 | return false, fmt.Errorf("Unable to parse IP %v", ip) 86 | } 87 | 88 | return cidrnet.Contains(ip), nil 89 | } 90 | -------------------------------------------------------------------------------- /middleware_cidr_test.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | 7 | . "github.com/onsi/ginkgo" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("CIDR Middleware", func() { 12 | 13 | var ( 14 | request *http.Request 15 | response *httptest.ResponseRecorder 16 | 17 | cidr1, cidr2, ip1, ip2, ip3 string 18 | ) 19 | 20 | BeforeEach(func() { 21 | response = httptest.NewRecorder() 22 | request = &http.Request{} 23 | cidr1 = "10.0.0.0/24" 24 | cidr2 = "127.0.0.0/24" 25 | ip1 = "10.0.0.1:22" 26 | ip2 = "127.0.0.1:22" 27 | ip3 = "192.0.0.1:22" 28 | }) 29 | 30 | Describe("handle", func() { 31 | Context("when a valid IP is used", func() { 32 | It("should return nil", func() { 33 | request.RemoteAddr = ip1 34 | resp := NewMiddlewareCIDR([]string{cidr1, cidr2})(response, request) 35 | Expect(resp).To(BeNil()) 36 | }) 37 | 38 | It("should return nil", func() { 39 | request.RemoteAddr = ip2 40 | resp := NewMiddlewareCIDR([]string{cidr1, cidr2})(response, request) 41 | Expect(resp).To(BeNil()) 42 | }) 43 | }) 44 | 45 | Context("when an invalid IP is used", func() { 46 | It("should return an error", func() { 47 | request.RemoteAddr = ip3 48 | resp := NewMiddlewareCIDR([]string{cidr1, cidr2})(response, request) 49 | Expect(resp).ToNot(BeNil()) 50 | Expect(resp.Err).To(HaveOccurred()) 51 | Expect(resp.Error()).To(ContainSubstring("not authorized")) 52 | Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) 53 | }) 54 | }) 55 | 56 | Context("when no IP exists", func() { 57 | It("should return an error", func() { 58 | resp := NewMiddlewareCIDR([]string{cidr1, cidr2})(response, request) 59 | Expect(resp).ToNot(BeNil()) 60 | Expect(resp.Err).To(HaveOccurred()) 61 | Expect(resp.Error()).To(ContainSubstring("Remote address error")) 62 | Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) 63 | }) 64 | }) 65 | 66 | Context("when an unrecognizable IP is used", func() { 67 | It("should return an error", func() { 68 | request.RemoteAddr = "blah:80" 69 | resp := NewMiddlewareCIDR([]string{cidr1, cidr2})(response, request) 70 | Expect(resp).ToNot(BeNil()) 71 | Expect(resp.Err).To(HaveOccurred()) 72 | Expect(resp.Error()).To(ContainSubstring("Error validating IP address: Unable to parse IP")) 73 | Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) 74 | }) 75 | }) 76 | 77 | Context("when an unrecognizable CIDR is used", func() { 78 | It("should return an error", func() { 79 | request.RemoteAddr = ip1 80 | resp := NewMiddlewareCIDR([]string{"blah"})(response, request) 81 | Expect(resp).ToNot(BeNil()) 82 | Expect(resp.Err).To(HaveOccurred()) 83 | Expect(resp.Error()).To(ContainSubstring("Error validating IP address: invalid CIDR address")) 84 | Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) 85 | }) 86 | }) 87 | 88 | }) 89 | }) 90 | -------------------------------------------------------------------------------- /middleware_cors.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | const ( 8 | // CORS Specific constants 9 | DEFAULT_CORS_ALLOW_ORIGIN = "*" 10 | DEFAULT_CORS_ALLOW_METHODS = "POST, GET, OPTIONS, PUT, DELETE" 11 | DEFAULT_CORS_ALLOW_HEADERS = "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-Access-Token" 12 | ) 13 | 14 | type cors struct { 15 | CORSAllowOrigin string 16 | CORSAllowMethods string 17 | CORSAllowHeaders string 18 | } 19 | 20 | // MiddlewareCORS is the struct to represent configuration of the CORS handler. 21 | func MiddlewareCORS() func(rw http.ResponseWriter, req *http.Request) *Response { 22 | c := &cors{ 23 | CORSAllowOrigin: DEFAULT_CORS_ALLOW_ORIGIN, 24 | CORSAllowMethods: DEFAULT_CORS_ALLOW_METHODS, 25 | CORSAllowHeaders: DEFAULT_CORS_ALLOW_HEADERS, 26 | } 27 | 28 | return c.handle 29 | } 30 | 31 | /* 32 | NewMiddlewareCORS creates a new handler to support CORS functionality. You can use this middleware by specifying `rye.MiddlewareCORS()` or `rye.NewMiddlewareCORS(origin, methods, headers)` 33 | when defining your routes. 34 | 35 | Default CORS Values: 36 | 37 | DEFAULT_CORS_ALLOW_ORIGIN**: "*" 38 | DEFAULT_CORS_ALLOW_METHODS**: "POST, GET, OPTIONS, PUT, DELETE" 39 | DEFAULT_CORS_ALLOW_HEADERS**: "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-Access-Token" 40 | 41 | If you are planning to use this in production - you should probably use this middleware *with* params. 42 | 43 | Example use case: 44 | 45 | routes.Handle("/some/route", a.Dependencies.MWHandler.Handle( 46 | []rye.Handler{ 47 | rye.MiddlewareCORS(), // use defaults for allowed origin, headers, methods 48 | yourHandler, 49 | })).Methods("PUT", "OPTIONS") 50 | 51 | OR: 52 | 53 | routes.Handle("/some/route", a.Dependencies.MWHandler.Handle( 54 | []rye.Handler{ 55 | rye.NewMiddlewareCORS("*", "POST, GET", "SomeHeader, AnotherHeader"), 56 | yourHandler, 57 | })).Methods("PUT", "OPTIONS") 58 | */ 59 | func NewMiddlewareCORS(origin, methods, headers string) func(rw http.ResponseWriter, req *http.Request) *Response { 60 | c := &cors{ 61 | CORSAllowOrigin: origin, 62 | CORSAllowMethods: methods, 63 | CORSAllowHeaders: headers, 64 | } 65 | 66 | return c.handle 67 | } 68 | 69 | // If `Origin` header gets passed, add required response headers for CORS support. 70 | // Return bool if `Origin` header was detected. 71 | func (c *cors) handle(rw http.ResponseWriter, req *http.Request) *Response { 72 | origin := req.Header.Get("Origin") 73 | 74 | // Origin header not provided, nothing for CORS to do 75 | if origin == "" { 76 | return nil 77 | } 78 | 79 | rw.Header().Set("Access-Control-Allow-Origin", c.CORSAllowOrigin) 80 | rw.Header().Set("Access-Control-Allow-Methods", c.CORSAllowMethods) 81 | rw.Header().Set("Access-Control-Allow-Headers", c.CORSAllowHeaders) 82 | 83 | // If this was a preflight request, stop further middleware execution 84 | if req.Method == "OPTIONS" { 85 | return &Response{ 86 | StopExecution: true, 87 | } 88 | } 89 | 90 | return nil 91 | } 92 | -------------------------------------------------------------------------------- /middleware_cors_test.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | 7 | . "github.com/onsi/ginkgo" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("CORS Middleware", func() { 12 | 13 | var ( 14 | request *http.Request 15 | response *httptest.ResponseRecorder 16 | ) 17 | 18 | BeforeEach(func() { 19 | response = httptest.NewRecorder() 20 | request = &http.Request{ 21 | Header: make(map[string][]string, 0), 22 | } 23 | }) 24 | 25 | Describe("handle", func() { 26 | Context("when origin header is not set", func() { 27 | It("should return nil", func() { 28 | resp := MiddlewareCORS()(response, request) 29 | Expect(resp).To(BeNil()) 30 | }) 31 | }) 32 | 33 | Context("when origin header is set", func() { 34 | Context("and CORS was instantiated with params", func() { 35 | var ( 36 | testOrigin = "*.invisionapp.com" 37 | testHeaders = "TestHeader" 38 | testMethods = "GET, POST, TESTMETHOD" 39 | ) 40 | 41 | It("should set all CORS headers from params", func() { 42 | request.Header.Add("Origin", "*.invisionapp.com") 43 | resp := NewMiddlewareCORS(testOrigin, testMethods, testHeaders)(response, request) 44 | 45 | Expect(resp).To(BeNil()) 46 | Expect(response.Header().Get("Access-Control-Allow-Origin")).To(Equal(testOrigin)) 47 | Expect(response.Header().Get("Access-Control-Allow-Methods")).To(Equal(testMethods)) 48 | Expect(response.Header().Get("Access-Control-Allow-Headers")).To(Equal(testHeaders)) 49 | }) 50 | }) 51 | 52 | Context("and CORS was instantiated with defaults", func() { 53 | It("should set all CORS headers using defaults", func() { 54 | request.Header.Add("Origin", "*.invisionapp.com") 55 | resp := MiddlewareCORS()(response, request) 56 | 57 | Expect(resp).To(BeNil()) 58 | Expect(response.Header().Get("Access-Control-Allow-Origin")).To(Equal(DEFAULT_CORS_ALLOW_ORIGIN)) 59 | Expect(response.Header().Get("Access-Control-Allow-Methods")).To(Equal(DEFAULT_CORS_ALLOW_METHODS)) 60 | Expect(response.Header().Get("Access-Control-Allow-Headers")).To(Equal(DEFAULT_CORS_ALLOW_HEADERS)) 61 | }) 62 | }) 63 | 64 | Context("and we got a preflight request (OPTIONS)", func() { 65 | It("should return a response with StopExecution", func() { 66 | request.Method = "OPTIONS" 67 | request.Header.Add("Origin", "*.invisionapp.com") 68 | resp := MiddlewareCORS()(response, request) 69 | 70 | Expect(resp).ToNot(BeNil()) 71 | Expect(resp.StopExecution).To(BeTrue()) 72 | }) 73 | }) 74 | }) 75 | }) 76 | }) 77 | -------------------------------------------------------------------------------- /middleware_getheader.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | ) 7 | 8 | type getHeader struct { 9 | headerName string 10 | contextKey string 11 | } 12 | 13 | /* 14 | NewMiddlewareGetHeader creates a new handler to extract any header and save its value into the context. 15 | headerName: the name of the header you want to extract 16 | contextKey: the value key that you would like to store this header under in the context 17 | 18 | Example usage: 19 | 20 | routes.Handle("/some/route", a.Dependencies.MWHandler.Handle( 21 | []rye.Handler{ 22 | rye.NewMiddlewareGetHeader(headerName, contextKey), 23 | yourHandler, 24 | })).Methods("POST") 25 | */ 26 | func NewMiddlewareGetHeader(headerName, contextKey string) func(rw http.ResponseWriter, req *http.Request) *Response { 27 | h := getHeader{headerName: headerName, contextKey: contextKey} 28 | return h.getHeaderMiddleware 29 | } 30 | 31 | func (h *getHeader) getHeaderMiddleware(rw http.ResponseWriter, r *http.Request) *Response { 32 | rID := r.Header.Get(h.headerName) 33 | if rID != "" { 34 | return &Response{ 35 | Context: context.WithValue(r.Context(), h.contextKey, rID), 36 | } 37 | } 38 | 39 | return nil 40 | } 41 | -------------------------------------------------------------------------------- /middleware_getheader_test.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | 7 | . "github.com/onsi/ginkgo" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("Get Header Middleware", func() { 12 | var ( 13 | request *http.Request 14 | response *httptest.ResponseRecorder 15 | ) 16 | 17 | BeforeEach(func() { 18 | response = httptest.NewRecorder() 19 | request = &http.Request{ 20 | Header: make(map[string][]string, 0), 21 | } 22 | }) 23 | 24 | Describe("getHeaderMiddleware", func() { 25 | Context("when a valid header is passed", func() { 26 | It("should return context with value", func() { 27 | headerName := "SpecialHeader" 28 | ctxKey := "special" 29 | request.Header.Add(headerName, "secret value") 30 | resp := NewMiddlewareGetHeader(headerName, ctxKey)(response, request) 31 | Expect(resp).ToNot(BeNil()) 32 | Expect(resp.Context).ToNot(BeNil()) 33 | Expect(resp.Context.Value(ctxKey)).To(Equal("secret value")) 34 | }) 35 | }) 36 | 37 | Context("when no header is passed", func() { 38 | It("should have no value in context", func() { 39 | resp := NewMiddlewareGetHeader("something", "not there")(response, request) 40 | Expect(resp).To(BeNil()) 41 | }) 42 | }) 43 | }) 44 | }) 45 | -------------------------------------------------------------------------------- /middleware_jwt.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import "net/http" 4 | 5 | const ( 6 | CONTEXT_JWT = "rye-middlewarejwt-jwt" 7 | ) 8 | 9 | type jwtVerify struct { 10 | secret string 11 | token string 12 | } 13 | 14 | /* 15 | This middleware is deprecated. Use NewMiddlewareAuth with NewJWTAuthFunc instead. 16 | 17 | This remains here as a shim for backwards compatibility. 18 | 19 | --------------------------------------------------------------------------- 20 | 21 | This middleware provides JWT verification functionality 22 | 23 | You can use this middleware by specifying `rye.NewMiddlewareJWT(shared_secret)` 24 | when defining your routes. 25 | 26 | This middleware has no default version, it must be configured with a shared secret. 27 | 28 | Example use case: 29 | 30 | routes.Handle("/some/route", a.Dependencies.MWHandler.Handle( 31 | []rye.Handler{ 32 | rye.NewMiddlewareJWT("this is a big secret"), 33 | yourHandler, 34 | })).Methods("PUT", "OPTIONS") 35 | 36 | Additionally, this middleware puts the JWT token into the context for use by other 37 | middlewares in your chain. 38 | 39 | Access to that is simple (using the CONTEXT_JWT constant as a key) 40 | 41 | func getJWTfromContext(rw http.ResponseWriter, r *http.Request) *rye.Response { 42 | 43 | // Retrieving the value is easy! 44 | // Just reference the rye.CONTEXT_JWT const as a key 45 | myVal := r.Context().Value(rye.CONTEXT_JWT) 46 | 47 | // Log it to the server log? 48 | log.Infof("Context Value: %v", myVal) 49 | 50 | return nil 51 | } 52 | 53 | */ 54 | func NewMiddlewareJWT(secret string) func(rw http.ResponseWriter, req *http.Request) *Response { 55 | return NewMiddlewareAuth(NewJWTAuthFunc(secret)) 56 | } 57 | -------------------------------------------------------------------------------- /middleware_jwt_test.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | 8 | . "github.com/onsi/ginkgo" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | var _ = Describe("JWT Middleware", func() { 13 | 14 | var ( 15 | request *http.Request 16 | response *httptest.ResponseRecorder 17 | shared_secret = "secret" 18 | hs256_jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" 19 | rs256_jwt = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2p3dC1pZHAuZXhhbXBsZS5jb20iLCJzdWIiOiJtYWlsdG86bWlrZUBleGFtcGxlLmNvbSIsIm5iZiI6MTQ3ODIwMTkxNiwiZXhwIjoxNDc4MjA1NTE2LCJpYXQiOjE0NzgyMDE5MTYsImp0aSI6ImlkMTIzNDU2IiwidHlwIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS9yZWdpc3RlciJ9.B9zAEk_zm_Hz3cn8QLtAZNizZtHlZ0ENQ0nn5Jl734cYKO6Rn2JJct24u3UPXl01atIre2Z8oKIs9gpePpBsvR50Z-gCFtTGM_5dTPw45H4hY4KkvjP9JvnYGz4V4DeQDTZz-HUByKHSbKNm4pCmhGLcuF2SBwBmj-xOoy4eCc4Zf77fSXz9ctwv3FHCteXQnXD6M2m243fVkPiWq7qaE0Z0CfR0vRHjUcbA2qWVAM1kuOUSIIqhc0hZ5sVIW1UYZ4XHJ7unXG_SRRT6sYEE3FRRKhURutRRkhLtMMF14TpcQdZC0UjkcJVVMnQR0HQiDG-L7TModfNRhBO5PpjDng" 20 | ) 21 | 22 | BeforeEach(func() { 23 | response = httptest.NewRecorder() 24 | request = &http.Request{ 25 | Header: make(map[string][]string, 0), 26 | } 27 | }) 28 | 29 | Describe("handle", func() { 30 | Context("when a valid token is passed", func() { 31 | It("should return nil", func() { 32 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", hs256_jwt)) 33 | resp := NewMiddlewareJWT(shared_secret)(response, request) 34 | Expect(resp).ToNot(BeNil()) 35 | Expect(resp.Context).ToNot(BeNil()) 36 | Expect(resp.Context.Value(CONTEXT_JWT)).To(Equal(hs256_jwt)) 37 | }) 38 | 39 | It("lower case bearer is also accepted", func() { 40 | request.Header.Add("Authorization", fmt.Sprintf("bearer %s", hs256_jwt)) 41 | resp := NewMiddlewareJWT(shared_secret)(response, request) 42 | Expect(resp).ToNot(BeNil()) 43 | Expect(resp.Context).ToNot(BeNil()) 44 | Expect(resp.Context.Value(CONTEXT_JWT)).To(Equal(hs256_jwt)) 45 | }) 46 | }) 47 | 48 | Context("when no token is passed", func() { 49 | It("should return an error", func() { 50 | resp := NewMiddlewareJWT(shared_secret)(response, request) 51 | Expect(resp).ToNot(BeNil()) 52 | Expect(resp.Error()).To(ContainSubstring("no authentication provided")) 53 | }) 54 | }) 55 | 56 | Context("when an invalid token is passed", func() { 57 | It("should return an error", func() { 58 | request.Header.Add("Authorization", "Bearer foo") 59 | resp := NewMiddlewareJWT(shared_secret)(response, request) 60 | Expect(resp).ToNot(BeNil()) 61 | Expect(resp.Error()).To(ContainSubstring("invalid")) 62 | }) 63 | }) 64 | 65 | Context("when a token with an incorrectly signed signature is passed", func() { 66 | It("should return an error", func() { 67 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", rs256_jwt)) 68 | resp := NewMiddlewareJWT(shared_secret)(response, request) 69 | Expect(resp).ToNot(BeNil()) 70 | Expect(resp.Error()).To(ContainSubstring("signing method")) 71 | }) 72 | }) 73 | 74 | Context("token with wrong header format", func() { 75 | It("should return an error", func() { 76 | request.Header.Add("Authorization", fmt.Sprintf("foo %s", rs256_jwt)) 77 | resp := NewMiddlewareJWT(shared_secret)(response, request) 78 | Expect(resp).ToNot(BeNil()) 79 | Expect(resp.Error()).To(ContainSubstring("invalid authentication")) 80 | }) 81 | }) 82 | }) 83 | }) 84 | -------------------------------------------------------------------------------- /middleware_routelogger.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "net/http" 5 | 6 | log "github.com/sirupsen/logrus" 7 | ) 8 | 9 | /* 10 | MiddlewareRouteLogger creates a new handler to provide simple logging output for the specific route. You can use this middleware by specifying `rye.MiddlewareRouteLogger` 11 | when defining your routes. 12 | 13 | Example use case: 14 | 15 | routes.Handle("/some/route", a.Dependencies.MWHandler.Handle( 16 | []rye.Handler{ 17 | rye.MiddlewareRouteLogger(), 18 | yourHandler, 19 | })).Methods("PUT", "OPTIONS") 20 | */ 21 | func MiddlewareRouteLogger() func(rw http.ResponseWriter, req *http.Request) *Response { 22 | return func(rw http.ResponseWriter, r *http.Request) *Response { 23 | log.Infof("%s \"%s %s %s\"", r.RemoteAddr, r.Method, r.RequestURI, r.Proto) 24 | return nil 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /middleware_routelogger_test.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | 7 | . "github.com/onsi/ginkgo" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("Route Logger Middleware", func() { 12 | 13 | var ( 14 | request *http.Request 15 | response *httptest.ResponseRecorder 16 | ) 17 | 18 | BeforeEach(func() { 19 | response = httptest.NewRecorder() 20 | request = &http.Request{ 21 | Header: make(map[string][]string, 0), 22 | } 23 | }) 24 | 25 | Describe("MiddlewareRouteLogger", func() { 26 | Context("when the route logger is called", func() { 27 | It("should return nil", func() { 28 | resp := MiddlewareRouteLogger()(response, request) 29 | Expect(resp).To(BeNil()) 30 | }) 31 | }) 32 | }) 33 | }) 34 | -------------------------------------------------------------------------------- /middleware_static_file.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type staticFile struct { 8 | path string 9 | } 10 | 11 | /* 12 | NewStaticFile creates a new handler to serve a file from a path on the local filesystem. 13 | The path should be an absolute path -> i.e., it's up to the program using Rye to 14 | correctly determine what path it should be serving from. An example is available 15 | in the `static_example.go` file which shows setting up a path relative to 16 | the go executable. 17 | 18 | The purpose of this handler is to serve a specific file for any requests through the 19 | route handler. For instance, in the example below, any requests made to `/ui` will 20 | always be routed to /dist/index.html. This is important for single page applications 21 | which happen to use client-side routers. Therefore, you might have a webpack application 22 | with it's entrypoint `/dist/index.html`. That file may point at your `bundle.js`. 23 | Every request into the app will need to always be routed to `/dist/index.html` 24 | 25 | Example use case: 26 | 27 | routes.PathPrefix("/ui/").Handler(middlewareHandler.Handle([]rye.Handler{ 28 | rye.MiddlewareRouteLogger(), 29 | rye.NewStaticFile(pwd + "/dist/index.html"), 30 | })) 31 | 32 | */ 33 | func NewStaticFile(path string) func(rw http.ResponseWriter, req *http.Request) *Response { 34 | s := &staticFile{ 35 | path: path, 36 | } 37 | return s.handle 38 | } 39 | 40 | func (s *staticFile) handle(rw http.ResponseWriter, req *http.Request) *Response { 41 | http.ServeFile(rw, req, s.path) 42 | return nil 43 | } 44 | -------------------------------------------------------------------------------- /middleware_static_file_test.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "io/ioutil" 5 | "net/http" 6 | "net/http/httptest" 7 | "net/url" 8 | "os" 9 | 10 | . "github.com/onsi/ginkgo" 11 | . "github.com/onsi/gomega" 12 | ) 13 | 14 | var _ = Describe("Static File Middleware", func() { 15 | 16 | var ( 17 | request *http.Request 18 | response *httptest.ResponseRecorder 19 | 20 | path string 21 | testPath string 22 | ) 23 | 24 | BeforeEach(func() { 25 | response = httptest.NewRecorder() 26 | request = &http.Request{} 27 | testPath, _ = os.Getwd() 28 | }) 29 | 30 | Describe("handle", func() { 31 | Context("when a valid file is referenced", func() { 32 | It("should return a response", func() { 33 | path = "/static-examples/dist/index.html" 34 | url, _ := url.Parse("/thisstuff") 35 | request.URL = url 36 | resp := NewStaticFile(testPath+path)(response, request) 37 | Expect(resp).To(BeNil()) 38 | Expect(response).ToNot(BeNil()) 39 | Expect(response.Code).To(Equal(200)) 40 | 41 | body, err := ioutil.ReadAll(response.Body) 42 | Expect(err).To(BeNil()) 43 | Expect(body).To(ContainSubstring("Index.html")) 44 | }) 45 | 46 | It("should return a Moved Permanently response", func() { 47 | path = "/static-examples/dist/index.html" 48 | url, _ := url.Parse("/thisstuff") 49 | request.URL = url 50 | resp := NewStaticFile("")(response, request) 51 | Expect(resp).To(BeNil()) 52 | Expect(response).ToNot(BeNil()) 53 | Expect(response.Code).To(Equal(301)) 54 | }) 55 | 56 | It("should return a File Not Found response", func() { 57 | path = "/static-examples/dist/index.html" 58 | url, _ := url.Parse("/thisstuff") 59 | request.URL = url 60 | resp := NewStaticFile(path)(response, request) 61 | Expect(resp).To(BeNil()) 62 | Expect(response).ToNot(BeNil()) 63 | Expect(response.Code).To(Equal(404)) 64 | }) 65 | }) 66 | }) 67 | }) 68 | -------------------------------------------------------------------------------- /middleware_static_filesystem.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type staticFilesystem struct { 8 | path string 9 | stripPrefix string 10 | } 11 | 12 | /* 13 | NewStaticFilesystem creates a new handler to serve a filesystem from a path 14 | on the local filesystem. The path should be an absolute path -> i.e., it's 15 | up to the program using Rye to correctly determine what path it should be 16 | serving from. An example is available in the `static_example.go` file which 17 | shows setting up a path relative to the go executable. 18 | 19 | The primary benefit of this is to serve an entire set of files. You can 20 | pre-pend typical Rye middlewares to the chain. The static filesystem 21 | middleware should always be last in a chain, however. The `stripPrefix` allows 22 | you to ignore the prefix on requests so that the proper files will be matched. 23 | 24 | Example use case: 25 | 26 | routes.PathPrefix("/dist/").Handler(middlewareHandler.Handle([]rye.Handler{ 27 | rye.MiddlewareRouteLogger(), 28 | rye.NewStaticFilesystem(pwd+"/dist/", "/dist/"), 29 | })) 30 | 31 | */ 32 | func NewStaticFilesystem(path string, stripPrefix string) func(rw http.ResponseWriter, req *http.Request) *Response { 33 | s := &staticFilesystem{ 34 | path: path, 35 | stripPrefix: stripPrefix, 36 | } 37 | return s.handle 38 | } 39 | 40 | func (s *staticFilesystem) handle(rw http.ResponseWriter, req *http.Request) *Response { 41 | x := http.StripPrefix(s.stripPrefix, http.FileServer(http.Dir(s.path))) 42 | x.ServeHTTP(rw, req) 43 | return nil 44 | } 45 | -------------------------------------------------------------------------------- /middleware_static_filesystem_test.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "io/ioutil" 5 | "net/http" 6 | "net/http/httptest" 7 | "os" 8 | 9 | . "github.com/onsi/ginkgo" 10 | . "github.com/onsi/gomega" 11 | ) 12 | 13 | var _ = Describe("Static File Middleware", func() { 14 | 15 | var ( 16 | request *http.Request 17 | response *httptest.ResponseRecorder 18 | 19 | path string 20 | testPath string 21 | ) 22 | 23 | BeforeEach(func() { 24 | response = httptest.NewRecorder() 25 | request = &http.Request{} 26 | testPath, _ = os.Getwd() 27 | }) 28 | 29 | Describe("handle", func() { 30 | Context("when a valid file is referenced", func() { 31 | It("should return a response", func() { 32 | path = "/static-examples/dist/" 33 | 34 | request, _ = http.NewRequest("GET", "/dist/test.html", nil) 35 | 36 | resp := NewStaticFilesystem(testPath+path, "/dist/")(response, request) 37 | Expect(resp).To(BeNil()) 38 | Expect(response).ToNot(BeNil()) 39 | Expect(response.Code).To(Equal(200)) 40 | 41 | body, err := ioutil.ReadAll(response.Body) 42 | Expect(err).To(BeNil()) 43 | Expect(body).To(ContainSubstring("Test.html")) 44 | }) 45 | 46 | It("should return Index.html when request is just path", func() { 47 | path = "/static-examples/dist/" 48 | 49 | request, _ := http.NewRequest("GET", "/dist/", nil) 50 | 51 | resp := NewStaticFilesystem(testPath+path, "/dist/")(response, request) 52 | Expect(resp).To(BeNil()) 53 | Expect(response).ToNot(BeNil()) 54 | Expect(response.Code).To(Equal(200)) 55 | 56 | body, err := ioutil.ReadAll(response.Body) 57 | Expect(err).To(BeNil()) 58 | Expect(body).To(ContainSubstring("Index.html")) 59 | }) 60 | 61 | It("should return Index.html when strip prefix is empty", func() { 62 | path = "/static-examples/dist/" 63 | 64 | request, _ := http.NewRequest("GET", "/", nil) 65 | 66 | resp := NewStaticFilesystem(testPath+path, "")(response, request) 67 | Expect(resp).To(BeNil()) 68 | Expect(response).ToNot(BeNil()) 69 | Expect(response.Code).To(Equal(200)) 70 | 71 | body, err := ioutil.ReadAll(response.Body) 72 | Expect(err).To(BeNil()) 73 | Expect(body).To(ContainSubstring("Index.html")) 74 | }) 75 | 76 | It("should return Index.html when strip prefix is empty", func() { 77 | path = "/static-examples/dist/" 78 | 79 | request, _ := http.NewRequest("GET", "/ASDads.HTML", nil) 80 | 81 | resp := NewStaticFilesystem(testPath+path, "")(response, request) 82 | Expect(resp).To(BeNil()) 83 | Expect(response).ToNot(BeNil()) 84 | Expect(response.Code).To(Equal(404)) 85 | }) 86 | 87 | It("should return test.css on subpath", func() { 88 | path = "/static-examples/dist/" 89 | 90 | request, _ := http.NewRequest("GET", "/styles/test.css", nil) 91 | 92 | resp := NewStaticFilesystem(testPath+path, "")(response, request) 93 | Expect(resp).To(BeNil()) 94 | Expect(response).ToNot(BeNil()) 95 | Expect(response.Code).To(Equal(200)) 96 | 97 | body, err := ioutil.ReadAll(response.Body) 98 | Expect(err).To(BeNil()) 99 | Expect(body).To(ContainSubstring("test.css")) 100 | }) 101 | }) 102 | }) 103 | }) 104 | -------------------------------------------------------------------------------- /rye.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "net/http" 8 | "reflect" 9 | "runtime" 10 | "strconv" 11 | "strings" 12 | "time" 13 | 14 | //log "github.com/sirupsen/logrus" 15 | 16 | "github.com/cactus/go-statsd-client/statsd" 17 | ) 18 | 19 | //go:generate counterfeiter -o fakes/statsdfakes/fake_statter.go $GOPATH/src/github.com/cactus/go-statsd-client/statsd/client.go Statter 20 | //go:generate perl -pi -e 's/$GOPATH\/src\///g' fakes/statsdfakes/fake_statter.go 21 | 22 | // MWHandler struct is used to configure and access rye's basic functionality. 23 | type MWHandler struct { 24 | Config Config 25 | beforeHandlers []Handler 26 | } 27 | 28 | // CustomStatter allows the client to log any additional statsD metrics Rye 29 | // computes around the request handler. 30 | type CustomStatter interface { 31 | ReportStats(handlerName string, elapsedTime time.Duration, req *http.Request, resp *Response) error 32 | } 33 | 34 | // Config struct allows you to set a reference to a statsd.Statter and include it's stats rate. 35 | type Config struct { 36 | Statter statsd.Statter 37 | StatRate float32 38 | 39 | // toggle types of stats sent 40 | NoErrStats bool 41 | NoDurationStats bool 42 | NoStatusCodeStats bool 43 | 44 | // Customer Statter for the client 45 | CustomStatter CustomStatter 46 | } 47 | 48 | // JSONStatus is a simple container used for conveying status messages. 49 | type JSONStatus struct { 50 | Message string `json:"message"` 51 | Status string `json:"status"` 52 | } 53 | 54 | // Response struct is utilized by middlewares as a way to share state; 55 | // ie. a middleware can return a *Response as a way to indicate 56 | // that further middleware execution should stop (without an error) or return a 57 | // a hard error by setting `Err` + `StatusCode`. 58 | type Response struct { 59 | Err error 60 | StatusCode int 61 | StopExecution bool 62 | Context context.Context 63 | } 64 | 65 | // Error bubbles a response error providing an implementation of the Error interface. 66 | // It returns the error as a string. 67 | func (r *Response) Error() string { 68 | return r.Err.Error() 69 | } 70 | 71 | // Handler is the primary type that any rye middleware must implement to be called in the Handle() function. 72 | // In order to use this you must return a *rye.Response. 73 | type Handler func(w http.ResponseWriter, r *http.Request) *Response 74 | 75 | // Constructor for new instantiating new rye instances 76 | // It returns a constructed *MWHandler instance. 77 | func NewMWHandler(config Config) *MWHandler { 78 | return &MWHandler{ 79 | Config: config, 80 | } 81 | } 82 | 83 | // Use adds a handler to every request. All handlers set up with use 84 | // are fired first and then any route specific handlers are called 85 | func (m *MWHandler) Use(handler Handler) { 86 | m.beforeHandlers = append(m.beforeHandlers, handler) 87 | } 88 | 89 | // The Handle function is the primary way to set up your chain of middlewares to be called by rye. 90 | // It returns a http.HandlerFunc from net/http that can be set as a route in your http server. 91 | func (m *MWHandler) Handle(customHandlers []Handler) http.Handler { 92 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 93 | exit := false 94 | for _, handler := range m.beforeHandlers { 95 | exit, r = m.do(w, r, handler) 96 | if exit { 97 | return 98 | } 99 | } 100 | 101 | for _, handler := range customHandlers { 102 | exit, r = m.do(w, r, handler) 103 | if exit { 104 | return 105 | } 106 | } 107 | }) 108 | } 109 | 110 | func (m *MWHandler) do(w http.ResponseWriter, r *http.Request, handler Handler) (bool, *http.Request) { 111 | var resp *Response 112 | 113 | // Record handler runtime 114 | func() { 115 | statusCode := "2xx" 116 | startTime := time.Now() 117 | 118 | if resp = handler(w, r); resp != nil { 119 | func() { 120 | // Stop execution if it's passed 121 | if resp.StopExecution { 122 | return 123 | } 124 | 125 | // If a context is returned, we will 126 | // replace the current request with a new request 127 | if resp.Context != nil { 128 | r = r.WithContext(resp.Context) 129 | return 130 | } 131 | 132 | // If there's no error but we have a response 133 | if resp.Err == nil { 134 | resp.Err = errors.New("Problem with middleware; neither Err or StopExecution is set") 135 | resp.StatusCode = http.StatusInternalServerError 136 | } 137 | 138 | // Now assume we have an error. 139 | if m.Config.Statter != nil && resp.StatusCode >= 500 { 140 | go m.reportError() 141 | } 142 | 143 | // Write the error out 144 | WriteJSONStatus(w, "error", resp.Error(), resp.StatusCode) 145 | }() 146 | 147 | if resp.StatusCode > 0 { 148 | statusCode = strconv.Itoa(resp.StatusCode) 149 | } 150 | } 151 | 152 | handlerName := getFuncName(handler) 153 | 154 | if m.Config.Statter != nil { 155 | // Record runtime metric 156 | go m.reportDuration(handlerName, startTime) 157 | 158 | // Record status code metric (default 2xx) 159 | go m.reportStatusCode(handlerName, statusCode) 160 | } 161 | 162 | // If a CustomStatter is set, send the handler metrics to it. 163 | // This allows the client to handle these metrics however it wants. 164 | if m.Config.CustomStatter != nil && resp != nil { 165 | go m.Config.CustomStatter.ReportStats(handlerName, time.Since(startTime), r, resp) 166 | } 167 | }() 168 | 169 | // stop executing rest of the 170 | // handlers if we encounter an error 171 | if resp != nil && (resp.StopExecution || resp.Err != nil) { 172 | return true, r 173 | } 174 | 175 | return false, r 176 | } 177 | 178 | func (m *MWHandler) reportError() { 179 | if m.Config.NoErrStats { 180 | return 181 | } 182 | 183 | m.Config.Statter.Inc("errors", 1, m.Config.StatRate) 184 | } 185 | 186 | func (m *MWHandler) reportDuration(handlerName string, startTime time.Time) { 187 | if m.Config.NoDurationStats { 188 | return 189 | } 190 | 191 | m.Config.Statter.TimingDuration( 192 | "handlers."+handlerName+".runtime", 193 | time.Since(startTime), // delta 194 | m.Config.StatRate, 195 | ) 196 | } 197 | 198 | func (m *MWHandler) reportStatusCode(handlerName string, statusCode string) { 199 | if m.Config.NoStatusCodeStats { 200 | return 201 | } 202 | 203 | m.Config.Statter.Inc( 204 | "handlers."+handlerName+"."+statusCode, 205 | 1, 206 | m.Config.StatRate, 207 | ) 208 | } 209 | 210 | // WriteJSONStatus is a wrapper for WriteJSONResponse that returns a marshalled JSONStatus blob 211 | func WriteJSONStatus(rw http.ResponseWriter, status, message string, statusCode int) { 212 | jsonData, _ := json.Marshal(&JSONStatus{ 213 | Message: message, 214 | Status: status, 215 | }) 216 | 217 | WriteJSONResponse(rw, statusCode, jsonData) 218 | } 219 | 220 | // WriteJSONResponse writes data and status code to the ResponseWriter 221 | func WriteJSONResponse(rw http.ResponseWriter, statusCode int, content []byte) { 222 | rw.Header().Set("Content-Type", "application/json") 223 | rw.WriteHeader(statusCode) 224 | rw.Write(content) 225 | } 226 | 227 | // getFuncName uses reflection to determine a given function name 228 | // It returns a string version of the function name (and performs string cleanup) 229 | func getFuncName(i interface{}) string { 230 | fullName := runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() 231 | ns := strings.Split(fullName, ".") 232 | 233 | // when we get a method (not a raw function) it comes attached to whatever struct is in its 234 | // method receiver via a function closure, this is not precisely the same as that method itself 235 | // so the compiler appends "-fm" so the name of the closure does not conflict with the actual function 236 | // http://grokbase.com/t/gg/golang-nuts/153jyb5b7p/go-nuts-fm-suffix-in-function-name-what-does-it-mean#20150318ssinqqzrmhx2ep45wjkxsa4rua 237 | return strings.TrimSuffix(ns[len(ns)-1], ")-fm") 238 | } 239 | -------------------------------------------------------------------------------- /rye_suite_test.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/sirupsen/logrus" 7 | 8 | . "github.com/onsi/ginkgo" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | func TestAPISuite(t *testing.T) { 13 | // reduce the noise when testing 14 | logrus.SetLevel(logrus.FatalLevel) 15 | 16 | RegisterFailHandler(Fail) 17 | RunSpecs(t, "Rye Suite") 18 | } 19 | -------------------------------------------------------------------------------- /rye_test.go: -------------------------------------------------------------------------------- 1 | package rye 2 | 3 | import ( 4 | "strconv" 5 | 6 | . "github.com/onsi/ginkgo" 7 | . "github.com/onsi/gomega" 8 | 9 | "context" 10 | "errors" 11 | "fmt" 12 | "net/http" 13 | "net/http/httptest" 14 | "os" 15 | "time" 16 | 17 | "github.com/InVisionApp/rye/fakes/statsdfakes" 18 | "github.com/onsi/gomega/types" 19 | ) 20 | 21 | const ( 22 | RYE_TEST_HANDLER_ENV_VAR = "RYE_TEST_HANDLER_PASS" 23 | RYE_TEST_HANDLER_2_ENV_VAR = "RYE_TEST_HANDLER_2_PASS" 24 | RYE_TEST_BEFORE_ENV_VAR = "RYE_TEST_HANDLER_BEFORE_PASS" 25 | ) 26 | 27 | type statsInc struct { 28 | Name string 29 | Time int64 30 | StatRate float32 31 | } 32 | 33 | type statsTiming struct { 34 | Name string 35 | Time time.Duration 36 | StatRate float32 37 | } 38 | 39 | var reportedStats = make(chan fakeReportedStats) 40 | 41 | type fakeReportedStats struct { 42 | HandlerName string 43 | Duration time.Duration 44 | Request *http.Request 45 | Response *Response 46 | } 47 | 48 | type fakeCustomStatter struct{} 49 | 50 | func (fcs *fakeCustomStatter) ReportStats(handler string, dur time.Duration, req *http.Request, res *Response) error { 51 | reportedStats <- fakeReportedStats{handler, dur, req, res} 52 | return nil 53 | } 54 | 55 | var _ = Describe("Rye", func() { 56 | 57 | var ( 58 | request *http.Request 59 | response *httptest.ResponseRecorder 60 | mwHandler *MWHandler 61 | ryeConfig Config 62 | fakeStatter *statsdfakes.FakeStatter 63 | fakeClientStatter *fakeCustomStatter 64 | inc chan statsInc 65 | timing chan statsTiming 66 | ) 67 | 68 | const ( 69 | STATRATE float32 = 1 70 | ) 71 | 72 | BeforeEach(func() { 73 | fakeStatter = &statsdfakes.FakeStatter{} 74 | fakeClientStatter = &fakeCustomStatter{} 75 | ryeConfig = Config{ 76 | Statter: fakeStatter, 77 | StatRate: STATRATE, 78 | } 79 | mwHandler = NewMWHandler(ryeConfig) 80 | 81 | response = httptest.NewRecorder() 82 | request = &http.Request{ 83 | Header: make(map[string][]string, 0), 84 | } 85 | 86 | os.Unsetenv(RYE_TEST_HANDLER_ENV_VAR) 87 | os.Unsetenv(RYE_TEST_BEFORE_ENV_VAR) 88 | os.Unsetenv(RYE_TEST_HANDLER_2_ENV_VAR) 89 | 90 | inc = make(chan statsInc, 2) 91 | timing = make(chan statsTiming) 92 | 93 | fakeStatter.IncStub = func(name string, time int64, statrate float32) error { 94 | inc <- statsInc{name, time, statrate} 95 | return nil 96 | } 97 | 98 | fakeStatter.TimingDurationStub = func(name string, time time.Duration, statrate float32) error { 99 | timing <- statsTiming{name, time, statrate} 100 | return nil 101 | } 102 | }) 103 | 104 | AfterEach(func() { 105 | os.Unsetenv(RYE_TEST_HANDLER_ENV_VAR) 106 | }) 107 | 108 | Describe("NewMWHandler", func() { 109 | Context("when instantiating a mwhandler", func() { 110 | It("should have correct attributes", func() { 111 | ryeConfig := Config{ 112 | Statter: fakeStatter, 113 | StatRate: STATRATE, 114 | } 115 | handler := NewMWHandler(ryeConfig) 116 | Expect(handler).NotTo(BeNil()) 117 | Expect(handler.Config.Statter).To(Equal(fakeStatter)) 118 | Expect(handler.Config.StatRate).To(Equal(STATRATE)) 119 | }) 120 | 121 | It("should have attributes with default values when passed an empty config", func() { 122 | handler := NewMWHandler(Config{}) 123 | Expect(handler).NotTo(BeNil()) 124 | Expect(handler.Config.Statter).To(BeNil()) 125 | Expect(handler.Config.StatRate).To(Equal(float32(0.0))) 126 | }) 127 | }) 128 | }) 129 | 130 | Describe("Handle", func() { 131 | Context("when adding a valid handler", func() { 132 | It("should return valid HandlerFunc", func() { 133 | 134 | h := mwHandler.Handle([]Handler{successHandler}) 135 | h.ServeHTTP(response, request) 136 | 137 | Expect(h).ToNot(BeNil()) 138 | Expect(h).To(BeAssignableToTypeOf(func(http.ResponseWriter, *http.Request) {})) 139 | Expect(os.Getenv(RYE_TEST_HANDLER_ENV_VAR)).To(Equal("1")) 140 | 141 | Eventually(inc).Should(Receive(Equal(statsInc{"handlers.successHandler.2xx", 1, float32(STATRATE)}))) 142 | Eventually(timing).Should(Receive(HaveTiming("handlers.successHandler.runtime", float32(STATRATE)))) 143 | }) 144 | }) 145 | 146 | Context("when adding a global handler it should get called for multiple handler chains", func() { 147 | 148 | It("should execute before handlers and end in success", func() { 149 | 150 | handlerWithGlobals := NewMWHandler(ryeConfig) 151 | handlerWithGlobals.Use(beforeHandler) 152 | handlerWithGlobals.Use(beforeHandler) 153 | handlerWithGlobals.Use(beforeHandler) 154 | 155 | h := handlerWithGlobals.Handle([]Handler{successHandler}) 156 | h.ServeHTTP(response, request) 157 | 158 | Expect(h).ToNot(BeNil()) 159 | Expect(h).To(BeAssignableToTypeOf(func(http.ResponseWriter, *http.Request) {})) 160 | Expect(os.Getenv(RYE_TEST_HANDLER_ENV_VAR)).To(Equal("1")) 161 | Expect(os.Getenv(RYE_TEST_BEFORE_ENV_VAR)).To(Equal("3")) 162 | }) 163 | 164 | It("should execute before handlers and multiple Handles should manage their closure correctly", func() { 165 | 166 | handlerWithGlobals := NewMWHandler(ryeConfig) 167 | handlerWithGlobals.Use(beforeHandler) 168 | handlerWithGlobals.Use(beforeHandler) 169 | handlerWithGlobals.Use(beforeHandler) 170 | 171 | h := handlerWithGlobals.Handle([]Handler{successHandler}) 172 | 173 | h2 := handlerWithGlobals.Handle([]Handler{success2Handler}) 174 | 175 | h.ServeHTTP(response, request) 176 | h2.ServeHTTP(response, request) 177 | 178 | Expect(h).ToNot(BeNil()) 179 | Expect(h).To(BeAssignableToTypeOf(func(http.ResponseWriter, *http.Request) {})) 180 | 181 | Expect(h2).ToNot(BeNil()) 182 | Expect(h2).To(BeAssignableToTypeOf(func(http.ResponseWriter, *http.Request) {})) 183 | 184 | before := os.Getenv(RYE_TEST_BEFORE_ENV_VAR) 185 | handler1 := os.Getenv(RYE_TEST_HANDLER_ENV_VAR) 186 | handler2 := os.Getenv(RYE_TEST_HANDLER_2_ENV_VAR) 187 | 188 | Expect(before).To(Equal("6")) 189 | Expect(handler1).To(Equal("1")) 190 | Expect(handler2).To(Equal("1")) 191 | }) 192 | }) 193 | 194 | Context("when a handler returns a response with StopExecution", func() { 195 | It("should not execute any further handlers", func() { 196 | request.Method = "OPTIONS" 197 | 198 | h := mwHandler.Handle([]Handler{stopExecutionHandler, successHandler}) 199 | h.ServeHTTP(response, request) 200 | 201 | Expect(os.Getenv(RYE_TEST_HANDLER_ENV_VAR)).ToNot(Equal("1")) 202 | }) 203 | }) 204 | 205 | Context("when a handler returns a response with StopExecution and StatusCode", func() { 206 | It("should not execute any further handlers", func() { 207 | h := mwHandler.Handle([]Handler{stopExecutionWithStatusHandler, successHandler}) 208 | h.ServeHTTP(response, request) 209 | 210 | Eventually(inc).Should(Receive(Equal(statsInc{"handlers.stopExecutionWithStatusHandler.404", 1, float32(STATRATE)}))) 211 | Expect(os.Getenv(RYE_TEST_HANDLER_ENV_VAR)).ToNot(Equal("1")) 212 | }) 213 | }) 214 | 215 | Context("when a before handler returns a response with StopExecution", func() { 216 | It("should not execute any further handlers", func() { 217 | request.Method = "OPTIONS" 218 | 219 | mwHandler.beforeHandlers = []Handler{stopExecutionHandler} 220 | 221 | h := mwHandler.Handle([]Handler{successHandler}) 222 | h.ServeHTTP(response, request) 223 | 224 | Expect(os.Getenv(RYE_TEST_HANDLER_ENV_VAR)).ToNot(Equal("1")) 225 | }) 226 | }) 227 | 228 | Context("when a handler returns a response with Context", func() { 229 | It("should add that new context to the next passed request", func() { 230 | h := mwHandler.Handle([]Handler{contextHandler, checkContextHandler}) 231 | h.ServeHTTP(response, request) 232 | 233 | Expect(os.Getenv(RYE_TEST_HANDLER_ENV_VAR)).To(Equal("1")) 234 | }) 235 | }) 236 | 237 | Context("when a beforehandler returns a response with Context", func() { 238 | It("should add that new context to the next passed request", func() { 239 | mwHandler.beforeHandlers = []Handler{contextHandler} 240 | 241 | h := mwHandler.Handle([]Handler{checkContextHandler}) 242 | h.ServeHTTP(response, request) 243 | 244 | Expect(os.Getenv(RYE_TEST_HANDLER_ENV_VAR)).To(Equal("1")) 245 | }) 246 | }) 247 | 248 | Context("when a handler returns a response with neither error or StopExecution set", func() { 249 | It("should return a 500 + error message (and stop execution)", func() { 250 | h := mwHandler.Handle([]Handler{badResponseHandler, successHandler}) 251 | h.ServeHTTP(response, request) 252 | 253 | Expect(response.Code).To(Equal(http.StatusInternalServerError)) 254 | Expect(os.Getenv(RYE_TEST_HANDLER_ENV_VAR)).ToNot(Equal("1")) 255 | }) 256 | }) 257 | 258 | Context("when adding an erroneous handler", func() { 259 | It("should interrupt handler chain and set a response status code", func() { 260 | 261 | h := mwHandler.Handle([]Handler{failureHandler}) 262 | h.ServeHTTP(response, request) 263 | 264 | Expect(h).ToNot(BeNil()) 265 | Expect(h).To(BeAssignableToTypeOf(func(http.ResponseWriter, *http.Request) {})) 266 | Expect(response.Code).To(Equal(505)) 267 | Eventually(inc).Should(Receive(Equal(statsInc{"errors", 1, float32(STATRATE)}))) 268 | Eventually(inc).Should(Receive(Equal(statsInc{"handlers.failureHandler.505", 1, float32(STATRATE)}))) 269 | Eventually(timing).Should(Receive(HaveTiming("handlers.failureHandler.runtime", float32(STATRATE)))) 270 | }) 271 | }) 272 | 273 | Context("when the statter is not set", func() { 274 | It("should not call Inc or TimingDuration", func() { 275 | 276 | ryeConfig := Config{} 277 | handler := NewMWHandler(ryeConfig) 278 | 279 | h := handler.Handle([]Handler{successHandler}) 280 | h.ServeHTTP(response, request) 281 | 282 | Expect(fakeStatter.IncCallCount()).To(Equal(0)) 283 | Expect(fakeStatter.TimingDurationCallCount()).To(Equal(0)) 284 | }) 285 | }) 286 | 287 | Context("when error stats are turned off", func() { 288 | It("should not call Inc or TimingDuration", func() { 289 | ryeConfig := Config{ 290 | Statter: fakeStatter, 291 | NoErrStats: true, 292 | } 293 | 294 | handler := NewMWHandler(ryeConfig) 295 | //use the failureHandler so an error would be reported 296 | h := handler.Handle([]Handler{failureHandler}) 297 | h.ServeHTTP(response, request) 298 | 299 | time.Sleep(time.Millisecond * 10) 300 | 301 | Expect(fakeStatter.IncCallCount()).To(Equal(1)) 302 | metric, _, _ := fakeStatter.IncArgsForCall(0) 303 | Expect(metric).ToNot(Equal("errors")) 304 | 305 | Expect(fakeStatter.TimingDurationCallCount()).To(Equal(1)) 306 | }) 307 | }) 308 | 309 | Context("when statusCode stats are turned off", func() { 310 | It("should not call Inc or TimingDuration", func() { 311 | ryeConfig := Config{ 312 | Statter: fakeStatter, 313 | NoStatusCodeStats: true, 314 | } 315 | 316 | handler := NewMWHandler(ryeConfig) 317 | //use the failureHandler so an error would be reported 318 | h := handler.Handle([]Handler{failureHandler}) 319 | h.ServeHTTP(response, request) 320 | 321 | time.Sleep(time.Millisecond * 10) 322 | 323 | Expect(fakeStatter.IncCallCount()).To(Equal(1)) 324 | metric, _, _ := fakeStatter.IncArgsForCall(0) 325 | Expect(metric).ToNot(ContainSubstring("handlers.")) 326 | 327 | Expect(fakeStatter.TimingDurationCallCount()).To(Equal(1)) 328 | }) 329 | }) 330 | 331 | Context("when timming stats are turned off", func() { 332 | It("should not call Inc or TimingDuration", func() { 333 | ryeConfig := Config{ 334 | Statter: fakeStatter, 335 | NoDurationStats: true, 336 | } 337 | 338 | handler := NewMWHandler(ryeConfig) 339 | //use the failureHandler so an error would be reported 340 | h := handler.Handle([]Handler{failureHandler}) 341 | h.ServeHTTP(response, request) 342 | 343 | time.Sleep(time.Millisecond * 10) 344 | 345 | Expect(fakeStatter.TimingDurationCallCount()).To(Equal(0)) 346 | Expect(fakeStatter.IncCallCount()).To(Equal(2)) 347 | }) 348 | }) 349 | 350 | Context("when a custom statter is supplied", func() { 351 | It("should call the ReportStats method", func() { 352 | ryeConfig := Config{ 353 | Statter: fakeStatter, 354 | StatRate: STATRATE, 355 | CustomStatter: fakeClientStatter, 356 | } 357 | 358 | handler := NewMWHandler(ryeConfig) 359 | h := handler.Handle([]Handler{successWithResponse}) 360 | h.ServeHTTP(response, request) 361 | 362 | Expect(h).ToNot(BeNil()) 363 | Expect(h).To(BeAssignableToTypeOf(func(http.ResponseWriter, *http.Request) {})) 364 | 365 | Eventually(inc).Should(Receive(Equal(statsInc{"handlers.successWithResponse.200", 1, float32(STATRATE)}))) 366 | Eventually(timing).Should(Receive(HaveTiming("handlers.successWithResponse.runtime", float32(STATRATE)))) 367 | 368 | var receivedReportedStats fakeReportedStats 369 | var resp *Response 370 | 371 | Eventually(reportedStats).Should(Receive(&receivedReportedStats)) 372 | Expect(receivedReportedStats.HandlerName).To(Equal("successWithResponse")) 373 | Expect(receivedReportedStats.Duration.Seconds()/1000 > 0).To(Equal(true)) 374 | Expect(receivedReportedStats.Request).To(BeAssignableToTypeOf(request)) 375 | Expect(receivedReportedStats.Response).To(BeAssignableToTypeOf(resp)) 376 | Expect(receivedReportedStats.Response.StatusCode).To(Equal(200)) 377 | }) 378 | }) 379 | 380 | Context("when a custom statter is NOT supplied", func() { 381 | It("should not call the ReportStats method", func() { 382 | ryeConfig := Config{ 383 | Statter: fakeStatter, 384 | StatRate: STATRATE, 385 | } 386 | 387 | handler := NewMWHandler(ryeConfig) 388 | h := handler.Handle([]Handler{successWithResponse}) 389 | h.ServeHTTP(response, request) 390 | 391 | Expect(h).ToNot(BeNil()) 392 | Expect(h).To(BeAssignableToTypeOf(func(http.ResponseWriter, *http.Request) {})) 393 | 394 | Eventually(inc).Should(Receive(Equal(statsInc{"handlers.successWithResponse.200", 1, float32(STATRATE)}))) 395 | Eventually(timing).Should(Receive(HaveTiming("handlers.successWithResponse.runtime", float32(STATRATE)))) 396 | 397 | time.Sleep(time.Millisecond * 10) 398 | 399 | var receivedReportedStats fakeReportedStats 400 | 401 | Expect(receivedReportedStats.HandlerName).To(Equal("")) 402 | Expect(receivedReportedStats.Duration.Nanoseconds()).To(Equal(int64(0))) 403 | Expect(receivedReportedStats.Request).To(BeNil()) 404 | Expect(receivedReportedStats.Response).To(BeNil()) 405 | }) 406 | }) 407 | }) 408 | 409 | Describe("getFuncName", func() { 410 | It("should return the name of the function as a string", func() { 411 | funcName := getFuncName(testFunc) 412 | Expect(funcName).To(Equal("testFunc")) 413 | }) 414 | }) 415 | 416 | Describe("Error()", func() { 417 | Context("when an error is set on Response struct", func() { 418 | It("should return a string if you call Error()", func() { 419 | resp := &Response{ 420 | Err: errors.New("some error"), 421 | } 422 | 423 | Expect(resp.Error()).To(Equal("some error")) 424 | }) 425 | }) 426 | }) 427 | }) 428 | 429 | func beforeHandler(rw http.ResponseWriter, r *http.Request) *Response { 430 | counter := os.Getenv(RYE_TEST_BEFORE_ENV_VAR) 431 | counterInt, err := strconv.Atoi(counter) 432 | if err != nil { 433 | counterInt = 0 434 | } 435 | counterInt++ 436 | os.Setenv(RYE_TEST_BEFORE_ENV_VAR, strconv.Itoa(counterInt)) 437 | return nil 438 | } 439 | 440 | func contextHandler(rw http.ResponseWriter, r *http.Request) *Response { 441 | ctx := context.WithValue(r.Context(), "test-val", "exists") 442 | return &Response{Context: ctx} 443 | } 444 | 445 | func checkContextHandler(rw http.ResponseWriter, r *http.Request) *Response { 446 | testVal := r.Context().Value("test-val") 447 | if testVal == "exists" { 448 | os.Setenv(RYE_TEST_HANDLER_ENV_VAR, "1") 449 | } 450 | return nil 451 | } 452 | 453 | func successHandler(rw http.ResponseWriter, r *http.Request) *Response { 454 | os.Setenv(RYE_TEST_HANDLER_ENV_VAR, "1") 455 | return nil 456 | } 457 | 458 | func success2Handler(rw http.ResponseWriter, r *http.Request) *Response { 459 | os.Setenv(RYE_TEST_HANDLER_2_ENV_VAR, "1") 460 | return nil 461 | } 462 | 463 | func successWithResponse(rw http.ResponseWriter, r *http.Request) *Response { 464 | return &Response{ 465 | StatusCode: 200, 466 | Err: nil, 467 | StopExecution: false, 468 | Context: context.Background(), 469 | } 470 | } 471 | 472 | func badResponseHandler(rw http.ResponseWriter, r *http.Request) *Response { 473 | return &Response{} 474 | } 475 | 476 | func failureHandler(rw http.ResponseWriter, r *http.Request) *Response { 477 | return &Response{ 478 | StatusCode: 505, 479 | Err: fmt.Errorf("Foo"), 480 | } 481 | } 482 | 483 | func stopExecutionHandler(rw http.ResponseWriter, r *http.Request) *Response { 484 | return &Response{ 485 | StopExecution: true, 486 | } 487 | } 488 | 489 | func stopExecutionWithStatusHandler(rw http.ResponseWriter, r *http.Request) *Response { 490 | return &Response{ 491 | StopExecution: true, 492 | StatusCode: 404, 493 | } 494 | } 495 | 496 | func testFunc() {} 497 | 498 | func HaveTiming(name string, statrate float32) types.GomegaMatcher { 499 | return WithTransform( 500 | func(p statsTiming) bool { 501 | return p.Name == name && p.StatRate == statrate 502 | }, BeTrue()) 503 | } 504 | -------------------------------------------------------------------------------- /static-examples/dist/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | Index.html 4 | 5 | 6 | 7 |

8 | Index.html 9 |

10 | 11 | -------------------------------------------------------------------------------- /static-examples/dist/styles/index.css: -------------------------------------------------------------------------------- 1 | h1 { 2 | color: #0000AC; 3 | margin-left: 120px; 4 | margin-top: 120px; 5 | } -------------------------------------------------------------------------------- /static-examples/dist/styles/test.css: -------------------------------------------------------------------------------- 1 | /* test.css */ 2 | h1 { 3 | color: #ff0000; 4 | margin-left: 20px; 5 | margin-top: 20px; 6 | } -------------------------------------------------------------------------------- /static-examples/dist/test.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | Test.html 4 | 5 | 6 | 7 |

8 | Test.html 9 |

10 | 11 | -------------------------------------------------------------------------------- /static-examples/static_example.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "os" 7 | 8 | "github.com/InVisionApp/rye" 9 | log "github.com/sirupsen/logrus" 10 | "github.com/cactus/go-statsd-client/statsd" 11 | "github.com/gorilla/mux" 12 | ) 13 | 14 | func main() { 15 | statsdClient, err := statsd.NewBufferedClient("localhost:12345", "my_service", 1.0, 0) 16 | if err != nil { 17 | log.Fatalf("Unable to instantiate statsd client: %v", err.Error()) 18 | } 19 | 20 | config := rye.Config{ 21 | Statter: statsdClient, 22 | StatRate: 1.0, 23 | } 24 | 25 | middlewareHandler := rye.NewMWHandler(config) 26 | 27 | pwd, err := os.Getwd() 28 | if err != nil { 29 | log.Fatalf("NewStaticFile: Could not get working directory.") 30 | } 31 | 32 | routes := mux.NewRouter().StrictSlash(true) 33 | 34 | routes.Handle("/", middlewareHandler.Handle([]rye.Handler{ 35 | middlewareFirstHandler, 36 | homeHandler, 37 | })).Methods("GET") 38 | 39 | routes.PathPrefix("/dist/").Handler(middlewareHandler.Handle([]rye.Handler{ 40 | rye.MiddlewareRouteLogger(), 41 | rye.NewStaticFilesystem(pwd+"/dist/", "/dist/"), 42 | })) 43 | 44 | routes.PathPrefix("/ui/").Handler(middlewareHandler.Handle([]rye.Handler{ 45 | rye.MiddlewareRouteLogger(), 46 | rye.NewStaticFile(pwd + "/dist/index.html"), 47 | })) 48 | 49 | log.Infof("API server listening on %v", "localhost:8181") 50 | 51 | srv := &http.Server{ 52 | Addr: "localhost:8181", 53 | Handler: routes, 54 | } 55 | 56 | srv.ListenAndServe() 57 | } 58 | 59 | func middlewareFirstHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 60 | log.Infof("Middleware handler has fired!") 61 | return nil 62 | } 63 | 64 | func homeHandler(rw http.ResponseWriter, r *http.Request) *rye.Response { 65 | log.Infof("Home handler has fired!") 66 | 67 | fmt.Fprint(rw, "This is the home handler") 68 | return nil 69 | } 70 | --------------------------------------------------------------------------------