├── .github ├── CODEOWNERS ├── FUNDING.yml └── workflows │ └── ci.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── README.md ├── basic_auth.go ├── basic_auth_test.go ├── benchmarks.go ├── benchmarks_test.go ├── blackwords.go ├── blackwords_test.go ├── cache_control.go ├── cache_control_test.go ├── depricattion.go ├── depricattion_test.go ├── file_server.go ├── file_server_test.go ├── go.mod ├── go.sum ├── gzip.go ├── gzip_test.go ├── httperrors.go ├── httperrors_test.go ├── logger ├── logger.go ├── logger_test.go └── options.go ├── metrics.go ├── metrics_test.go ├── middleware.go ├── middleware_test.go ├── nocache.go ├── nocache_test.go ├── onlyfrom.go ├── onlyfrom_test.go ├── profiler.go ├── realip ├── real.go └── real_test.go ├── rest.go ├── rest_test.go ├── rewrite.go ├── rewrite_test.go ├── sizelimit.go ├── sizelimit_test.go ├── testdata ├── index.html └── root │ ├── 1 │ ├── f1.html │ └── f2.html │ ├── 2 │ ├── f123.txt │ └── index.html │ ├── index.html │ └── xyz.js ├── throttle.go ├── throttle_test.go ├── trace.go └── trace_test.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # These owners will be the default owners for everything in the repo. 2 | # Unless a later match takes precedence, @umputun will be requested for 3 | # review when someone opens a pull request. 4 | 5 | * @umputun 6 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [umputun] 2 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: 6 | tags: 7 | pull_request: 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: set up go 15 | uses: actions/setup-go@v5 16 | with: 17 | go-version: "1.23" 18 | id: go 19 | 20 | - name: checkout 21 | uses: actions/checkout@v4 22 | 23 | - name: build and test 24 | run: | 25 | go get -v 26 | go test -timeout=60s -race -covermode=atomic -coverprofile=$GITHUB_WORKSPACE/profile.cov_tmp 27 | cat $GITHUB_WORKSPACE/profile.cov_tmp | grep -v "_mock.go" > $GITHUB_WORKSPACE/profile.cov 28 | go build -race 29 | env: 30 | GO111MODULE: "on" 31 | TZ: "America/Chicago" 32 | 33 | - name: golangci-lint 34 | uses: golangci/golangci-lint-action@v6 35 | with: 36 | version: v1.64 37 | 38 | - name: install goveralls 39 | run: | 40 | go install github.com/mattn/goveralls@latest 41 | 42 | - name: submit coverage 43 | run: $(go env GOPATH)/bin/goveralls -service="github" -coverprofile=$GITHUB_WORKSPACE/profile.cov 44 | env: 45 | COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, build with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | vendor -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | timeout: 5m 3 | tests: false 4 | 5 | linters-settings: 6 | govet: 7 | enable: 8 | - shadow 9 | goconst: 10 | min-len: 2 11 | min-occurrences: 2 12 | misspell: 13 | locale: US 14 | lll: 15 | line-length: 140 16 | gocritic: 17 | enabled-tags: 18 | - performance 19 | - style 20 | - experimental 21 | disabled-checks: 22 | - wrapperFunc 23 | - hugeParam 24 | - rangeValCopy 25 | - singleCaseSwitch 26 | - ifElseChain 27 | 28 | linters: 29 | enable: 30 | - revive 31 | - govet 32 | - unconvert 33 | - staticcheck 34 | - unused 35 | - gosec 36 | - dupl 37 | - misspell 38 | - unparam 39 | - typecheck 40 | - ineffassign 41 | - stylecheck 42 | - gochecknoinits 43 | - copyloopvar 44 | - gocritic 45 | - nakedret 46 | - gosimple 47 | - prealloc 48 | fast: false 49 | disable-all: true 50 | 51 | issues: 52 | exclude-dirs: 53 | - vendor 54 | exclude-rules: 55 | - text: "at least one file in a package should have a package comment" 56 | linters: 57 | - stylecheck 58 | - text: "should have a package comment" 59 | linters: 60 | - revive 61 | - path: _test\.go 62 | linters: 63 | - gosec 64 | - dupl 65 | exclude-use-default: false 66 | 67 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Umputun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## REST helpers and middleware [](https://github.com/go-pkgz/rest/actions) [](https://goreportcard.com/report/github.com/go-pkgz/rest) [](https://coveralls.io/github/go-pkgz/rest?branch=master) [](https://godoc.org/github.com/go-pkgz/rest) 2 | 3 | 4 | ## Install and update 5 | 6 | `go get -u github.com/go-pkgz/rest` 7 | 8 | ## Middlewares 9 | 10 | ### AppInfo middleware 11 | 12 | Adds info to every response header: 13 | - App-Name - application name 14 | - App-Version - application version 15 | - Org - organization 16 | - M-Host - host name from instance-level `$MHOST` env 17 | 18 | ### Ping-Pong middleware 19 | 20 | Responds with `pong` on `GET /ping`. Also, responds to anything with `/ping` suffix, like `/v2/ping`. 21 | 22 | Example for both: 23 | 24 | ``` 25 | > http GET https://remark42.radio-t.com/ping 26 | 27 | HTTP/1.1 200 OK 28 | Date: Sun, 15 Jul 2018 19:40:31 GMT 29 | Content-Type: text/plain 30 | Content-Length: 4 31 | Connection: keep-alive 32 | App-Name: remark42 33 | App-Version: master-ed92a0b-20180630-15:59:56 34 | Org: Umputun 35 | 36 | pong 37 | ``` 38 | 39 | ### Health middleware 40 | 41 | Responds with the status 200 if all health checks passed, 503 if any failed. Both health path and check functions passed by consumer. 42 | For production usage this middleware should be used with throttler/limiter and, optionally, with some auth middlewares 43 | 44 | Example of usage: 45 | 46 | ```go 47 | check1 := func(ctx context.Context) (name string, err error) { 48 | // do some check, for example check DB connection 49 | return "check1", nil // all good, passed 50 | } 51 | check2 := func(ctx context.Context) (name string, err error) { 52 | // do some other check, for example ping an external service 53 | return "check2", errors.New("some error") // check failed 54 | } 55 | 56 | router := chi.NewRouter() 57 | router.Use(rest.Health("/health", check1, check2)) 58 | ``` 59 | 60 | example of the actual call and response: 61 | 62 | ``` 63 | > http GET https://example.com/health 64 | 65 | HTTP/1.1 503 Service Unavailable 66 | Date: Sun, 15 Jul 2018 19:40:31 GMT 67 | Content-Type: application/json; charset=utf-8 68 | Content-Length: 36 69 | 70 | [ 71 | {"name":"check1","status":"ok"}, 72 | {"name":"check2","status":"failed","error":"some error"} 73 | ] 74 | ``` 75 | 76 | _this middleware is pretty basic, but can be used for simple health checks. For more complex cases, like async/cached health checks see [alexliesenfeld/health](https://github.com/alexliesenfeld/health)_ 77 | 78 | ### Logger middleware 79 | 80 | Logs request, request handling time and response. Log record fields in order of occurrence: 81 | 82 | - Request's HTTP method 83 | - Requested URL (with sanitized query) 84 | - Remote IP 85 | - Response's HTTP status code 86 | - Response body size 87 | - Request handling time 88 | - Userinfo associated with the request (optional) 89 | - Request subject (optional) 90 | - Request ID (if `X-Request-ID` present) 91 | - Request body (optional) 92 | 93 | _remote IP can be masked with user defined function_ 94 | 95 | example: `019/03/05 17:26:12.976 [INFO] GET - /api/v1/find?site=remark - 8e228e9cfece - 200 (115) - 4.47784618s` 96 | 97 | ### Recoverer middleware 98 | 99 | Recoverer is a middleware that recovers from panics, logs the panic (and a backtrace), 100 | and returns an HTTP 500 (Internal Server Error) status if possible. 101 | It prevents server crashes in case of panic in one of the controllers. 102 | 103 | ### OnlyFrom middleware 104 | 105 | OnlyFrom middleware allows access from a limited list of source IPs. 106 | Such IPs can be defined as complete ip (like 192.168.1.12), prefix (129.168.) or CIDR (192.168.0.0/16). 107 | The middleware will respond with `StatusForbidden` (403) if the request comes from a different IP. 108 | It supports both IPv4 and IPv6 and checks the usual headers like `X-Forwarded-For` and `X-Real-IP` and the remote address. 109 | 110 | _Note: headers should be trusted and set by a proxy, otherwise it is possible to spoof them._ 111 | 112 | ### Metrics middleware 113 | 114 | Metrics middleware responds to GET /metrics with list of [expvar](https://golang.org/pkg/expvar/). 115 | Optionally allows a restricted list of source ips. 116 | 117 | ### BlackWords middleware 118 | 119 | BlackWords middleware doesn't allow user-defined words in the request body. 120 | 121 | ### SizeLimit middleware 122 | 123 | SizeLimit middleware checks if body size is above the limit and returns `StatusRequestEntityTooLarge` (413) 124 | 125 | ### Trace middleware 126 | 127 | The `Trace` middleware is designed to add request tracing functionality. It looks for the `X-Request-ID` header in 128 | the incoming HTTP request. If not found, a random ID is generated. This trace ID is then set in the response headers 129 | and added to the request's context. 130 | 131 | ### Deprecation middleware 132 | 133 | Adds the HTTP Deprecation response header, see [draft-ietf-httpapi-deprecation-header-02](https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-deprecation-header-02) 134 | 135 | ### BasicAuth middleware 136 | 137 | BasicAuth middleware requires basic auth and matches user & passwd with client-provided checker. In case if no basic auth headers returns 138 | `StatusUnauthorized`, in case if checker failed - `StatusForbidden` 139 | 140 | ### Rewrite middleware 141 | 142 | The `Rewrite` middleware is designed to rewrite the URL path based on a given rule, similar to how URL rewriting is done in nginx. It supports regular expressions for pattern matching and prevents multiple rewrites. 143 | 144 | For example, `Rewrite("^/sites/(.*)/settings/$", "/sites/settings/$1")` will change request's URL from `/sites/id1/settings/` to `/sites/settings/id1` 145 | 146 | ### NoCache middleware 147 | 148 | Sets a number of HTTP headers to prevent a router (handler's) response from being cached by an upstream proxy and/or client. 149 | 150 | ### Headers middleware 151 | 152 | Sets headers (passed as key:value) to requests. I.e. `rest.Headers("Server:MyServer", "X-Blah:Foo")` 153 | 154 | ### Gzip middleware 155 | 156 | Compresses response with gzip. 157 | 158 | ### RealIP middleware 159 | 160 | RealIP is a middleware that sets a http.Request's RemoteAddr to the results of parsing either the X-Forwarded-For or X-Real-IP headers. 161 | 162 | ### Maybe middleware 163 | 164 | Maybe middleware allows changing the flow of the middleware stack execution depending on the return 165 | value of maybeFn(request). This is useful, for example, to skip a middleware handler if a request does not satisfy the maybeFn logic. 166 | 167 | ### Reject middleware 168 | 169 | Reject is a middleware that rejects requests with a given status code and message based on a user-defined function. 170 | This is useful, for example, to reject requests to a particular resource based on a request header, 171 | or to implement a conditional request handler based on service parameters. 172 | 173 | example with chi router: 174 | 175 | ```go 176 | router := chi.NewRouter() 177 | 178 | rejectFn := func(r *http.Request) (bool) { 179 | return r.Header.Get("X-Request-Id") == "" // reject if no X-Request-Id header 180 | } 181 | 182 | router.Use(rest.Reject(http.StatusBadRequest, "X-Request-Id header is required", rejectFn)) 183 | ``` 184 | 185 | ### BasicAuth middleware family 186 | 187 | The package provides several BasicAuth middleware implementations for different authentication needs: 188 | 189 | #### BasicAuth 190 | The base middleware that requires basic auth and matches user & passwd with a client-provided checker function. 191 | ```go 192 | checkFn := func(user, passwd string) bool { 193 | return user == "admin" && passwd == "secret" 194 | } 195 | router.Use(rest.BasicAuth(checkFn)) 196 | ``` 197 | 198 | #### BasicAuthWithUserPasswd 199 | A simpler version comparing user & password with provided values directly. 200 | ```go 201 | router.Use(rest.BasicAuthWithUserPasswd("admin", "secret")) 202 | ``` 203 | 204 | #### BasicAuthWithBcryptHash 205 | Matches username and bcrypt-hashed password. Useful when storing hashed passwords. 206 | ```go 207 | hash, err := rest.GenerateBcryptHash("secret") 208 | if err != nil { 209 | // handle error 210 | } 211 | router.Use(rest.BasicAuthWithBcryptHash("admin", hash)) 212 | ``` 213 | 214 | #### BasicAuthWithArgon2Hash 215 | Similar to bcrypt version but uses Argon2id hash with a separate salt. Both hash and salt are base64 encoded. 216 | ```go 217 | hash, salt, err := rest.GenerateArgon2Hash("secret") 218 | if err != nil { 219 | // handle error 220 | } 221 | router.Use(rest.BasicAuthWithArgon2Hash("admin", hash, salt)) 222 | ``` 223 | 224 | #### BasicAuthWithPrompt 225 | Similar to BasicAuthWithUserPasswd but adds browser's authentication prompt by setting the WWW-Authenticate header. 226 | ```go 227 | router.Use(rest.BasicAuthWithPrompt("admin", "secret")) 228 | ``` 229 | 230 | All BasicAuth middlewares: 231 | - Return `StatusUnauthorized` (401) if no auth header provided 232 | - Return `StatusForbidden` (403) if credentials check failed 233 | - Add IsAuthorized flag to the request context, retrievable with `rest.IsAuthorized(r.Context())` 234 | - Use constant-time comparison to prevent timing attacks 235 | - Support secure password hashing with bcrypt and Argon2id 236 | 237 | ### Benchmarks middleware 238 | 239 | Benchmarks middleware allows measuring the time of request handling, number of requests per second and report aggregated metrics. 240 | This middleware keeps track of the request in the memory and keep up to 900 points (15 minutes, data-point per second). 241 | 242 | To retrieve the data user should call `Stats(d duration)` method. 243 | The `duration` is the time window for which the benchmark data should be returned. 244 | It can be any duration from 1s to 15m. Note: all the time data is in microseconds. 245 | 246 | example with chi router: 247 | 248 | ```go 249 | router := chi.NewRouter() 250 | bench = rest.NewBenchmarks() 251 | router.Use(bench.Middleware) 252 | ... 253 | router.Get("/bench", func(w http.ResponseWriter, r *http.Request) { 254 | resp := struct { 255 | OneMin rest.BenchmarkStats `json:"1min"` 256 | FiveMin rest.BenchmarkStats `json:"5min"` 257 | FifteenMin rest.BenchmarkStats `json:"15min"` 258 | }{ 259 | bench.Stats(time.Minute), 260 | bench.Stats(time.Minute * 5), 261 | bench.Stats(time.Minute * 15), 262 | } 263 | render.JSON(w, r, resp) 264 | }) 265 | ``` 266 | 267 | ## Helpers 268 | 269 | - `rest.Wrap` - converts a list of middlewares to nested handlers calls (in reverse order) 270 | - `rest.JSON` - map alias, just for convenience `type JSON map[string]interface{}` 271 | - `rest.RenderJSON` - renders json response from `interface{}` 272 | - `rest.RenderJSONFromBytes` - renders json response from `[]byte` 273 | - `rest.RenderJSONWithHTML` - renders json response with html tags and forced `charset=utf-8` 274 | - `rest.SendErrorJSON` - makes `{error: blah, details: blah}` json body and responds with given error code. Also, adds context to the logged message 275 | - `rest.NewErrorLogger` - creates a struct providing shorter form of logger call 276 | - `rest.FileServer` - creates a file server for static assets with directory listing disabled 277 | - `realip.Get` - returns client's IP address 278 | - `rest.ParseFromTo` - parses "from" and "to" request's query params with various formats 279 | - `rest.DecodeJSON` - decodes request body to the provided struct 280 | - `rest.EncodeJSON` - encodes response body from the provided struct, sets `Content-Type` to `application/json` and sends the status code 281 | 282 | ## Profiler 283 | 284 | Profiler is a convenient sub-router used for mounting net/http/pprof, i.e. 285 | 286 | ```go 287 | func MyService() http.Handler { 288 | r := chi.NewRouter() 289 | // ..middlewares 290 | r.Mount("/debug", middleware.Profiler()) 291 | // ..routes 292 | return r 293 | } 294 | ``` 295 | 296 | It exposes a bunch of `/pprof/*` endpoints as well as `/vars`. Builtin support for `onlyIps` allows restricting access, which is important if it runs on a publicly exposed port. However, counting on IP check only is not that reliable way to limit request and for production use it would be better to add some sort of auth (for example provided `BasicAuth` middleware) or run with a separate http server, exposed to internal ip/port only. 297 | 298 | -------------------------------------------------------------------------------- /basic_auth.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "context" 5 | "crypto/rand" 6 | "crypto/subtle" 7 | "encoding/base64" 8 | "net/http" 9 | 10 | "golang.org/x/crypto/argon2" 11 | "golang.org/x/crypto/bcrypt" 12 | ) 13 | 14 | const baContextKey = "authorizedWithBasicAuth" 15 | 16 | // BasicAuth middleware requires basic auth and matches user & passwd with client-provided checker 17 | func BasicAuth(checker func(user, passwd string) bool) func(http.Handler) http.Handler { 18 | 19 | return func(h http.Handler) http.Handler { 20 | fn := func(w http.ResponseWriter, r *http.Request) { 21 | 22 | u, p, ok := r.BasicAuth() 23 | if !ok { 24 | w.WriteHeader(http.StatusUnauthorized) 25 | return 26 | } 27 | if !checker(u, p) { 28 | w.WriteHeader(http.StatusForbidden) 29 | return 30 | } 31 | h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey(baContextKey), true))) 32 | } 33 | return http.HandlerFunc(fn) 34 | } 35 | } 36 | 37 | // BasicAuthWithUserPasswd middleware requires basic auth and matches user & passwd with client-provided values 38 | func BasicAuthWithUserPasswd(user, passwd string) func(http.Handler) http.Handler { 39 | checkFn := func(reqUser, reqPasswd string) bool { 40 | matchUser := subtle.ConstantTimeCompare([]byte(user), []byte(reqUser)) 41 | matchPass := subtle.ConstantTimeCompare([]byte(passwd), []byte(reqPasswd)) 42 | return matchUser == 1 && matchPass == 1 43 | } 44 | return BasicAuth(checkFn) 45 | } 46 | 47 | // BasicAuthWithBcryptHash middleware requires basic auth and matches user & bcrypt hashed password 48 | func BasicAuthWithBcryptHash(user, hashedPassword string) func(http.Handler) http.Handler { 49 | checkFn := func(reqUser, reqPasswd string) bool { 50 | if reqUser != user { 51 | return false 52 | } 53 | err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(reqPasswd)) 54 | return err == nil 55 | } 56 | return BasicAuth(checkFn) 57 | } 58 | 59 | // BasicAuthWithArgon2Hash middleware requires basic auth and matches user & argon2 hashed password 60 | // both hashedPassword and salt must be base64 encoded strings 61 | // Uses Argon2id with parameters: t=1, m=64*1024 KB, p=4 threads 62 | func BasicAuthWithArgon2Hash(user, hashedPassword, salt string) func(http.Handler) http.Handler { 63 | checkFn := func(reqUser, reqPasswd string) bool { 64 | if reqUser != user { 65 | return false 66 | } 67 | 68 | saltBytes, err := base64.StdEncoding.DecodeString(salt) 69 | if err != nil { 70 | return false 71 | } 72 | storedHashBytes, err := base64.StdEncoding.DecodeString(hashedPassword) 73 | if err != nil { 74 | return false 75 | } 76 | 77 | hash := argon2.IDKey([]byte(reqPasswd), saltBytes, 1, 64*1024, 4, 32) 78 | return subtle.ConstantTimeCompare(hash, storedHashBytes) == 1 79 | } 80 | return BasicAuth(checkFn) 81 | } 82 | 83 | // IsAuthorized returns true is user authorized. 84 | // it can be used in handlers to check if BasicAuth middleware was applied 85 | func IsAuthorized(ctx context.Context) bool { 86 | v := ctx.Value(contextKey(baContextKey)) 87 | return v != nil && v.(bool) 88 | } 89 | 90 | // BasicAuthWithPrompt middleware requires basic auth and matches user & passwd with client-provided values 91 | // If the user is not authorized, it will prompt for basic auth 92 | func BasicAuthWithPrompt(user, passwd string) func(http.Handler) http.Handler { 93 | checkFn := func(reqUser, reqPasswd string) bool { 94 | matchUser := subtle.ConstantTimeCompare([]byte(user), []byte(reqUser)) 95 | matchPass := subtle.ConstantTimeCompare([]byte(passwd), []byte(reqPasswd)) 96 | return matchUser == 1 && matchPass == 1 97 | } 98 | 99 | return func(h http.Handler) http.Handler { 100 | fn := func(w http.ResponseWriter, r *http.Request) { 101 | 102 | // extract basic auth from request 103 | u, p, ok := r.BasicAuth() 104 | if ok && checkFn(u, p) { 105 | h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey(baContextKey), true))) 106 | return 107 | } 108 | // not authorized, prompt for basic auth 109 | w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`) 110 | http.Error(w, "Unauthorized", http.StatusUnauthorized) 111 | } 112 | return http.HandlerFunc(fn) 113 | } 114 | } 115 | 116 | // BasicAuthWithBcryptHashAndPrompt middleware requires basic auth and matches user & bcrypt hashed password 117 | // If the user is not authorized, it will prompt for basic auth 118 | func BasicAuthWithBcryptHashAndPrompt(user, hashedPassword string) func(http.Handler) http.Handler { 119 | checkFn := func(reqUser, reqPasswd string) bool { 120 | if reqUser != user { 121 | return false 122 | } 123 | err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(reqPasswd)) 124 | return err == nil 125 | } 126 | 127 | return func(h http.Handler) http.Handler { 128 | fn := func(w http.ResponseWriter, r *http.Request) { 129 | // extract basic auth from request 130 | u, p, ok := r.BasicAuth() 131 | if ok && checkFn(u, p) { 132 | h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey(baContextKey), true))) 133 | return 134 | } 135 | // not authorized, prompt for basic auth 136 | w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`) 137 | http.Error(w, "Unauthorized", http.StatusUnauthorized) 138 | } 139 | return http.HandlerFunc(fn) 140 | } 141 | } 142 | 143 | // GenerateBcryptHash generates a bcrypt hash from a password 144 | func GenerateBcryptHash(password string) (string, error) { 145 | hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) 146 | if err != nil { 147 | return "", err 148 | } 149 | return string(hash), nil 150 | } 151 | 152 | // GenerateArgon2Hash generates an argon2 hash and salt from a password 153 | func GenerateArgon2Hash(password string) (hash, salt string, err error) { 154 | saltBytes := make([]byte, 16) 155 | if _, err := rand.Read(saltBytes); err != nil { 156 | return "", "", err 157 | } 158 | 159 | // using recommended parameters: time=1, memory=64*1024, threads=4, keyLen=32 160 | hashBytes := argon2.IDKey([]byte(password), saltBytes, 1, 64*1024, 4, 32) 161 | 162 | return base64.StdEncoding.EncodeToString(hashBytes), base64.StdEncoding.EncodeToString(saltBytes), nil 163 | } 164 | -------------------------------------------------------------------------------- /basic_auth_test.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "encoding/base64" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | "golang.org/x/crypto/argon2" 14 | "golang.org/x/crypto/bcrypt" 15 | ) 16 | 17 | func TestBasicAuth(t *testing.T) { 18 | 19 | mw := BasicAuth(func(user, passwd string) bool { 20 | return user == "dev" && passwd == "good" 21 | }) 22 | 23 | ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 24 | t.Logf("request %s", r.URL) 25 | w.WriteHeader(http.StatusOK) 26 | _, err := w.Write([]byte("blah")) 27 | require.NoError(t, err) 28 | assert.True(t, IsAuthorized(r.Context())) 29 | }))) 30 | defer ts.Close() 31 | 32 | u := fmt.Sprintf("%s%s", ts.URL, "/something") 33 | 34 | client := http.Client{Timeout: 5 * time.Second} 35 | 36 | { 37 | req, err := http.NewRequest("GET", u, http.NoBody) 38 | require.NoError(t, err) 39 | resp, err := client.Do(req) 40 | require.NoError(t, err) 41 | assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) 42 | } 43 | 44 | { 45 | req, err := http.NewRequest("GET", u, http.NoBody) 46 | require.NoError(t, err) 47 | req.SetBasicAuth("dev", "good") 48 | resp, err := client.Do(req) 49 | require.NoError(t, err) 50 | assert.Equal(t, http.StatusOK, resp.StatusCode) 51 | } 52 | 53 | { 54 | req, err := http.NewRequest("GET", u, http.NoBody) 55 | require.NoError(t, err) 56 | req.SetBasicAuth("dev", "bad") 57 | resp, err := client.Do(req) 58 | require.NoError(t, err) 59 | assert.Equal(t, http.StatusForbidden, resp.StatusCode) 60 | } 61 | } 62 | 63 | func TestBasicAuthWithUserPasswd(t *testing.T) { 64 | mw := BasicAuthWithUserPasswd("dev", "good") 65 | 66 | ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 67 | t.Logf("request %s", r.URL) 68 | w.WriteHeader(http.StatusOK) 69 | _, err := w.Write([]byte("blah")) 70 | require.NoError(t, err) 71 | assert.True(t, IsAuthorized(r.Context())) 72 | }))) 73 | defer ts.Close() 74 | 75 | u := fmt.Sprintf("%s%s", ts.URL, "/something") 76 | 77 | client := http.Client{Timeout: 5 * time.Second} 78 | 79 | { 80 | req, err := http.NewRequest("GET", u, http.NoBody) 81 | require.NoError(t, err) 82 | resp, err := client.Do(req) 83 | require.NoError(t, err) 84 | assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) 85 | } 86 | 87 | { 88 | req, err := http.NewRequest("GET", u, http.NoBody) 89 | require.NoError(t, err) 90 | req.SetBasicAuth("dev", "good") 91 | resp, err := client.Do(req) 92 | require.NoError(t, err) 93 | assert.Equal(t, http.StatusOK, resp.StatusCode) 94 | } 95 | { 96 | req, err := http.NewRequest("GET", u, http.NoBody) 97 | require.NoError(t, err) 98 | req.SetBasicAuth("dev", "bad") 99 | resp, err := client.Do(req) 100 | require.NoError(t, err) 101 | assert.Equal(t, http.StatusForbidden, resp.StatusCode) 102 | } 103 | } 104 | 105 | func TestBasicAuthWithPrompt(t *testing.T) { 106 | mw := BasicAuthWithPrompt("dev", "good") 107 | 108 | ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 109 | t.Logf("request %s", r.URL) 110 | w.WriteHeader(http.StatusOK) 111 | _, err := w.Write([]byte("blah")) 112 | require.NoError(t, err) 113 | assert.True(t, IsAuthorized(r.Context())) 114 | }))) 115 | defer ts.Close() 116 | 117 | u := fmt.Sprintf("%s%s", ts.URL, "/something") 118 | 119 | client := http.Client{Timeout: 5 * time.Second} 120 | 121 | { 122 | req, err := http.NewRequest("GET", u, http.NoBody) 123 | require.NoError(t, err) 124 | resp, err := client.Do(req) 125 | require.NoError(t, err) 126 | assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) 127 | assert.Equal(t, `Basic realm="restricted", charset="UTF-8"`, resp.Header.Get("WWW-Authenticate")) 128 | } 129 | 130 | { 131 | req, err := http.NewRequest("GET", u, http.NoBody) 132 | require.NoError(t, err) 133 | req.SetBasicAuth("dev", "good") 134 | resp, err := client.Do(req) 135 | require.NoError(t, err) 136 | assert.Equal(t, http.StatusOK, resp.StatusCode) 137 | } 138 | { 139 | req, err := http.NewRequest("GET", u, http.NoBody) 140 | require.NoError(t, err) 141 | req.SetBasicAuth("dev", "bad") 142 | resp, err := client.Do(req) 143 | require.NoError(t, err) 144 | assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) 145 | assert.Equal(t, `Basic realm="restricted", charset="UTF-8"`, resp.Header.Get("WWW-Authenticate")) 146 | } 147 | } 148 | 149 | func TestBasicAuthWithHash(t *testing.T) { 150 | hashedPassword, err := bcrypt.GenerateFromPassword([]byte("good"), bcrypt.MinCost) 151 | require.NoError(t, err) 152 | t.Logf("hashed password: %s", string(hashedPassword)) 153 | 154 | mw := BasicAuthWithBcryptHash("dev", string(hashedPassword)) 155 | 156 | ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 157 | t.Logf("request %s", r.URL) 158 | w.WriteHeader(http.StatusOK) 159 | _, err := w.Write([]byte("blah")) 160 | require.NoError(t, err) 161 | assert.True(t, IsAuthorized(r.Context())) 162 | }))) 163 | defer ts.Close() 164 | 165 | u := fmt.Sprintf("%s%s", ts.URL, "/something") 166 | client := http.Client{Timeout: 5 * time.Second} 167 | 168 | tests := []struct { 169 | name string 170 | username string 171 | password string 172 | expectedStatus int 173 | }{ 174 | { 175 | name: "no auth provided", 176 | username: "", 177 | password: "", 178 | expectedStatus: http.StatusUnauthorized, 179 | }, 180 | { 181 | name: "correct credentials", 182 | username: "dev", 183 | password: "good", 184 | expectedStatus: http.StatusOK, 185 | }, 186 | { 187 | name: "wrong username", 188 | username: "wrong", 189 | password: "good", 190 | expectedStatus: http.StatusForbidden, 191 | }, 192 | { 193 | name: "wrong password", 194 | username: "dev", 195 | password: "bad", 196 | expectedStatus: http.StatusForbidden, 197 | }, 198 | { 199 | name: "empty password", 200 | username: "dev", 201 | password: "", 202 | expectedStatus: http.StatusForbidden, 203 | }, 204 | } 205 | 206 | for _, tc := range tests { 207 | t.Run(tc.name, func(t *testing.T) { 208 | req, err := http.NewRequest("GET", u, http.NoBody) 209 | require.NoError(t, err) 210 | 211 | if tc.username != "" || tc.password != "" { 212 | req.SetBasicAuth(tc.username, tc.password) 213 | } 214 | 215 | resp, err := client.Do(req) 216 | require.NoError(t, err) 217 | assert.Equal(t, tc.expectedStatus, resp.StatusCode) 218 | }) 219 | } 220 | } 221 | 222 | func TestBasicAuthWithArgon2Hash(t *testing.T) { 223 | password := "good" 224 | hash, salt, err := GenerateArgon2Hash(password) 225 | require.NoError(t, err) 226 | t.Logf("hash: %s, salt: %s", hash, salt) 227 | 228 | // verify the returned values are valid base64 229 | _, err = base64.StdEncoding.DecodeString(hash) 230 | require.NoError(t, err, "hash should be valid base64") 231 | _, err = base64.StdEncoding.DecodeString(salt) 232 | require.NoError(t, err, "salt should be valid base64") 233 | 234 | mw := BasicAuthWithArgon2Hash("dev", hash, salt) 235 | 236 | ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 237 | t.Logf("request %s", r.URL) 238 | w.WriteHeader(http.StatusOK) 239 | _, err := w.Write([]byte("blah")) 240 | require.NoError(t, err) 241 | assert.True(t, IsAuthorized(r.Context())) 242 | }))) 243 | defer ts.Close() 244 | 245 | u := fmt.Sprintf("%s%s", ts.URL, "/something") 246 | client := http.Client{Timeout: 5 * time.Second} 247 | 248 | tests := []struct { 249 | name string 250 | username string 251 | password string 252 | expectedStatus int 253 | }{ 254 | { 255 | name: "no auth provided", 256 | username: "", 257 | password: "", 258 | expectedStatus: http.StatusUnauthorized, 259 | }, 260 | { 261 | name: "correct credentials", 262 | username: "dev", 263 | password: "good", 264 | expectedStatus: http.StatusOK, 265 | }, 266 | { 267 | name: "wrong username", 268 | username: "wrong", 269 | password: "good", 270 | expectedStatus: http.StatusForbidden, 271 | }, 272 | { 273 | name: "wrong password", 274 | username: "dev", 275 | password: "bad", 276 | expectedStatus: http.StatusForbidden, 277 | }, 278 | } 279 | 280 | for _, tc := range tests { 281 | t.Run(tc.name, func(t *testing.T) { 282 | req, err := http.NewRequest("GET", u, http.NoBody) 283 | require.NoError(t, err) 284 | 285 | if tc.username != "" || tc.password != "" { 286 | req.SetBasicAuth(tc.username, tc.password) 287 | } 288 | 289 | resp, err := client.Do(req) 290 | require.NoError(t, err) 291 | assert.Equal(t, tc.expectedStatus, resp.StatusCode) 292 | }) 293 | } 294 | } 295 | 296 | func TestHashGenerationFunctions(t *testing.T) { 297 | t.Run("bcrypt hash generation", func(t *testing.T) { 298 | hash, err := GenerateBcryptHash("testpassword") 299 | require.NoError(t, err) 300 | require.NotEmpty(t, hash) 301 | 302 | err = bcrypt.CompareHashAndPassword([]byte(hash), []byte("testpassword")) 303 | require.NoError(t, err) 304 | }) 305 | 306 | t.Run("argon2 hash generation", func(t *testing.T) { 307 | hash, salt, err := GenerateArgon2Hash("testpassword") 308 | require.NoError(t, err) 309 | require.NotEmpty(t, hash) 310 | require.NotEmpty(t, salt) 311 | 312 | // verify the values are valid base64 313 | hashBytes, err := base64.StdEncoding.DecodeString(hash) 314 | require.NoError(t, err, "hash should be valid base64") 315 | saltBytes, err := base64.StdEncoding.DecodeString(salt) 316 | require.NoError(t, err, "salt should be valid base64") 317 | 318 | // verify the hash works 319 | newHash := argon2.IDKey([]byte("testpassword"), saltBytes, 1, 64*1024, 4, 32) 320 | require.Equal(t, hashBytes, newHash) 321 | 322 | // test with wrong password 323 | wrongHash := argon2.IDKey([]byte("wrongpassword"), saltBytes, 1, 64*1024, 4, 32) 324 | require.NotEqual(t, hashBytes, wrongHash) 325 | }) 326 | } 327 | 328 | func TestArgon2InvalidInputs(t *testing.T) { 329 | t.Run("invalid base64 salt", func(t *testing.T) { 330 | mw := BasicAuthWithArgon2Hash("dev", "validbase64==", "invalid-base64") 331 | ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 332 | t.Error("Handler should not be called with invalid base64") 333 | }))) 334 | defer ts.Close() 335 | 336 | req, err := http.NewRequest("GET", ts.URL, http.NoBody) 337 | require.NoError(t, err) 338 | req.SetBasicAuth("dev", "password") 339 | 340 | client := http.Client{Timeout: 5 * time.Second} 341 | resp, err := client.Do(req) 342 | require.NoError(t, err) 343 | assert.Equal(t, http.StatusForbidden, resp.StatusCode) 344 | }) 345 | 346 | t.Run("invalid base64 hash", func(t *testing.T) { 347 | mw := BasicAuthWithArgon2Hash("dev", "invalid-base64", "validbase64==") 348 | ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 349 | t.Error("Handler should not be called with invalid base64") 350 | }))) 351 | defer ts.Close() 352 | 353 | req, err := http.NewRequest("GET", ts.URL, http.NoBody) 354 | require.NoError(t, err) 355 | req.SetBasicAuth("dev", "password") 356 | 357 | client := http.Client{Timeout: 5 * time.Second} 358 | resp, err := client.Do(req) 359 | require.NoError(t, err) 360 | assert.Equal(t, http.StatusForbidden, resp.StatusCode) 361 | }) 362 | } 363 | 364 | func TestBasicAuthWithBcryptHashAndPrompt(t *testing.T) { 365 | hashedPassword, err := bcrypt.GenerateFromPassword([]byte("good"), bcrypt.MinCost) 366 | require.NoError(t, err) 367 | t.Logf("hashed password: %s", string(hashedPassword)) 368 | 369 | mw := BasicAuthWithBcryptHashAndPrompt("dev", string(hashedPassword)) 370 | 371 | ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 372 | t.Logf("request %s", r.URL) 373 | w.WriteHeader(http.StatusOK) 374 | _, err := w.Write([]byte("blah")) 375 | require.NoError(t, err) 376 | assert.True(t, IsAuthorized(r.Context())) 377 | }))) 378 | defer ts.Close() 379 | 380 | u := fmt.Sprintf("%s%s", ts.URL, "/something") 381 | client := http.Client{Timeout: 5 * time.Second} 382 | 383 | tests := []struct { 384 | name string 385 | username string 386 | password string 387 | expectedStatus int 388 | checkPrompt bool 389 | }{ 390 | { 391 | name: "no auth provided", 392 | username: "", 393 | password: "", 394 | expectedStatus: http.StatusUnauthorized, 395 | checkPrompt: true, 396 | }, 397 | { 398 | name: "correct credentials", 399 | username: "dev", 400 | password: "good", 401 | expectedStatus: http.StatusOK, 402 | checkPrompt: false, 403 | }, 404 | { 405 | name: "wrong username", 406 | username: "wrong", 407 | password: "good", 408 | expectedStatus: http.StatusUnauthorized, 409 | checkPrompt: true, 410 | }, 411 | { 412 | name: "wrong password", 413 | username: "dev", 414 | password: "bad", 415 | expectedStatus: http.StatusUnauthorized, 416 | checkPrompt: true, 417 | }, 418 | { 419 | name: "empty password", 420 | username: "dev", 421 | password: "", 422 | expectedStatus: http.StatusUnauthorized, 423 | checkPrompt: true, 424 | }, 425 | } 426 | 427 | for _, tc := range tests { 428 | t.Run(tc.name, func(t *testing.T) { 429 | req, err := http.NewRequest("GET", u, http.NoBody) 430 | require.NoError(t, err) 431 | 432 | if tc.username != "" || tc.password != "" { 433 | req.SetBasicAuth(tc.username, tc.password) 434 | } 435 | 436 | resp, err := client.Do(req) 437 | require.NoError(t, err) 438 | assert.Equal(t, tc.expectedStatus, resp.StatusCode) 439 | 440 | if tc.checkPrompt { 441 | assert.Equal(t, `Basic realm="restricted", charset="UTF-8"`, resp.Header.Get("WWW-Authenticate"), 442 | "should include WWW-Authenticate header") 443 | } 444 | }) 445 | } 446 | } 447 | -------------------------------------------------------------------------------- /benchmarks.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "container/list" 5 | "net/http" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | var maxTimeRangeDefault = time.Duration(15) * time.Minute 11 | 12 | // Benchmarks is a basic benchmarking middleware collecting and reporting performance metrics 13 | // It keeps track of the requests speeds and counts in 1s benchData buckets ,limiting the number of buckets 14 | // to maxTimeRange. User can request the benchmark for any time duration. This is intended to be used 15 | // for retrieving the benchmark data for the last minute, 5 minutes and up to maxTimeRange. 16 | type Benchmarks struct { 17 | st time.Time 18 | data *list.List 19 | lock sync.RWMutex 20 | maxTimeRange time.Duration 21 | 22 | nowFn func() time.Time // for testing only 23 | } 24 | 25 | type benchData struct { 26 | // 1s aggregates 27 | requests int 28 | respTime time.Duration 29 | minRespTime time.Duration 30 | maxRespTime time.Duration 31 | ts time.Time 32 | } 33 | 34 | // BenchmarkStats holds the stats for a given interval 35 | type BenchmarkStats struct { 36 | Requests int `json:"requests"` 37 | RequestsSec float64 `json:"requests_sec"` 38 | AverageRespTime int64 `json:"average_resp_time"` 39 | MinRespTime int64 `json:"min_resp_time"` 40 | MaxRespTime int64 `json:"max_resp_time"` 41 | } 42 | 43 | // NewBenchmarks creates a new benchmark middleware 44 | func NewBenchmarks() *Benchmarks { 45 | res := &Benchmarks{ 46 | st: time.Now(), 47 | data: list.New(), 48 | nowFn: time.Now, 49 | maxTimeRange: maxTimeRangeDefault, 50 | } 51 | return res 52 | } 53 | 54 | // WithTimeRange sets the maximum time range for the benchmark to keep data for. 55 | // Default is 15 minutes. The increase of this range will change memory utilization as each second of the range 56 | // kept as benchData aggregate. The default means 15*60 = 900 seconds of data aggregate. 57 | // Larger range allows for longer time periods to be benchmarked. 58 | func (b *Benchmarks) WithTimeRange(maximum time.Duration) *Benchmarks { 59 | b.lock.Lock() 60 | defer b.lock.Unlock() 61 | b.maxTimeRange = maximum 62 | return b 63 | } 64 | 65 | // Handler calculates 1/5/10m request per second and allows to access those values 66 | func (b *Benchmarks) Handler(next http.Handler) http.Handler { 67 | 68 | fn := func(w http.ResponseWriter, r *http.Request) { 69 | st := b.nowFn() 70 | defer func() { 71 | b.update(time.Since(st)) 72 | }() 73 | next.ServeHTTP(w, r) 74 | } 75 | return http.HandlerFunc(fn) 76 | } 77 | 78 | func (b *Benchmarks) update(reqDuration time.Duration) { 79 | now := b.nowFn().Truncate(time.Second) 80 | 81 | b.lock.Lock() 82 | defer b.lock.Unlock() 83 | 84 | // keep maxTimeRange in the list, drop the rest 85 | for e := b.data.Front(); e != nil; e = e.Next() { 86 | if b.data.Front().Value.(benchData).ts.After(b.nowFn().Add(-b.maxTimeRange)) { 87 | break 88 | } 89 | b.data.Remove(b.data.Front()) 90 | } 91 | 92 | last := b.data.Back() 93 | if last == nil || last.Value.(benchData).ts.Before(now) { 94 | b.data.PushBack(benchData{requests: 1, respTime: reqDuration, ts: now, 95 | minRespTime: reqDuration, maxRespTime: reqDuration}) 96 | return 97 | } 98 | 99 | bd := last.Value.(benchData) 100 | bd.requests++ 101 | bd.respTime += reqDuration 102 | 103 | if bd.minRespTime == 0 || reqDuration < bd.minRespTime { 104 | bd.minRespTime = reqDuration 105 | } 106 | if bd.maxRespTime == 0 || reqDuration > bd.maxRespTime { 107 | bd.maxRespTime = reqDuration 108 | } 109 | 110 | last.Value = bd 111 | } 112 | 113 | // Stats returns the current benchmark stats for the given duration 114 | func (b *Benchmarks) Stats(interval time.Duration) BenchmarkStats { 115 | if interval < time.Second { // minimum interval is 1s due to the bucket size 116 | return BenchmarkStats{} 117 | } 118 | 119 | b.lock.RLock() 120 | defer b.lock.RUnlock() 121 | 122 | var ( 123 | requests int 124 | respTime time.Duration 125 | ) 126 | 127 | now := b.nowFn().Truncate(time.Second) 128 | cutoff := now.Add(-interval) 129 | stInterval, fnInterval := time.Time{}, time.Time{} 130 | var minRespTime, maxRespTime time.Duration 131 | count := 0 132 | 133 | for e := b.data.Back(); e != nil && count < int(interval.Seconds()); e = e.Prev() { // reverse order 134 | bd := e.Value.(benchData) 135 | if bd.ts.Before(cutoff) { 136 | break 137 | } 138 | 139 | if minRespTime == 0 || bd.minRespTime < minRespTime { 140 | minRespTime = bd.minRespTime 141 | } 142 | if maxRespTime == 0 || bd.maxRespTime > maxRespTime { 143 | maxRespTime = bd.maxRespTime 144 | } 145 | requests += bd.requests 146 | respTime += bd.respTime 147 | if fnInterval.IsZero() { 148 | fnInterval = bd.ts.Add(time.Second) 149 | } 150 | stInterval = bd.ts 151 | count++ 152 | } 153 | 154 | if requests == 0 { 155 | return BenchmarkStats{} 156 | } 157 | 158 | // ensure we calculate rate based on actual interval 159 | actualInterval := fnInterval.Sub(stInterval) 160 | if actualInterval < time.Second { 161 | actualInterval = time.Second 162 | } 163 | 164 | return BenchmarkStats{ 165 | Requests: requests, 166 | RequestsSec: float64(requests) / actualInterval.Seconds(), 167 | AverageRespTime: respTime.Microseconds() / int64(requests), 168 | MinRespTime: minRespTime.Microseconds(), 169 | MaxRespTime: maxRespTime.Microseconds(), 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /benchmarks_test.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestBenchmark_Stats(t *testing.T) { 15 | bench := NewBenchmarks() 16 | bench.update(time.Millisecond * 50) 17 | bench.update(time.Millisecond * 150) 18 | bench.update(time.Millisecond * 250) 19 | bench.update(time.Millisecond * 100) 20 | 21 | { 22 | res := bench.Stats(time.Minute) 23 | t.Logf("%+v", res) 24 | assert.Equal(t, BenchmarkStats{Requests: 4, RequestsSec: 4, AverageRespTime: 137500, 25 | MinRespTime: (time.Millisecond * 50).Microseconds(), MaxRespTime: (time.Millisecond * 250).Microseconds()}, res) 26 | } 27 | 28 | { 29 | res := bench.Stats(time.Second * 5) 30 | t.Logf("%+v", res) 31 | assert.Equal(t, BenchmarkStats{Requests: 4, RequestsSec: 4, AverageRespTime: 137500, 32 | MinRespTime: (time.Millisecond * 50).Microseconds(), MaxRespTime: (time.Millisecond * 250).Microseconds()}, res) 33 | } 34 | 35 | { 36 | res := bench.Stats(time.Millisecond * 999) 37 | t.Logf("%+v", res) 38 | assert.Equal(t, BenchmarkStats{}, res) 39 | } 40 | } 41 | 42 | func TestBenchmark_Stats2s(t *testing.T) { 43 | bench := NewBenchmarks() 44 | bench.update(time.Millisecond * 50) 45 | bench.update(time.Millisecond * 150) 46 | bench.update(time.Millisecond * 250) 47 | time.Sleep(time.Second) 48 | bench.update(time.Millisecond * 100) 49 | 50 | res := bench.Stats(time.Minute) 51 | t.Logf("%+v", res) 52 | assert.Equal(t, BenchmarkStats{Requests: 4, RequestsSec: 2, AverageRespTime: 137500, 53 | MinRespTime: (time.Millisecond * 50).Microseconds(), MaxRespTime: (time.Millisecond * 250).Microseconds()}, res) 54 | } 55 | 56 | func TestBenchmark_WithTimeRange(t *testing.T) { 57 | 58 | nowFn := func(dt time.Time) func() time.Time { 59 | return func() time.Time { return dt } 60 | } 61 | 62 | { 63 | bench := NewBenchmarks().WithTimeRange(time.Minute) 64 | 65 | bench.nowFn = nowFn(time.Date(2018, time.January, 1, 0, 0, 0, 0, time.UTC)) 66 | bench.update(time.Millisecond * 50) 67 | bench.update(time.Millisecond * 150) 68 | bench.update(time.Millisecond * 250) 69 | bench.update(time.Millisecond * 100) 70 | 71 | bench.nowFn = nowFn(time.Date(2018, time.January, 1, 1, 0, 0, 0, time.UTC)) // 1 hour later 72 | bench.update(time.Millisecond * 1000) 73 | 74 | res := bench.Stats(time.Minute) 75 | t.Logf("%+v", res) 76 | assert.Equal(t, BenchmarkStats{Requests: 1, RequestsSec: 1, AverageRespTime: 1000000, 77 | MinRespTime: (time.Millisecond * 1000).Microseconds(), MaxRespTime: (time.Millisecond * 1000).Microseconds()}, res) 78 | 79 | res = bench.Stats(time.Hour) 80 | t.Logf("%+v", res) 81 | assert.Equal(t, BenchmarkStats{Requests: 1, RequestsSec: 1, AverageRespTime: 1000000, 82 | MinRespTime: (time.Millisecond * 1000).Microseconds(), MaxRespTime: (time.Millisecond * 1000).Microseconds()}, res) 83 | } 84 | 85 | { 86 | bench := NewBenchmarks().WithTimeRange(time.Hour * 2) 87 | 88 | bench.nowFn = nowFn(time.Date(2018, time.January, 1, 0, 0, 0, 0, time.UTC)) 89 | bench.update(time.Millisecond * 50) 90 | bench.update(time.Millisecond * 150) 91 | bench.update(time.Millisecond * 250) 92 | bench.update(time.Millisecond * 100) 93 | 94 | bench.nowFn = nowFn(time.Date(2018, time.January, 1, 1, 0, 0, 0, time.UTC)) // 1 hour later 95 | bench.update(time.Millisecond * 1000) 96 | 97 | res := bench.Stats(time.Minute) 98 | t.Logf("%+v", res) 99 | assert.Equal(t, BenchmarkStats{Requests: 1, RequestsSec: 1, AverageRespTime: 1000000, 100 | MinRespTime: (time.Millisecond * 1000).Microseconds(), MaxRespTime: (time.Millisecond * 1000).Microseconds()}, res) 101 | 102 | res = bench.Stats(time.Hour) 103 | t.Logf("%+v", res) 104 | assert.Equal(t, BenchmarkStats{Requests: 5, RequestsSec: 0.0013885031935573452, AverageRespTime: 310000, 105 | MinRespTime: (time.Millisecond * 50).Microseconds(), MaxRespTime: (time.Millisecond * 1000).Microseconds()}, res) 106 | } 107 | } 108 | 109 | func TestBenchmark_Cleanup(t *testing.T) { 110 | bench := NewBenchmarks() 111 | for i := 0; i < 1000; i++ { 112 | bench.nowFn = func() time.Time { 113 | return time.Date(2022, 5, 15, 0, 0, 0, 0, time.UTC).Add(time.Duration(i) * time.Second) // every 2s fake time 114 | } 115 | bench.update(time.Millisecond * 50) 116 | } 117 | 118 | { 119 | res := bench.Stats(time.Hour) 120 | t.Logf("%+v", res) 121 | assert.Equal(t, BenchmarkStats{Requests: 900, RequestsSec: 1, AverageRespTime: 50000, 122 | MinRespTime: (time.Millisecond * 50).Microseconds(), MaxRespTime: (time.Millisecond * 50).Microseconds()}, res) 123 | } 124 | { 125 | res := bench.Stats(time.Minute) 126 | t.Logf("%+v", res) 127 | assert.Equal(t, BenchmarkStats{Requests: 60, RequestsSec: 1, AverageRespTime: 50000, 128 | MinRespTime: (time.Millisecond * 50).Microseconds(), MaxRespTime: (time.Millisecond * 50).Microseconds()}, res) 129 | } 130 | 131 | assert.Equal(t, 900, bench.data.Len()) 132 | } 133 | 134 | func TestBenchmarks_Handler(t *testing.T) { 135 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 136 | _, err := w.Write([]byte("blah blah")) 137 | time.Sleep(time.Millisecond * 50) 138 | require.NoError(t, err) 139 | }) 140 | 141 | bench := NewBenchmarks() 142 | ts := httptest.NewServer(bench.Handler(handler)) 143 | defer ts.Close() 144 | 145 | for i := 0; i < 100; i++ { 146 | resp, err := ts.Client().Get(ts.URL) 147 | require.NoError(t, err) 148 | assert.Equal(t, http.StatusOK, resp.StatusCode) 149 | } 150 | 151 | { 152 | res := bench.Stats(time.Minute) 153 | t.Logf("%+v", res) 154 | assert.Equal(t, 100, res.Requests) 155 | assert.True(t, res.RequestsSec <= 20 && res.RequestsSec >= 10) 156 | assert.InDelta(t, 50000, res.AverageRespTime, 10000) 157 | assert.InDelta(t, 50000, res.MinRespTime, 10000) 158 | assert.InDelta(t, 50000, res.MaxRespTime, 10000) 159 | assert.True(t, res.MaxRespTime >= res.MinRespTime) 160 | } 161 | 162 | { 163 | res := bench.Stats(time.Minute * 15) 164 | t.Logf("%+v", res) 165 | assert.Equal(t, 100, res.Requests) 166 | assert.True(t, res.RequestsSec <= 20 && res.RequestsSec >= 10, res.RequestsSec) 167 | assert.InDelta(t, 50000, res.AverageRespTime, 10000) 168 | assert.InDelta(t, 50000, res.MinRespTime, 10000) 169 | assert.InDelta(t, 50000, res.MaxRespTime, 10000) 170 | assert.True(t, res.MaxRespTime >= res.MinRespTime) 171 | } 172 | } 173 | 174 | func TestBenchmark_ConcurrentAccess(t *testing.T) { 175 | bench := NewBenchmarks() 176 | var wg sync.WaitGroup 177 | 178 | // simulate concurrent updates 179 | for i := 0; i < 100; i++ { 180 | wg.Add(1) 181 | go func(i int) { 182 | defer wg.Done() 183 | bench.update(time.Duration(i) * time.Millisecond) 184 | }(i) 185 | } 186 | 187 | // simulate concurrent stats reads while updating 188 | for i := 0; i < 10; i++ { 189 | wg.Add(1) 190 | go func() { 191 | defer wg.Done() 192 | stats := bench.Stats(time.Minute) 193 | require.GreaterOrEqual(t, stats.Requests, 0) 194 | }() 195 | } 196 | 197 | wg.Wait() 198 | 199 | stats := bench.Stats(time.Minute) 200 | assert.Equal(t, 100, stats.Requests) 201 | } 202 | 203 | func TestBenchmark_EdgeCases(t *testing.T) { 204 | bench := NewBenchmarks() 205 | 206 | t.Run("zero duration", func(t *testing.T) { 207 | bench.update(0) 208 | stats := bench.Stats(time.Minute) 209 | assert.Equal(t, int64(0), stats.MinRespTime) 210 | assert.Equal(t, int64(0), stats.MaxRespTime) 211 | }) 212 | 213 | t.Run("very large duration", func(t *testing.T) { 214 | bench.update(time.Hour) 215 | stats := bench.Stats(time.Minute) 216 | assert.Equal(t, time.Hour.Microseconds(), stats.MaxRespTime) 217 | }) 218 | 219 | t.Run("negative stats interval", func(t *testing.T) { 220 | stats := bench.Stats(-time.Minute) 221 | assert.Equal(t, BenchmarkStats{}, stats) 222 | }) 223 | } 224 | 225 | func TestBenchmark_TimeWindowBoundaries(t *testing.T) { 226 | bench := NewBenchmarks() 227 | now := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) 228 | 229 | bench.nowFn = func() time.Time { return now } 230 | 231 | // add data points exactly at minute boundaries 232 | for i := 0; i < 120; i++ { 233 | bench.nowFn = func() time.Time { 234 | return now.Add(time.Duration(i) * time.Second) 235 | } 236 | bench.update(time.Millisecond * 50) 237 | } 238 | 239 | tests := []struct { 240 | name string 241 | interval time.Duration 242 | want int // expected number of requests 243 | }{ 244 | {"exact minute", time.Minute, 60}, 245 | {"30 seconds", time.Second * 30, 30}, 246 | {"90 seconds", time.Second * 90, 90}, 247 | {"2 minutes", time.Minute * 2, 120}, 248 | } 249 | 250 | for _, tt := range tests { 251 | t.Run(tt.name, func(t *testing.T) { 252 | stats := bench.Stats(tt.interval) 253 | assert.Equal(t, tt.want, stats.Requests, "interval %v should have %d requests", tt.interval, tt.want) 254 | }) 255 | } 256 | } 257 | 258 | func TestBenchmark_CustomTimeRange(t *testing.T) { 259 | tests := []struct { 260 | name string 261 | timeRange time.Duration 262 | dataPoints int 263 | wantKept int 264 | }{ 265 | {"1 minute range", time.Minute, 120, 60}, 266 | {"5 minute range", time.Minute * 5, 400, 300}, 267 | {"custom 45s range", time.Second * 45, 100, 45}, 268 | } 269 | 270 | for _, tt := range tests { 271 | t.Run(tt.name, func(t *testing.T) { 272 | bench := NewBenchmarks().WithTimeRange(tt.timeRange) 273 | now := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) 274 | 275 | // Add data points 276 | for i := 0; i < tt.dataPoints; i++ { 277 | bench.nowFn = func() time.Time { 278 | return now.Add(time.Duration(i) * time.Second) 279 | } 280 | bench.update(time.Millisecond * 50) 281 | } 282 | 283 | assert.Equal(t, tt.wantKept, bench.data.Len(), 284 | "should keep only %d data points for %v time range", 285 | tt.wantKept, tt.timeRange) 286 | }) 287 | } 288 | } 289 | 290 | func TestBenchmark_VariableLoad(t *testing.T) { 291 | bench := NewBenchmarks() 292 | now := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) 293 | 294 | // simulate variable load pattern 295 | patterns := []struct { 296 | count int 297 | duration time.Duration 298 | }{ 299 | {10, time.Millisecond * 10}, // fast responses 300 | {5, time.Millisecond * 100}, // medium responses 301 | {2, time.Millisecond * 1000}, // slow responses 302 | } 303 | 304 | for i, p := range patterns { 305 | bench.nowFn = func() time.Time { 306 | return now.Add(time.Duration(i) * time.Second) 307 | } 308 | for j := 0; j < p.count; j++ { 309 | bench.update(p.duration) 310 | } 311 | } 312 | 313 | stats := bench.Stats(time.Minute) 314 | assert.Equal(t, 17, stats.Requests) // total requests across all patterns 315 | assert.Equal(t, int64(1000*1000), stats.MaxRespTime) // should be the max (1000ms = 1_000_000 microseconds) 316 | assert.Equal(t, int64(10*1000), stats.MinRespTime) // should be the min (10ms = 10_000 microseconds) 317 | } 318 | -------------------------------------------------------------------------------- /blackwords.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net/http" 7 | "strings" 8 | ) 9 | 10 | // BlackWords middleware doesn't allow some words in the request body 11 | func BlackWords(words ...string) func(http.Handler) http.Handler { 12 | 13 | return func(h http.Handler) http.Handler { 14 | fn := func(w http.ResponseWriter, r *http.Request) { 15 | 16 | if content, err := io.ReadAll(r.Body); err == nil { 17 | body := strings.ToLower(string(content)) 18 | r.Body = io.NopCloser(bytes.NewReader(content)) 19 | 20 | if body != "" { 21 | for _, word := range words { 22 | if strings.Contains(body, strings.ToLower(word)) { 23 | w.WriteHeader(http.StatusForbidden) 24 | RenderJSON(w, JSON{"error": "one of blacklisted words detected"}) 25 | return 26 | } 27 | } 28 | } 29 | } 30 | h.ServeHTTP(w, r) 31 | } 32 | return http.HandlerFunc(fn) 33 | } 34 | } 35 | 36 | // BlackWordsFn middleware uses func to get the list and doesn't allow some words in the request body 37 | func BlackWordsFn(fn func() []string) func(http.Handler) http.Handler { 38 | return BlackWords(fn()...) 39 | } 40 | -------------------------------------------------------------------------------- /blackwords_test.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestBlackwords(t *testing.T) { 15 | 16 | tbl := []struct { 17 | inp string 18 | code int 19 | }{ 20 | {"", 200}, 21 | {"blah blah body", 200}, 22 | {"blah blah body bad1", 403}, 23 | {"blah blah body bad", 200}, 24 | {"blah bad2 body bad", 403}, 25 | {`{"fld": 123, "aa": {"$where": {"aaa": 567}}}`, 403}, 26 | } 27 | 28 | bwMiddleware := BlackWords("bad1", "bad2", "$where") 29 | ts := httptest.NewServer(bwMiddleware(getTestHandlerBlah())) 30 | defer ts.Close() 31 | 32 | u := fmt.Sprintf("%s%s", ts.URL, "/something") 33 | 34 | client := http.Client{Timeout: 5 * time.Second} 35 | 36 | for n, tt := range tbl { 37 | tt := tt 38 | t.Run(fmt.Sprintf("test-%d", n), func(t *testing.T) { 39 | req, err := http.NewRequest("GET", u, bytes.NewBuffer([]byte(tt.inp))) 40 | assert.Nil(t, err) 41 | 42 | r, err := client.Do(req) 43 | assert.Nil(t, err) 44 | assert.Equal(t, tt.code, r.StatusCode) 45 | }) 46 | } 47 | } 48 | 49 | func TestBlackwordsFn(t *testing.T) { 50 | tbl := []struct { 51 | inp string 52 | code int 53 | }{ 54 | {"", 200}, 55 | {"blah blah body", 200}, 56 | {"blah blah body bad1", 403}, 57 | {"blah blah body bad", 200}, 58 | {"blah bad2 body bad", 403}, 59 | {`{"fld": 123, "aa": {"$where": {"aaa": 567}}}`, 403}, 60 | } 61 | 62 | bwMiddleware := BlackWordsFn(func() []string { 63 | return []string{"bad1", "bad2", "$where"} 64 | }) 65 | 66 | ts := httptest.NewServer(bwMiddleware(getTestHandlerBlah())) 67 | defer ts.Close() 68 | 69 | u := fmt.Sprintf("%s%s", ts.URL, "/something") 70 | 71 | client := http.Client{Timeout: 5 * time.Second} 72 | 73 | for n, tt := range tbl { 74 | tt := tt 75 | t.Run(fmt.Sprintf("test-%d", n), func(t *testing.T) { 76 | req, err := http.NewRequest("GET", u, bytes.NewBuffer([]byte(tt.inp))) 77 | assert.Nil(t, err) 78 | 79 | r, err := client.Do(req) 80 | assert.Nil(t, err) 81 | assert.Equal(t, tt.code, r.StatusCode) 82 | }) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /cache_control.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "crypto/sha1" //nolint not used for cryptography 5 | "fmt" 6 | "net/http" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | // CacheControl is a middleware setting cache expiration. Using url+version for etag 12 | func CacheControl(expiration time.Duration, version string) func(http.Handler) http.Handler { 13 | 14 | etag := func(r *http.Request, version string) string { 15 | s := fmt.Sprintf("%s:%s", version, r.URL.String()) 16 | return fmt.Sprintf("%x", sha1.Sum([]byte(s))) //nolint 17 | } 18 | 19 | return func(h http.Handler) http.Handler { 20 | fn := func(w http.ResponseWriter, r *http.Request) { 21 | e := `"` + etag(r, version) + `"` 22 | w.Header().Set("Etag", e) 23 | w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, no-cache", int(expiration.Seconds()))) 24 | 25 | if match := r.Header.Get("If-None-Match"); match != "" { 26 | if strings.Contains(match, e) { 27 | w.WriteHeader(http.StatusNotModified) 28 | return 29 | } 30 | } 31 | h.ServeHTTP(w, r) 32 | } 33 | return http.HandlerFunc(fn) 34 | } 35 | } 36 | 37 | // CacheControlDynamic is a middleware setting cache expiration. Using url+ func(r) for etag 38 | func CacheControlDynamic(expiration time.Duration, versionFn func(r *http.Request) string) func(http.Handler) http.Handler { 39 | 40 | etag := func(r *http.Request, version string) string { 41 | s := fmt.Sprintf("%s:%s", version, r.URL.String()) 42 | return fmt.Sprintf("%x", sha1.Sum([]byte(s))) //nolint 43 | } 44 | 45 | return func(h http.Handler) http.Handler { 46 | fn := func(w http.ResponseWriter, r *http.Request) { 47 | 48 | e := `"` + etag(r, versionFn(r)) + `"` 49 | w.Header().Set("Etag", e) 50 | w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, no-cache", int(expiration.Seconds()))) 51 | 52 | if match := r.Header.Get("If-None-Match"); match != "" { 53 | if strings.Contains(match, e) { 54 | w.WriteHeader(http.StatusNotModified) 55 | return 56 | } 57 | } 58 | h.ServeHTTP(w, r) 59 | } 60 | return http.HandlerFunc(fn) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /cache_control_test.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strconv" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestRest_cacheControl(t *testing.T) { 14 | 15 | tbl := []struct { 16 | url string 17 | version string 18 | exp time.Duration 19 | etag string 20 | maxAge int 21 | }{ 22 | {"http://example.com/foo", "v1", time.Hour, "b433be1ea19edaee9dc92ca4b895b6bdf3c058cb", 3600}, 23 | {"http://example.com/foo2", "v1", 10 * time.Hour, "6d8466aef3246c1057452561acddf7ad9d0d99e0", 36000}, 24 | {"http://example.com/foo", "v2", time.Hour, "481700c52aab0dfbca99f3ffc2a4fbb27884c114", 3600}, 25 | {"https://example.com/foo", "v2", time.Hour, "bebd4f1b87f474792c4e75e5affe31fbf67f5778", 3600}, 26 | } 27 | 28 | for i, tt := range tbl { 29 | tt := tt 30 | t.Run(strconv.Itoa(i), func(t *testing.T) { 31 | req := httptest.NewRequest("GET", tt.url, http.NoBody) 32 | w := httptest.NewRecorder() 33 | 34 | h := CacheControl(tt.exp, tt.version)(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) 35 | h.ServeHTTP(w, req) 36 | resp := w.Result() 37 | assert.Equal(t, http.StatusOK, resp.StatusCode) 38 | t.Logf("%+v", resp.Header) 39 | assert.Equal(t, `"`+tt.etag+`"`, resp.Header.Get("Etag")) 40 | assert.Equal(t, `max-age=`+strconv.Itoa(int(tt.exp.Seconds()))+", no-cache", resp.Header.Get("Cache-Control")) 41 | 42 | }) 43 | } 44 | 45 | } 46 | 47 | func TestCacheControlDynamic(t *testing.T) { 48 | tbl := []struct { 49 | url string 50 | version string 51 | exp time.Duration 52 | etag string 53 | maxAge int 54 | }{ 55 | {"http://example.com/foo", "v1", time.Hour, "b433be1ea19edaee9dc92ca4b895b6bdf3c058cb", 3600}, 56 | {"http://example.com/foo2", "v1", 10 * time.Hour, "6d8466aef3246c1057452561acddf7ad9d0d99e0", 36000}, 57 | {"http://example.com/foo", "v2", time.Hour, "481700c52aab0dfbca99f3ffc2a4fbb27884c114", 3600}, 58 | {"https://example.com/foo", "v2", time.Hour, "bebd4f1b87f474792c4e75e5affe31fbf67f5778", 3600}, 59 | } 60 | 61 | for i, tt := range tbl { 62 | tt := tt 63 | t.Run(strconv.Itoa(i), func(t *testing.T) { 64 | req := httptest.NewRequest("GET", tt.url, http.NoBody) 65 | req.Header.Set("key", tt.version) 66 | w := httptest.NewRecorder() 67 | 68 | fn := func(r *http.Request) string { 69 | return r.Header.Get("key") 70 | } 71 | h := CacheControlDynamic(tt.exp, fn)(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) 72 | h.ServeHTTP(w, req) 73 | resp := w.Result() 74 | assert.Equal(t, http.StatusOK, resp.StatusCode) 75 | t.Logf("%+v", resp.Header) 76 | assert.Equal(t, `"`+tt.etag+`"`, resp.Header.Get("Etag")) 77 | assert.Equal(t, `max-age=`+strconv.Itoa(int(tt.exp.Seconds()))+", no-cache", resp.Header.Get("Cache-Control")) 78 | 79 | }) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /depricattion.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | // Deprecation adds a header 'Deprecation: version="version", date="date" header' 10 | // see https://tools.ietf.org/id/draft-dalal-deprecation-header-00.html 11 | func Deprecation(version string, date time.Time) func(http.Handler) http.Handler { 12 | f := func(h http.Handler) http.Handler { 13 | fn := func(w http.ResponseWriter, r *http.Request) { 14 | headerVal := fmt.Sprintf("version=%q, date=%q", version, date.Format(time.RFC3339)) 15 | w.Header().Set("Deprecation", headerVal) 16 | h.ServeHTTP(w, r) 17 | } 18 | return http.HandlerFunc(fn) 19 | } 20 | return f 21 | } 22 | -------------------------------------------------------------------------------- /depricattion_test.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestDeprecated(t *testing.T) { 15 | deprecated := Deprecation("1.0.2", time.Date(2020, 9, 1, 18, 32, 0, 0, time.UTC)) 16 | ts := httptest.NewServer(deprecated(getTestHandlerBlah())) 17 | defer ts.Close() 18 | 19 | u := fmt.Sprintf("%s%s", ts.URL, "/something") 20 | 21 | client := http.Client{Timeout: 5 * time.Second} 22 | req, err := http.NewRequest("GET", u, http.NoBody) 23 | require.NoError(t, err) 24 | 25 | r, err := client.Do(req) 26 | require.NoError(t, err) 27 | defer r.Body.Close() 28 | assert.Equal(t, 200, r.StatusCode) 29 | 30 | assert.Equal(t, `version="1.0.2", date="2020-09-01T18:32:00Z"`, r.Header.Get("Deprecation")) 31 | } 32 | -------------------------------------------------------------------------------- /file_server.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "os" 8 | "path/filepath" 9 | "strings" 10 | ) 11 | 12 | // FS provides http.FileServer handler to serve static files from a http.FileSystem, 13 | // prevents directory listing by default and supports spa-friendly mode (off by default) returning /index.html on 404. 14 | // - public defines base path of the url, i.e. for http://example.com/static/* it should be /static 15 | // - local for the local path to the root of the served directory 16 | // - notFound is the reader for the custom 404 html, can be nil for default 17 | type FS struct { 18 | public, root string 19 | notFound io.Reader 20 | isSpa bool 21 | enableListing bool 22 | handler http.HandlerFunc 23 | } 24 | 25 | // NewFileServer creates file server with optional spa mode and optional direcroty listing (disabled by default) 26 | func NewFileServer(public, local string, options ...FsOpt) (*FS, error) { 27 | res := FS{ 28 | public: public, 29 | notFound: nil, 30 | isSpa: false, 31 | enableListing: false, 32 | } 33 | 34 | root, err := filepath.Abs(local) 35 | if err != nil { 36 | return nil, fmt.Errorf("can't get absolute path for %s: %w", local, err) 37 | } 38 | res.root = root 39 | 40 | if _, err = os.Stat(root); os.IsNotExist(err) { 41 | return nil, fmt.Errorf("local path %s doesn't exist: %w", root, err) 42 | } 43 | 44 | for _, opt := range options { 45 | err = opt(&res) 46 | if err != nil { 47 | return nil, err 48 | } 49 | } 50 | 51 | cfs := customFS{ 52 | fs: http.Dir(root), 53 | spa: res.isSpa, 54 | listing: res.enableListing, 55 | } 56 | f := http.StripPrefix(public, http.FileServer(cfs)) 57 | res.handler = func(w http.ResponseWriter, r *http.Request) { f.ServeHTTP(w, r) } 58 | 59 | if !res.enableListing { 60 | h, err := custom404Handler(f, res.notFound) 61 | if err != nil { 62 | return nil, err 63 | } 64 | res.handler = func(w http.ResponseWriter, r *http.Request) { h.ServeHTTP(w, r) } 65 | } 66 | 67 | return &res, nil 68 | } 69 | 70 | // FileServer is a shortcut for making FS with listing disabled and the custom noFound reader (can be nil). 71 | // Deprecated: the method is for back-compatibility only and user should use the universal NewFileServer instead 72 | func FileServer(public, local string, notFound io.Reader) (http.Handler, error) { 73 | return NewFileServer(public, local, FsOptCustom404(notFound)) 74 | } 75 | 76 | // FileServerSPA is a shortcut for making FS with SPA-friendly handling of 404, listing disabled and the custom noFound reader (can be nil). 77 | // Deprecated: the method is for back-compatibility only and user should use the universal NewFileServer instead 78 | func FileServerSPA(public, local string, notFound io.Reader) (http.Handler, error) { 79 | return NewFileServer(public, local, FsOptCustom404(notFound), FsOptSPA) 80 | } 81 | 82 | // ServeHTTP makes FileServer compatible with http.Handler interface 83 | func (fs *FS) ServeHTTP(w http.ResponseWriter, r *http.Request) { 84 | fs.handler(w, r) 85 | } 86 | 87 | // FsOpt defines functional option type 88 | type FsOpt func(fs *FS) error 89 | 90 | // FsOptSPA turns on SPA mode returning "/index.html" on not-found 91 | func FsOptSPA(fs *FS) error { 92 | fs.isSpa = true 93 | return nil 94 | } 95 | 96 | // FsOptListing turns on directory listing 97 | func FsOptListing(fs *FS) error { 98 | fs.enableListing = true 99 | return nil 100 | } 101 | 102 | // FsOptCustom404 sets custom 404 reader 103 | func FsOptCustom404(fr io.Reader) FsOpt { 104 | return func(fs *FS) error { 105 | fs.notFound = fr 106 | return nil 107 | } 108 | } 109 | 110 | // customFS wraps http.FileSystem with spa and no-listing optional support 111 | type customFS struct { 112 | fs http.FileSystem 113 | spa bool 114 | listing bool 115 | } 116 | 117 | // Open file on FS, for directory enforce index.html and fail on a missing index 118 | func (cfs customFS) Open(name string) (http.File, error) { 119 | 120 | f, err := cfs.fs.Open(name) 121 | if err != nil { 122 | if cfs.spa { 123 | return cfs.fs.Open("/index.html") 124 | } 125 | return nil, err 126 | } 127 | 128 | finfo, err := f.Stat() 129 | if err != nil { 130 | return nil, err 131 | } 132 | 133 | if finfo.IsDir() { 134 | index := strings.TrimSuffix(name, "/") + "/index.html" 135 | if _, err := cfs.fs.Open(index); err == nil { // index.html will be served if found 136 | return f, nil 137 | } 138 | // no index.html in directory 139 | if !cfs.listing { // listing disabled 140 | if _, err := cfs.fs.Open(index); err != nil { 141 | return nil, err 142 | } 143 | } 144 | } 145 | 146 | return f, nil 147 | } 148 | 149 | // respWriter404 intercept Write to provide custom 404 response 150 | type respWriter404 struct { 151 | http.ResponseWriter 152 | status int 153 | msg []byte 154 | } 155 | 156 | func (w *respWriter404) WriteHeader(status int) { 157 | w.status = status 158 | if status == http.StatusNotFound { 159 | w.Header().Set("Content-Type", "text/html; charset=utf-8") 160 | } 161 | w.ResponseWriter.WriteHeader(status) 162 | } 163 | 164 | func (w *respWriter404) Write(p []byte) (n int, err error) { 165 | if w.status != http.StatusNotFound || w.msg == nil { 166 | return w.ResponseWriter.Write(p) 167 | } 168 | _, err = w.ResponseWriter.Write(w.msg) 169 | return len(p), err 170 | } 171 | 172 | func custom404Handler(next http.Handler, notFound io.Reader) (http.Handler, error) { 173 | if notFound == nil { 174 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r) }), nil 175 | } 176 | 177 | body, err := io.ReadAll(notFound) 178 | if err != nil { 179 | return nil, err 180 | } 181 | 182 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 183 | next.ServeHTTP(&respWriter404{ResponseWriter: w, msg: body}, r) 184 | }), nil 185 | } 186 | -------------------------------------------------------------------------------- /file_server_test.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "strconv" 9 | "strings" 10 | "testing" 11 | "time" 12 | 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | 16 | "github.com/go-pkgz/rest/logger" 17 | ) 18 | 19 | func TestFileServerDefault(t *testing.T) { 20 | fh1, err := NewFileServer("/static", "./testdata/root") 21 | require.NoError(t, err) 22 | 23 | fh2, err := FileServer("/static", "./testdata/root", nil) 24 | require.NoError(t, err) 25 | 26 | ts1 := httptest.NewServer(logger.Logger(fh1)) 27 | defer ts1.Close() 28 | ts2 := httptest.NewServer(logger.Logger(fh2)) 29 | defer ts2.Close() 30 | 31 | client := http.Client{Timeout: 599 * time.Second} 32 | 33 | tbl := []struct { 34 | req string 35 | body string 36 | status int 37 | }{ 38 | {"/static", "testdata/index.html", 200}, 39 | {"/static/index.html", "testdata/index.html", 200}, 40 | {"/static/xyz.js", "testdata/xyz.js", 200}, 41 | {"/static/1/", "", 404}, 42 | {"/static/1/nothing", "", 404}, 43 | {"/static/1/f1.html", "testdata/1/f1.html", 200}, 44 | {"/static/2/", "testdata/2/index.html", 200}, 45 | {"/static/2", "testdata/2/index.html", 200}, 46 | {"/static/2/index.html", "testdata/2/index.html", 200}, 47 | {"/static/2/index", "", 404}, 48 | {"/static/2/f123.txt", "testdata/2/f123.txt", 200}, 49 | {"/static/1/../", "testdata/index.html", 200}, 50 | {"/static/../", "testdata/index.html", 200}, 51 | {"/static/../../", "testdata/index.html", 200}, 52 | {"/static/../../../", "testdata/index.html", 200}, 53 | {"/static/%2e%2e%2f%2e%2e%2f%2e%2e%2f/", "testdata/index.html", 200}, 54 | } 55 | 56 | for i, tt := range tbl { 57 | tt := tt 58 | t.Run(strconv.Itoa(i), func(t *testing.T) { 59 | for _, ts := range []*httptest.Server{ts1, ts2} { 60 | req, err := http.NewRequest("GET", ts.URL+tt.req, http.NoBody) 61 | require.NoError(t, err) 62 | resp, err := client.Do(req) 63 | require.NoError(t, err) 64 | t.Logf("headers: %v", resp.Header) 65 | assert.Equal(t, tt.status, resp.StatusCode) 66 | if resp.StatusCode == http.StatusNotFound { 67 | msg, e := io.ReadAll(resp.Body) 68 | require.NoError(t, e) 69 | assert.Equal(t, "404 page not found\n", string(msg)) 70 | return 71 | } 72 | body, err := io.ReadAll(resp.Body) 73 | require.NoError(t, err) 74 | assert.Equal(t, tt.body, string(body)) 75 | } 76 | }) 77 | } 78 | } 79 | 80 | func TestFileServerWithListing(t *testing.T) { 81 | fh, err := NewFileServer("/static", "./testdata/root", FsOptListing) 82 | require.NoError(t, err) 83 | ts := httptest.NewServer(logger.Logger(fh)) 84 | defer ts.Close() 85 | client := http.Client{Timeout: 599 * time.Second} 86 | 87 | { 88 | req, err := http.NewRequest("GET", ts.URL+"/static/1", http.NoBody) 89 | require.NoError(t, err) 90 | resp, err := client.Do(req) 91 | require.NoError(t, err) 92 | assert.Equal(t, http.StatusOK, resp.StatusCode) 93 | msg, err := io.ReadAll(resp.Body) 94 | require.NoError(t, err) 95 | exp := `
96 | f1.html 97 | f2.html 98 |99 | ` 100 | assert.Contains(t, string(msg), exp) 101 | } 102 | 103 | { 104 | req, err := http.NewRequest("GET", ts.URL+"/static/xyz.js", http.NoBody) 105 | require.NoError(t, err) 106 | resp, err := client.Do(req) 107 | require.NoError(t, err) 108 | assert.Equal(t, http.StatusOK, resp.StatusCode) 109 | msg, err := io.ReadAll(resp.Body) 110 | require.NoError(t, err) 111 | assert.Equal(t, "testdata/xyz.js", string(msg)) 112 | assert.True(t, strings.Contains(resp.Header.Get("Content-Type"), "javascript"), resp.Header.Get("Content-Type")) 113 | } 114 | 115 | { 116 | req, err := http.NewRequest("GET", ts.URL+"/static/no-such-thing.html", http.NoBody) 117 | require.NoError(t, err) 118 | resp, err := client.Do(req) 119 | require.NoError(t, err) 120 | assert.Equal(t, http.StatusNotFound, resp.StatusCode) 121 | assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type")) 122 | } 123 | } 124 | 125 | func TestFileServer_Custom404(t *testing.T) { 126 | nf := FsOptCustom404(bytes.NewBufferString("custom 404")) 127 | fh, err := NewFileServer("/static", "./testdata/root", nf) 128 | require.NoError(t, err) 129 | ts := httptest.NewServer(logger.Logger(fh)) 130 | defer ts.Close() 131 | client := http.Client{Timeout: 599 * time.Second} 132 | 133 | { 134 | req, err := http.NewRequest("GET", ts.URL+"/static/xyz.js", http.NoBody) 135 | require.NoError(t, err) 136 | resp, err := client.Do(req) 137 | require.NoError(t, err) 138 | assert.Equal(t, http.StatusOK, resp.StatusCode) 139 | msg, err := io.ReadAll(resp.Body) 140 | require.NoError(t, err) 141 | assert.Equal(t, "testdata/xyz.js", string(msg)) 142 | assert.True(t, strings.Contains(resp.Header.Get("Content-Type"), "javascript"), resp.Header.Get("Content-Type")) 143 | } 144 | 145 | { 146 | req, err := http.NewRequest("GET", ts.URL+"/static/nofile.js", http.NoBody) 147 | require.NoError(t, err) 148 | resp, err := client.Do(req) 149 | require.NoError(t, err) 150 | assert.Equal(t, http.StatusNotFound, resp.StatusCode) 151 | msg, err := io.ReadAll(resp.Body) 152 | require.NoError(t, err) 153 | assert.Equal(t, "custom 404", string(msg)) 154 | assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type")) 155 | } 156 | 157 | { 158 | req, err := http.NewRequest("GET", ts.URL+"/xyz.html", http.NoBody) 159 | require.NoError(t, err) 160 | resp, err := client.Do(req) 161 | require.NoError(t, err) 162 | assert.Equal(t, http.StatusNotFound, resp.StatusCode) 163 | msg, err := io.ReadAll(resp.Body) 164 | require.NoError(t, err) 165 | assert.Equal(t, "custom 404", string(msg)) 166 | assert.Equal(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type")) 167 | } 168 | 169 | { 170 | req, err := http.NewRequest("GET", ts.URL+"/static/xyz.js", http.NoBody) 171 | require.NoError(t, err) 172 | resp, err := client.Do(req) 173 | require.NoError(t, err) 174 | assert.Equal(t, http.StatusOK, resp.StatusCode) 175 | msg, err := io.ReadAll(resp.Body) 176 | require.NoError(t, err) 177 | assert.Equal(t, "testdata/xyz.js", string(msg)) 178 | } 179 | } 180 | 181 | func TestFileServerSPA(t *testing.T) { 182 | fh1, err := NewFileServer("/static", "./testdata/root", FsOptSPA) 183 | require.NoError(t, err) 184 | fh2, err := FileServerSPA("/static", "./testdata/root", nil) 185 | require.NoError(t, err) 186 | 187 | ts1 := httptest.NewServer(logger.Logger(fh1)) 188 | defer ts1.Close() 189 | ts2 := httptest.NewServer(logger.Logger(fh2)) 190 | defer ts2.Close() 191 | client := http.Client{Timeout: 599 * time.Second} 192 | 193 | tbl := []struct { 194 | req string 195 | body string 196 | status int 197 | }{ 198 | {"/static/blah", "testdata/index.html", 200}, 199 | {"/static/blah/foo/123.html", "testdata/index.html", 200}, 200 | {"/static", "testdata/index.html", 200}, 201 | {"/static/index.html", "testdata/index.html", 200}, 202 | {"/static/xyz.js", "testdata/xyz.js", 200}, 203 | {"/static/1/", "", 404}, 204 | {"/static/1/nothing", "testdata/index.html", 200}, 205 | {"/static/1/f1.html", "testdata/1/f1.html", 200}, 206 | {"/static/2/", "testdata/2/index.html", 200}, 207 | {"/static/2", "testdata/2/index.html", 200}, 208 | {"/static/2/index.html", "testdata/2/index.html", 200}, 209 | {"/static/2/index", "testdata/index.html", 200}, 210 | {"/static/2/f123.txt", "testdata/2/f123.txt", 200}, 211 | {"/static/1/../", "testdata/index.html", 200}, 212 | {"/static/../", "testdata/index.html", 200}, 213 | {"/static/../../", "testdata/index.html", 200}, 214 | {"/static/../../../", "testdata/index.html", 200}, 215 | {"/static/%2e%2e%2f%2e%2e%2f%2e%2e%2f/", "testdata/index.html", 200}, 216 | } 217 | 218 | for i, tt := range tbl { 219 | tt := tt 220 | t.Run(strconv.Itoa(i), func(t *testing.T) { 221 | for _, ts := range []*httptest.Server{ts1, ts2} { 222 | req, err := http.NewRequest("GET", ts.URL+tt.req, http.NoBody) 223 | require.NoError(t, err) 224 | resp, err := client.Do(req) 225 | require.NoError(t, err) 226 | t.Logf("headers: %v", resp.Header) 227 | assert.Equal(t, tt.status, resp.StatusCode) 228 | if resp.StatusCode == http.StatusNotFound { 229 | msg, e := io.ReadAll(resp.Body) 230 | require.NoError(t, e) 231 | assert.Equal(t, "404 page not found\n", string(msg)) 232 | return 233 | } 234 | body, err := io.ReadAll(resp.Body) 235 | require.NoError(t, err) 236 | assert.Equal(t, tt.body, string(body)) 237 | } 238 | }) 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-pkgz/rest 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/stretchr/testify v1.10.0 7 | golang.org/x/crypto v0.37.0 8 | ) 9 | 10 | require ( 11 | github.com/davecgh/go-spew v1.1.1 // indirect 12 | github.com/pmezard/go-difflib v1.0.0 // indirect 13 | golang.org/x/sys v0.32.0 // indirect 14 | gopkg.in/yaml.v3 v3.0.1 // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 6 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 7 | golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= 8 | golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= 9 | golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= 10 | golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= 11 | golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= 12 | golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 13 | golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= 14 | golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 15 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 16 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 17 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 18 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 19 | -------------------------------------------------------------------------------- /gzip.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "compress/gzip" 5 | "io" 6 | "net/http" 7 | "strings" 8 | "sync" 9 | ) 10 | 11 | var gzDefaultContentTypes = []string{ 12 | "text/css", 13 | "text/javascript", 14 | "text/xml", 15 | "text/html", 16 | "text/plain", 17 | "application/javascript", 18 | "application/x-javascript", 19 | "application/json", 20 | } 21 | 22 | var gzPool = sync.Pool{ 23 | New: func() interface{} { return gzip.NewWriter(io.Discard) }, 24 | } 25 | 26 | type gzipResponseWriter struct { 27 | io.Writer 28 | http.ResponseWriter 29 | } 30 | 31 | func (w *gzipResponseWriter) WriteHeader(status int) { 32 | w.Header().Del("Content-Length") 33 | w.ResponseWriter.WriteHeader(status) 34 | } 35 | 36 | func (w *gzipResponseWriter) Write(b []byte) (int, error) { 37 | return w.Writer.Write(b) 38 | } 39 | 40 | // Gzip is a middleware compressing response 41 | func Gzip(contentTypes ...string) func(http.Handler) http.Handler { 42 | 43 | gzCts := gzDefaultContentTypes 44 | if len(contentTypes) > 0 { 45 | gzCts = contentTypes 46 | } 47 | 48 | contentType := func(r *http.Request) string { 49 | result := r.Header.Get("Content-type") 50 | if result == "" { 51 | return "application/octet-stream" 52 | } 53 | return result 54 | } 55 | 56 | f := func(next http.Handler) http.Handler { 57 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 58 | if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { 59 | next.ServeHTTP(w, r) 60 | return 61 | } 62 | 63 | var gzOk bool 64 | ctype := contentType(r) 65 | for _, c := range gzCts { 66 | if strings.HasPrefix(strings.ToLower(ctype), strings.ToLower(c)) { 67 | gzOk = true 68 | break 69 | } 70 | } 71 | 72 | if !gzOk { 73 | next.ServeHTTP(w, r) 74 | return 75 | } 76 | 77 | w.Header().Set("Content-Encoding", "gzip") 78 | gz := gzPool.Get().(*gzip.Writer) 79 | defer gzPool.Put(gz) 80 | 81 | gz.Reset(w) 82 | defer gz.Close() 83 | 84 | next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r) 85 | }) 86 | } 87 | return f 88 | } 89 | -------------------------------------------------------------------------------- /gzip_test.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "strings" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestGzipCustom(t *testing.T) { 17 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 18 | _, err := w.Write([]byte("Lorem Ipsum is simply dummy text of the printing and typesetting industry. " + 19 | "Lorem Ipsum has been the industry’s standard dummy text ever since the 1500s, when an unknown printer took " + 20 | "a galley of type and scrambled it to make a type specimen book. It has survived not only five centuries," + 21 | " but also the leap into electronic typesetting, remaining essentially unchanged. It was popularized" + 22 | " in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, " + 23 | "and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.")) 24 | require.NoError(t, err) 25 | }) 26 | ts := httptest.NewServer(Gzip("text/plain", "text/html")(handler)) 27 | defer ts.Close() 28 | 29 | client := http.Client{} 30 | 31 | { 32 | req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody) 33 | require.NoError(t, err) 34 | req.Header.Set("Accept-Encoding", "gzip") 35 | req.Header.Set("Content-Type", "text/plain; charset=utf-8") 36 | resp, err := client.Do(req) 37 | require.NoError(t, err) 38 | assert.Equal(t, 200, resp.StatusCode) 39 | defer resp.Body.Close() 40 | b, err := io.ReadAll(resp.Body) 41 | assert.NoError(t, err) 42 | assert.Equal(t, 357, len(b), "compressed size") 43 | 44 | gzr, err := gzip.NewReader(bytes.NewBuffer(b)) 45 | require.NoError(t, err) 46 | b, err = io.ReadAll(gzr) 47 | require.NoError(t, err) 48 | assert.True(t, strings.HasPrefix(string(b), "Lorem Ipsum"), string(b)) 49 | } 50 | 51 | { 52 | req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody) 53 | require.NoError(t, err) 54 | req.Header.Set("Accept-Encoding", "gzip") 55 | req.Header.Set("Content-Type", "something") 56 | resp, err := client.Do(req) 57 | require.NoError(t, err) 58 | assert.Equal(t, 200, resp.StatusCode) 59 | defer resp.Body.Close() 60 | b, err := io.ReadAll(resp.Body) 61 | assert.NoError(t, err) 62 | assert.Equal(t, 576, len(b), "uncompressed size") 63 | } 64 | 65 | } 66 | 67 | func TestGzipDefault(t *testing.T) { 68 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 69 | _, err := w.Write([]byte("Lorem Ipsum is simply dummy text of the printing and typesetting industry. " + 70 | "Lorem Ipsum has been the industry’s standard dummy text ever since the 1500s, when an unknown printer took " + 71 | "a galley of type and scrambled it to make a type specimen book. It has survived not only five centuries," + 72 | " but also the leap into electronic typesetting, remaining essentially unchanged. It was popularized" + 73 | " in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, " + 74 | "and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.")) 75 | require.NoError(t, err) 76 | }) 77 | ts := httptest.NewServer(Gzip()(handler)) 78 | defer ts.Close() 79 | 80 | client := http.Client{} 81 | 82 | { 83 | req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody) 84 | require.NoError(t, err) 85 | req.Header.Set("Accept-Encoding", "gzip") 86 | req.Header.Set("Content-Type", "text/plain") 87 | resp, err := client.Do(req) 88 | require.NoError(t, err) 89 | assert.Equal(t, 200, resp.StatusCode) 90 | defer resp.Body.Close() 91 | b, err := io.ReadAll(resp.Body) 92 | assert.NoError(t, err) 93 | assert.Equal(t, 357, len(b), "compressed size") 94 | 95 | gzr, err := gzip.NewReader(bytes.NewBuffer(b)) 96 | require.NoError(t, err) 97 | b, err = io.ReadAll(gzr) 98 | require.NoError(t, err) 99 | assert.True(t, strings.HasPrefix(string(b), "Lorem Ipsum"), string(b)) 100 | } 101 | 102 | { 103 | req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody) 104 | require.NoError(t, err) 105 | resp, err := client.Do(req) 106 | require.Nil(t, err) 107 | assert.Equal(t, 200, resp.StatusCode) 108 | defer resp.Body.Close() 109 | b, err := io.ReadAll(resp.Body) 110 | assert.NoError(t, err) 111 | assert.Equal(t, 576, len(b), "uncompressed size") 112 | } 113 | 114 | { 115 | req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody) 116 | require.NoError(t, err) 117 | req.Header.Set("Accept-Encoding", "gzip") 118 | req.Header.Set("Content-Type", "something") 119 | resp, err := client.Do(req) 120 | require.NoError(t, err) 121 | assert.Equal(t, 200, resp.StatusCode) 122 | defer resp.Body.Close() 123 | b, err := io.ReadAll(resp.Body) 124 | assert.NoError(t, err) 125 | assert.Equal(t, 576, len(b), "uncompressed size") 126 | } 127 | 128 | } 129 | -------------------------------------------------------------------------------- /httperrors.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "net/url" 8 | "runtime" 9 | "strings" 10 | 11 | "github.com/go-pkgz/rest/logger" 12 | ) 13 | 14 | // ErrorLogger wraps logger.Backend 15 | type ErrorLogger struct { 16 | l logger.Backend 17 | } 18 | 19 | // NewErrorLogger creates ErrorLogger for given Backend 20 | func NewErrorLogger(l logger.Backend) *ErrorLogger { 21 | return &ErrorLogger{l: l} 22 | } 23 | 24 | // Log sends json error message {error: msg} with error code and logging error and caller 25 | func (e *ErrorLogger) Log(w http.ResponseWriter, r *http.Request, httpCode int, err error, msg ...string) { 26 | m := "" 27 | if len(msg) > 0 { 28 | m = strings.Join(msg, ". ") 29 | } 30 | if e.l != nil { 31 | e.l.Logf("%s", errDetailsMsg(r, httpCode, err, m)) 32 | } 33 | renderJSONWithStatus(w, JSON{"error": m}, httpCode) 34 | } 35 | 36 | // SendErrorJSON sends {error: msg} with error code and logging error and caller 37 | func SendErrorJSON(w http.ResponseWriter, r *http.Request, l logger.Backend, code int, err error, msg string) { 38 | if l != nil { 39 | l.Logf("%s", errDetailsMsg(r, code, err, msg)) 40 | } 41 | renderJSONWithStatus(w, JSON{"error": msg}, code) 42 | } 43 | 44 | func errDetailsMsg(r *http.Request, code int, err error, msg string) string { 45 | 46 | q := r.URL.String() 47 | if qun, e := url.QueryUnescape(q); e == nil { 48 | q = qun 49 | } 50 | 51 | srcFileInfo := "" 52 | if pc, file, line, ok := runtime.Caller(2); ok { 53 | fnameElems := strings.Split(file, "/") 54 | funcNameElems := strings.Split(runtime.FuncForPC(pc).Name(), "/") 55 | srcFileInfo = fmt.Sprintf(" [caused by %s:%d %s]", strings.Join(fnameElems[len(fnameElems)-3:], "/"), 56 | line, funcNameElems[len(funcNameElems)-1]) 57 | } 58 | 59 | remoteIP := r.RemoteAddr 60 | if pos := strings.Index(remoteIP, ":"); pos >= 0 { 61 | remoteIP = remoteIP[:pos] 62 | } 63 | if err == nil { 64 | err = errors.New("no error") 65 | } 66 | return fmt.Sprintf("%s - %v - %d - %s - %s%s", msg, err, code, remoteIP, q, srcFileInfo) 67 | } 68 | -------------------------------------------------------------------------------- /httperrors_test.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestSendErrorJSON(t *testing.T) { 15 | l := &mockLgr{} 16 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 17 | if r.URL.Path == "/error" { 18 | t.Log("http err request", r.URL) 19 | SendErrorJSON(w, r, l, 500, errors.New("error 500"), "error details 123456") 20 | return 21 | } 22 | w.WriteHeader(404) 23 | })) 24 | defer ts.Close() 25 | 26 | resp, err := http.Get(ts.URL + "/error") 27 | require.Nil(t, err) 28 | defer resp.Body.Close() 29 | 30 | body, err := io.ReadAll(resp.Body) 31 | require.Nil(t, err) 32 | assert.Equal(t, 500, resp.StatusCode) 33 | assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("content-type")) 34 | assert.Equal(t, `{"error":"error details 123456"}`+"\n", string(body)) 35 | t.Log(l.buf.String()) 36 | } 37 | 38 | func TestErrorDetailsMsg(t *testing.T) { 39 | callerFn := func() { 40 | req, err := http.NewRequest("GET", "https://example.com/test?k1=v1&k2=v2", http.NoBody) 41 | require.Nil(t, err) 42 | req.RemoteAddr = "1.2.3.4" 43 | msg := errDetailsMsg(req, 500, errors.New("error 500"), "error details 123456") 44 | assert.Contains(t, msg, "error details 123456 - error 500 - 500 - 1.2.3.4 - https://example."+ 45 | "com/test?k1=v1&k2=v2 [caused by") 46 | assert.Contains(t, msg, "rest/httperrors_test.go:49 rest.TestErrorDetailsMsg]", msg) 47 | 48 | } 49 | callerFn() 50 | } 51 | 52 | func TestErrorDetailsMsgNoError(t *testing.T) { 53 | callerFn := func() { 54 | req, err := http.NewRequest("GET", "https://example.com/test?k1=v1&k2=v2", http.NoBody) 55 | require.Nil(t, err) 56 | req.RemoteAddr = "1.2.3.4" 57 | msg := errDetailsMsg(req, 500, nil, "error details 123456") 58 | assert.Contains(t, msg, "error details 123456 - no error - 500 - 1.2.3.4 - https://example.com/test?k1=v1&k2=v2 [caused by") 59 | assert.Contains(t, msg, "rest/httperrors_test.go:61 rest.TestErrorDetailsMsgNoError]", msg) 60 | } 61 | callerFn() 62 | } 63 | 64 | func TestErrorLogger_Log(t *testing.T) { 65 | l := &mockLgr{} 66 | errLogger := NewErrorLogger(l) 67 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 68 | if r.URL.Path == "/error" { 69 | t.Log("http err request", r.URL) 70 | errLogger.Log(w, r, 500, errors.New("error 500"), "error details 123456") 71 | return 72 | } 73 | w.WriteHeader(404) 74 | })) 75 | defer ts.Close() 76 | 77 | resp, err := http.Get(ts.URL + "/error") 78 | require.Nil(t, err) 79 | defer resp.Body.Close() 80 | 81 | body, err := io.ReadAll(resp.Body) 82 | require.Nil(t, err) 83 | assert.Equal(t, 500, resp.StatusCode) 84 | 85 | assert.Equal(t, `{"error":"error details 123456"}`+"\n", string(body)) 86 | assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("content-type")) 87 | 88 | t.Log(l.buf.String()) 89 | } 90 | -------------------------------------------------------------------------------- /logger/logger.go: -------------------------------------------------------------------------------- 1 | // Package logger implements logging middleware 2 | package logger 3 | 4 | import ( 5 | "bufio" 6 | "bytes" 7 | "fmt" 8 | "io" 9 | "log" 10 | "net" 11 | "net/http" 12 | "net/url" 13 | "regexp" 14 | "strconv" 15 | "strings" 16 | "time" 17 | 18 | "github.com/go-pkgz/rest/realip" 19 | ) 20 | 21 | // Middleware is a logger for rest requests. 22 | type Middleware struct { 23 | prefix string 24 | logBody bool 25 | maxBodySize int 26 | ipFn func(ip string) string 27 | userFn func(r *http.Request) (string, error) 28 | subjFn func(r *http.Request) (string, error) 29 | log Backend 30 | apacheCombined bool 31 | } 32 | 33 | // Backend is logging backend 34 | type Backend interface { 35 | Logf(format string, args ...interface{}) 36 | } 37 | 38 | type logParts struct { 39 | duration time.Duration 40 | rawURL string 41 | method string 42 | remoteIP string 43 | statusCode int 44 | respSize int 45 | host string 46 | 47 | prefix string 48 | user string 49 | body string 50 | } 51 | 52 | type stdBackend struct{} 53 | 54 | func (s stdBackend) Logf(format string, args ...interface{}) { 55 | log.Printf(format, args...) 56 | } 57 | 58 | // Logger is a default logger middleware with "REST" prefix 59 | func Logger(next http.Handler) http.Handler { 60 | l := New(Prefix("REST")) 61 | return l.Handler(next) 62 | 63 | } 64 | 65 | // New makes rest logger with given options 66 | func New(options ...Option) *Middleware { 67 | res := Middleware{ 68 | prefix: "", 69 | maxBodySize: 1024, 70 | log: stdBackend{}, 71 | } 72 | for _, opt := range options { 73 | opt(&res) 74 | } 75 | return &res 76 | } 77 | 78 | // Handler middleware prints http log 79 | func (l *Middleware) Handler(next http.Handler) http.Handler { 80 | 81 | formater := l.formatDefault 82 | if l.apacheCombined { 83 | formater = l.formatApacheCombined 84 | } 85 | 86 | fn := func(w http.ResponseWriter, r *http.Request) { 87 | ww := newCustomResponseWriter(w) 88 | 89 | user := "" 90 | if l.userFn != nil { 91 | if u, err := l.userFn(r); err == nil { 92 | user = u 93 | } 94 | } 95 | 96 | body := l.getBody(r) 97 | t1 := time.Now() 98 | defer func() { 99 | t2 := time.Now() 100 | 101 | u := *r.URL // shallow copy 102 | u.RawQuery = l.sanitizeQuery(u.RawQuery) 103 | rawurl := u.String() 104 | if unescURL, err := url.QueryUnescape(rawurl); err == nil { 105 | rawurl = unescURL 106 | } 107 | 108 | remoteIP, err := realip.Get(r) 109 | if err != nil { 110 | remoteIP = "unknown ip" 111 | } 112 | if l.ipFn != nil { // mask ip with ipFn 113 | remoteIP = l.ipFn(remoteIP) 114 | } 115 | 116 | server := r.URL.Hostname() 117 | if server == "" { 118 | server = strings.Split(r.Host, ":")[0] 119 | } 120 | 121 | p := &logParts{ 122 | duration: t2.Sub(t1), 123 | rawURL: rawurl, 124 | method: r.Method, 125 | host: server, 126 | remoteIP: remoteIP, 127 | statusCode: ww.status, 128 | respSize: ww.size, 129 | prefix: l.prefix, 130 | user: user, 131 | body: body, 132 | } 133 | 134 | l.log.Logf(formater(r, p)) 135 | }() 136 | 137 | next.ServeHTTP(ww, r) 138 | } 139 | return http.HandlerFunc(fn) 140 | } 141 | 142 | func (l *Middleware) formatDefault(r *http.Request, p *logParts) string { 143 | var bld strings.Builder 144 | if l.prefix != "" { 145 | _, _ = bld.WriteString(l.prefix) 146 | _, _ = bld.WriteString(" ") 147 | } 148 | 149 | _, _ = bld.WriteString(fmt.Sprintf("%s - %s - %s - %s - %d (%d) - %v", 150 | p.method, p.rawURL, p.host, p.remoteIP, p.statusCode, p.respSize, p.duration)) 151 | 152 | if p.user != "" { 153 | _, _ = bld.WriteString(" - ") 154 | _, _ = bld.WriteString(p.user) 155 | } 156 | 157 | if l.subjFn != nil { 158 | if subj, err := l.subjFn(r); err == nil { 159 | _, _ = bld.WriteString(" - ") 160 | _, _ = bld.WriteString(subj) 161 | } 162 | } 163 | 164 | if traceID := r.Header.Get("X-Request-ID"); traceID != "" { 165 | _, _ = bld.WriteString(" - ") 166 | _, _ = bld.WriteString(traceID) 167 | } 168 | 169 | if p.body != "" { 170 | _, _ = bld.WriteString(" - ") 171 | _, _ = bld.WriteString(p.body) 172 | } 173 | return bld.String() 174 | } 175 | 176 | // 127.0.0.1 - frank [10/Oct/2000:13:55:36 -0700] "GET /apache_pb.gif HTTP/1.0" 200 2326 "http://www.example.com/start.html" "Mozilla/4.08 [en] (Win98; I ;Nav)" 177 | // nolint gosec 178 | func (l *Middleware) formatApacheCombined(r *http.Request, p *logParts) string { 179 | username := "-" 180 | if p.user != "" { 181 | username = p.user 182 | } 183 | 184 | var bld strings.Builder 185 | bld.WriteString(p.remoteIP) 186 | bld.WriteString(" - ") 187 | bld.WriteString(username) 188 | bld.WriteString(" [") 189 | bld.WriteString(time.Now().Format("02/Jan/2006:15:04:05 -0700")) 190 | bld.WriteString(`] "`) 191 | bld.WriteString(p.method) 192 | bld.WriteString(" ") 193 | bld.WriteString(p.rawURL) 194 | bld.WriteString(`" `) 195 | bld.WriteString(r.Proto) 196 | bld.WriteString(`" `) 197 | bld.WriteString(strconv.Itoa(p.statusCode)) 198 | bld.WriteString(" ") 199 | bld.WriteString(strconv.Itoa(p.respSize)) 200 | 201 | bld.WriteString(` "`) 202 | bld.WriteString(r.Header.Get("Referer")) 203 | bld.WriteString(`" "`) 204 | bld.WriteString(r.Header.Get("User-Agent")) 205 | bld.WriteString(`"`) 206 | return bld.String() 207 | } 208 | 209 | var reMultWhtsp = regexp.MustCompile(`[\s\p{Zs}]{2,}`) 210 | 211 | func (l *Middleware) getBody(r *http.Request) string { 212 | if !l.logBody { 213 | return "" 214 | } 215 | 216 | reader, body, hasMore, err := peek(r.Body, int64(l.maxBodySize)) 217 | if err != nil { 218 | return "" 219 | } 220 | 221 | // "The Server will close the request body. The ServeHTTP Handler does not need to." 222 | // https://golang.org/pkg/net/http/#Request 223 | // So we can use ioutil.NopCloser() to make io.ReadCloser. 224 | // Note that below assignment is not approved by the docs: 225 | // "Except for reading the body, handlers should not modify the provided Request." 226 | // https://golang.org/pkg/net/http/#Handler 227 | r.Body = io.NopCloser(reader) 228 | 229 | if body != "" { 230 | body = strings.ReplaceAll(body, "\n", " ") 231 | body = reMultWhtsp.ReplaceAllString(body, " ") 232 | } 233 | 234 | if hasMore { 235 | body += "..." 236 | } 237 | 238 | return body 239 | } 240 | 241 | // peek the first n bytes as string 242 | func peek(r io.Reader, n int64) (reader io.Reader, s string, hasMore bool, err error) { 243 | if n < 0 { 244 | n = 0 245 | } 246 | 247 | buf := new(bytes.Buffer) 248 | _, err = io.CopyN(buf, r, n+1) 249 | if err == io.EOF { 250 | str := buf.String() 251 | return buf, str, false, nil 252 | } 253 | if err != nil { 254 | return r, "", false, err 255 | } 256 | 257 | // one extra byte is successfully read 258 | s = buf.String() 259 | s = s[:len(s)-1] 260 | 261 | return io.MultiReader(buf, r), s, true, nil 262 | } 263 | 264 | var keysToHide = []string{"password", "passwd", "secret", "credentials", "token"} 265 | 266 | // Hide query values for keysToHide. May change order of query params. 267 | // May escape unescaped query params. 268 | func (l *Middleware) sanitizeQuery(rawQuery string) string { 269 | // note that we skip non-nil error further 270 | query, err := url.ParseQuery(rawQuery) 271 | 272 | isHidden := func(key string) bool { 273 | for _, k := range keysToHide { 274 | if strings.EqualFold(k, key) { 275 | return true 276 | } 277 | } 278 | return false 279 | } 280 | 281 | present := false 282 | for key, values := range query { 283 | if isHidden(key) { 284 | present = true 285 | for i := range values { 286 | values[i] = "********" 287 | } 288 | } 289 | } 290 | 291 | // short circuit 292 | if (err == nil) && !present { 293 | return rawQuery 294 | } 295 | 296 | return query.Encode() 297 | } 298 | 299 | // AnonymizeIP is a function to reset the last part of IPv4 to 0. 300 | // from 123.212.12.78 it will make 123.212.12.0 301 | func AnonymizeIP(ip string) string { 302 | if ip == "" { 303 | return "" 304 | } 305 | 306 | parts := strings.Split(ip, ".") 307 | if len(parts) != 4 { 308 | return ip 309 | } 310 | 311 | return strings.Join(parts[:3], ".") + ".0" 312 | } 313 | 314 | // customResponseWriter is an HTTP response logger that keeps HTTP status code and 315 | // the number of bytes written. 316 | // It implements http.ResponseWriter, http.Flusher and http.Hijacker. 317 | // Note that type assertion from http.ResponseWriter(customResponseWriter) to 318 | // http.Flusher and http.Hijacker is always succeed but underlying http.ResponseWriter 319 | // may not implement them. 320 | type customResponseWriter struct { 321 | http.ResponseWriter 322 | status int 323 | size int 324 | } 325 | 326 | func newCustomResponseWriter(w http.ResponseWriter) *customResponseWriter { 327 | return &customResponseWriter{ 328 | ResponseWriter: w, 329 | status: 200, 330 | } 331 | } 332 | 333 | // WriteHeader implements http.ResponseWriter and saves status 334 | func (c *customResponseWriter) WriteHeader(status int) { 335 | c.status = status 336 | c.ResponseWriter.WriteHeader(status) 337 | } 338 | 339 | // Write implements http.ResponseWriter and tracks number of bytes written 340 | func (c *customResponseWriter) Write(b []byte) (int, error) { 341 | size, err := c.ResponseWriter.Write(b) 342 | c.size += size 343 | return size, err 344 | } 345 | 346 | // Flush implements http.Flusher 347 | func (c *customResponseWriter) Flush() { 348 | if f, ok := c.ResponseWriter.(http.Flusher); ok { 349 | f.Flush() 350 | } 351 | } 352 | 353 | // Hijack implements http.Hijacker 354 | func (c *customResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 355 | if hj, ok := c.ResponseWriter.(http.Hijacker); ok { 356 | return hj.Hijack() 357 | } 358 | return nil, nil, fmt.Errorf("ResponseWriter does not implement the Hijacker interface") //nolint:golint //capital letter is OK here 359 | } 360 | -------------------------------------------------------------------------------- /logger/logger_test.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "net/http/httptest" 10 | "net/url" 11 | "strconv" 12 | "strings" 13 | "testing" 14 | "time" 15 | 16 | "github.com/stretchr/testify/assert" 17 | "github.com/stretchr/testify/require" 18 | ) 19 | 20 | func TestLoggerMinimal(t *testing.T) { 21 | 22 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 23 | _, err := w.Write([]byte("blah blah")) 24 | require.NoError(t, err) 25 | }) 26 | 27 | lb := &mockLgr{} 28 | l := New(Prefix("[INFO] REST"), Log(lb)) 29 | 30 | ts := httptest.NewServer(l.Handler(handler)) 31 | defer ts.Close() 32 | 33 | resp, err := http.Post(ts.URL+"/blah", "", bytes.NewBufferString("1234567890 abcdefg")) 34 | require.Nil(t, err) 35 | defer resp.Body.Close() // nolint 36 | assert.Equal(t, 200, resp.StatusCode) 37 | b, err := io.ReadAll(resp.Body) 38 | assert.NoError(t, err) 39 | assert.Equal(t, "blah blah", string(b)) 40 | 41 | s := lb.buf.String() 42 | t.Log(s) 43 | prefix := "[INFO] REST POST - /blah - 127.0.0.1 - 127.0.0.1 - 200 (9) - " 44 | assert.True(t, strings.HasPrefix(s, prefix), s) 45 | _, err = time.ParseDuration(s[len(prefix):]) 46 | assert.NoError(t, err) 47 | 48 | } 49 | 50 | func TestLoggerMinimalLocalhost(t *testing.T) { 51 | 52 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 53 | _, err := w.Write([]byte("blah blah")) 54 | require.NoError(t, err) 55 | }) 56 | 57 | lb := &mockLgr{} 58 | l := New(Prefix("[INFO] REST"), Log(lb)) 59 | 60 | ts := httptest.NewServer(l.Handler(handler)) 61 | defer ts.Close() 62 | 63 | port := strings.Split(ts.URL, ":")[2] 64 | resp, err := http.Post("http://localhost:"+port+"/blah", "", bytes.NewBufferString("1234567890 abcdefg")) 65 | require.Nil(t, err) 66 | defer resp.Body.Close() // nolint 67 | assert.Equal(t, 200, resp.StatusCode) 68 | b, err := io.ReadAll(resp.Body) 69 | assert.NoError(t, err) 70 | assert.Equal(t, "blah blah", string(b)) 71 | 72 | s := lb.buf.String() 73 | t.Log(s) 74 | prefix := "[INFO] REST POST - /blah - localhost - 127.0.0.1 - 200 (9) - " 75 | assert.True(t, strings.HasPrefix(s, prefix), s) 76 | _, err = time.ParseDuration(s[len(prefix):]) 77 | assert.NoError(t, err) 78 | } 79 | 80 | func TestLogger(t *testing.T) { 81 | 82 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 83 | _, err := w.Write([]byte("blah blah")) 84 | require.NoError(t, err) 85 | }) 86 | 87 | lb := &mockLgr{} 88 | l := New(Prefix("[INFO] REST"), WithBody, 89 | Log(lb), 90 | IPfn(func(ip string) string { 91 | return ip + "!masked" 92 | }), 93 | UserFn(func(*http.Request) (string, error) { 94 | return "user", nil 95 | }), 96 | SubjFn(func(*http.Request) (string, error) { 97 | return "subj", nil 98 | }), 99 | ) 100 | 101 | ts := httptest.NewServer(l.Handler(handler)) 102 | defer ts.Close() 103 | 104 | resp, err := http.Post(ts.URL+"/blah?password=secret&key=val&var=123", "", bytes.NewBufferString("1234567890 abcdefg")) 105 | require.Nil(t, err) 106 | assert.Equal(t, 200, resp.StatusCode) 107 | defer resp.Body.Close() // nolint 108 | b, err := io.ReadAll(resp.Body) 109 | assert.NoError(t, err) 110 | assert.Equal(t, "blah blah", string(b)) 111 | 112 | s := lb.buf.String() 113 | t.Log(s) 114 | assert.True(t, strings.Contains(s, "[INFO] REST POST - /blah?key=val&password=********&var=123 - 127.0.0.1 - 127.0.0.1!masked - 200 (9) -"), s) 115 | assert.True(t, strings.HasSuffix(s, "- user - subj - 1234567890 abcdefg")) 116 | } 117 | 118 | func TestLoggerIP(t *testing.T) { 119 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 120 | _, err := w.Write([]byte("blah blah")) 121 | require.NoError(t, err) 122 | }) 123 | 124 | lb := &mockLgr{} 125 | l := New(Log(lb), Prefix("[INFO] REST"), IPfn(func(ip string) string { return ip + "!masked" })) 126 | 127 | ts := httptest.NewServer(l.Handler(handler)) 128 | defer ts.Close() 129 | 130 | clint := http.Client{} 131 | 132 | req, err := http.NewRequest("GET", ts.URL+"/blah", http.NoBody) 133 | require.NoError(t, err) 134 | resp, err := clint.Do(req) 135 | require.Nil(t, err) 136 | assert.Equal(t, 200, resp.StatusCode) 137 | s := lb.buf.String() 138 | t.Log(s) 139 | assert.True(t, strings.Contains(s, "- 127.0.0.1!masked -")) 140 | 141 | lb.buf.Reset() 142 | req, err = http.NewRequest("GET", ts.URL+"/blah", http.NoBody) 143 | require.NoError(t, err) 144 | req.Header.Set("X-Forwarded-For", "1.2.3.4") 145 | resp, err = clint.Do(req) 146 | require.Nil(t, err) 147 | assert.Equal(t, 200, resp.StatusCode) 148 | s = lb.buf.String() 149 | t.Log(s) 150 | assert.True(t, strings.Contains(s, "- 1.2.3.4!masked -")) 151 | } 152 | 153 | func TestLoggerIPAnon(t *testing.T) { 154 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 155 | _, err := w.Write([]byte("blah blah")) 156 | require.NoError(t, err) 157 | }) 158 | 159 | lb := &mockLgr{} 160 | l := New(Log(lb), Prefix("[INFO] REST"), IPfn(AnonymizeIP)) 161 | 162 | ts := httptest.NewServer(l.Handler(handler)) 163 | defer ts.Close() 164 | 165 | clint := http.Client{} 166 | 167 | req, err := http.NewRequest("GET", ts.URL+"/blah", http.NoBody) 168 | require.NoError(t, err) 169 | resp, err := clint.Do(req) 170 | require.Nil(t, err) 171 | assert.Equal(t, 200, resp.StatusCode) 172 | s := lb.buf.String() 173 | t.Log(s) 174 | assert.True(t, strings.Contains(s, "- 127.0.0.0 -"), s) 175 | 176 | lb.buf.Reset() 177 | req, err = http.NewRequest("GET", ts.URL+"/blah", http.NoBody) 178 | require.NoError(t, err) 179 | req.Header.Set("X-Forwarded-For", "1.2.3.4") 180 | resp, err = clint.Do(req) 181 | require.Nil(t, err) 182 | assert.Equal(t, 200, resp.StatusCode) 183 | s = lb.buf.String() 184 | t.Log(s) 185 | assert.True(t, strings.Contains(s, "- 1.2.3.0 -"), s) 186 | } 187 | 188 | func TestLoggerTraceID(t *testing.T) { 189 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 190 | _, err := w.Write([]byte("blah blah")) 191 | require.NoError(t, err) 192 | }) 193 | 194 | lb := &mockLgr{} 195 | l := New(Prefix("[INFO] REST"), WithBody, 196 | Log(lb), 197 | IPfn(func(ip string) string { 198 | return ip + "!masked" 199 | }), 200 | UserFn(func(*http.Request) (string, error) { 201 | return "user", nil 202 | }), 203 | SubjFn(func(*http.Request) (string, error) { 204 | return "subj", nil 205 | }), 206 | ) 207 | 208 | ts := httptest.NewServer(l.Handler(handler)) 209 | defer ts.Close() 210 | 211 | clint := http.Client{} 212 | req, err := http.NewRequest("GET", ts.URL+"/blah", http.NoBody) 213 | require.NoError(t, err) 214 | req.Header.Set("X-Request-ID", "0000-reqid") 215 | resp, err := clint.Do(req) 216 | require.Nil(t, err) 217 | assert.Equal(t, 200, resp.StatusCode) 218 | 219 | s := lb.buf.String() 220 | t.Log(s) 221 | assert.True(t, strings.HasSuffix(s, "- user - subj - 0000-reqid")) 222 | 223 | req, err = http.NewRequest("POST", ts.URL+"/blah", bytes.NewBufferString("1234567890 abcdefg")) 224 | require.NoError(t, err) 225 | req.Header.Set("X-Request-ID", "11111-reqid") 226 | resp, err = clint.Do(req) 227 | require.Nil(t, err) 228 | assert.Equal(t, 200, resp.StatusCode) 229 | s = lb.buf.String() 230 | t.Log(s) 231 | assert.True(t, strings.HasSuffix(s, "- user - subj - 11111-reqid - 1234567890 abcdefg")) 232 | } 233 | 234 | func TestLoggerMaxBodySize(t *testing.T) { 235 | 236 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 237 | body, err := io.ReadAll(r.Body) 238 | assert.NoError(t, err) 239 | assert.Equal(t, "1234567890 abcdefg", string(body)) 240 | _, err = w.Write([]byte("blah blah")) 241 | require.NoError(t, err) 242 | }) 243 | 244 | lb := &mockLgr{} 245 | l := New(Prefix("[INFO] REST"), WithBody, Log(lb), MaxBodySize(10)) 246 | 247 | ts := httptest.NewServer(l.Handler(handler)) 248 | defer ts.Close() 249 | 250 | resp, err := http.Post(ts.URL+"/blah", "", bytes.NewBufferString("1234567890 abcdefg")) 251 | require.Nil(t, err) 252 | assert.Equal(t, 200, resp.StatusCode) 253 | defer resp.Body.Close() 254 | b, err := io.ReadAll(resp.Body) 255 | assert.NoError(t, err) 256 | assert.Equal(t, "blah blah", string(b)) 257 | 258 | s := lb.buf.String() 259 | t.Log(s) 260 | assert.True(t, strings.Contains(s, "[INFO] REST POST - /blah - 127.0.0.1 - 127.0.0.1 - 200 (9) -"), s) 261 | assert.True(t, strings.Contains(s, "1234567890..."), s) 262 | } 263 | 264 | func TestLoggerDefault(t *testing.T) { 265 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 266 | _, err := w.Write([]byte("blah blah")) 267 | require.NoError(t, err) 268 | }) 269 | 270 | ts := httptest.NewServer(Logger(handler)) 271 | defer ts.Close() 272 | 273 | resp, err := http.Get(ts.URL + "/blah") 274 | require.Nil(t, err) 275 | assert.Equal(t, 200, resp.StatusCode) 276 | defer resp.Body.Close() 277 | b, err := io.ReadAll(resp.Body) 278 | assert.NoError(t, err) 279 | assert.Equal(t, "blah blah", string(b)) 280 | } 281 | 282 | type mockLgr struct { 283 | buf bytes.Buffer 284 | } 285 | 286 | func (m *mockLgr) Logf(format string, args ...interface{}) { 287 | _, _ = m.buf.WriteString(fmt.Sprintf(format, args...)) 288 | } 289 | 290 | func TestGetBody(t *testing.T) { 291 | req, err := http.NewRequest("GET", "http://example.com/request", strings.NewReader("body")) 292 | require.Nil(t, err) 293 | 294 | l := New() 295 | body := l.getBody(req) 296 | assert.Equal(t, "", body) 297 | 298 | l = New(WithBody) 299 | body = l.getBody(req) 300 | assert.Equal(t, "body", body) 301 | } 302 | 303 | func TestPeek(t *testing.T) { 304 | cases := []struct { 305 | body string 306 | n int64 307 | excerpt string 308 | hasMore bool 309 | }{ 310 | {"", -1, "", false}, 311 | {"", 0, "", false}, 312 | {"", 1024, "", false}, 313 | {"123456", -1, "", true}, 314 | {"123456", 0, "", true}, 315 | {"123456", 4, "1234", true}, 316 | {"123456", 5, "12345", true}, 317 | {"123456", 6, "123456", false}, 318 | {"123456", 7, "123456", false}, 319 | } 320 | 321 | for _, c := range cases { 322 | r, excerpt, hasMore, err := peek(strings.NewReader(c.body), c.n) 323 | if !assert.NoError(t, err) { 324 | continue 325 | } 326 | body, err := io.ReadAll(r) 327 | if !assert.NoError(t, err) { 328 | continue 329 | } 330 | assert.Equal(t, c.body, string(body)) 331 | assert.Equal(t, c.excerpt, excerpt) 332 | assert.Equal(t, c.hasMore, hasMore) 333 | } 334 | 335 | _, _, _, err := peek(errReader{}, 1024) 336 | assert.Error(t, err) 337 | } 338 | 339 | type errReader struct{} 340 | 341 | func (errReader) Read(_ []byte) (n int, err error) { 342 | return 0, errors.New("test error") 343 | } 344 | 345 | func TestSanitizeReqURL(t *testing.T) { 346 | tbl := []struct { 347 | in string 348 | out string 349 | }{ 350 | {"", ""}, 351 | {"xyz=123", "xyz=123"}, 352 | {"foo=bar&foo=buzz", "foo=bar&foo=buzz"}, 353 | {"foo=%2&password=1234", "password=********"}, 354 | {"xyz=123&seCret=asdfghjk", "seCret=********&xyz=123"}, 355 | {"xyz=123&secret=asdfghjk&key=val", "key=val&secret=********&xyz=123"}, 356 | {"xyz=123&secret=asdfghjk&key=val&password=1234", "key=val&password=********&secret=********&xyz=123"}, 357 | {"xyz=тест&passwoRD=1234", "passwoRD=********&xyz=тест"}, 358 | {"xyz=тест&password=1234&bar=buzz", "bar=buzz&password=********&xyz=тест"}, 359 | {"xyz=тест&password=пароль&bar=buzz", "bar=buzz&password=********&xyz=тест"}, 360 | {"xyz=тест&password=пароль&bar=buzz&q=?sss?ccc", "bar=buzz&password=********&q=?sss?ccc&xyz=тест"}, 361 | } 362 | unesc := func(s string) string { 363 | s, _ = url.QueryUnescape(s) 364 | return s 365 | } 366 | var l *Middleware 367 | for i, tt := range tbl { 368 | i := i 369 | tt := tt 370 | t.Run(tt.in, func(t *testing.T) { 371 | assert.Equal(t, tt.out, unesc(l.sanitizeQuery(tt.in)), "check #%d, %s", i, tt.in) 372 | }) 373 | } 374 | } 375 | 376 | func TestLoggerApacheCombined(t *testing.T) { 377 | 378 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 379 | _, err := w.Write([]byte("blah blah")) 380 | require.NoError(t, err) 381 | }) 382 | 383 | lb := &mockLgr{} 384 | l := New(Log(lb), ApacheCombined, 385 | IPfn(func(ip string) string { 386 | return ip + "!masked" 387 | }), 388 | UserFn(func(*http.Request) (string, error) { 389 | return "user", nil 390 | }), 391 | ) 392 | 393 | ts := httptest.NewServer(l.Handler(handler)) 394 | defer ts.Close() 395 | 396 | resp, err := http.Post(ts.URL+"/blah?password=secret&key=val&var=123", "", bytes.NewBufferString("1234567890 abcdefg")) 397 | require.Nil(t, err) 398 | assert.Equal(t, 200, resp.StatusCode) 399 | defer resp.Body.Close() 400 | b, err := io.ReadAll(resp.Body) 401 | assert.NoError(t, err) 402 | assert.Equal(t, "blah blah", string(b)) 403 | 404 | s := lb.buf.String() 405 | t.Log(s) 406 | assert.True(t, strings.HasPrefix(s, "127.0.0.1!masked - user [")) 407 | assert.True(t, strings.HasSuffix(s, ` "POST /blah?key=val&password=********&var=123" HTTP/1.1" 200 9 "" "Go-http-client/1.1"`), s) 408 | } 409 | 410 | func TestAnonymizeIP(t *testing.T) { 411 | tbl := []struct { 412 | inp, out string 413 | }{ 414 | {"12.34.56.78", "12.34.56.0"}, 415 | {"", ""}, 416 | {"", ""}, 417 | {"12.34.56", "12.34.56"}, 418 | } 419 | 420 | for i, tt := range tbl { 421 | t.Run(strconv.Itoa(i), func(t *testing.T) { 422 | assert.Equal(t, tt.out, AnonymizeIP(tt.inp)) 423 | }) 424 | } 425 | } 426 | -------------------------------------------------------------------------------- /logger/options.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | // Option func type 8 | type Option func(l *Middleware) 9 | 10 | // WithBody triggers request body logging. Body size is limited (default 1k) 11 | func WithBody(l *Middleware) { 12 | l.logBody = true 13 | } 14 | 15 | // MaxBodySize sets size of the logged part of the request body. 16 | func MaxBodySize(maximum int) Option { 17 | return func(l *Middleware) { 18 | if maximum >= 0 { 19 | l.maxBodySize = maximum 20 | } 21 | } 22 | } 23 | 24 | // Prefix sets log line prefix. 25 | func Prefix(prefix string) Option { 26 | return func(l *Middleware) { 27 | l.prefix = prefix 28 | } 29 | } 30 | 31 | // IPfn sets IP masking function. If ipFn is nil then IP address will be logged as is. 32 | func IPfn(ipFn func(ip string) string) Option { 33 | return func(l *Middleware) { 34 | l.ipFn = ipFn 35 | } 36 | } 37 | 38 | // UserFn triggers user name logging if userFn is not nil. 39 | func UserFn(userFn func(r *http.Request) (string, error)) Option { 40 | return func(l *Middleware) { 41 | l.userFn = userFn 42 | } 43 | } 44 | 45 | // SubjFn triggers subject logging if subjFn is not nil. 46 | func SubjFn(subjFn func(r *http.Request) (string, error)) Option { 47 | return func(l *Middleware) { 48 | l.subjFn = subjFn 49 | } 50 | } 51 | 52 | // ApacheCombined sets format to Apache Combined Log. 53 | // See http://httpd.apache.org/docs/2.2/logs.html#combined 54 | func ApacheCombined(l *Middleware) { 55 | l.apacheCombined = true 56 | } 57 | 58 | // Log sets logging backend. 59 | func Log(log Backend) Option { 60 | return func(l *Middleware) { 61 | l.log = log 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /metrics.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "expvar" 5 | "fmt" 6 | "net/http" 7 | "strings" 8 | ) 9 | 10 | // Metrics responds to GET /metrics with list of expvar 11 | func Metrics(onlyIps ...string) func(http.Handler) http.Handler { 12 | return func(h http.Handler) http.Handler { 13 | fn := func(w http.ResponseWriter, r *http.Request) { 14 | if r.Method == "GET" && strings.HasSuffix(strings.ToLower(r.URL.Path), "/metrics") { 15 | if matched, ip, err := matchSourceIP(r, onlyIps); !matched || err != nil { 16 | w.WriteHeader(http.StatusForbidden) 17 | RenderJSON(w, JSON{"error": fmt.Sprintf("ip %s rejected", ip)}) 18 | return 19 | } 20 | expvar.Handler().ServeHTTP(w, r) 21 | return 22 | } 23 | h.ServeHTTP(w, r) 24 | } 25 | 26 | return http.HandlerFunc(fn) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /metrics_test.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "net/http/httptest" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestMetrics(t *testing.T) { 15 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 16 | _, err := w.Write([]byte("blah blah")) 17 | require.NoError(t, err) 18 | }) 19 | ts := httptest.NewServer(Metrics("127.0.0.1")(handler)) 20 | defer ts.Close() 21 | 22 | resp, err := http.Get(ts.URL + "/metrics") 23 | require.Nil(t, err) 24 | defer resp.Body.Close() 25 | assert.Equal(t, 200, resp.StatusCode) 26 | 27 | b, err := io.ReadAll(resp.Body) 28 | assert.NoError(t, err) 29 | assert.True(t, strings.Contains(string(b), "cmdline")) 30 | assert.True(t, strings.Contains(string(b), "memstats")) 31 | } 32 | 33 | func TestMetricsRejected(t *testing.T) { 34 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 35 | _, err := w.Write([]byte("blah blah")) 36 | require.NoError(t, err) 37 | }) 38 | ts := httptest.NewServer(Metrics("1.1.1.1")(handler)) 39 | defer ts.Close() 40 | 41 | resp, err := http.Get(ts.URL + "/metrics") 42 | require.Nil(t, err) 43 | defer resp.Body.Close() 44 | assert.Equal(t, 403, resp.StatusCode) 45 | } 46 | -------------------------------------------------------------------------------- /middleware.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "os" 7 | "runtime/debug" 8 | "strings" 9 | 10 | "github.com/go-pkgz/rest/logger" 11 | "github.com/go-pkgz/rest/realip" 12 | ) 13 | 14 | // Wrap converts a list of middlewares to nested calls (in reverse order) 15 | func Wrap(handler http.Handler, mws ...func(http.Handler) http.Handler) http.Handler { 16 | for i := len(mws) - 1; i >= 0; i-- { 17 | handler = mws[i](handler) 18 | } 19 | return handler 20 | } 21 | 22 | 23 | // AppInfo adds custom app-info to the response header 24 | func AppInfo(app, author, version string) func(http.Handler) http.Handler { 25 | f := func(h http.Handler) http.Handler { 26 | fn := func(w http.ResponseWriter, r *http.Request) { 27 | w.Header().Set("Author", author) 28 | w.Header().Set("App-Name", app) 29 | w.Header().Set("App-Version", version) 30 | if mhost := os.Getenv("MHOST"); mhost != "" { 31 | w.Header().Set("Host", mhost) 32 | } 33 | h.ServeHTTP(w, r) 34 | } 35 | return http.HandlerFunc(fn) 36 | } 37 | return f 38 | } 39 | 40 | // Ping middleware response with pong to /ping. Stops chain if ping request detected 41 | func Ping(next http.Handler) http.Handler { 42 | fn := func(w http.ResponseWriter, r *http.Request) { 43 | 44 | if r.Method == "GET" && strings.HasSuffix(strings.ToLower(r.URL.Path), "/ping") { 45 | w.Header().Set("Content-Type", "text/plain") 46 | w.WriteHeader(http.StatusOK) 47 | _, _ = w.Write([]byte("pong")) 48 | return 49 | } 50 | next.ServeHTTP(w, r) 51 | } 52 | return http.HandlerFunc(fn) 53 | } 54 | 55 | // Health middleware response with health info and status (200 if healthy). Stops chain if health request detected 56 | // passed checkers implements custom health checks and returns error if health check failed. The check has to return name 57 | // regardless to the error state. 58 | // For production usage this middleware should be used with throttler and, optionally, with BasicAuth middlewares 59 | func Health(path string, checkers ...func(ctx context.Context) (name string, err error)) func(http.Handler) http.Handler { 60 | 61 | type hr struct { 62 | Name string `json:"name"` 63 | Status string `json:"status"` 64 | Error string `json:"error,omitempty"` 65 | } 66 | 67 | return func(h http.Handler) http.Handler { 68 | fn := func(w http.ResponseWriter, r *http.Request) { 69 | if r.Method != "GET" || !strings.EqualFold(r.URL.Path, path) { 70 | h.ServeHTTP(w, r) // not the health check request, continue the chain 71 | return 72 | } 73 | resp := []hr{} 74 | var anyError bool 75 | for _, check := range checkers { 76 | name, err := check(r.Context()) 77 | hh := hr{Name: name, Status: "ok"} 78 | if err != nil { 79 | hh.Status = "failed" 80 | hh.Error = err.Error() 81 | anyError = true 82 | } 83 | resp = append(resp, hh) 84 | } 85 | if anyError { 86 | w.WriteHeader(http.StatusServiceUnavailable) 87 | } else { 88 | w.WriteHeader(http.StatusOK) 89 | } 90 | RenderJSON(w, resp) 91 | } 92 | return http.HandlerFunc(fn) 93 | } 94 | } 95 | 96 | // Recoverer is a middleware that recovers from panics, logs the panic and returns a HTTP 500 status if possible. 97 | func Recoverer(l logger.Backend) func(http.Handler) http.Handler { 98 | return func(h http.Handler) http.Handler { 99 | fn := func(w http.ResponseWriter, r *http.Request) { 100 | defer func() { 101 | if rvr := recover(); rvr != nil { 102 | l.Logf("request panic for %s from %s, %v", r.URL.String(), r.RemoteAddr, rvr) 103 | if rvr != http.ErrAbortHandler { 104 | l.Logf(string(debug.Stack())) 105 | } 106 | http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 107 | } 108 | }() 109 | h.ServeHTTP(w, r) 110 | } 111 | return http.HandlerFunc(fn) 112 | } 113 | } 114 | 115 | // Headers middleware adds headers to request 116 | func Headers(headers ...string) func(http.Handler) http.Handler { 117 | 118 | return func(h http.Handler) http.Handler { 119 | 120 | fn := func(w http.ResponseWriter, r *http.Request) { 121 | for _, h := range headers { 122 | elems := strings.Split(h, ":") 123 | if len(elems) != 2 { 124 | continue 125 | } 126 | r.Header.Set(strings.TrimSpace(elems[0]), strings.TrimSpace(elems[1])) 127 | } 128 | h.ServeHTTP(w, r) 129 | } 130 | return http.HandlerFunc(fn) 131 | } 132 | } 133 | 134 | // Maybe middleware will allow you to change the flow of the middleware stack execution depending on return 135 | // value of maybeFn(request). This is useful for example if you'd like to skip a middleware handler if 136 | // a request does not satisfy the maybeFn logic. 137 | // borrowed from https://github.com/go-chi/chi/blob/master/middleware/maybe.go 138 | func Maybe(mw func(http.Handler) http.Handler, maybeFn func(r *http.Request) bool) func(http.Handler) http.Handler { 139 | return func(next http.Handler) http.Handler { 140 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 141 | if maybeFn(r) { 142 | mw(next).ServeHTTP(w, r) 143 | } else { 144 | next.ServeHTTP(w, r) 145 | } 146 | }) 147 | } 148 | } 149 | 150 | // RealIP is a middleware that sets a http.Request's RemoteAddr to the results 151 | // of parsing either the X-Forwarded-For or X-Real-IP headers. 152 | // 153 | // This middleware should only be used if user can trust the headers sent with request. 154 | // If reverse proxies are configured to pass along arbitrary header values from the client, 155 | // or if this middleware used without a reverse proxy, malicious clients could set anything 156 | // as X-Forwarded-For header and attack the server in various ways. 157 | func RealIP(h http.Handler) http.Handler { 158 | fn := func(w http.ResponseWriter, r *http.Request) { 159 | if rip, err := realip.Get(r); err == nil { 160 | r.RemoteAddr = rip 161 | } 162 | h.ServeHTTP(w, r) 163 | } 164 | 165 | return http.HandlerFunc(fn) 166 | } 167 | 168 | // Reject is a middleware that conditionally rejects requests with a given status code and message. 169 | // user-defined condition function rejectFn is used to determine if the request should be rejected. 170 | func Reject(errCode int, errMsg string, rejectFn func(r *http.Request) bool) func(h http.Handler) http.Handler { 171 | return func(h http.Handler) http.Handler { 172 | fn := func(w http.ResponseWriter, r *http.Request) { 173 | if rejectFn(r) { 174 | http.Error(w, errMsg, errCode) 175 | return 176 | } 177 | h.ServeHTTP(w, r) 178 | } 179 | return http.HandlerFunc(fn) 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /middleware_test.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "net/http/httptest" 11 | "os" 12 | "sync/atomic" 13 | "testing" 14 | "time" 15 | 16 | "github.com/stretchr/testify/assert" 17 | "github.com/stretchr/testify/require" 18 | 19 | "github.com/go-pkgz/rest/realip" 20 | ) 21 | 22 | func TestMiddleware_AppInfo(t *testing.T) { 23 | err := os.Setenv("MHOST", "host1") 24 | assert.NoError(t, err) 25 | 26 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 27 | _, err = w.Write([]byte("blah blah")) 28 | require.NoError(t, err) 29 | }) 30 | ts := httptest.NewServer(AppInfo("app-name", "Umputun", "12345")(handler)) 31 | defer ts.Close() 32 | 33 | resp, err := http.Get(ts.URL + "/blah") 34 | require.Nil(t, err) 35 | assert.Equal(t, 200, resp.StatusCode) 36 | defer resp.Body.Close() 37 | 38 | b, err := io.ReadAll(resp.Body) 39 | assert.NoError(t, err) 40 | 41 | assert.Equal(t, "blah blah", string(b)) 42 | assert.Equal(t, "app-name", resp.Header.Get("App-Name")) 43 | assert.Equal(t, "12345", resp.Header.Get("App-Version")) 44 | assert.Equal(t, "Umputun", resp.Header.Get("Author")) 45 | assert.Equal(t, "host1", resp.Header.Get("Host")) 46 | } 47 | 48 | func TestMiddleware_Ping(t *testing.T) { 49 | 50 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 51 | _, err := w.Write([]byte("blah blah")) 52 | require.NoError(t, err) 53 | }) 54 | ts := httptest.NewServer(Ping(handler)) 55 | defer ts.Close() 56 | 57 | resp, err := http.Get(ts.URL + "/ping") 58 | require.Nil(t, err) 59 | assert.Equal(t, 200, resp.StatusCode) 60 | defer resp.Body.Close() 61 | b, err := io.ReadAll(resp.Body) 62 | assert.NoError(t, err) 63 | assert.Equal(t, "pong", string(b)) 64 | 65 | resp, err = http.Get(ts.URL + "/blah") 66 | require.Nil(t, err) 67 | assert.Equal(t, 200, resp.StatusCode) 68 | defer resp.Body.Close() 69 | b, err = io.ReadAll(resp.Body) 70 | assert.NoError(t, err) 71 | assert.Equal(t, "blah blah", string(b)) 72 | } 73 | 74 | func TestMiddleware_Recoverer(t *testing.T) { 75 | 76 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 77 | if r.URL.Path == "/failed" { 78 | panic("oh my!") 79 | } 80 | _, err := w.Write([]byte("blah blah")) 81 | require.NoError(t, err) 82 | }) 83 | l := &mockLgr{} 84 | ts := httptest.NewServer(Recoverer(l)(handler)) 85 | defer ts.Close() 86 | 87 | resp, err := http.Get(ts.URL + "/failed") 88 | s := l.buf.String() 89 | t.Log("->> ", s) 90 | require.NoError(t, err) 91 | assert.Equal(t, 500, resp.StatusCode) 92 | 93 | assert.Contains(t, s, "request panic for /failed from") 94 | assert.Contains(t, s, "oh my!") 95 | assert.Contains(t, s, "goroutine") 96 | assert.Contains(t, s, "github.com/go-pkgz/rest.TestMiddleware_Recoverer") 97 | 98 | resp, err = http.Get(ts.URL + "/blah") 99 | require.NoError(t, err) 100 | assert.Equal(t, 200, resp.StatusCode) 101 | defer resp.Body.Close() 102 | b, err := io.ReadAll(resp.Body) 103 | assert.NoError(t, err) 104 | assert.Equal(t, "blah blah", string(b)) 105 | } 106 | 107 | func TestWrap(t *testing.T) { 108 | handler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { 109 | t.Logf("%s", r.URL.String()) 110 | }) 111 | 112 | mw1 := func(h http.Handler) http.Handler { 113 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 114 | w.Header().Set("X-MW1", "1") 115 | h.ServeHTTP(w, r) 116 | }) 117 | } 118 | mw2 := func(h http.Handler) http.Handler { 119 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 120 | w.Header().Set("X-MW2", "2") 121 | h.ServeHTTP(w, r) 122 | }) 123 | } 124 | 125 | t.Run("no middleware", func(t *testing.T) { 126 | h := Wrap(handler) 127 | ts := httptest.NewServer(h) 128 | defer ts.Close() 129 | 130 | resp, err := http.Get(ts.URL + "/something") 131 | require.NoError(t, err) 132 | assert.Equal(t, 200, resp.StatusCode) 133 | assert.Equal(t, "", resp.Header.Get("X-MW1")) 134 | assert.Equal(t, "", resp.Header.Get("X-MW2")) 135 | }) 136 | 137 | t.Run("with middleware", func(t *testing.T) { 138 | h := Wrap(handler, mw1, mw2) 139 | ts := httptest.NewServer(h) 140 | defer ts.Close() 141 | 142 | resp, err := http.Get(ts.URL + "/something") 143 | require.NoError(t, err) 144 | assert.Equal(t, 200, resp.StatusCode) 145 | assert.Equal(t, "1", resp.Header.Get("X-MW1")) 146 | assert.Equal(t, "2", resp.Header.Get("X-MW2")) 147 | }) 148 | } 149 | 150 | func TestHeaders(t *testing.T) { 151 | req := httptest.NewRequest("GET", "/something", http.NoBody) 152 | w := httptest.NewRecorder() 153 | 154 | h := Headers("h1:v1", "bad", "h2:v2")(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) 155 | h.ServeHTTP(w, req) 156 | resp := w.Result() 157 | assert.Equal(t, http.StatusOK, resp.StatusCode) 158 | t.Logf("%+v", req.Header) 159 | assert.Equal(t, "v1", req.Header.Get("h1")) 160 | assert.Equal(t, "v2", req.Header.Get("h2")) 161 | assert.Equal(t, 2, len(req.Header)) 162 | } 163 | 164 | func TestMaybe(t *testing.T) { 165 | var count int32 166 | h := Maybe(Headers("h1:v1", "bad", "h2:v2"), func(*http.Request) bool { 167 | return atomic.AddInt32(&count, 1) == 1 168 | })(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) 169 | 170 | { 171 | req := httptest.NewRequest("GET", "/something", http.NoBody) 172 | w := httptest.NewRecorder() 173 | 174 | h.ServeHTTP(w, req) 175 | resp := w.Result() 176 | assert.Equal(t, http.StatusOK, resp.StatusCode) 177 | t.Logf("%+v", req.Header) 178 | assert.Equal(t, "v1", req.Header.Get("h1")) 179 | assert.Equal(t, "v2", req.Header.Get("h2")) 180 | assert.Equal(t, 2, len(req.Header)) 181 | } 182 | { 183 | req := httptest.NewRequest("GET", "/something", http.NoBody) 184 | w := httptest.NewRecorder() 185 | 186 | h.ServeHTTP(w, req) 187 | resp := w.Result() 188 | assert.Equal(t, http.StatusOK, resp.StatusCode) 189 | t.Logf("%+v", req.Header) 190 | assert.Equal(t, "", req.Header.Get("h1")) 191 | assert.Equal(t, "", req.Header.Get("h2")) 192 | assert.Equal(t, 0, len(req.Header)) 193 | } 194 | } 195 | 196 | func TestRealIP(t *testing.T) { 197 | 198 | handler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { 199 | t.Logf("%v", r) 200 | require.Equal(t, "1.2.3.4", r.RemoteAddr) 201 | adr, err := realip.Get(r) 202 | require.NoError(t, err) 203 | assert.Equal(t, "1.2.3.4", adr) 204 | }) 205 | 206 | ts := httptest.NewServer(RealIP(handler)) 207 | 208 | req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody) 209 | require.NoError(t, err) 210 | client := http.Client{Timeout: time.Second} 211 | req.Header.Add("X-Real-IP", "1.2.3.4") 212 | _, err = client.Do(req) 213 | require.NoError(t, err) 214 | } 215 | 216 | func TestHealthPassed(t *testing.T) { 217 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 218 | _, err := w.Write([]byte("blah blah")) 219 | require.NoError(t, err) 220 | }) 221 | 222 | check1 := func(context.Context) (string, error) { 223 | return "check1", nil 224 | } 225 | check2 := func(context.Context) (string, error) { 226 | return "check2", nil 227 | } 228 | 229 | ts := httptest.NewServer(Health("/health", check1, check2)(handler)) 230 | defer ts.Close() 231 | 232 | resp, err := http.Get(ts.URL + "/health") 233 | require.Nil(t, err) 234 | assert.Equal(t, 200, resp.StatusCode) 235 | defer resp.Body.Close() 236 | b, err := io.ReadAll(resp.Body) 237 | assert.NoError(t, err) 238 | assert.Equal(t, `[{"name":"check1","status":"ok"},{"name":"check2","status":"ok"}]`+"\n", string(b)) 239 | 240 | resp, err = http.Get(ts.URL + "/blah") 241 | require.Nil(t, err) 242 | assert.Equal(t, http.StatusOK, resp.StatusCode) 243 | defer resp.Body.Close() 244 | b, err = io.ReadAll(resp.Body) 245 | assert.NoError(t, err) 246 | assert.Equal(t, "blah blah", string(b)) 247 | } 248 | 249 | func TestHealthFailed(t *testing.T) { 250 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 251 | _, err := w.Write([]byte("blah blah")) 252 | require.NoError(t, err) 253 | }) 254 | 255 | check1 := func(context.Context) (string, error) { 256 | return "check1", nil 257 | } 258 | check2 := func(context.Context) (string, error) { 259 | return "check2", errors.New("some error") 260 | } 261 | 262 | ts := httptest.NewServer(Health("/health", check1, check2)(handler)) 263 | defer ts.Close() 264 | 265 | resp, err := http.Get(ts.URL + "/health") 266 | require.Nil(t, err) 267 | assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) 268 | defer resp.Body.Close() 269 | b, err := io.ReadAll(resp.Body) 270 | assert.NoError(t, err) 271 | assert.Equal(t, `[{"name":"check1","status":"ok"},{"name":"check2","status":"failed","error":"some error"}]`+"\n", string(b)) 272 | } 273 | 274 | func TestReject(t *testing.T) { 275 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 276 | _, err := w.Write([]byte("blah blah")) 277 | require.NoError(t, err) 278 | }) 279 | 280 | rej := Reject(http.StatusForbidden, "no no", func(r *http.Request) bool { 281 | return r.Header.Get("h1") == "v1" 282 | }) 283 | 284 | ts := httptest.NewServer(rej(handler)) 285 | defer ts.Close() 286 | 287 | client := http.Client{Timeout: time.Second} 288 | { // not rejected 289 | req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody) 290 | require.NoError(t, err) 291 | resp, err := client.Do(req) 292 | require.NoError(t, err) 293 | defer resp.Body.Close() 294 | assert.Equal(t, http.StatusOK, resp.StatusCode) 295 | b, err := io.ReadAll(resp.Body) 296 | assert.NoError(t, err) 297 | assert.Equal(t, `blah blah`, string(b)) 298 | } 299 | { // rejected 300 | req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody) 301 | req.Header.Add("h1", "v1") 302 | require.NoError(t, err) 303 | resp, err := client.Do(req) 304 | require.NoError(t, err) 305 | defer resp.Body.Close() 306 | assert.Equal(t, http.StatusForbidden, resp.StatusCode) 307 | b, err := io.ReadAll(resp.Body) 308 | assert.NoError(t, err) 309 | assert.Equal(t, "no no\n", string(b)) 310 | } 311 | } 312 | 313 | type mockLgr struct { 314 | buf bytes.Buffer 315 | } 316 | 317 | func (m *mockLgr) Logf(format string, args ...interface{}) { 318 | _, _ = m.buf.WriteString(fmt.Sprintf(format+"\n", args...)) 319 | } 320 | -------------------------------------------------------------------------------- /nocache.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | ) 7 | 8 | // borrowed from https://github.com/go-chi/chi/blob/master/middleware/nocache.go 9 | 10 | var epoch = time.Unix(0, 0).UTC().Format(time.RFC1123) 11 | 12 | var noCacheHeaders = map[string]string{ 13 | "Expires": epoch, 14 | "Cache-Control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0", 15 | "Pragma": "no-cache", 16 | "X-Accel-Expires": "0", 17 | } 18 | 19 | var etagHeaders = []string{ 20 | "ETag", 21 | "If-Modified-Since", 22 | "If-Match", 23 | "If-None-Match", 24 | "If-Range", 25 | "If-Unmodified-Since", 26 | } 27 | 28 | // NoCache is a simple piece of middleware that sets a number of HTTP headers to prevent 29 | // a router (or subrouter) from being cached by an upstream proxy and/or client. 30 | // 31 | // As per http://wiki.nginx.org/HttpProxyModule - NoCache sets: 32 | // Expires: Thu, 01 Jan 1970 00:00:00 UTC 33 | // Cache-Control: no-cache, private, max-age=0 34 | // X-Accel-Expires: 0 35 | // Pragma: no-cache (for HTTP/1.0 proxies/clients) 36 | func NoCache(h http.Handler) http.Handler { 37 | fn := func(w http.ResponseWriter, r *http.Request) { 38 | 39 | // Delete any ETag headers that may have been set 40 | for _, v := range etagHeaders { 41 | if r.Header.Get(v) != "" { 42 | r.Header.Del(v) 43 | } 44 | } 45 | 46 | // Set our NoCache headers 47 | for k, v := range noCacheHeaders { 48 | w.Header().Set(k, v) 49 | } 50 | 51 | h.ServeHTTP(w, r) 52 | } 53 | 54 | return http.HandlerFunc(fn) 55 | } 56 | -------------------------------------------------------------------------------- /nocache_test.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestNoCache(t *testing.T) { 13 | 14 | rr := httptest.NewRecorder() 15 | testHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { 16 | t.Logf("%+v", r) 17 | }) 18 | 19 | handler := NoCache(testHandler) 20 | req, err := http.NewRequest("GET", "/api/v1/params", http.NoBody) 21 | require.NoError(t, err) 22 | req.Header.Set("ETag", "123") 23 | req.Header.Set("If-None-Match", "xyz") 24 | require.NoError(t, err) 25 | handler.ServeHTTP(rr, req) 26 | 27 | assert.Equal(t, "Thu, 01 Jan 1970 00:00:00 UTC", rr.Header().Get("Expires")) 28 | assert.Equal(t, "no-cache", rr.Header().Get("Pragma")) 29 | } 30 | -------------------------------------------------------------------------------- /onlyfrom.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/http" 7 | "strings" 8 | 9 | "github.com/go-pkgz/rest/realip" 10 | ) 11 | 12 | // OnlyFrom middleware allows access for limited list of source IPs. 13 | // Such IPs can be defined as complete ip (like 192.168.1.12), prefix (129.168.) or CIDR (192.168.0.0/16) 14 | func OnlyFrom(onlyIps ...string) func(http.Handler) http.Handler { 15 | return func(h http.Handler) http.Handler { 16 | fn := func(w http.ResponseWriter, r *http.Request) { 17 | if len(onlyIps) == 0 { 18 | // no restrictions if no ips defined 19 | h.ServeHTTP(w, r) 20 | return 21 | } 22 | matched, ip, err := matchSourceIP(r, onlyIps) 23 | if err != nil { 24 | w.WriteHeader(http.StatusInternalServerError) 25 | RenderJSON(w, JSON{"error": fmt.Sprintf("can't get realip: %s", err)}) 26 | return 27 | } 28 | if matched { 29 | // matched ip - allow 30 | h.ServeHTTP(w, r) 31 | return 32 | } 33 | 34 | w.WriteHeader(http.StatusForbidden) 35 | RenderJSON(w, JSON{"error": fmt.Sprintf("ip %q rejected", ip)}) 36 | } 37 | return http.HandlerFunc(fn) 38 | } 39 | } 40 | 41 | // matchSourceIP returns true if request's ip matches any of ips 42 | func matchSourceIP(r *http.Request, ips []string) (result bool, match string, err error) { 43 | ip, err := realip.Get(r) 44 | if err != nil { 45 | return false, "", fmt.Errorf("can't get realip: %w", err) // we can't get ip, so no match 46 | } 47 | // check for ip prefix or CIDR 48 | for _, exclIP := range ips { 49 | if _, cidrnet, err := net.ParseCIDR(exclIP); err == nil { 50 | if cidrnet.Contains(net.ParseIP(ip)) { 51 | return true, ip, nil 52 | } 53 | } 54 | if strings.HasPrefix(ip, exclIP) { 55 | return true, ip, nil 56 | } 57 | } 58 | return false, ip, nil 59 | } 60 | -------------------------------------------------------------------------------- /onlyfrom_test.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestOnlyFromAllowedIP(t *testing.T) { 14 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 15 | _, err := w.Write([]byte("blah blah")) 16 | require.NoError(t, err) 17 | }) 18 | ts := httptest.NewServer(OnlyFrom("127.0.0.1")(handler)) 19 | defer ts.Close() 20 | 21 | resp, err := http.Get(ts.URL + "/blah") 22 | require.Nil(t, err) 23 | defer resp.Body.Close() 24 | assert.Equal(t, 200, resp.StatusCode) 25 | 26 | b, err := io.ReadAll(resp.Body) 27 | assert.NoError(t, err) 28 | assert.Equal(t, "blah blah", string(b)) 29 | } 30 | 31 | func TestOnlyFromAllowedHeaders(t *testing.T) { 32 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 33 | _, err := w.Write([]byte("blah blah")) 34 | require.NoError(t, err) 35 | }) 36 | ts := httptest.NewServer(OnlyFrom("1.1.1.1")(handler)) 37 | defer ts.Close() 38 | 39 | reqWithHeader := func(header string) (*http.Request, error) { 40 | req, err := http.NewRequest("GET", ts.URL+"/blah", http.NoBody) 41 | if err != nil { 42 | return nil, err 43 | } 44 | req.Header.Set(header, "1.1.1.1") 45 | return req, err 46 | } 47 | client := http.Client{} 48 | 49 | t.Run("X-Real-IP", func(t *testing.T) { 50 | req, err := reqWithHeader("X-Real-IP") 51 | require.NoError(t, err) 52 | resp, err := client.Do(req) 53 | require.NoError(t, err) 54 | defer resp.Body.Close() 55 | assert.Equal(t, 200, resp.StatusCode) 56 | }) 57 | 58 | t.Run("X-Forwarded-For", func(t *testing.T) { 59 | req, err := reqWithHeader("X-Forwarded-For") 60 | require.NoError(t, err) 61 | resp, err := client.Do(req) 62 | require.NoError(t, err) 63 | defer resp.Body.Close() 64 | assert.Equal(t, 200, resp.StatusCode) 65 | }) 66 | 67 | t.Run("X-Forwarded-For and X-Real-IP missing", func(t *testing.T) { 68 | req, err := reqWithHeader("blah") 69 | require.NoError(t, err) 70 | resp, err := client.Do(req) 71 | require.NoError(t, err) 72 | defer resp.Body.Close() 73 | assert.Equal(t, 403, resp.StatusCode) 74 | }) 75 | } 76 | 77 | func TestOnlyFromAllowedCIDR(t *testing.T) { 78 | 79 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 80 | _, err := w.Write([]byte("blah blah")) 81 | require.NoError(t, err) 82 | }) 83 | ts := httptest.NewServer(OnlyFrom("1.1.1.0/24")(handler)) 84 | defer ts.Close() 85 | 86 | client := http.Client{} 87 | req, err := http.NewRequest("GET", ts.URL+"/blah", http.NoBody) 88 | require.NoError(t, err) 89 | req.Header.Set("X-Real-IP", "1.1.1.1") 90 | resp, err := client.Do(req) 91 | require.NoError(t, err) 92 | defer resp.Body.Close() 93 | assert.Equal(t, 200, resp.StatusCode) 94 | 95 | req.Header.Set("X-Real-IP", "1.1.2.0") 96 | resp, err = client.Do(req) 97 | require.NoError(t, err) 98 | defer resp.Body.Close() 99 | assert.Equal(t, 403, resp.StatusCode) 100 | } 101 | 102 | func TestOnlyFromRejected(t *testing.T) { 103 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 104 | _, err := w.Write([]byte("blah blah")) 105 | require.NoError(t, err) 106 | }) 107 | ts := httptest.NewServer(OnlyFrom("127.0.0.2")(handler)) 108 | defer ts.Close() 109 | 110 | resp, err := http.Get(ts.URL + "/blah") 111 | require.Nil(t, err) 112 | defer resp.Body.Close() 113 | assert.Equal(t, 403, resp.StatusCode) 114 | } 115 | 116 | func TestOnlyFromErrors(t *testing.T) { 117 | tests := []struct { 118 | name string 119 | remoteAddr string 120 | status int 121 | }{ 122 | { 123 | name: "Invalid RemoteAddr", 124 | remoteAddr: "bad-addr", 125 | status: http.StatusInternalServerError, 126 | }, 127 | } 128 | 129 | for _, tt := range tests { 130 | t.Run(tt.name, func(t *testing.T) { 131 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 132 | r.RemoteAddr = tt.remoteAddr 133 | OnlyFrom("1.1.1.1")(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 134 | _, err := w.Write([]byte("blah blah")) 135 | require.NoError(t, err) 136 | })).ServeHTTP(w, r) 137 | }) 138 | 139 | ts := httptest.NewServer(handler) 140 | defer ts.Close() 141 | 142 | req, err := http.NewRequest("GET", ts.URL+"/blah", http.NoBody) 143 | require.NoError(t, err) 144 | 145 | client := http.Client{} 146 | resp, err := client.Do(req) 147 | require.NoError(t, err) 148 | defer resp.Body.Close() 149 | assert.Equal(t, tt.status, resp.StatusCode) 150 | }) 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /profiler.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "expvar" 5 | "fmt" 6 | "net/http" 7 | "net/http/pprof" 8 | ) 9 | 10 | // Profiler is a convenient subrouter used for mounting net/http/pprof. ie. 11 | // 12 | // func MyService() http.Handler { 13 | // r := chi.NewRouter() 14 | // // ..middlewares 15 | // r.Mount("/debug", middleware.Profiler()) 16 | // // ..routes 17 | // return r 18 | // } 19 | func Profiler(onlyIps ...string) http.Handler { 20 | mux := http.NewServeMux() 21 | mux.HandleFunc("/pprof/", pprof.Index) 22 | mux.HandleFunc("/pprof/cmdline", pprof.Cmdline) 23 | mux.HandleFunc("/pprof/profile", pprof.Profile) 24 | mux.HandleFunc("/pprof/symbol", pprof.Symbol) 25 | mux.HandleFunc("/pprof/trace", pprof.Trace) 26 | mux.Handle("/pprof/block", pprof.Handler("block")) 27 | mux.Handle("/pprof/heap", pprof.Handler("heap")) 28 | mux.Handle("/pprof/goroutine", pprof.Handler("goroutine")) 29 | mux.Handle("/pprof/threadcreate", pprof.Handler("threadcreate")) 30 | mux.HandleFunc("/vars", expVars) 31 | 32 | return Wrap(mux, NoCache, OnlyFrom(onlyIps...)) 33 | } 34 | 35 | // expVars copied from stdlib expvar.go as is not public. 36 | func expVars(w http.ResponseWriter, _ *http.Request) { 37 | first := true 38 | w.Header().Set("Content-Type", "application/json") 39 | fmt.Fprintf(w, "{\n") 40 | expvar.Do(func(kv expvar.KeyValue) { 41 | if !first { 42 | fmt.Fprintf(w, ",\n") 43 | } 44 | first = false 45 | fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value) 46 | }) 47 | fmt.Fprintf(w, "\n}\n") 48 | } 49 | -------------------------------------------------------------------------------- /realip/real.go: -------------------------------------------------------------------------------- 1 | // Package realip extracts a real IP address from the request. 2 | package realip 3 | 4 | import ( 5 | "bytes" 6 | "fmt" 7 | "net" 8 | "net/http" 9 | "strings" 10 | ) 11 | 12 | type ipRange struct { 13 | start net.IP 14 | end net.IP 15 | } 16 | 17 | // privateRanges contains the list of private and special-use IP ranges. 18 | // reference: https://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml 19 | var privateRanges = []ipRange{ 20 | // IPv4 Private Ranges 21 | {start: net.ParseIP("10.0.0.0"), end: net.ParseIP("10.255.255.255")}, 22 | {start: net.ParseIP("172.16.0.0"), end: net.ParseIP("172.31.255.255")}, 23 | {start: net.ParseIP("192.168.0.0"), end: net.ParseIP("192.168.255.255")}, 24 | // IPv4 Link-Local 25 | {start: net.ParseIP("169.254.0.0"), end: net.ParseIP("169.254.255.255")}, 26 | // IPv4 Shared Address Space (RFC 6598) 27 | {start: net.ParseIP("100.64.0.0"), end: net.ParseIP("100.127.255.255")}, 28 | // IPv4 Benchmarking (RFC 2544) 29 | {start: net.ParseIP("198.18.0.0"), end: net.ParseIP("198.19.255.255")}, 30 | // IPv6 Unique Local Addresses (ULA) 31 | {start: net.ParseIP("fc00::"), end: net.ParseIP("fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")}, 32 | // IPv6 Link-local Addresses 33 | {start: net.ParseIP("fe80::"), end: net.ParseIP("febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff")}, 34 | } 35 | 36 | // Get returns real ip from the given request 37 | // Prioritize public IPs over private IPs 38 | func Get(r *http.Request) (string, error) { 39 | var firstIP string 40 | for _, h := range []string{"X-Forwarded-For", "X-Real-Ip"} { 41 | addresses := strings.Split(r.Header.Get(h), ",") 42 | for i := len(addresses) - 1; i >= 0; i-- { 43 | ip := strings.TrimSpace(addresses[i]) 44 | realIP := net.ParseIP(ip) 45 | if firstIP == "" && realIP != nil { 46 | firstIP = ip 47 | } 48 | // Guard against nil realIP 49 | if realIP == nil || !realIP.IsGlobalUnicast() || isPrivateSubnet(realIP) { 50 | continue 51 | } 52 | return ip, nil 53 | } 54 | } 55 | 56 | if firstIP != "" { 57 | return firstIP, nil 58 | } 59 | 60 | // handle RemoteAddr which may be just an IP or IP:port 61 | remoteIP := r.RemoteAddr 62 | 63 | // try to extract host from host:port format 64 | host, _, err := net.SplitHostPort(remoteIP) 65 | if err == nil { 66 | remoteIP = host 67 | } 68 | 69 | // at this point remoteIP could be either: 70 | // 1. the host part extracted from host:port 71 | // 2. yhe original RemoteAddr if it doesn't contain a port 72 | 73 | // try to parse it as a valid IP address 74 | if netIP := net.ParseIP(remoteIP); netIP == nil { 75 | return "", fmt.Errorf("no valid ip found in %q", r.RemoteAddr) 76 | } 77 | 78 | return remoteIP, nil 79 | } 80 | 81 | // isPrivateSubnet - check to see if this ip is in a private subnet 82 | func isPrivateSubnet(ipAddress net.IP) bool { 83 | inRange := func(r ipRange, ipAddress net.IP) bool { // check to see if a given ip address is within a range given 84 | // ensure the IPs are in the same format for comparison 85 | ipAddress = ipAddress.To16() 86 | r.start = r.start.To16() 87 | r.end = r.end.To16() 88 | return bytes.Compare(ipAddress, r.start) >= 0 && bytes.Compare(ipAddress, r.end) <= 0 89 | } 90 | 91 | for _, r := range privateRanges { 92 | if inRange(r, ipAddress) { 93 | return true 94 | } 95 | } 96 | return false 97 | } 98 | -------------------------------------------------------------------------------- /realip/real_test.go: -------------------------------------------------------------------------------- 1 | package realip 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestGetFromHeaders(t *testing.T) { 15 | t.Run("single X-Real-IP", func(t *testing.T) { 16 | req, err := http.NewRequest("GET", "/something", http.NoBody) 17 | assert.NoError(t, err) 18 | req.Header.Add("Something", "1234567") 19 | req.Header.Add("X-Real-IP", "8.8.8.8") 20 | adr, err := Get(req) 21 | require.NoError(t, err) 22 | assert.Equal(t, "8.8.8.8", adr) 23 | }) 24 | t.Run("X-Forwarded-For last public", func(t *testing.T) { 25 | req, err := http.NewRequest("GET", "/something", http.NoBody) 26 | assert.NoError(t, err) 27 | req.Header.Add("Something", "1234567") 28 | req.Header.Add("X-Forwarded-For", "8.8.8.8,1.1.1.2, 30.30.30.1") 29 | adr, err := Get(req) 30 | require.NoError(t, err) 31 | assert.Equal(t, "30.30.30.1", adr) 32 | }) 33 | t.Run("X-Forwarded-For last private", func(t *testing.T) { 34 | req, err := http.NewRequest("GET", "/something", http.NoBody) 35 | assert.NoError(t, err) 36 | req.Header.Add("Something", "1234567") 37 | req.Header.Add("X-Forwarded-For", "8.8.8.8,1.1.1.2,192.168.1.1,10.0.0.65") 38 | adr, err := Get(req) 39 | require.NoError(t, err) 40 | assert.Equal(t, "1.1.1.2", adr) 41 | }) 42 | t.Run("X-Forwarded-For public im the middle", func(t *testing.T) { 43 | req, err := http.NewRequest("GET", "/something", http.NoBody) 44 | assert.NoError(t, err) 45 | req.Header.Add("Something", "1234567") 46 | req.Header.Add("X-Forwarded-For", "192.168.1.1, 8.8.8.8, 10.0.0.65") 47 | adr, err := Get(req) 48 | require.NoError(t, err) 49 | assert.Equal(t, "8.8.8.8", adr) 50 | }) 51 | t.Run("X-Forwarded-For all private", func(t *testing.T) { 52 | req, err := http.NewRequest("GET", "/something", http.NoBody) 53 | assert.NoError(t, err) 54 | req.Header.Add("Something", "1234567") 55 | req.Header.Add("X-Forwarded-For", "192.168.1.1,10.0.0.65") 56 | adr, err := Get(req) 57 | require.NoError(t, err) 58 | assert.Equal(t, "10.0.0.65", adr) 59 | }) 60 | t.Run("X-Forwarded-For public, X-Real-IP private", func(t *testing.T) { 61 | req, err := http.NewRequest("GET", "/something", http.NoBody) 62 | assert.NoError(t, err) 63 | req.Header.Add("Something", "1234567") 64 | req.Header.Add("X-Forwarded-For", "30.30.30.1") 65 | req.Header.Add("X-Real-Ip", "10.0.0.1") 66 | adr, err := Get(req) 67 | require.NoError(t, err) 68 | assert.Equal(t, "30.30.30.1", adr) 69 | }) 70 | t.Run("X-Forwarded-For and X-Real-IP public", func(t *testing.T) { 71 | req, err := http.NewRequest("GET", "/something", http.NoBody) 72 | assert.NoError(t, err) 73 | req.Header.Add("Something", "1234567") 74 | req.Header.Add("X-Forwarded-For", "30.30.30.1") 75 | req.Header.Add("X-Real-Ip", "8.8.8.8") 76 | adr, err := Get(req) 77 | require.NoError(t, err) 78 | assert.Equal(t, "30.30.30.1", adr) 79 | }) 80 | t.Run("X-Forwarded-For private and X-Real-IP public", func(t *testing.T) { 81 | req, err := http.NewRequest("GET", "/something", http.NoBody) 82 | assert.NoError(t, err) 83 | req.Header.Add("Something", "1234567") 84 | req.Header.Add("X-Forwarded-For", "10.0.0.2,192.168.1.1") 85 | req.Header.Add("X-Real-Ip", "8.8.8.8") 86 | adr, err := Get(req) 87 | require.NoError(t, err) 88 | assert.Equal(t, "8.8.8.8", adr) 89 | }) 90 | t.Run("RemoteAddr fallback", func(t *testing.T) { 91 | req, err := http.NewRequest("GET", "/something", http.NoBody) 92 | assert.NoError(t, err) 93 | req.RemoteAddr = "192.0.2.1:1234" 94 | adr, err := Get(req) 95 | require.NoError(t, err) 96 | assert.Equal(t, "192.0.2.1", adr) 97 | }) 98 | t.Run("X-Forwarded-For and X-Real-IP missing, no RemoteAddr either", func(t *testing.T) { 99 | req, err := http.NewRequest("GET", "/something", http.NoBody) 100 | assert.NoError(t, err) 101 | ip, err := Get(req) 102 | assert.Error(t, err) 103 | assert.Equal(t, "", ip) 104 | }) 105 | t.Run("X-Real-IP IPv6", func(t *testing.T) { 106 | req, err := http.NewRequest("GET", "/something", http.NoBody) 107 | assert.NoError(t, err) 108 | req.Header.Add("X-Real-IP", "2001:0db8:85a3:0000:0000:8a2e:0370:7334") 109 | adr, err := Get(req) 110 | require.NoError(t, err) 111 | assert.Equal(t, "2001:0db8:85a3:0000:0000:8a2e:0370:7334", adr) 112 | }) 113 | t.Run("X-Forwarded-For last IPv6 public", func(t *testing.T) { 114 | req, err := http.NewRequest("GET", "/something", http.NoBody) 115 | assert.NoError(t, err) 116 | req.Header.Add("X-Forwarded-For", "2001:db8::ff00:42:8329,::1,fc00::") 117 | adr, err := Get(req) 118 | require.NoError(t, err) 119 | assert.Equal(t, "2001:db8::ff00:42:8329", adr) 120 | }) 121 | 122 | t.Run("RemoteAddr IPv6 fallback", func(t *testing.T) { 123 | req, err := http.NewRequest("GET", "/something", http.NoBody) 124 | assert.NoError(t, err) 125 | req.RemoteAddr = "[2001:db8::ff00:42:8329]:1234" 126 | adr, err := Get(req) 127 | require.NoError(t, err) 128 | assert.Equal(t, "2001:db8::ff00:42:8329", adr) 129 | }) 130 | } 131 | 132 | func TestGetFromRemoteAddr(t *testing.T) { 133 | ts := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { 134 | log.Printf("%v", r) 135 | adr, err := Get(r) 136 | require.NoError(t, err) 137 | assert.Equal(t, "127.0.0.1", adr) 138 | })) 139 | 140 | req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody) 141 | require.NoError(t, err) 142 | client := http.Client{Timeout: time.Second} 143 | _, err = client.Do(req) 144 | require.NoError(t, err) 145 | } 146 | 147 | func TestGetWithVariousIPFormats(t *testing.T) { 148 | tests := []struct { 149 | name string 150 | remoteAddr string 151 | headers map[string]string 152 | wantIP string 153 | wantErr bool 154 | }{ 155 | { 156 | name: "IPv4 with port", 157 | remoteAddr: "192.168.1.1:8080", 158 | wantIP: "192.168.1.1", 159 | wantErr: false, 160 | }, 161 | { 162 | name: "IPv4 without port", 163 | remoteAddr: "127.0.0.1", 164 | wantIP: "127.0.0.1", 165 | wantErr: false, 166 | }, 167 | { 168 | name: "IPv6 with port", 169 | remoteAddr: "[::1]:8080", 170 | wantIP: "::1", 171 | wantErr: false, 172 | }, 173 | { 174 | name: "IPv6 without port", 175 | remoteAddr: "::1", 176 | wantIP: "::1", 177 | wantErr: false, 178 | }, 179 | { 180 | name: "X-Forwarded-For", 181 | remoteAddr: "127.0.0.1:8080", 182 | headers: map[string]string{"X-Forwarded-For": "203.0.113.1"}, 183 | wantIP: "203.0.113.1", 184 | wantErr: false, 185 | }, 186 | { 187 | name: "Invalid IP", 188 | remoteAddr: "invalid-ip", 189 | wantErr: true, 190 | }, 191 | } 192 | 193 | for _, tt := range tests { 194 | t.Run(tt.name, func(t *testing.T) { 195 | req, err := http.NewRequest("GET", "/", nil) 196 | require.NoError(t, err) 197 | 198 | req.RemoteAddr = tt.remoteAddr 199 | for k, v := range tt.headers { 200 | req.Header.Set(k, v) 201 | } 202 | 203 | gotIP, err := Get(req) 204 | if tt.wantErr { 205 | assert.Error(t, err) 206 | return 207 | } 208 | 209 | assert.NoError(t, err) 210 | assert.Equal(t, tt.wantIP, gotIP) 211 | }) 212 | } 213 | } 214 | -------------------------------------------------------------------------------- /rest.go: -------------------------------------------------------------------------------- 1 | // Package rest provides common middlewares and helpers for rest services 2 | package rest 3 | 4 | import ( 5 | "bytes" 6 | "encoding/json" 7 | "fmt" 8 | "net/http" 9 | "time" 10 | ) 11 | 12 | // JSON is a map alias, just for convenience 13 | type JSON map[string]any 14 | 15 | // RenderJSON sends data as json 16 | func RenderJSON(w http.ResponseWriter, data interface{}) { 17 | buf := &bytes.Buffer{} 18 | enc := json.NewEncoder(buf) 19 | enc.SetEscapeHTML(true) 20 | if err := enc.Encode(data); err != nil { 21 | http.Error(w, err.Error(), http.StatusInternalServerError) 22 | return 23 | } 24 | w.Header().Set("Content-Type", "application/json; charset=utf-8") 25 | _, _ = w.Write(buf.Bytes()) 26 | } 27 | 28 | // RenderJSONFromBytes sends binary data as json 29 | func RenderJSONFromBytes(w http.ResponseWriter, r *http.Request, data []byte) error { 30 | w.Header().Set("Content-Type", "application/json; charset=utf-8") 31 | if _, err := w.Write(data); err != nil { 32 | return fmt.Errorf("failed to send response to %s: %w", r.RemoteAddr, err) 33 | } 34 | return nil 35 | } 36 | 37 | // RenderJSONWithHTML allows html tags and forces charset=utf-8 38 | func RenderJSONWithHTML(w http.ResponseWriter, r *http.Request, v interface{}) error { 39 | 40 | encodeJSONWithHTML := func(v interface{}) ([]byte, error) { 41 | buf := &bytes.Buffer{} 42 | enc := json.NewEncoder(buf) 43 | enc.SetEscapeHTML(false) 44 | if err := enc.Encode(v); err != nil { 45 | return nil, fmt.Errorf("json encoding failed: %w", err) 46 | } 47 | return buf.Bytes(), nil 48 | } 49 | 50 | data, err := encodeJSONWithHTML(v) 51 | if err != nil { 52 | return err 53 | } 54 | return RenderJSONFromBytes(w, r, data) 55 | } 56 | 57 | // renderJSONWithStatus sends data as json and enforces status code 58 | func renderJSONWithStatus(w http.ResponseWriter, data interface{}, code int) { 59 | buf := &bytes.Buffer{} 60 | enc := json.NewEncoder(buf) 61 | enc.SetEscapeHTML(true) 62 | if err := enc.Encode(data); err != nil { 63 | http.Error(w, err.Error(), http.StatusInternalServerError) 64 | return 65 | } 66 | w.Header().Set("Content-Type", "application/json; charset=utf-8") 67 | w.WriteHeader(code) 68 | _, _ = w.Write(buf.Bytes()) 69 | } 70 | 71 | // ParseFromTo parses from and to query params of the request 72 | func ParseFromTo(r *http.Request) (from, to time.Time, err error) { 73 | parseTimeStamp := func(ts string) (time.Time, error) { 74 | formats := []string{ 75 | "2006-01-02T15:04:05.000000000", 76 | "2006-01-02T15:04:05", 77 | "2006-01-02T15:04", 78 | "20060102", 79 | time.RFC3339, 80 | time.RFC3339Nano, 81 | } 82 | 83 | for _, f := range formats { 84 | if t, e := time.Parse(f, ts); e == nil { 85 | return t, nil 86 | } 87 | } 88 | return time.Time{}, fmt.Errorf("can't parse date %q", ts) 89 | } 90 | 91 | if from, err = parseTimeStamp(r.URL.Query().Get("from")); err != nil { 92 | return from, to, fmt.Errorf("incorrect from time: %w", err) 93 | } 94 | 95 | if to, err = parseTimeStamp(r.URL.Query().Get("to")); err != nil { 96 | return from, to, fmt.Errorf("incorrect to time: %w", err) 97 | } 98 | return from, to, nil 99 | } 100 | 101 | // DecodeJSON decodes json request from http.Request to given type 102 | func DecodeJSON[T any](r *http.Request, res *T) error { 103 | if err := json.NewDecoder(r.Body).Decode(&res); err != nil { 104 | return fmt.Errorf("decode json: %w", err) 105 | } 106 | return nil 107 | } 108 | 109 | // EncodeJSON encodes given type to http.ResponseWriter and sets status code and content type header 110 | func EncodeJSON[T any](w http.ResponseWriter, status int, v T) error { 111 | w.Header().Set("Content-Type", "application/json; charset=utf-8") 112 | w.WriteHeader(status) 113 | if err := json.NewEncoder(w).Encode(v); err != nil { 114 | return fmt.Errorf("encode json: %w", err) 115 | } 116 | return nil 117 | } 118 | -------------------------------------------------------------------------------- /rest_test.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "io" 8 | "net/http" 9 | "net/http/httptest" 10 | "strconv" 11 | "testing" 12 | "time" 13 | 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/require" 16 | ) 17 | 18 | func TestRest_RenderJSON(t *testing.T) { 19 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 20 | j := JSON{"key1": 1, "key2": "222"} 21 | RenderJSON(w, j) 22 | })) 23 | defer ts.Close() 24 | 25 | resp, err := http.Get(ts.URL + "/test") 26 | require.NoError(t, err) 27 | assert.Equal(t, 200, resp.StatusCode) 28 | require.NoError(t, err) 29 | body, err := io.ReadAll(resp.Body) 30 | require.NoError(t, err) 31 | defer resp.Body.Close() 32 | 33 | assert.Equal(t, `{"key1":1,"key2":"222"}`+"\n", string(body)) 34 | assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type")) 35 | } 36 | 37 | func TestRest_RenderJSONFromBytes(t *testing.T) { 38 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 39 | require.NoError(t, RenderJSONFromBytes(w, r, []byte("some data"))) 40 | })) 41 | defer ts.Close() 42 | 43 | resp, err := http.Get(ts.URL + "/test") 44 | require.NoError(t, err) 45 | assert.Equal(t, 200, resp.StatusCode) 46 | body, err := io.ReadAll(resp.Body) 47 | require.NoError(t, err) 48 | defer resp.Body.Close() 49 | 50 | assert.Equal(t, "some data", string(body)) 51 | assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type")) 52 | } 53 | 54 | func TestRest_RenderJSONWithHTML(t *testing.T) { 55 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 56 | j := JSON{"key1": "val1", "key2": 2.0, "html": `