├── .github └── workflows │ ├── go-cross.yml │ └── pr.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── Makefile ├── README.md ├── buffer ├── buffer.go ├── buffer_test.go ├── options.go ├── retry_test.go └── threshold.go ├── cbreaker ├── cbreaker.go ├── cbreaker_test.go ├── effect.go ├── fallback.go ├── options.go ├── predicates.go ├── predicates_test.go ├── ratio.go └── ratio_test.go ├── connlimit ├── connlimit.go ├── connlimit_test.go └── options.go ├── forward ├── example_test.go ├── fwd.go ├── fwd_test.go ├── fwd_websocket_test.go ├── headers.go ├── middlewares.go ├── rewrite.go └── rewrite_test.go ├── go.mod ├── go.sum ├── internal └── holsterv4 │ ├── LICENSE │ ├── README.md │ ├── clock │ ├── README.md │ ├── clock.go │ ├── clock_mutex.go │ ├── duration.go │ ├── duration_test.go │ ├── frozen.go │ ├── frozen_test.go │ ├── go19.go │ ├── interface.go │ ├── rfc822.go │ ├── rfc822_test.go │ ├── system.go │ └── system_test.go │ └── collections │ ├── README.md │ ├── priority_queue.go │ ├── priority_queue_test.go │ ├── ttlmap.go │ └── ttlmap_test.go ├── memmetrics ├── anomaly.go ├── anomaly_test.go ├── counter.go ├── counter_test.go ├── histogram.go ├── histogram_test.go ├── options.go ├── ratio.go ├── ratio_test.go ├── roundtrip.go └── roundtrip_test.go ├── ratelimit ├── bucket.go ├── bucket_test.go ├── bucketset.go ├── bucketset_test.go ├── options.go ├── tokenlimiter.go └── tokenlimiter_test.go ├── roundrobin ├── RequestRewriteListener.go ├── options.go ├── rebalancer.go ├── rebalancer_test.go ├── rr.go ├── rr_test.go ├── stickycookie │ ├── aes_value.go │ ├── cookie_value.go │ ├── fallback_value.go │ ├── fallback_value_test.go │ ├── hash_value.go │ └── raw_value.go ├── stickysessions.go └── stickysessions_test.go ├── stream ├── options.go ├── stream.go ├── stream_test.go └── threshold.go ├── testutils └── utils.go ├── trace ├── options.go ├── trace.go └── trace_test.go └── utils ├── auth.go ├── auth_test.go ├── dumpreq.go ├── dumpreq_test.go ├── handler.go ├── handler_test.go ├── log.go ├── netutils.go ├── netutils_test.go └── source.go /.github/workflows/go-cross.yml: -------------------------------------------------------------------------------- 1 | name: Go Matrix 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | pull_request: 9 | branches: 10 | - main 11 | - master 12 | 13 | jobs: 14 | 15 | cross: 16 | name: Go 17 | runs-on: ${{ matrix.os }} 18 | env: 19 | CGO_ENABLED: 0 20 | 21 | strategy: 22 | matrix: 23 | go-version: [ stable, oldstable ] 24 | os: [ubuntu-latest, macos-latest] 25 | # TODO ignore windows but need to be added in the future 26 | # os: [ubuntu-latest, macos-latest, windows-latest] 27 | 28 | steps: 29 | - uses: actions/checkout@v4 30 | - uses: actions/setup-go@v5 31 | with: 32 | go-version: ${{ matrix.go-version }} 33 | 34 | - name: Test 35 | run: go test -v -cover ./... 36 | 37 | -------------------------------------------------------------------------------- /.github/workflows/pr.yml: -------------------------------------------------------------------------------- 1 | name: Main 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | pull_request: 9 | branches: 10 | - main 11 | - master 12 | 13 | jobs: 14 | 15 | main: 16 | name: Main Process 17 | runs-on: ubuntu-latest 18 | env: 19 | GO_VERSION: stable 20 | GOLANGCI_LINT_VERSION: v2.0.1 21 | 22 | steps: 23 | - uses: actions/checkout@v4 24 | - uses: actions/setup-go@v5 25 | with: 26 | go-version: ${{ env.GO_VERSION }} 27 | 28 | - name: Check and get dependencies 29 | run: | 30 | go mod tidy 31 | git diff --exit-code go.mod 32 | git diff --exit-code go.sum 33 | 34 | # https://golangci-lint.run/usage/install#other-ci 35 | - name: Install golangci-lint ${{ env.GOLANGCI_LINT_VERSION }} 36 | run: | 37 | curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin ${GOLANGCI_LINT_VERSION} 38 | golangci-lint --version 39 | 40 | - name: Make 41 | run: make 42 | 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | .idea/ 26 | 27 | flymake_* 28 | 29 | vendor/ -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | formatters: 4 | enable: 5 | - gci 6 | - gofumpt 7 | settings: 8 | gofumpt: 9 | extra-rules: false 10 | 11 | linters: 12 | default: all 13 | disable: 14 | - bodyclose # Too many false positives: https://github.com/timakin/bodyclose/issues/30 15 | - cyclop # duplicate of gocyclo 16 | - dupl 17 | - err113 18 | - exhaustive 19 | - exhaustruct 20 | - forcetypeassert 21 | - gochecknoglobals # TODO(ldez) should be use on the project 22 | - ireturn 23 | - lll 24 | - mnd 25 | - musttag 26 | - nestif # TODO(ldez) should be use on the project 27 | - nilnil 28 | - nlreturn 29 | - noctx 30 | - nonamedreturns 31 | - paralleltest 32 | - prealloc 33 | - rowserrcheck # not relevant (SQL) 34 | - sqlclosecheck # not relevant (SQL) 35 | - tagliatelle 36 | - testpackage 37 | - tparallel 38 | - varnamelen 39 | - wrapcheck 40 | - wsl # TODO(ldez) should be use on the project 41 | settings: 42 | depguard: 43 | rules: 44 | main: 45 | deny: 46 | - pkg: github.com/instana/testify 47 | desc: not allowed 48 | - pkg: github.com/pkg/errors 49 | desc: Should be replaced by standard lib errors package 50 | funlen: 51 | lines: -1 52 | statements: 50 53 | goconst: 54 | min-len: 5 55 | min-occurrences: 3 56 | gocritic: 57 | disabled-checks: 58 | - sloppyReassign 59 | - rangeValCopy 60 | - octalLiteral 61 | - paramTypeCombine # already handle by gofumpt.extra-rules 62 | - httpNoBody 63 | - unnamedResult 64 | - deferInLoop # TODO(ldez) should be use on the project 65 | enabled-tags: 66 | - diagnostic 67 | - style 68 | - performance 69 | settings: 70 | hugeParam: 71 | sizeThreshold: 100 72 | gocyclo: 73 | min-complexity: 15 74 | godox: 75 | keywords: 76 | - FIXME 77 | govet: 78 | disable: 79 | - fieldalignment 80 | - shadow 81 | enable-all: true 82 | misspell: 83 | locale: US 84 | perfsprint: 85 | err-error: true 86 | errorf: true 87 | sprintf1: true 88 | strconcat: false 89 | testifylint: 90 | disable: 91 | - go-require 92 | exclusions: 93 | warn-unused: true 94 | presets: 95 | - comments 96 | - std-error-handling 97 | rules: 98 | - linters: 99 | - canonicalheader 100 | - funlen 101 | - goconst 102 | - gosec 103 | path: .*_test.go 104 | - linters: 105 | - gosec 106 | path: testutils/.+ 107 | - path: cbreaker/cbreaker_test.go # TODO(ldez) must be fixed 108 | text: 'statsNetErrors - threshold always receives 0.6' 109 | - path: buffer/buffer.go # TODO(ldez) must be fixed 110 | text: (cognitive|cyclomatic) complexity \d+ of func `\(\*Buffer\)\.ServeHTTP` is high 111 | - path: buffer/buffer.go # TODO(ldez) must be fixed 112 | text: Function 'ServeHTTP' has too many statements 113 | - path: memmetrics/ratio_test.go # TODO(ldez) must be fixed 114 | text: 'float-compare: use assert\.InEpsilon \(or InDelta\)' 115 | - path: memmetrics/roundtrip_test.go # TODO(ldez) must be fixed 116 | text: 'float-compare: use assert\.InEpsilon \(or InDelta\)' 117 | - path: memmetrics/anomaly_test.go # TODO(ldez) must be fixed 118 | text: 'float-compare: use assert\.InEpsilon \(or InDelta\)' 119 | - path: (.+)\.go$ # TODO(ldez) must be fixed 120 | text: 'SA1019: http.CloseNotifier has been deprecated' 121 | - path: (.+)\.go$ # TODO(ldez) must be fixed 122 | text: 'exported: func name will be used as roundrobin.RoundRobinRequestRewriteListener by other packages' 123 | paths: 124 | - internal/holsterv4 125 | issues: 126 | max-issues-per-linter: 0 127 | max-same-issues: 0 128 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: default clean checks test test-verbose 2 | 3 | export GO111MODULE=on 4 | 5 | default: clean checks test 6 | 7 | test: clean 8 | go test -race -cover -count 1 ./... 9 | 10 | test-verbose: clean 11 | go test -v -race -cover ./... 12 | 13 | clean: 14 | find . -name flymake_* -delete 15 | rm -f cover.out 16 | 17 | checks: 18 | golangci-lint run 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Oxy [![Build Status](https://travis-ci.org/vulcand/oxy.svg?branch=master)](https://travis-ci.org/vulcand/oxy) 2 | ===== 3 | 4 | Oxy is a Go library with HTTP handlers that enhance HTTP standard library: 5 | 6 | * [Buffer](https://pkg.go.dev/github.com/vulcand/oxy/buffer) retries and buffers requests and responses 7 | * [Stream](https://pkg.go.dev/github.com/vulcand/oxy/stream) passes-through requests, supports chunked encoding with configurable flush interval 8 | * [Forward](https://pkg.go.dev/github.com/vulcand/oxy/forward) forwards requests to remote location and rewrites headers 9 | * [Roundrobin](https://pkg.go.dev/github.com/vulcand/oxy/roundrobin) is a round-robin load balancer 10 | * [Circuit Breaker](https://pkg.go.dev/github.com/vulcand/oxy/cbreaker) Hystrix-style circuit breaker 11 | * [Connlimit](https://pkg.go.dev/github.com/vulcand/oxy/connlimit) Simultaneous connections limiter 12 | * [Ratelimit](https://pkg.go.dev/github.com/vulcand/oxy/ratelimit) Rate limiter (based on tokenbucket algo) 13 | * [Trace](https://pkg.go.dev/github.com/vulcand/oxy/trace) Structured request and response logger 14 | 15 | It is designed to be fully compatible with http standard library, easy to customize and reuse. 16 | 17 | Status 18 | ------ 19 | 20 | * Initial design is completed 21 | * Covered by tests 22 | * Used as a reverse proxy engine in [Vulcand](https://github.com/vulcand/vulcand) 23 | 24 | Quickstart 25 | ----------- 26 | 27 | Every handler is ``http.Handler``, so writing and plugging in a middleware is easy. Let us write a simple reverse proxy as an example: 28 | 29 | Simple reverse proxy 30 | ==================== 31 | 32 | ```go 33 | 34 | import ( 35 | "net/http" 36 | "github.com/vulcand/oxy/v2/forward" 37 | "github.com/vulcand/oxy/v2/testutils" 38 | ) 39 | 40 | // Forwards incoming requests to whatever location URL points to, adds proper forwarding headers 41 | fwd := forward.New(false) 42 | 43 | redirect := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 44 | // let us forward this request to another server 45 | req.URL = testutils.ParseURI("http://localhost:63450") 46 | fwd.ServeHTTP(w, req) 47 | }) 48 | 49 | // that's it! our reverse proxy is ready! 50 | s := &http.Server{ 51 | Addr: ":8080", 52 | Handler: redirect, 53 | } 54 | s.ListenAndServe() 55 | ``` 56 | 57 | As a next step, let us add a round robin load-balancer: 58 | 59 | 60 | ```go 61 | 62 | import ( 63 | "net/http" 64 | "github.com/vulcand/oxy/v2/forward" 65 | "github.com/vulcand/oxy/v2/roundrobin" 66 | ) 67 | 68 | // Forwards incoming requests to whatever location URL points to, adds proper forwarding headers 69 | fwd := forward.New(false) 70 | lb, _ := roundrobin.New(fwd) 71 | 72 | lb.UpsertServer(url1) 73 | lb.UpsertServer(url2) 74 | 75 | s := &http.Server{ 76 | Addr: ":8080", 77 | Handler: lb, 78 | } 79 | s.ListenAndServe() 80 | ``` 81 | 82 | What if we want to handle retries and replay the request in case of errors? `buffer` handler will help: 83 | 84 | 85 | ```go 86 | 87 | import ( 88 | "net/http" 89 | "github.com/vulcand/oxy/v2/forward" 90 | "github.com/vulcand/oxy/v2/buffer" 91 | "github.com/vulcand/oxy/v2/roundrobin" 92 | ) 93 | 94 | // Forwards incoming requests to whatever location URL points to, adds proper forwarding headers 95 | 96 | fwd := forward.New(false) 97 | lb, _ := roundrobin.New(fwd) 98 | 99 | // buffer will read the request body and will replay the request again in case if forward returned status 100 | // corresponding to nework error (e.g. Gateway Timeout) 101 | buffer, _ := buffer.New(lb, buffer.Retry(`IsNetworkError() && Attempts() < 2`)) 102 | 103 | lb.UpsertServer(url1) 104 | lb.UpsertServer(url2) 105 | 106 | // that's it! our reverse proxy is ready! 107 | s := &http.Server{ 108 | Addr: ":8080", 109 | Handler: buffer, 110 | } 111 | s.ListenAndServe() 112 | ``` 113 | -------------------------------------------------------------------------------- /buffer/options.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/vulcand/oxy/v2/utils" 7 | ) 8 | 9 | // Option represents an option you can pass to New. 10 | type Option func(b *Buffer) error 11 | 12 | // Logger defines the logger used by Buffer. 13 | func Logger(l utils.Logger) Option { 14 | return func(b *Buffer) error { 15 | b.log = l 16 | return nil 17 | } 18 | } 19 | 20 | // Verbose additional debug information. 21 | func Verbose(verbose bool) Option { 22 | return func(b *Buffer) error { 23 | b.verbose = verbose 24 | return nil 25 | } 26 | } 27 | 28 | // Cond Conditional setter. 29 | // ex: Cond(a > 4, MemRequestBodyBytes(a)) 30 | func Cond(condition bool, setter Option) Option { 31 | if !condition { 32 | // NoOp setter 33 | return func(*Buffer) error { 34 | return nil 35 | } 36 | } 37 | return setter 38 | } 39 | 40 | // Retry provides a predicate that allows buffer middleware to replay the request 41 | // if it matches certain condition, e.g. returns special error code. Available functions are: 42 | // 43 | // Attempts() - limits the amount of retry attempts 44 | // ResponseCode() - returns http response code 45 | // IsNetworkError() - tests if response code is related to networking error 46 | // 47 | // Example of the predicate: 48 | // 49 | // `Attempts() <= 2 && ResponseCode() == 502`. 50 | func Retry(predicate string) Option { 51 | return func(b *Buffer) error { 52 | p, err := parseExpression(predicate) 53 | if err != nil { 54 | return err 55 | } 56 | b.retryPredicate = p 57 | return nil 58 | } 59 | } 60 | 61 | // ErrorHandler sets error handler of the server. 62 | func ErrorHandler(h utils.ErrorHandler) Option { 63 | return func(b *Buffer) error { 64 | b.errHandler = h 65 | return nil 66 | } 67 | } 68 | 69 | // MaxRequestBodyBytes sets the maximum request body size in bytes. 70 | func MaxRequestBodyBytes(m int64) Option { 71 | return func(b *Buffer) error { 72 | if m < 0 { 73 | return fmt.Errorf("max bytes should be >= 0 got %d", m) 74 | } 75 | b.maxRequestBodyBytes = m 76 | return nil 77 | } 78 | } 79 | 80 | // MemRequestBodyBytes bytes sets the maximum request body to be stored in memory 81 | // buffer middleware will serialize the excess to disk. 82 | func MemRequestBodyBytes(m int64) Option { 83 | return func(b *Buffer) error { 84 | if m < 0 { 85 | return fmt.Errorf("mem bytes should be >= 0 got %d", m) 86 | } 87 | b.memRequestBodyBytes = m 88 | return nil 89 | } 90 | } 91 | 92 | // MaxResponseBodyBytes sets the maximum response body size in bytes. 93 | func MaxResponseBodyBytes(m int64) Option { 94 | return func(b *Buffer) error { 95 | if m < 0 { 96 | return fmt.Errorf("max bytes should be >= 0 got %d", m) 97 | } 98 | b.maxResponseBodyBytes = m 99 | return nil 100 | } 101 | } 102 | 103 | // MemResponseBodyBytes sets the maximum response body to be stored in memory 104 | // buffer middleware will serialize the excess to disk. 105 | func MemResponseBodyBytes(m int64) Option { 106 | return func(b *Buffer) error { 107 | if m < 0 { 108 | return fmt.Errorf("mem bytes should be >= 0 got %d", m) 109 | } 110 | b.memResponseBodyBytes = m 111 | return nil 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /buffer/retry_test.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | "github.com/vulcand/oxy/v2/forward" 11 | "github.com/vulcand/oxy/v2/roundrobin" 12 | "github.com/vulcand/oxy/v2/testutils" 13 | ) 14 | 15 | func TestBuffer_success(t *testing.T) { 16 | srv := testutils.NewHandler(func(w http.ResponseWriter, _ *http.Request) { 17 | _, _ = w.Write([]byte("hello")) 18 | }) 19 | t.Cleanup(srv.Close) 20 | 21 | lb, rt := newBufferMiddleware(t, `IsNetworkError() && Attempts() <= 2`) 22 | 23 | proxy := httptest.NewServer(rt) 24 | t.Cleanup(proxy.Close) 25 | 26 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(srv.URL))) 27 | 28 | re, body, err := testutils.Get(proxy.URL) 29 | require.NoError(t, err) 30 | assert.Equal(t, http.StatusOK, re.StatusCode) 31 | assert.Equal(t, "hello", string(body)) 32 | } 33 | 34 | func TestBuffer_retryOnError(t *testing.T) { 35 | srv := testutils.NewHandler(func(w http.ResponseWriter, _ *http.Request) { 36 | _, _ = w.Write([]byte("hello")) 37 | }) 38 | t.Cleanup(srv.Close) 39 | 40 | lb, rt := newBufferMiddleware(t, `IsNetworkError() && Attempts() <= 2`) 41 | 42 | proxy := httptest.NewServer(rt) 43 | t.Cleanup(proxy.Close) 44 | 45 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI("http://localhost:64321"))) 46 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(srv.URL))) 47 | 48 | re, body, err := testutils.Get(proxy.URL, testutils.Body("some request parameters")) 49 | require.NoError(t, err) 50 | assert.Equal(t, http.StatusOK, re.StatusCode) 51 | assert.Equal(t, "hello", string(body)) 52 | } 53 | 54 | func TestBuffer_retryExceedAttempts(t *testing.T) { 55 | srv := testutils.NewHandler(func(w http.ResponseWriter, _ *http.Request) { 56 | _, _ = w.Write([]byte("hello")) 57 | }) 58 | t.Cleanup(srv.Close) 59 | 60 | lb, rt := newBufferMiddleware(t, `IsNetworkError() && Attempts() <= 2`) 61 | 62 | proxy := httptest.NewServer(rt) 63 | t.Cleanup(proxy.Close) 64 | 65 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI("http://localhost:64321"))) 66 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI("http://localhost:64322"))) 67 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI("http://localhost:64323"))) 68 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(srv.URL))) 69 | 70 | re, _, err := testutils.Get(proxy.URL) 71 | require.NoError(t, err) 72 | assert.Equal(t, http.StatusBadGateway, re.StatusCode) 73 | } 74 | 75 | func newBufferMiddleware(t *testing.T, p string) (*roundrobin.RoundRobin, *Buffer) { 76 | t.Helper() 77 | 78 | // forwarder will proxy the request to whatever destination 79 | fwd := forward.New(false) 80 | 81 | // load balancer will round robin request 82 | lb, err := roundrobin.New(fwd) 83 | require.NoError(t, err) 84 | 85 | // stream handler will forward requests to redirect, make sure it uses files 86 | st, err := New(lb, Retry(p), MemRequestBodyBytes(1)) 87 | require.NoError(t, err) 88 | 89 | return lb, st 90 | } 91 | -------------------------------------------------------------------------------- /buffer/threshold.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/vulcand/predicate" 8 | ) 9 | 10 | type hpredicate func(*context) bool 11 | 12 | // IsValidExpression check if it's a valid expression. 13 | func IsValidExpression(expr string) bool { 14 | _, err := parseExpression(expr) 15 | return err == nil 16 | } 17 | 18 | // Parses expression in the go language into Failover predicates. 19 | func parseExpression(in string) (hpredicate, error) { 20 | p, err := predicate.NewParser(predicate.Def{ 21 | Operators: predicate.Operators{ 22 | AND: and, 23 | OR: or, 24 | EQ: eq, 25 | NEQ: neq, 26 | LT: lt, 27 | GT: gt, 28 | LE: le, 29 | GE: ge, 30 | }, 31 | Functions: map[string]interface{}{ 32 | "RequestMethod": requestMethod, 33 | "IsNetworkError": isNetworkError, 34 | "Attempts": attempts, 35 | "ResponseCode": responseCode, 36 | }, 37 | }) 38 | if err != nil { 39 | return nil, err 40 | } 41 | out, err := p.Parse(in) 42 | if err != nil { 43 | return nil, err 44 | } 45 | pr, ok := out.(hpredicate) 46 | if !ok { 47 | return nil, fmt.Errorf("expected predicate, got %T", out) 48 | } 49 | return pr, nil 50 | } 51 | 52 | // IsNetworkError returns a predicate that returns true if last attempt ended with network error. 53 | func isNetworkError() hpredicate { 54 | return func(c *context) bool { 55 | return c.responseCode == http.StatusBadGateway || c.responseCode == http.StatusGatewayTimeout 56 | } 57 | } 58 | 59 | // and returns predicate by joining the passed predicates with logical 'and'. 60 | func and(fns ...hpredicate) hpredicate { 61 | return func(c *context) bool { 62 | for _, fn := range fns { 63 | if !fn(c) { 64 | return false 65 | } 66 | } 67 | return true 68 | } 69 | } 70 | 71 | // or returns predicate by joining the passed predicates with logical 'or'. 72 | func or(fns ...hpredicate) hpredicate { 73 | return func(c *context) bool { 74 | for _, fn := range fns { 75 | if fn(c) { 76 | return true 77 | } 78 | } 79 | return false 80 | } 81 | } 82 | 83 | // not creates negation of the passed predicate. 84 | func not(p hpredicate) hpredicate { 85 | return func(c *context) bool { 86 | return !p(c) 87 | } 88 | } 89 | 90 | // eq returns predicate that tests for equality of the value of the mapper and the constant. 91 | func eq(m interface{}, value interface{}) (hpredicate, error) { 92 | switch mapper := m.(type) { 93 | case toString: 94 | return stringEQ(mapper, value) 95 | case toInt: 96 | return intEQ(mapper, value) 97 | } 98 | return nil, fmt.Errorf("unsupported argument: %T", m) 99 | } 100 | 101 | // neq returns predicate that tests for inequality of the value of the mapper and the constant. 102 | func neq(m interface{}, value interface{}) (hpredicate, error) { 103 | p, err := eq(m, value) 104 | if err != nil { 105 | return nil, err 106 | } 107 | return not(p), nil 108 | } 109 | 110 | // lt returns predicate that tests that value of the mapper function is less than the constant. 111 | func lt(m interface{}, value interface{}) (hpredicate, error) { 112 | switch mapper := m.(type) { 113 | case toInt: 114 | return intLT(mapper, value) 115 | default: 116 | return nil, fmt.Errorf("unsupported argument: %T", m) 117 | } 118 | } 119 | 120 | // le returns predicate that tests that value of the mapper function is less or equal than the constant. 121 | func le(m interface{}, value interface{}) (hpredicate, error) { 122 | l, err := lt(m, value) 123 | if err != nil { 124 | return nil, err 125 | } 126 | e, err := eq(m, value) 127 | if err != nil { 128 | return nil, err 129 | } 130 | return func(c *context) bool { 131 | return l(c) || e(c) 132 | }, nil 133 | } 134 | 135 | // gt returns predicate that tests that value of the mapper function is greater than the constant. 136 | func gt(m interface{}, value interface{}) (hpredicate, error) { 137 | switch mapper := m.(type) { 138 | case toInt: 139 | return intGT(mapper, value) 140 | default: 141 | return nil, fmt.Errorf("unsupported argument: %T", m) 142 | } 143 | } 144 | 145 | // ge returns predicate that tests that value of the mapper function is less or equal than the constant. 146 | func ge(m interface{}, value interface{}) (hpredicate, error) { 147 | g, err := gt(m, value) 148 | if err != nil { 149 | return nil, err 150 | } 151 | e, err := eq(m, value) 152 | if err != nil { 153 | return nil, err 154 | } 155 | return func(c *context) bool { 156 | return g(c) || e(c) 157 | }, nil 158 | } 159 | 160 | func stringEQ(m toString, val interface{}) (hpredicate, error) { 161 | value, ok := val.(string) 162 | if !ok { 163 | return nil, fmt.Errorf("expected string, got %T", val) 164 | } 165 | return func(c *context) bool { 166 | return m(c) == value 167 | }, nil 168 | } 169 | 170 | func intEQ(m toInt, val interface{}) (hpredicate, error) { 171 | value, ok := val.(int) 172 | if !ok { 173 | return nil, fmt.Errorf("expected int, got %T", val) 174 | } 175 | return func(c *context) bool { 176 | return m(c) == value 177 | }, nil 178 | } 179 | 180 | func intLT(m toInt, val interface{}) (hpredicate, error) { 181 | value, ok := val.(int) 182 | if !ok { 183 | return nil, fmt.Errorf("expected int, got %T", val) 184 | } 185 | return func(c *context) bool { 186 | return m(c) < value 187 | }, nil 188 | } 189 | 190 | func intGT(m toInt, val interface{}) (hpredicate, error) { 191 | value, ok := val.(int) 192 | if !ok { 193 | return nil, fmt.Errorf("expected int, got %T", val) 194 | } 195 | return func(c *context) bool { 196 | return m(c) > value 197 | }, nil 198 | } 199 | 200 | type context struct { 201 | r *http.Request 202 | attempt int 203 | responseCode int 204 | } 205 | 206 | type toString func(c *context) string 207 | 208 | type toInt func(c *context) int 209 | 210 | // RequestMethod returns mapper of the request to its method e.g. POST. 211 | func requestMethod() toString { 212 | return func(c *context) string { 213 | return c.r.Method 214 | } 215 | } 216 | 217 | // Attempts returns mapper of the request to the number of proxy attempts. 218 | func attempts() toInt { 219 | return func(c *context) int { 220 | return c.attempt 221 | } 222 | } 223 | 224 | // ResponseCode returns mapper of the request to the last response code, returns 0 if there was no response code. 225 | func responseCode() toInt { 226 | return func(c *context) int { 227 | return c.responseCode 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /cbreaker/effect.go: -------------------------------------------------------------------------------- 1 | package cbreaker 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "io" 7 | "net/http" 8 | "net/url" 9 | "strings" 10 | 11 | "github.com/vulcand/oxy/v2/utils" 12 | ) 13 | 14 | // SideEffect a side effect. 15 | type SideEffect interface { 16 | Exec() error 17 | } 18 | 19 | // Webhook Web hook. 20 | type Webhook struct { 21 | URL string 22 | Method string 23 | Headers http.Header 24 | Form url.Values 25 | Body []byte 26 | } 27 | 28 | // WebhookSideEffect a web hook side effect. 29 | type WebhookSideEffect struct { 30 | w Webhook 31 | 32 | log utils.Logger 33 | } 34 | 35 | // NewWebhookSideEffectsWithLogger creates a new WebhookSideEffect. 36 | func NewWebhookSideEffectsWithLogger(w Webhook, l utils.Logger) (*WebhookSideEffect, error) { 37 | if w.Method == "" { 38 | return nil, errors.New("supply method") 39 | } 40 | _, err := url.Parse(w.URL) 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | return &WebhookSideEffect{w: w, log: l}, nil 46 | } 47 | 48 | // NewWebhookSideEffect creates a new WebhookSideEffect. 49 | func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) { 50 | return NewWebhookSideEffectsWithLogger(w, &utils.NoopLogger{}) 51 | } 52 | 53 | func (w *WebhookSideEffect) getBody() io.Reader { 54 | if len(w.w.Form) != 0 { 55 | return strings.NewReader(w.w.Form.Encode()) 56 | } 57 | if len(w.w.Body) != 0 { 58 | return bytes.NewBuffer(w.w.Body) 59 | } 60 | return nil 61 | } 62 | 63 | // Exec execute the side effect. 64 | func (w *WebhookSideEffect) Exec() error { 65 | r, err := http.NewRequest(w.w.Method, w.w.URL, w.getBody()) 66 | if err != nil { 67 | return err 68 | } 69 | if len(w.w.Headers) != 0 { 70 | utils.CopyHeaders(r.Header, w.w.Headers) 71 | } 72 | if len(w.w.Form) != 0 { 73 | r.Header.Set("Content-Type", "application/x-www-form-urlencoded") 74 | } 75 | re, err := http.DefaultClient.Do(r) 76 | if err != nil { 77 | return err 78 | } 79 | if re.Body != nil { 80 | defer func() { _ = re.Body.Close() }() 81 | } 82 | body, err := io.ReadAll(re.Body) 83 | if err != nil { 84 | return err 85 | } 86 | w.log.Debug("%v got response: (%s): %s", w, re.Status, string(body)) 87 | return nil 88 | } 89 | -------------------------------------------------------------------------------- /cbreaker/fallback.go: -------------------------------------------------------------------------------- 1 | package cbreaker 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "strconv" 7 | 8 | "github.com/vulcand/oxy/v2/utils" 9 | ) 10 | 11 | // Response response model. 12 | type Response struct { 13 | StatusCode int 14 | ContentType string 15 | Body []byte 16 | } 17 | 18 | // ResponseFallback fallback response handler. 19 | type ResponseFallback struct { 20 | r Response 21 | 22 | debug bool 23 | log utils.Logger 24 | } 25 | 26 | // NewResponseFallback creates a new ResponseFallback. 27 | func NewResponseFallback(r Response, options ...ResponseFallbackOption) (*ResponseFallback, error) { 28 | rf := &ResponseFallback{r: r, log: &utils.NoopLogger{}} 29 | 30 | for _, s := range options { 31 | if err := s(rf); err != nil { 32 | return nil, err 33 | } 34 | } 35 | 36 | return rf, nil 37 | } 38 | 39 | func (f *ResponseFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { 40 | if f.debug { 41 | dump := utils.DumpHTTPRequest(req) 42 | f.log.Debug("vulcand/oxy/fallback/response: begin ServeHttp on request: %s", dump) 43 | defer f.log.Debug("vulcand/oxy/fallback/response: completed ServeHttp on request: %s", dump) 44 | } 45 | 46 | if f.r.ContentType != "" { 47 | w.Header().Set("Content-Type", f.r.ContentType) 48 | } 49 | w.Header().Set("Content-Length", strconv.Itoa(len(f.r.Body))) 50 | w.WriteHeader(f.r.StatusCode) 51 | _, err := w.Write(f.r.Body) 52 | if err != nil { 53 | f.log.Error("vulcand/oxy/fallback/response: failed to write response, err: %v", err) 54 | } 55 | } 56 | 57 | // Redirect redirect model. 58 | type Redirect struct { 59 | URL string 60 | PreservePath bool 61 | } 62 | 63 | // RedirectFallback fallback redirect handler. 64 | type RedirectFallback struct { 65 | r Redirect 66 | 67 | u *url.URL 68 | 69 | debug bool 70 | log utils.Logger 71 | } 72 | 73 | // NewRedirectFallback creates a new RedirectFallback. 74 | func NewRedirectFallback(r Redirect, options ...RedirectFallbackOption) (*RedirectFallback, error) { 75 | u, err := url.ParseRequestURI(r.URL) 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | rf := &RedirectFallback{r: r, u: u, log: &utils.NoopLogger{}} 81 | 82 | for _, s := range options { 83 | if err := s(rf); err != nil { 84 | return nil, err 85 | } 86 | } 87 | 88 | return rf, nil 89 | } 90 | 91 | func (f *RedirectFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { 92 | if f.debug { 93 | dump := utils.DumpHTTPRequest(req) 94 | f.log.Debug("vulcand/oxy/fallback/redirect: begin ServeHttp on request: %s", dump) 95 | defer f.log.Debug("vulcand/oxy/fallback/redirect: completed ServeHttp on request: %s", dump) 96 | } 97 | 98 | location := f.u.String() 99 | if f.r.PreservePath { 100 | location += req.URL.Path 101 | } 102 | 103 | w.Header().Set("Location", location) 104 | w.WriteHeader(http.StatusFound) 105 | _, err := w.Write([]byte(http.StatusText(http.StatusFound))) 106 | if err != nil { 107 | f.log.Error("vulcand/oxy/fallback/redirect: failed to write response, err: %v", err) 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /cbreaker/options.go: -------------------------------------------------------------------------------- 1 | package cbreaker 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | 7 | "github.com/vulcand/oxy/v2/utils" 8 | ) 9 | 10 | // Option represents an option you can pass to New. 11 | type Option func(*CircuitBreaker) error 12 | 13 | // Logger defines the logger used by CircuitBreaker. 14 | func Logger(l utils.Logger) Option { 15 | return func(c *CircuitBreaker) error { 16 | c.log = l 17 | return nil 18 | } 19 | } 20 | 21 | // Verbose additional debug information. 22 | func Verbose(verbose bool) Option { 23 | return func(c *CircuitBreaker) error { 24 | c.verbose = verbose 25 | return nil 26 | } 27 | } 28 | 29 | // FallbackDuration is how long the CircuitBreaker will remain in the Tripped 30 | // state before trying to recover. 31 | func FallbackDuration(d time.Duration) Option { 32 | return func(c *CircuitBreaker) error { 33 | c.fallbackDuration = d 34 | return nil 35 | } 36 | } 37 | 38 | // RecoveryDuration is how long the CircuitBreaker will take to ramp up 39 | // requests during the Recovering state. 40 | func RecoveryDuration(d time.Duration) Option { 41 | return func(c *CircuitBreaker) error { 42 | c.recoveryDuration = d 43 | return nil 44 | } 45 | } 46 | 47 | // CheckPeriod is how long the CircuitBreaker will wait between successive 48 | // checks of the breaker condition. 49 | func CheckPeriod(d time.Duration) Option { 50 | return func(c *CircuitBreaker) error { 51 | c.checkPeriod = d 52 | return nil 53 | } 54 | } 55 | 56 | // OnTripped sets a SideEffect to run when entering the Tripped state. 57 | // Only one SideEffect can be set for this hook. 58 | func OnTripped(s SideEffect) Option { 59 | return func(c *CircuitBreaker) error { 60 | c.onTripped = s 61 | return nil 62 | } 63 | } 64 | 65 | // OnStandby sets a SideEffect to run when entering the Standby state. 66 | // Only one SideEffect can be set for this hook. 67 | func OnStandby(s SideEffect) Option { 68 | return func(c *CircuitBreaker) error { 69 | c.onStandby = s 70 | return nil 71 | } 72 | } 73 | 74 | // Fallback defines the http.Handler that the CircuitBreaker should route 75 | // requests to when it prevents a request from taking its normal path. 76 | func Fallback(h http.Handler) Option { 77 | return func(c *CircuitBreaker) error { 78 | c.fallback = h 79 | return nil 80 | } 81 | } 82 | 83 | // ResponseFallbackOption represents an option you can pass to NewResponseFallback. 84 | type ResponseFallbackOption func(*ResponseFallback) error 85 | 86 | // ResponseFallbackLogger defines the logger used by ResponseFallback. 87 | func ResponseFallbackLogger(l utils.Logger) ResponseFallbackOption { 88 | return func(c *ResponseFallback) error { 89 | c.log = l 90 | return nil 91 | } 92 | } 93 | 94 | // ResponseFallbackDebug additional debug information. 95 | func ResponseFallbackDebug(debug bool) ResponseFallbackOption { 96 | return func(c *ResponseFallback) error { 97 | c.debug = debug 98 | return nil 99 | } 100 | } 101 | 102 | // RedirectFallbackOption represents an option you can pass to NewRedirectFallback. 103 | type RedirectFallbackOption func(*RedirectFallback) error 104 | 105 | // RedirectFallbackLogger defines the logger used by ResponseFallback. 106 | func RedirectFallbackLogger(l utils.Logger) RedirectFallbackOption { 107 | return func(c *RedirectFallback) error { 108 | c.log = l 109 | return nil 110 | } 111 | } 112 | 113 | // RedirectFallbackDebug additional debug information. 114 | func RedirectFallbackDebug(debug bool) RedirectFallbackOption { 115 | return func(c *RedirectFallback) error { 116 | c.debug = debug 117 | return nil 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /cbreaker/predicates_test.go: -------------------------------------------------------------------------------- 1 | package cbreaker 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/require" 8 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 9 | "github.com/vulcand/oxy/v2/memmetrics" 10 | ) 11 | 12 | func Test_parseExpression_tripped(t *testing.T) { 13 | testCases := []struct { 14 | expression string 15 | metrics *memmetrics.RTMetrics 16 | expected bool 17 | }{ 18 | { 19 | expression: "NetworkErrorRatio() > 0.5", 20 | metrics: statsNetErrors(0.6), 21 | expected: true, 22 | }, 23 | { 24 | expression: "NetworkErrorRatio() < 0.5", 25 | metrics: statsNetErrors(0.6), 26 | expected: false, 27 | }, 28 | { 29 | expression: "LatencyAtQuantileMS(50.0) > 50", 30 | metrics: statsLatencyAtQuantile(50, clock.Millisecond*51), 31 | expected: true, 32 | }, 33 | { 34 | expression: "LatencyAtQuantileMS(50.0) < 50", 35 | metrics: statsLatencyAtQuantile(50, clock.Millisecond*51), 36 | expected: false, 37 | }, 38 | { 39 | expression: "ResponseCodeRatio(500, 600, 0, 600) > 0.5", 40 | metrics: statsResponseCodes(statusCode{Code: 200, Count: 5}, statusCode{Code: 500, Count: 6}), 41 | expected: true, 42 | }, 43 | { 44 | expression: "ResponseCodeRatio(500, 600, 0, 600) > 0.5", 45 | metrics: statsResponseCodes(statusCode{Code: 200, Count: 5}, statusCode{Code: 500, Count: 4}), 46 | expected: false, 47 | }, 48 | { 49 | // quantile not defined 50 | expression: "LatencyAtQuantileMS(40.0) > 50", 51 | metrics: statsNetErrors(0.6), 52 | expected: false, 53 | }, 54 | } 55 | 56 | for _, test := range testCases { 57 | t.Run(test.expression, func(t *testing.T) { 58 | t.Parallel() 59 | 60 | p, err := parseExpression(test.expression) 61 | require.NoError(t, err) 62 | require.NotNil(t, p) 63 | 64 | assert.Equal(t, test.expected, p(&CircuitBreaker{metrics: test.metrics})) 65 | }) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /cbreaker/ratio.go: -------------------------------------------------------------------------------- 1 | package cbreaker 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 8 | "github.com/vulcand/oxy/v2/utils" 9 | ) 10 | 11 | // ratioController allows passing portions traffic back to the endpoints, 12 | // increasing the amount of passed requests using linear function: 13 | // 14 | // allowedRequestsRatio = 0.5 * (Now() - Start())/Duration 15 | type ratioController struct { 16 | duration time.Duration 17 | start clock.Time 18 | allowed int 19 | denied int 20 | 21 | log utils.Logger 22 | } 23 | 24 | func newRatioController(rampUp time.Duration, log utils.Logger) *ratioController { 25 | return &ratioController{ 26 | duration: rampUp, 27 | start: clock.Now().UTC(), 28 | log: log, 29 | } 30 | } 31 | 32 | func (r *ratioController) String() string { 33 | return fmt.Sprintf("RatioController(target=%f, current=%f, allowed=%d, denied=%d)", r.targetRatio(), r.computeRatio(r.allowed, r.denied), r.allowed, r.denied) 34 | } 35 | 36 | func (r *ratioController) allowRequest() bool { 37 | r.log.Debug("%v", r) 38 | t := r.targetRatio() 39 | // This condition answers the question - would we satisfy the target ratio if we allow this request? 40 | e := r.computeRatio(r.allowed+1, r.denied) 41 | if e < t { 42 | r.allowed++ 43 | r.log.Debug("%v allowed", r) 44 | return true 45 | } 46 | r.denied++ 47 | r.log.Debug("%v denied", r) 48 | return false 49 | } 50 | 51 | func (r *ratioController) computeRatio(allowed, denied int) float64 { 52 | if denied+allowed == 0 { 53 | return 0 54 | } 55 | return float64(allowed) / float64(denied+allowed) 56 | } 57 | 58 | func (r *ratioController) targetRatio() float64 { 59 | // Here's why it's 0.5: 60 | // We are watching the following ratio: 61 | // ratio = a / (a + d) 62 | // We can notice, that once we get to 0.5: 63 | // 0.5 = a / (a + d) 64 | // we can evaluate that a = d 65 | // that means equilibrium, where we would allow all the requests 66 | // after this point to achieve ratio of 1 (that can never be reached unless d is 0) 67 | // so we stop from there 68 | multiplier := 0.5 / float64(r.duration) 69 | return multiplier * float64(clock.Now().UTC().Sub(r.start)) 70 | } 71 | -------------------------------------------------------------------------------- /cbreaker/ratio_test.go: -------------------------------------------------------------------------------- 1 | package cbreaker 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 9 | "github.com/vulcand/oxy/v2/testutils" 10 | "github.com/vulcand/oxy/v2/utils" 11 | ) 12 | 13 | func Test_ratioController_rampUp(t *testing.T) { 14 | testutils.FreezeTime(t) 15 | 16 | duration := 10 * clock.Second 17 | rc := newRatioController(duration, &utils.NoopLogger{}) 18 | 19 | allowed, denied := 0, 0 20 | for range duration / clock.Millisecond { 21 | ratio := sendRequest(&allowed, &denied, rc) 22 | expected := rc.targetRatio() 23 | diff := math.Abs(expected - ratio) 24 | t.Log("Ratio", ratio) 25 | t.Log("Expected", expected) 26 | t.Log("Diff", diff) 27 | assert.EqualValues(t, 0, round(diff, 0.5, 1)) //nolint:testifylint // the rounding is already handled. 28 | clock.Advance(clock.Millisecond) 29 | } 30 | } 31 | 32 | func sendRequest(allowed, denied *int, rc *ratioController) float64 { 33 | if rc.allowRequest() { 34 | *allowed++ 35 | } else { 36 | *denied++ 37 | } 38 | if *allowed+*denied == 0 { 39 | return 0 40 | } 41 | return float64(*allowed) / float64(*allowed+*denied) 42 | } 43 | 44 | func round(val float64, roundOn float64, places int) float64 { 45 | pow := math.Pow(10, float64(places)) 46 | digit := pow * val 47 | _, div := math.Modf(digit) 48 | var round float64 49 | if div >= roundOn { 50 | round = math.Ceil(digit) 51 | } else { 52 | round = math.Floor(digit) 53 | } 54 | return round / pow 55 | } 56 | -------------------------------------------------------------------------------- /connlimit/connlimit.go: -------------------------------------------------------------------------------- 1 | // Package connlimit provides control over simultaneous connections coming from the same source 2 | package connlimit 3 | 4 | import ( 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "sync" 9 | 10 | "github.com/vulcand/oxy/v2/utils" 11 | ) 12 | 13 | // ConnLimiter tracks concurrent connection per token 14 | // and is capable of rejecting connections if they are failed. 15 | type ConnLimiter struct { 16 | mutex *sync.Mutex 17 | extract utils.SourceExtractor 18 | connections map[string]int64 19 | maxConnections int64 20 | totalConnections int64 21 | next http.Handler 22 | 23 | errHandler utils.ErrorHandler 24 | 25 | verbose bool 26 | log utils.Logger 27 | } 28 | 29 | // New creates a new ConnLimiter. 30 | func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...Option) (*ConnLimiter, error) { 31 | if extract == nil { 32 | return nil, errors.New("extract function can not be nil") 33 | } 34 | 35 | cl := &ConnLimiter{ 36 | mutex: &sync.Mutex{}, 37 | extract: extract, 38 | maxConnections: maxConnections, 39 | connections: make(map[string]int64), 40 | next: next, 41 | log: &utils.NoopLogger{}, 42 | } 43 | 44 | for _, o := range options { 45 | if err := o(cl); err != nil { 46 | return nil, err 47 | } 48 | } 49 | 50 | if cl.errHandler == nil { 51 | cl.errHandler = &ConnErrHandler{ 52 | debug: cl.verbose, 53 | log: cl.log, 54 | } 55 | } 56 | 57 | return cl, nil 58 | } 59 | 60 | // Wrap sets the next handler to be called by connection limiter handler. 61 | func (cl *ConnLimiter) Wrap(h http.Handler) { 62 | cl.next = h 63 | } 64 | 65 | func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { 66 | token, amount, err := cl.extract.Extract(r) 67 | if err != nil { 68 | cl.log.Error("failed to extract source of the connection: %v", err) 69 | cl.errHandler.ServeHTTP(w, r, err) 70 | return 71 | } 72 | if err := cl.acquire(token, amount); err != nil { 73 | cl.log.Debug("limiting request source %s: %v", token, err) 74 | cl.errHandler.ServeHTTP(w, r, err) 75 | return 76 | } 77 | 78 | defer cl.release(token, amount) 79 | 80 | cl.next.ServeHTTP(w, r) 81 | } 82 | 83 | func (cl *ConnLimiter) acquire(token string, amount int64) error { 84 | cl.mutex.Lock() 85 | defer cl.mutex.Unlock() 86 | 87 | connections := cl.connections[token] 88 | if connections >= cl.maxConnections { 89 | return &MaxConnError{max: cl.maxConnections} 90 | } 91 | 92 | cl.connections[token] += amount 93 | cl.totalConnections += amount 94 | return nil 95 | } 96 | 97 | func (cl *ConnLimiter) release(token string, amount int64) { 98 | cl.mutex.Lock() 99 | defer cl.mutex.Unlock() 100 | 101 | cl.connections[token] -= amount 102 | cl.totalConnections -= amount 103 | 104 | // Otherwise it would grow forever 105 | if cl.connections[token] == 0 { 106 | delete(cl.connections, token) 107 | } 108 | } 109 | 110 | // MaxConnError maximum connections reached error. 111 | type MaxConnError struct { 112 | max int64 113 | } 114 | 115 | func (m *MaxConnError) Error() string { 116 | return fmt.Sprintf("max connections reached: %d", m.max) 117 | } 118 | 119 | // ConnErrHandler connection limiter error handler. 120 | type ConnErrHandler struct { 121 | debug bool 122 | log utils.Logger 123 | } 124 | 125 | func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { 126 | if e.debug { 127 | dump := utils.DumpHTTPRequest(req) 128 | e.log.Debug("vulcand/oxy/connlimit: begin ServeHttp on request: %s", dump) 129 | defer e.log.Debug("vulcand/oxy/connlimit: completed ServeHttp on request: %s", dump) 130 | } 131 | 132 | //nolint:errorlint // must be changed 133 | if _, ok := err.(*MaxConnError); ok { 134 | w.WriteHeader(http.StatusTooManyRequests) 135 | _, _ = w.Write([]byte(err.Error())) 136 | return 137 | } 138 | utils.DefaultHandler.ServeHTTP(w, req, err) 139 | } 140 | -------------------------------------------------------------------------------- /connlimit/connlimit_test.go: -------------------------------------------------------------------------------- 1 | package connlimit 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | "github.com/vulcand/oxy/v2/testutils" 12 | "github.com/vulcand/oxy/v2/utils" 13 | ) 14 | 15 | // We've hit the limit and were able to proceed once the request has completed. 16 | func TestConnLimiter_hitLimitAndRelease(t *testing.T) { 17 | wait := make(chan bool) 18 | proceed := make(chan bool) 19 | finish := make(chan bool) 20 | 21 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 22 | t.Logf("%v", req.Header) 23 | if req.Header.Get("Wait") != "" { 24 | proceed <- true 25 | <-wait 26 | } 27 | _, _ = w.Write([]byte("hello")) 28 | }) 29 | 30 | cl, err := New(handler, headerLimit, 1) 31 | require.NoError(t, err) 32 | 33 | srv := httptest.NewServer(cl) 34 | t.Cleanup(srv.Close) 35 | 36 | go func() { 37 | re, _, errGet := testutils.Get(srv.URL, testutils.Header("Limit", "a"), testutils.Header("wait", "yes")) 38 | require.NoError(t, errGet) 39 | assert.Equal(t, http.StatusOK, re.StatusCode) 40 | finish <- true 41 | }() 42 | 43 | <-proceed 44 | 45 | re, _, err := testutils.Get(srv.URL, testutils.Header("Limit", "a")) 46 | require.NoError(t, err) 47 | assert.Equal(t, http.StatusTooManyRequests, re.StatusCode) 48 | 49 | // request from another source succeeds 50 | re, _, err = testutils.Get(srv.URL, testutils.Header("Limit", "b")) 51 | require.NoError(t, err) 52 | assert.Equal(t, http.StatusOK, re.StatusCode) 53 | 54 | // Once the first request finished, next one succeeds 55 | close(wait) 56 | <-finish 57 | 58 | re, _, err = testutils.Get(srv.URL, testutils.Header("Limit", "a")) 59 | require.NoError(t, err) 60 | assert.Equal(t, http.StatusOK, re.StatusCode) 61 | } 62 | 63 | // We've hit the limit and were able to proceed once the request has completed. 64 | func TestConnLimiter_customHandlers(t *testing.T) { 65 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 66 | _, _ = w.Write([]byte("hello")) 67 | }) 68 | 69 | errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, _ *http.Request, _ error) { 70 | w.WriteHeader(http.StatusTeapot) 71 | _, _ = w.Write([]byte(http.StatusText(http.StatusTeapot))) 72 | }) 73 | 74 | l, err := New(handler, headerLimit, 0, ErrorHandler(errHandler)) 75 | require.NoError(t, err) 76 | 77 | srv := httptest.NewServer(l) 78 | t.Cleanup(srv.Close) 79 | 80 | re, _, err := testutils.Get(srv.URL, testutils.Header("Limit", "a")) 81 | require.NoError(t, err) 82 | assert.Equal(t, http.StatusTeapot, re.StatusCode) 83 | } 84 | 85 | // We've hit the limit and were able to proceed once the request has completed. 86 | func TestConnLimiter_faultyExtract(t *testing.T) { 87 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 88 | _, _ = w.Write([]byte("hello")) 89 | }) 90 | 91 | l, err := New(handler, faultyExtract, 1) 92 | require.NoError(t, err) 93 | 94 | srv := httptest.NewServer(l) 95 | t.Cleanup(srv.Close) 96 | 97 | re, _, err := testutils.Get(srv.URL) 98 | require.NoError(t, err) 99 | assert.Equal(t, http.StatusInternalServerError, re.StatusCode) 100 | } 101 | 102 | func headerLimiter(req *http.Request) (string, int64, error) { 103 | return req.Header.Get("Limit"), 1, nil 104 | } 105 | 106 | func faultyExtractor(_ *http.Request) (string, int64, error) { 107 | return "", -1, errors.New("oops") 108 | } 109 | 110 | var headerLimit = utils.ExtractorFunc(headerLimiter) 111 | 112 | var faultyExtract = utils.ExtractorFunc(faultyExtractor) 113 | -------------------------------------------------------------------------------- /connlimit/options.go: -------------------------------------------------------------------------------- 1 | package connlimit 2 | 3 | import ( 4 | "github.com/vulcand/oxy/v2/utils" 5 | ) 6 | 7 | // Option represents an option you can pass to New. 8 | type Option func(l *ConnLimiter) error 9 | 10 | // Logger defines the logger used by ConnLimiter. 11 | func Logger(l utils.Logger) Option { 12 | return func(cl *ConnLimiter) error { 13 | cl.log = l 14 | return nil 15 | } 16 | } 17 | 18 | // Verbose additional debug information. 19 | func Verbose(verbose bool) Option { 20 | return func(cl *ConnLimiter) error { 21 | cl.verbose = verbose 22 | return nil 23 | } 24 | } 25 | 26 | // ErrorHandler sets error handler of the server. 27 | func ErrorHandler(h utils.ErrorHandler) Option { 28 | return func(cl *ConnLimiter) error { 29 | cl.errHandler = h 30 | return nil 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /forward/example_test.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "net/url" 10 | ) 11 | 12 | func ExampleNew_customErrHandler() { 13 | f := New(true) 14 | f.ErrorHandler = func(w http.ResponseWriter, _ *http.Request, _ error) { 15 | w.WriteHeader(http.StatusTeapot) 16 | _, _ = w.Write([]byte(http.StatusText(http.StatusTeapot))) 17 | } 18 | 19 | proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 20 | req.URL, _ = url.ParseRequestURI("http://localhost:63450") 21 | f.ServeHTTP(w, req) 22 | })) 23 | defer proxy.Close() 24 | 25 | resp, err := http.Get(proxy.URL) 26 | if err != nil { 27 | fmt.Println(err) 28 | return 29 | } 30 | 31 | body, err := io.ReadAll(resp.Body) 32 | if err != nil { 33 | fmt.Println(err) 34 | return 35 | } 36 | 37 | fmt.Println(resp.StatusCode) 38 | fmt.Println(string(body)) 39 | 40 | // output: 41 | // 418 42 | // I'm a teapot 43 | } 44 | 45 | func ExampleNew_responseModifier() { 46 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 47 | _, _ = w.Write([]byte("hello")) 48 | })) 49 | defer srv.Close() 50 | 51 | f := New(true) 52 | f.ModifyResponse = func(resp *http.Response) error { 53 | resp.Header.Add("X-Test", "CUSTOM") 54 | return nil 55 | } 56 | 57 | proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 58 | req.URL, _ = url.ParseRequestURI(srv.URL) 59 | f.ServeHTTP(w, req) 60 | })) 61 | defer proxy.Close() 62 | 63 | resp, err := http.Get(proxy.URL) 64 | if err != nil { 65 | fmt.Println(err) 66 | return 67 | } 68 | 69 | fmt.Println(resp.StatusCode) 70 | fmt.Println(resp.Header.Get("X-Test")) 71 | 72 | // Output: 73 | // 200 74 | // CUSTOM 75 | } 76 | 77 | func ExampleNew_customTransport() { 78 | srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 79 | _, _ = w.Write([]byte("hello")) 80 | })) 81 | defer srv.Close() 82 | 83 | f := New(true) 84 | 85 | f.Transport = &http.Transport{ 86 | TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, 87 | } 88 | 89 | proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 90 | req.URL, _ = url.ParseRequestURI(srv.URL) 91 | f.ServeHTTP(w, req) 92 | })) 93 | defer proxy.Close() 94 | 95 | resp, err := http.Get(proxy.URL) 96 | if err != nil { 97 | fmt.Println(err) 98 | return 99 | } 100 | 101 | body, err := io.ReadAll(resp.Body) 102 | if err != nil { 103 | fmt.Println(err) 104 | return 105 | } 106 | 107 | fmt.Println(resp.StatusCode) 108 | fmt.Println(string(body)) 109 | 110 | // Output: 111 | // 200 112 | // hello 113 | } 114 | 115 | func ExampleNewStateListener() { 116 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 117 | _, _ = w.Write([]byte("hello")) 118 | })) 119 | defer srv.Close() 120 | 121 | f := New(true) 122 | f.ModifyResponse = func(resp *http.Response) error { 123 | resp.Header.Add("X-Test", "CUSTOM") 124 | return nil 125 | } 126 | 127 | stateLn := NewStateListener(f, func(u *url.URL, i int) { 128 | fmt.Println(u.Hostname(), i) 129 | }) 130 | 131 | proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 132 | req.URL, _ = url.ParseRequestURI(srv.URL) 133 | stateLn.ServeHTTP(w, req) 134 | })) 135 | defer proxy.Close() 136 | 137 | resp, err := http.Get(proxy.URL) 138 | if err != nil { 139 | fmt.Println(err) 140 | return 141 | } 142 | 143 | fmt.Println(resp.StatusCode) 144 | 145 | // Output: 146 | // 127.0.0.1 0 147 | // 127.0.0.1 1 148 | // 200 149 | } 150 | -------------------------------------------------------------------------------- /forward/fwd.go: -------------------------------------------------------------------------------- 1 | // Package forward creates a pre-configured httputil.ReverseProxy. 2 | package forward 3 | 4 | import ( 5 | "net/http" 6 | "net/http/httputil" 7 | "net/url" 8 | 9 | "github.com/vulcand/oxy/v2/utils" 10 | ) 11 | 12 | // New creates a new ReverseProxy. 13 | func New(passHostHeader bool) *httputil.ReverseProxy { 14 | h := NewHeaderRewriter() 15 | 16 | return &httputil.ReverseProxy{ 17 | Director: func(request *http.Request) { 18 | modifyRequest(request) 19 | 20 | h.Rewrite(request) 21 | 22 | if !passHostHeader { 23 | request.Host = request.URL.Host 24 | } 25 | }, 26 | ErrorHandler: utils.DefaultHandler.ServeHTTP, 27 | } 28 | } 29 | 30 | // Modify the request to handle the target URL. 31 | func modifyRequest(outReq *http.Request) { 32 | u := getURLFromRequest(outReq) 33 | 34 | outReq.URL.Path = u.Path 35 | outReq.URL.RawPath = u.RawPath 36 | outReq.URL.RawQuery = u.RawQuery 37 | outReq.RequestURI = "" // Outgoing request should not have RequestURI 38 | 39 | outReq.Proto = "HTTP/1.1" 40 | outReq.ProtoMajor = 1 41 | outReq.ProtoMinor = 1 42 | } 43 | 44 | func getURLFromRequest(req *http.Request) *url.URL { 45 | // If the Request was created by Go via a real HTTP request, 46 | // RequestURI will contain the original query string. 47 | // If the Request was created in code, 48 | // RequestURI will be empty, and we will use the URL object instead 49 | u := req.URL 50 | if req.RequestURI != "" { 51 | parsedURL, err := url.ParseRequestURI(req.RequestURI) 52 | if err == nil { 53 | return parsedURL 54 | } 55 | } 56 | return u 57 | } 58 | -------------------------------------------------------------------------------- /forward/fwd_test.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "net/url" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | "github.com/vulcand/oxy/v2/testutils" 12 | ) 13 | 14 | func TestDefaultErrHandler(t *testing.T) { 15 | f := New(true) 16 | 17 | proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 18 | req.URL = testutils.MustParseRequestURI("http://localhost:63450") 19 | f.ServeHTTP(w, req) 20 | })) 21 | t.Cleanup(proxy.Close) 22 | 23 | resp, err := http.Get(proxy.URL) 24 | require.NoError(t, err) 25 | assert.Equal(t, http.StatusBadGateway, resp.StatusCode) 26 | } 27 | 28 | func TestXForwardedHostHeader(t *testing.T) { 29 | tests := []struct { 30 | Description string 31 | PassHostHeader bool 32 | TargetURL string 33 | ProxyfiedURL string 34 | ExpectedXForwardedHost string 35 | }{ 36 | { 37 | Description: "XForwardedHost without PassHostHeader", 38 | PassHostHeader: false, 39 | TargetURL: "http://xforwardedhost.com", 40 | ProxyfiedURL: "http://backend.com", 41 | ExpectedXForwardedHost: "xforwardedhost.com", 42 | }, 43 | { 44 | Description: "XForwardedHost with PassHostHeader", 45 | PassHostHeader: true, 46 | TargetURL: "http://xforwardedhost.com", 47 | ProxyfiedURL: "http://backend.com", 48 | ExpectedXForwardedHost: "xforwardedhost.com", 49 | }, 50 | } 51 | 52 | for _, test := range tests { 53 | t.Run(test.Description, func(t *testing.T) { 54 | t.Parallel() 55 | 56 | f := New(true) 57 | 58 | r, err := http.NewRequest(http.MethodGet, test.TargetURL, nil) 59 | require.NoError(t, err) 60 | 61 | backendURL, err := url.Parse(test.ProxyfiedURL) 62 | require.NoError(t, err) 63 | r.URL = backendURL 64 | 65 | f.Director(r) 66 | require.Equal(t, test.ExpectedXForwardedHost, r.Header.Get(XForwardedHost)) 67 | }) 68 | } 69 | } 70 | 71 | func TestForwardedProto(t *testing.T) { 72 | var proto string 73 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 74 | proto = req.Header.Get(XForwardedProto) 75 | _, _ = w.Write([]byte("hello")) 76 | })) 77 | t.Cleanup(srv.Close) 78 | 79 | f := New(true) 80 | 81 | proxy := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 82 | req.URL = testutils.MustParseRequestURI(srv.URL) 83 | f.ServeHTTP(w, req) 84 | })) 85 | proxy.StartTLS() 86 | t.Cleanup(proxy.Close) 87 | 88 | re, _, err := testutils.Get(proxy.URL) 89 | require.NoError(t, err) 90 | 91 | assert.Equal(t, http.StatusOK, re.StatusCode) 92 | assert.Equal(t, "https", proto) 93 | } 94 | -------------------------------------------------------------------------------- /forward/headers.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | // X-* Header names. 4 | const ( 5 | XForwardedProto = "X-Forwarded-Proto" 6 | XForwardedFor = "X-Forwarded-For" 7 | XForwardedHost = "X-Forwarded-Host" 8 | XForwardedPort = "X-Forwarded-Port" 9 | XForwardedServer = "X-Forwarded-Server" 10 | XRealIP = "X-Real-Ip" 11 | ) 12 | 13 | // Headers names. 14 | const ( 15 | Connection = "Connection" 16 | KeepAlive = "Keep-Alive" 17 | ProxyAuthenticate = "Proxy-Authenticate" 18 | ProxyAuthorization = "Proxy-Authorization" 19 | Te = "Te" // canonicalized version of "TE" 20 | Trailers = "Trailers" 21 | TransferEncoding = "Transfer-Encoding" 22 | Upgrade = "Upgrade" 23 | ContentLength = "Content-Length" 24 | ) 25 | 26 | // WebSocket Header names. 27 | const ( 28 | SecWebsocketKey = "Sec-Websocket-Key" 29 | SecWebsocketVersion = "Sec-Websocket-Version" 30 | SecWebsocketExtensions = "Sec-Websocket-Extensions" 31 | SecWebsocketAccept = "Sec-Websocket-Accept" 32 | ) 33 | 34 | // XHeaders X-* headers. 35 | var XHeaders = []string{ 36 | XForwardedProto, 37 | XForwardedFor, 38 | XForwardedHost, 39 | XForwardedPort, 40 | XForwardedServer, 41 | XRealIP, 42 | } 43 | 44 | // HopHeaders Hop-by-hop headers. These are removed when sent to the backend. 45 | // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html 46 | // Copied from reverseproxy.go, too bad. 47 | var HopHeaders = []string{ 48 | Connection, 49 | KeepAlive, 50 | ProxyAuthenticate, 51 | ProxyAuthorization, 52 | Te, // canonicalized version of "TE" 53 | Trailers, 54 | TransferEncoding, 55 | Upgrade, 56 | } 57 | 58 | // WebsocketDialHeaders Websocket dial headers. 59 | var WebsocketDialHeaders = []string{ 60 | Upgrade, 61 | Connection, 62 | SecWebsocketKey, 63 | SecWebsocketVersion, 64 | SecWebsocketExtensions, 65 | SecWebsocketAccept, 66 | } 67 | 68 | // WebsocketUpgradeHeaders Websocket upgrade headers. 69 | var WebsocketUpgradeHeaders = []string{ 70 | Upgrade, 71 | Connection, 72 | SecWebsocketAccept, 73 | SecWebsocketExtensions, 74 | } 75 | -------------------------------------------------------------------------------- /forward/middlewares.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | ) 7 | 8 | // Connection states. 9 | const ( 10 | StateConnected = iota 11 | StateDisconnected 12 | ) 13 | 14 | // URLForwardingStateListener URL forwarding state listener. 15 | type URLForwardingStateListener func(*url.URL, int) 16 | 17 | // StateListener listens on state change for urls. 18 | type StateListener struct { 19 | next http.Handler 20 | stateListener URLForwardingStateListener 21 | } 22 | 23 | // NewStateListener creates a new StateListener middleware. 24 | func NewStateListener(next http.Handler, stateListener URLForwardingStateListener) *StateListener { 25 | return &StateListener{next: next, stateListener: stateListener} 26 | } 27 | 28 | func (s *StateListener) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 29 | s.stateListener(req.URL, StateConnected) 30 | s.next.ServeHTTP(rw, req) 31 | s.stateListener(req.URL, StateDisconnected) 32 | } 33 | -------------------------------------------------------------------------------- /forward/rewrite.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | import ( 4 | "net" 5 | "net/http" 6 | "os" 7 | "strings" 8 | 9 | "github.com/vulcand/oxy/v2/utils" 10 | ) 11 | 12 | // NewHeaderRewriter creates a new HeaderRewriter middleware. 13 | func NewHeaderRewriter() *HeaderRewriter { 14 | h, err := os.Hostname() 15 | if err != nil { 16 | h = "localhost" 17 | } 18 | return &HeaderRewriter{TrustForwardHeader: true, Hostname: h} 19 | } 20 | 21 | // HeaderRewriter is responsible for removing hop-by-hop headers and setting forwarding headers. 22 | type HeaderRewriter struct { 23 | TrustForwardHeader bool 24 | Hostname string 25 | } 26 | 27 | // clean up IP in case if it is ipv6 address and it has {zone} information in it, like "[fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 Ethernet Adapter - Virtual Switch)]:64692". 28 | func ipv6fix(clientIP string) string { 29 | return strings.Split(clientIP, "%")[0] 30 | } 31 | 32 | // Rewrite request headers. 33 | func (rw *HeaderRewriter) Rewrite(req *http.Request) { 34 | if !rw.TrustForwardHeader { 35 | utils.RemoveHeaders(req.Header, XHeaders...) 36 | } 37 | 38 | if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { 39 | clientIP = ipv6fix(clientIP) 40 | 41 | if req.Header.Get(XRealIP) == "" { 42 | req.Header.Set(XRealIP, clientIP) 43 | } 44 | } 45 | 46 | xfProto := req.Header.Get(XForwardedProto) 47 | if xfProto == "" { 48 | if req.TLS != nil { 49 | req.Header.Set(XForwardedProto, "https") 50 | } else { 51 | req.Header.Set(XForwardedProto, "http") 52 | } 53 | } 54 | 55 | if xfPort := req.Header.Get(XForwardedPort); xfPort == "" { 56 | req.Header.Set(XForwardedPort, forwardedPort(req)) 57 | } 58 | 59 | if xfHost := req.Header.Get(XForwardedHost); xfHost == "" && req.Host != "" { 60 | req.Header.Set(XForwardedHost, req.Host) 61 | } 62 | 63 | if rw.Hostname != "" { 64 | req.Header.Set(XForwardedServer, rw.Hostname) 65 | } 66 | } 67 | 68 | func forwardedPort(req *http.Request) string { 69 | if req == nil { 70 | return "" 71 | } 72 | 73 | if _, port, err := net.SplitHostPort(req.Host); err == nil && port != "" { 74 | return port 75 | } 76 | 77 | if req.Header.Get(XForwardedProto) == "https" || req.Header.Get(XForwardedProto) == "wss" { 78 | return "443" 79 | } 80 | 81 | if req.TLS != nil { 82 | return "443" 83 | } 84 | 85 | return "80" 86 | } 87 | -------------------------------------------------------------------------------- /forward/rewrite_test.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func Test_ipv6fix(t *testing.T) { 10 | testCases := []struct { 11 | desc string 12 | clientIP string 13 | expected string 14 | }{ 15 | { 16 | desc: "empty", 17 | clientIP: "", 18 | expected: "", 19 | }, 20 | { 21 | desc: "ipv4 localhost", 22 | clientIP: "127.0.0.1", 23 | expected: "127.0.0.1", 24 | }, 25 | { 26 | desc: "ipv4", 27 | clientIP: "10.13.14.15", 28 | expected: "10.13.14.15", 29 | }, 30 | { 31 | desc: "ipv6 zone", 32 | clientIP: `fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 Ethernet Adapter - Virtual Switch)`, 33 | expected: "fe80::d806:a55d:eb1b:49cc", 34 | }, 35 | { 36 | desc: "ipv6 medium", 37 | clientIP: `fe80::1`, 38 | expected: "fe80::1", 39 | }, 40 | { 41 | desc: "ipv6 small", 42 | clientIP: `2000::`, 43 | expected: "2000::", 44 | }, 45 | { 46 | desc: "ipv6", 47 | clientIP: `2001:3452:4952:2837::`, 48 | expected: "2001:3452:4952:2837::", 49 | }, 50 | } 51 | 52 | for _, test := range testCases { 53 | t.Run(test.desc, func(t *testing.T) { 54 | t.Parallel() 55 | 56 | actual := ipv6fix(test.clientIP) 57 | assert.Equal(t, test.expected, actual) 58 | }) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/vulcand/oxy/v2 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/HdrHistogram/hdrhistogram-go v1.1.2 7 | github.com/gorilla/websocket v1.5.3 8 | github.com/mailgun/multibuf v0.1.2 9 | github.com/segmentio/fasthash v1.0.3 10 | github.com/stretchr/testify v1.10.0 11 | github.com/vulcand/predicate v1.2.0 12 | golang.org/x/net v0.37.0 13 | ) 14 | 15 | require ( 16 | github.com/davecgh/go-spew v1.1.1 // indirect 17 | github.com/gravitational/trace v1.1.16-0.20220114165159-14a9a7dd6aaf // indirect 18 | github.com/jonboulle/clockwork v0.4.0 // indirect 19 | github.com/pmezard/go-difflib v1.0.0 // indirect 20 | github.com/sirupsen/logrus v1.9.3 // indirect 21 | golang.org/x/crypto v0.36.0 // indirect 22 | golang.org/x/sys v0.31.0 // indirect 23 | golang.org/x/term v0.30.0 // indirect 24 | gopkg.in/yaml.v3 v3.0.1 // indirect 25 | ) 26 | -------------------------------------------------------------------------------- /internal/holsterv4/README.md: -------------------------------------------------------------------------------- 1 | # What is this? 2 | 3 | This is a vendored copy of 2 packages (`clock` and `collections`) from the 4 | github.com/mailgun/holster@v4.2.5 module. 5 | 6 | The `clock` package was completely copied over and the following modifications 7 | were made: 8 | 9 | * pkg/errors was replaced with the stdlib errors package / fmt.Errorf's %w; 10 | * import names changed in blackbox test packages; 11 | * a small race condition in the testing logic was fixed using the provided 12 | mutex. 13 | 14 | The `collections` package only contains the priority_queue and ttlmap and 15 | corresponding test files. The only changes made to those files were to adjust 16 | the package names to use the vendored packages. 17 | 18 | ## Why 19 | 20 | TL;DR: holster is a utility repo with many dependencies and even with graph 21 | pruning using it in oxy can transitively impact oxy users in negative ways by 22 | forcing version bumps (at the least). 23 | 24 | Full details can be found here: https://github.com/vulcand/oxy/pull/223 25 | -------------------------------------------------------------------------------- /internal/holsterv4/clock/README.md: -------------------------------------------------------------------------------- 1 | # Clock 2 | 3 | A drop in (almost) replacement for the system `time` package. It provides a way 4 | to make scheduled calls, timers and tickers deterministic in tests. By default 5 | it forwards all calls to the system `time` package. In test, however, it is 6 | possible to enable the frozen clock mode, and advance time manually to make 7 | scheduled even trigger at certain moments. 8 | 9 | # Usage 10 | 11 | ```go 12 | package foo 13 | 14 | import ( 15 | "testing" 16 | 17 | "github.com/vulcand/oxy/internal/holsterv4/clock" 18 | "github.com/stretchr/testify/assert" 19 | ) 20 | 21 | func TestSleep(t *testing.T) { 22 | // Freeze switches the clock package to the frozen clock mode. You need to 23 | // advance time manually from now on. Note that all scheduled events, timers 24 | // and ticker created before this call keep operating in real time. 25 | // 26 | // The initial time is set to now here, but you can set any datetime. 27 | clock.Freeze(clock.Now()) 28 | // Do not forget to revert the effect of Freeze at the end of the test. 29 | defer clock.Unfreeze() 30 | 31 | var fired bool 32 | 33 | clock.AfterFunc(100*clock.Millisecond, func() { 34 | fired = true 35 | }) 36 | clock.Advance(93*clock.Millisecond) 37 | 38 | // Advance will make all fire all events, timers, tickers that are 39 | // scheduled for the passed period of time. Note that scheduled functions 40 | // are called from within Advanced unlike system time package that calls 41 | // them in their own goroutine. 42 | assert.Equal(t, 97*clock.Millisecond, clock.Advance(6*clock.Millisecond)) 43 | assert.True(t, fired) 44 | assert.Equal(t, 100*clock.Millisecond, clock.Advance(1*clock.Millisecond)) 45 | assert.True(t, fired) 46 | } 47 | ``` 48 | -------------------------------------------------------------------------------- /internal/holsterv4/clock/clock.go: -------------------------------------------------------------------------------- 1 | //go:build !holster_test_mode 2 | 3 | // Package clock provides the same functions as the system package time. In 4 | // production it forwards all calls to the system time package, but in tests 5 | // the time can be frozen by calling Freeze function and from that point it has 6 | // to be advanced manually with Advance function making all scheduled calls 7 | // deterministic. 8 | // 9 | // The functions provided by the package have the same parameters and return 10 | // values as their system counterparts with a few exceptions. Where either 11 | // *time.Timer or *time.Ticker is returned by a system function, the clock 12 | // package counterpart returns clock.Timer or clock.Ticker interface 13 | // respectively. The interfaces provide API as respective structs except C is 14 | // not a channel, but a function that returns <-chan time.Time. 15 | package clock 16 | 17 | import "time" 18 | 19 | var ( 20 | frozenAt time.Time 21 | realtime = &systemTime{} 22 | provider Clock = realtime 23 | ) 24 | 25 | // Freeze after this function is called all time related functions start 26 | // generate deterministic timers that are triggered by Advance function. It is 27 | // supposed to be used in tests only. Returns an Unfreezer so it can be a 28 | // one-liner in tests: defer clock.Freeze(clock.Now()).Unfreeze() 29 | func Freeze(now time.Time) Unfreezer { 30 | frozenAt = now.UTC() 31 | provider = &frozenTime{now: now} 32 | return Unfreezer{} 33 | } 34 | 35 | type Unfreezer struct{} 36 | 37 | func (u Unfreezer) Unfreeze() { 38 | Unfreeze() 39 | } 40 | 41 | // Unfreeze reverses effect of Freeze. 42 | func Unfreeze() { 43 | provider = realtime 44 | } 45 | 46 | // Realtime returns a clock provider wrapping the SDK's time package. It is 47 | // supposed to be used in tests when time is frozen to schedule test timeouts. 48 | func Realtime() Clock { 49 | return realtime 50 | } 51 | 52 | // Makes the deterministic time move forward by the specified duration, firing 53 | // timers along the way in the natural order. It returns how much time has 54 | // passed since it was frozen. So you can assert on the return value in tests 55 | // to make it explicit where you stand on the deterministic time scale. 56 | func Advance(d time.Duration) time.Duration { 57 | ft, ok := provider.(*frozenTime) 58 | if !ok { 59 | panic("Freeze time first!") 60 | } 61 | ft.advance(d) 62 | return Now().UTC().Sub(frozenAt) 63 | } 64 | 65 | // Wait4Scheduled blocks until either there are n or more scheduled events, or 66 | // the timeout elapses. It returns true if the wait condition has been met 67 | // before the timeout expired, false otherwise. 68 | func Wait4Scheduled(count int, timeout time.Duration) bool { 69 | return provider.Wait4Scheduled(count, timeout) 70 | } 71 | 72 | // Now see time.Now. 73 | func Now() time.Time { 74 | return provider.Now() 75 | } 76 | 77 | // Sleep see time.Sleep. 78 | func Sleep(d time.Duration) { 79 | provider.Sleep(d) 80 | } 81 | 82 | // After see time.After. 83 | func After(d time.Duration) <-chan time.Time { 84 | return provider.After(d) 85 | } 86 | 87 | // NewTimer see time.NewTimer. 88 | func NewTimer(d time.Duration) Timer { 89 | return provider.NewTimer(d) 90 | } 91 | 92 | // AfterFunc see time.AfterFunc. 93 | func AfterFunc(d time.Duration, f func()) Timer { 94 | return provider.AfterFunc(d, f) 95 | } 96 | 97 | // NewTicker see time.Ticker. 98 | func NewTicker(d time.Duration) Ticker { 99 | return provider.NewTicker(d) 100 | } 101 | 102 | // Tick see time.Tick. 103 | func Tick(d time.Duration) <-chan time.Time { 104 | return provider.Tick(d) 105 | } 106 | -------------------------------------------------------------------------------- /internal/holsterv4/clock/clock_mutex.go: -------------------------------------------------------------------------------- 1 | //go:build holster_test_mode 2 | 3 | // Package clock provides the same functions as the system package time. In 4 | // production it forwards all calls to the system time package, but in tests 5 | // the time can be frozen by calling Freeze function and from that point it has 6 | // to be advanced manually with Advance function making all scheduled calls 7 | // deterministic. 8 | // 9 | // The functions provided by the package have the same parameters and return 10 | // values as their system counterparts with a few exceptions. Where either 11 | // *time.Timer or *time.Ticker is returned by a system function, the clock 12 | // package counterpart returns clock.Timer or clock.Ticker interface 13 | // respectively. The interfaces provide API as respective structs except C is 14 | // not a channel, but a function that returns <-chan time.Time. 15 | package clock 16 | 17 | import ( 18 | "sync" 19 | "time" 20 | ) 21 | 22 | var ( 23 | frozenAt time.Time 24 | realtime = &systemTime{} 25 | provider Clock = realtime 26 | rwMutex = sync.RWMutex{} 27 | ) 28 | 29 | // Freeze after this function is called all time related functions start 30 | // generate deterministic timers that are triggered by Advance function. It is 31 | // supposed to be used in tests only. Returns an Unfreezer so it can be a 32 | // one-liner in tests: defer clock.Freeze(clock.Now()).Unfreeze() 33 | func Freeze(now time.Time) Unfreezer { 34 | frozenAt = now.UTC() 35 | rwMutex.Lock() 36 | defer rwMutex.Unlock() 37 | provider = &frozenTime{now: now} 38 | return Unfreezer{} 39 | } 40 | 41 | type Unfreezer struct{} 42 | 43 | func (u Unfreezer) Unfreeze() { 44 | Unfreeze() 45 | } 46 | 47 | // Unfreeze reverses effect of Freeze. 48 | func Unfreeze() { 49 | rwMutex.Lock() 50 | defer rwMutex.Unlock() 51 | provider = realtime 52 | } 53 | 54 | // Realtime returns a clock provider wrapping the SDK's time package. It is 55 | // supposed to be used in tests when time is frozen to schedule test timeouts. 56 | func Realtime() Clock { 57 | return realtime 58 | } 59 | 60 | // Makes the deterministic time move forward by the specified duration, firing 61 | // timers along the way in the natural order. It returns how much time has 62 | // passed since it was frozen. So you can assert on the return value in tests 63 | // to make it explicit where you stand on the deterministic time scale. 64 | func Advance(d time.Duration) time.Duration { 65 | rwMutex.RLock() 66 | ft, ok := provider.(*frozenTime) 67 | rwMutex.RUnlock() 68 | if !ok { 69 | panic("Freeze time first!") 70 | } 71 | ft.advance(d) 72 | return Now().UTC().Sub(frozenAt) 73 | } 74 | 75 | // Wait4Scheduled blocks until either there are n or more scheduled events, or 76 | // the timeout elapses. It returns true if the wait condition has been met 77 | // before the timeout expired, false otherwise. 78 | func Wait4Scheduled(count int, timeout time.Duration) bool { 79 | rwMutex.RLock() 80 | defer rwMutex.RUnlock() 81 | return provider.Wait4Scheduled(count, timeout) 82 | } 83 | 84 | // Now see time.Now. 85 | func Now() time.Time { 86 | rwMutex.RLock() 87 | defer rwMutex.RUnlock() 88 | return provider.Now() 89 | } 90 | 91 | // Sleep see time.Sleep. 92 | func Sleep(d time.Duration) { 93 | rwMutex.RLock() 94 | defer rwMutex.RUnlock() 95 | provider.Sleep(d) 96 | } 97 | 98 | // After see time.After. 99 | func After(d time.Duration) <-chan time.Time { 100 | rwMutex.RLock() 101 | defer rwMutex.RUnlock() 102 | return provider.After(d) 103 | } 104 | 105 | // NewTimer see time.NewTimer. 106 | func NewTimer(d time.Duration) Timer { 107 | rwMutex.RLock() 108 | defer rwMutex.RUnlock() 109 | return provider.NewTimer(d) 110 | } 111 | 112 | // AfterFunc see time.AfterFunc. 113 | func AfterFunc(d time.Duration, f func()) Timer { 114 | rwMutex.RLock() 115 | defer rwMutex.RUnlock() 116 | return provider.AfterFunc(d, f) 117 | } 118 | 119 | // NewTicker see time.Ticker. 120 | func NewTicker(d time.Duration) Ticker { 121 | rwMutex.RLock() 122 | defer rwMutex.RUnlock() 123 | return provider.NewTicker(d) 124 | } 125 | 126 | // Tick see time.Tick. 127 | func Tick(d time.Duration) <-chan time.Time { 128 | rwMutex.RLock() 129 | defer rwMutex.RUnlock() 130 | return provider.Tick(d) 131 | } 132 | -------------------------------------------------------------------------------- /internal/holsterv4/clock/duration.go: -------------------------------------------------------------------------------- 1 | package clock 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | ) 7 | 8 | type DurationJSON struct { 9 | Duration Duration 10 | } 11 | 12 | func NewDurationJSON(v interface{}) (DurationJSON, error) { 13 | switch v := v.(type) { 14 | case Duration: 15 | return DurationJSON{Duration: v}, nil 16 | case float64: 17 | return DurationJSON{Duration: Duration(v)}, nil 18 | case int64: 19 | return DurationJSON{Duration: Duration(v)}, nil 20 | case int: 21 | return DurationJSON{Duration: Duration(v)}, nil 22 | case []byte: 23 | duration, err := ParseDuration(string(v)) 24 | if err != nil { 25 | return DurationJSON{}, fmt.Errorf("while parsing []byte: %w", err) 26 | } 27 | return DurationJSON{Duration: duration}, nil 28 | case string: 29 | duration, err := ParseDuration(v) 30 | if err != nil { 31 | return DurationJSON{}, fmt.Errorf("while parsing string: %w", err) 32 | } 33 | return DurationJSON{Duration: duration}, nil 34 | default: 35 | return DurationJSON{}, fmt.Errorf("bad type %T", v) 36 | } 37 | } 38 | 39 | func NewDurationJSONOrPanic(v interface{}) DurationJSON { 40 | d, err := NewDurationJSON(v) 41 | if err != nil { 42 | panic(err) 43 | } 44 | return d 45 | } 46 | 47 | func (d DurationJSON) MarshalJSON() ([]byte, error) { 48 | return json.Marshal(d.Duration.String()) 49 | } 50 | 51 | func (d *DurationJSON) UnmarshalJSON(b []byte) error { 52 | var v interface{} 53 | var err error 54 | 55 | if err = json.Unmarshal(b, &v); err != nil { 56 | return err 57 | } 58 | 59 | *d, err = NewDurationJSON(v) 60 | return err 61 | } 62 | 63 | func (d DurationJSON) String() string { 64 | return d.Duration.String() 65 | } 66 | -------------------------------------------------------------------------------- /internal/holsterv4/clock/duration_test.go: -------------------------------------------------------------------------------- 1 | package clock_test 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/suite" 8 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 9 | ) 10 | 11 | type DurationSuite struct { 12 | suite.Suite 13 | } 14 | 15 | func TestDurationSuite(t *testing.T) { 16 | suite.Run(t, new(DurationSuite)) 17 | } 18 | 19 | func (s *DurationSuite) TestNewOk() { 20 | for _, v := range []interface{}{ 21 | 42 * clock.Second, 22 | int(42000000000), 23 | int64(42000000000), 24 | 42000000000., 25 | "42s", 26 | []byte("42s"), 27 | } { 28 | d, err := clock.NewDurationJSON(v) 29 | s.Nil(err) 30 | s.Equal(42*clock.Second, d.Duration) 31 | } 32 | } 33 | 34 | func (s *DurationSuite) TestNewError() { 35 | for _, tc := range []struct { 36 | v interface{} 37 | errMsg string 38 | }{{ 39 | v: "foo", 40 | errMsg: "while parsing string: time: invalid duration \"foo\"", 41 | }, { 42 | v: []byte("foo"), 43 | errMsg: "while parsing []byte: time: invalid duration \"foo\"", 44 | }, { 45 | v: true, 46 | errMsg: "bad type bool", 47 | }} { 48 | d, err := clock.NewDurationJSON(tc.v) 49 | s.Equal(tc.errMsg, err.Error()) 50 | s.Equal(clock.DurationJSON{}, d) 51 | } 52 | } 53 | 54 | func (s *DurationSuite) TestUnmarshal() { 55 | for _, v := range []string{ 56 | `{"foo": 42000000000}`, 57 | `{"foo": 0.42e11}`, 58 | `{"foo": "42s"}`, 59 | } { 60 | var withDuration struct { 61 | Foo clock.DurationJSON `json:"foo"` 62 | } 63 | err := json.Unmarshal([]byte(v), &withDuration) 64 | s.Nil(err) 65 | s.Equal(42*clock.Second, withDuration.Foo.Duration) 66 | } 67 | } 68 | 69 | func (s *DurationSuite) TestMarshalling() { 70 | d, err := clock.NewDurationJSON(42 * clock.Second) 71 | s.Nil(err) 72 | encoded, err := d.MarshalJSON() 73 | s.Nil(err) 74 | var decoded clock.DurationJSON 75 | err = decoded.UnmarshalJSON(encoded) 76 | s.Nil(err) 77 | s.Equal(d, decoded) 78 | s.Equal("42s", decoded.String()) 79 | } 80 | -------------------------------------------------------------------------------- /internal/holsterv4/clock/frozen.go: -------------------------------------------------------------------------------- 1 | package clock 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | type frozenTime struct { 10 | mu sync.Mutex 11 | now time.Time 12 | timers []*frozenTimer 13 | waiter *waiter 14 | } 15 | 16 | type waiter struct { 17 | count int 18 | signalCh chan struct{} 19 | } 20 | 21 | func (ft *frozenTime) Now() time.Time { 22 | ft.mu.Lock() 23 | defer ft.mu.Unlock() 24 | return ft.now 25 | } 26 | 27 | func (ft *frozenTime) Sleep(d time.Duration) { 28 | <-ft.NewTimer(d).C() 29 | } 30 | 31 | func (ft *frozenTime) After(d time.Duration) <-chan time.Time { 32 | return ft.NewTimer(d).C() 33 | } 34 | 35 | func (ft *frozenTime) NewTimer(d time.Duration) Timer { 36 | return ft.AfterFunc(d, nil) 37 | } 38 | 39 | func (ft *frozenTime) AfterFunc(d time.Duration, f func()) Timer { 40 | t := &frozenTimer{ 41 | ft: ft, 42 | when: ft.Now().Add(d), 43 | f: f, 44 | } 45 | if f == nil { 46 | t.c = make(chan time.Time, 1) 47 | } 48 | ft.startTimer(t) 49 | return t 50 | } 51 | 52 | func (ft *frozenTime) advance(d time.Duration) { 53 | ft.mu.Lock() 54 | defer ft.mu.Unlock() 55 | 56 | ft.now = ft.now.Add(d) 57 | for t := ft.nextExpired(); t != nil; t = ft.nextExpired() { 58 | // Send the timer expiration time to the timer channel if it is 59 | // defined. But make sure not to block on the send if the channel is 60 | // full. This behavior will make a ticker skip beats if it readers are 61 | // not fast enough. 62 | if t.c != nil { 63 | select { 64 | case t.c <- t.when: 65 | default: 66 | } 67 | } 68 | // If it is a ticking timer then schedule next tick, otherwise mark it 69 | // as stopped. 70 | if t.interval != 0 { 71 | t.when = t.when.Add(t.interval) 72 | t.stopped = false 73 | ft.unlockedStartTimer(t) 74 | } 75 | // If a function is associated with the timer then call it, but make 76 | // sure to release the lock for the time of call it is necessary 77 | // because the lock is not re-entrant but the function may need to 78 | // start another timer or ticker. 79 | if t.f != nil { 80 | func() { 81 | ft.mu.Unlock() 82 | defer ft.mu.Lock() 83 | t.f() 84 | }() 85 | } 86 | } 87 | } 88 | 89 | func (ft *frozenTime) stopTimer(t *frozenTimer) bool { 90 | ft.mu.Lock() 91 | defer ft.mu.Unlock() 92 | 93 | if t.stopped { 94 | return false 95 | } 96 | for i, curr := range ft.timers { 97 | if curr == t { 98 | t.stopped = true 99 | copy(ft.timers[i:], ft.timers[i+1:]) 100 | lastIdx := len(ft.timers) - 1 101 | ft.timers[lastIdx] = nil 102 | ft.timers = ft.timers[:lastIdx] 103 | return true 104 | } 105 | } 106 | return false 107 | } 108 | 109 | func (ft *frozenTime) nextExpired() *frozenTimer { 110 | if len(ft.timers) == 0 { 111 | return nil 112 | } 113 | t := ft.timers[0] 114 | if ft.now.Before(t.when) { 115 | return nil 116 | } 117 | copy(ft.timers, ft.timers[1:]) 118 | lastIdx := len(ft.timers) - 1 119 | ft.timers[lastIdx] = nil 120 | ft.timers = ft.timers[:lastIdx] 121 | t.stopped = true 122 | return t 123 | } 124 | 125 | func (ft *frozenTime) startTimer(t *frozenTimer) { 126 | ft.mu.Lock() 127 | defer ft.mu.Unlock() 128 | 129 | ft.unlockedStartTimer(t) 130 | 131 | if ft.waiter == nil { 132 | return 133 | } 134 | if len(ft.timers) >= ft.waiter.count { 135 | close(ft.waiter.signalCh) 136 | } 137 | } 138 | 139 | func (ft *frozenTime) unlockedStartTimer(t *frozenTimer) { 140 | pos := 0 141 | for _, curr := range ft.timers { 142 | if t.when.Before(curr.when) { 143 | break 144 | } 145 | pos++ 146 | } 147 | ft.timers = append(ft.timers, nil) 148 | copy(ft.timers[pos+1:], ft.timers[pos:]) 149 | ft.timers[pos] = t 150 | } 151 | 152 | type frozenTimer struct { 153 | ft *frozenTime 154 | when time.Time 155 | interval time.Duration 156 | stopped bool 157 | c chan time.Time 158 | f func() 159 | } 160 | 161 | func (t *frozenTimer) C() <-chan time.Time { 162 | return t.c 163 | } 164 | 165 | func (t *frozenTimer) Stop() bool { 166 | return t.ft.stopTimer(t) 167 | } 168 | 169 | func (t *frozenTimer) Reset(d time.Duration) bool { 170 | active := t.ft.stopTimer(t) 171 | t.when = t.ft.Now().Add(d) 172 | t.ft.startTimer(t) 173 | return active 174 | } 175 | 176 | type frozenTicker struct { 177 | t *frozenTimer 178 | } 179 | 180 | func (t *frozenTicker) C() <-chan time.Time { 181 | return t.t.C() 182 | } 183 | 184 | func (t *frozenTicker) Stop() { 185 | t.t.Stop() 186 | } 187 | 188 | func (ft *frozenTime) NewTicker(d time.Duration) Ticker { 189 | if d <= 0 { 190 | panic(errors.New("non-positive interval for NewTicker")) 191 | } 192 | t := &frozenTimer{ 193 | ft: ft, 194 | when: ft.Now().Add(d), 195 | interval: d, 196 | c: make(chan time.Time, 1), 197 | } 198 | ft.startTimer(t) 199 | return &frozenTicker{t} 200 | } 201 | 202 | func (ft *frozenTime) Tick(d time.Duration) <-chan time.Time { 203 | if d <= 0 { 204 | return nil 205 | } 206 | return ft.NewTicker(d).C() 207 | } 208 | 209 | func (ft *frozenTime) Wait4Scheduled(count int, timeout time.Duration) bool { 210 | ft.mu.Lock() 211 | if len(ft.timers) >= count { 212 | ft.mu.Unlock() 213 | return true 214 | } 215 | if ft.waiter != nil { 216 | panic("Concurrent call") 217 | } 218 | ft.waiter = &waiter{count, make(chan struct{})} 219 | ft.mu.Unlock() 220 | 221 | success := false 222 | select { 223 | case <-ft.waiter.signalCh: 224 | success = true 225 | case <-time.After(timeout): 226 | } 227 | ft.mu.Lock() 228 | ft.waiter = nil 229 | ft.mu.Unlock() 230 | return success 231 | } 232 | -------------------------------------------------------------------------------- /internal/holsterv4/clock/go19.go: -------------------------------------------------------------------------------- 1 | // +build go1.9 2 | 3 | // This file introduces aliases to allow using of the clock package as a 4 | // drop-in replacement of the standard time package. 5 | 6 | package clock 7 | 8 | import "time" 9 | 10 | type ( 11 | Time = time.Time 12 | Duration = time.Duration 13 | Location = time.Location 14 | 15 | Weekday = time.Weekday 16 | Month = time.Month 17 | 18 | ParseError = time.ParseError 19 | ) 20 | 21 | const ( 22 | Nanosecond = time.Nanosecond 23 | Microsecond = time.Microsecond 24 | Millisecond = time.Millisecond 25 | Second = time.Second 26 | Minute = time.Minute 27 | Hour = time.Hour 28 | 29 | Sunday = time.Sunday 30 | Monday = time.Monday 31 | Tuesday = time.Tuesday 32 | Wednesday = time.Wednesday 33 | Thursday = time.Thursday 34 | Friday = time.Friday 35 | Saturday = time.Saturday 36 | 37 | January = time.January 38 | February = time.February 39 | March = time.March 40 | April = time.April 41 | May = time.May 42 | June = time.June 43 | July = time.July 44 | August = time.August 45 | September = time.September 46 | October = time.October 47 | November = time.November 48 | December = time.December 49 | 50 | ANSIC = time.ANSIC 51 | UnixDate = time.UnixDate 52 | RubyDate = time.RubyDate 53 | RFC822 = time.RFC822 54 | RFC822Z = time.RFC822Z 55 | RFC850 = time.RFC850 56 | RFC1123 = time.RFC1123 57 | RFC1123Z = time.RFC1123Z 58 | RFC3339 = time.RFC3339 59 | RFC3339Nano = time.RFC3339Nano 60 | Kitchen = time.Kitchen 61 | Stamp = time.Stamp 62 | StampMilli = time.StampMilli 63 | StampMicro = time.StampMicro 64 | StampNano = time.StampNano 65 | ) 66 | 67 | var ( 68 | UTC = time.UTC 69 | Local = time.Local 70 | ) 71 | 72 | func Date(year int, month Month, day, hour, min, sec, nsec int, loc *Location) Time { 73 | return time.Date(year, month, day, hour, min, sec, nsec, loc) 74 | } 75 | 76 | func FixedZone(name string, offset int) *Location { 77 | return time.FixedZone(name, offset) 78 | } 79 | 80 | func LoadLocation(name string) (*Location, error) { 81 | return time.LoadLocation(name) 82 | } 83 | 84 | func Parse(layout, value string) (Time, error) { 85 | return time.Parse(layout, value) 86 | } 87 | 88 | func ParseDuration(s string) (Duration, error) { 89 | return time.ParseDuration(s) 90 | } 91 | 92 | func ParseInLocation(layout, value string, loc *Location) (Time, error) { 93 | return time.ParseInLocation(layout, value, loc) 94 | } 95 | 96 | func Unix(sec int64, nsec int64) Time { 97 | return time.Unix(sec, nsec) 98 | } 99 | 100 | func Since(t Time) Duration { 101 | return provider.Now().Sub(t) 102 | } 103 | 104 | func Until(t Time) Duration { 105 | return t.Sub(provider.Now()) 106 | } 107 | -------------------------------------------------------------------------------- /internal/holsterv4/clock/interface.go: -------------------------------------------------------------------------------- 1 | package clock 2 | 3 | import "time" 4 | 5 | // Timer see time.Timer. 6 | type Timer interface { 7 | C() <-chan time.Time 8 | Stop() bool 9 | Reset(d time.Duration) bool 10 | } 11 | 12 | // Ticker see time.Ticker. 13 | type Ticker interface { 14 | C() <-chan time.Time 15 | Stop() 16 | } 17 | 18 | // NewStoppedTimer returns a stopped timer. Call Reset to get it ticking. 19 | func NewStoppedTimer() Timer { 20 | t := NewTimer(42 * time.Hour) 21 | t.Stop() 22 | return t 23 | } 24 | 25 | // Clock is an interface that mimics the one of the SDK time package. 26 | type Clock interface { 27 | Now() time.Time 28 | Sleep(d time.Duration) 29 | After(d time.Duration) <-chan time.Time 30 | NewTimer(d time.Duration) Timer 31 | AfterFunc(d time.Duration, f func()) Timer 32 | NewTicker(d time.Duration) Ticker 33 | Tick(d time.Duration) <-chan time.Time 34 | Wait4Scheduled(n int, timeout time.Duration) bool 35 | } 36 | -------------------------------------------------------------------------------- /internal/holsterv4/clock/rfc822.go: -------------------------------------------------------------------------------- 1 | package clock 2 | 3 | import ( 4 | "strconv" 5 | "time" 6 | ) 7 | 8 | var datetimeLayouts = [48]string{ 9 | // Day first month 2nd abbreviated. 10 | "Mon, 2 Jan 2006 15:04:05 MST", 11 | "Mon, 2 Jan 2006 15:04:05 -0700", 12 | "Mon, 2 Jan 2006 15:04:05 -0700 (MST)", 13 | "2 Jan 2006 15:04:05 MST", 14 | "2 Jan 2006 15:04:05 -0700", 15 | "2 Jan 2006 15:04:05 -0700 (MST)", 16 | "Mon, 2 Jan 2006 15:04 MST", 17 | "Mon, 2 Jan 2006 15:04 -0700", 18 | "Mon, 2 Jan 2006 15:04 -0700 (MST)", 19 | "2 Jan 2006 15:04 MST", 20 | "2 Jan 2006 15:04 -0700", 21 | "2 Jan 2006 15:04 -0700 (MST)", 22 | 23 | // Month first day 2nd abbreviated. 24 | "Mon, Jan 2 2006 15:04:05 MST", 25 | "Mon, Jan 2 2006 15:04:05 -0700", 26 | "Mon, Jan 2 2006 15:04:05 -0700 (MST)", 27 | "Jan 2 2006 15:04:05 MST", 28 | "Jan 2 2006 15:04:05 -0700", 29 | "Jan 2 2006 15:04:05 -0700 (MST)", 30 | "Mon, Jan 2 2006 15:04 MST", 31 | "Mon, Jan 2 2006 15:04 -0700", 32 | "Mon, Jan 2 2006 15:04 -0700 (MST)", 33 | "Jan 2 2006 15:04 MST", 34 | "Jan 2 2006 15:04 -0700", 35 | "Jan 2 2006 15:04 -0700 (MST)", 36 | 37 | // Day first month 2nd not abbreviated. 38 | "Mon, 2 January 2006 15:04:05 MST", 39 | "Mon, 2 January 2006 15:04:05 -0700", 40 | "Mon, 2 January 2006 15:04:05 -0700 (MST)", 41 | "2 January 2006 15:04:05 MST", 42 | "2 January 2006 15:04:05 -0700", 43 | "2 January 2006 15:04:05 -0700 (MST)", 44 | "Mon, 2 January 2006 15:04 MST", 45 | "Mon, 2 January 2006 15:04 -0700", 46 | "Mon, 2 January 2006 15:04 -0700 (MST)", 47 | "2 January 2006 15:04 MST", 48 | "2 January 2006 15:04 -0700", 49 | "2 January 2006 15:04 -0700 (MST)", 50 | 51 | // Month first day 2nd not abbreviated. 52 | "Mon, January 2 2006 15:04:05 MST", 53 | "Mon, January 2 2006 15:04:05 -0700", 54 | "Mon, January 2 2006 15:04:05 -0700 (MST)", 55 | "January 2 2006 15:04:05 MST", 56 | "January 2 2006 15:04:05 -0700", 57 | "January 2 2006 15:04:05 -0700 (MST)", 58 | "Mon, January 2 2006 15:04 MST", 59 | "Mon, January 2 2006 15:04 -0700", 60 | "Mon, January 2 2006 15:04 -0700 (MST)", 61 | "January 2 2006 15:04 MST", 62 | "January 2 2006 15:04 -0700", 63 | "January 2 2006 15:04 -0700 (MST)", 64 | } 65 | 66 | // Allows seamless JSON encoding/decoding of rfc822 formatted timestamps. 67 | // https://www.ietf.org/rfc/rfc822.txt section 5. 68 | type RFC822Time struct { 69 | Time 70 | } 71 | 72 | // NewRFC822Time creates RFC822Time from a standard Time. The created value is 73 | // truncated down to second precision because RFC822 does not allow for better. 74 | func NewRFC822Time(t Time) RFC822Time { 75 | return RFC822Time{Time: t.Truncate(Second)} 76 | } 77 | 78 | // ParseRFC822Time parses an RFC822 time string. 79 | func ParseRFC822Time(s string) (Time, error) { 80 | var t time.Time 81 | var err error 82 | for _, layout := range datetimeLayouts { 83 | t, err = Parse(layout, s) 84 | if err == nil { 85 | return t, err 86 | } 87 | } 88 | return t, err 89 | } 90 | 91 | // NewRFC822Time creates RFC822Time from a Unix timestamp (seconds from Epoch). 92 | func NewRFC822TimeFromUnix(timestamp int64) RFC822Time { 93 | return RFC822Time{Time: Unix(timestamp, 0).UTC()} 94 | } 95 | 96 | func (t RFC822Time) MarshalJSON() ([]byte, error) { 97 | return []byte(strconv.Quote(t.Format(RFC1123))), nil 98 | } 99 | 100 | func (t *RFC822Time) UnmarshalJSON(s []byte) error { 101 | q, err := strconv.Unquote(string(s)) 102 | if err != nil { 103 | return err 104 | } 105 | parsed, err := ParseRFC822Time(q) 106 | if err != nil { 107 | return err 108 | } 109 | t.Time = parsed 110 | return nil 111 | } 112 | 113 | func (t RFC822Time) String() string { 114 | return t.Format(RFC1123) 115 | } 116 | 117 | func (t RFC822Time) StringWithOffset() string { 118 | return t.Format(RFC1123Z) 119 | } 120 | -------------------------------------------------------------------------------- /internal/holsterv4/clock/system.go: -------------------------------------------------------------------------------- 1 | package clock 2 | 3 | import "time" 4 | 5 | type systemTime struct{} 6 | 7 | func (st *systemTime) Now() time.Time { 8 | return time.Now() 9 | } 10 | 11 | func (st *systemTime) Sleep(d time.Duration) { 12 | time.Sleep(d) 13 | } 14 | 15 | func (st *systemTime) After(d time.Duration) <-chan time.Time { 16 | return time.After(d) 17 | } 18 | 19 | type systemTimer struct { 20 | t *time.Timer 21 | } 22 | 23 | func (st *systemTime) NewTimer(d time.Duration) Timer { 24 | t := time.NewTimer(d) 25 | return &systemTimer{t} 26 | } 27 | 28 | func (st *systemTime) AfterFunc(d time.Duration, f func()) Timer { 29 | t := time.AfterFunc(d, f) 30 | return &systemTimer{t} 31 | } 32 | 33 | func (t *systemTimer) C() <-chan time.Time { 34 | return t.t.C 35 | } 36 | 37 | func (t *systemTimer) Stop() bool { 38 | return t.t.Stop() 39 | } 40 | 41 | func (t *systemTimer) Reset(d time.Duration) bool { 42 | return t.t.Reset(d) 43 | } 44 | 45 | type systemTicker struct { 46 | t *time.Ticker 47 | } 48 | 49 | func (t *systemTicker) C() <-chan time.Time { 50 | return t.t.C 51 | } 52 | 53 | func (t *systemTicker) Stop() { 54 | t.t.Stop() 55 | } 56 | 57 | func (st *systemTime) NewTicker(d time.Duration) Ticker { 58 | t := time.NewTicker(d) 59 | return &systemTicker{t} 60 | } 61 | 62 | func (st *systemTime) Tick(d time.Duration) <-chan time.Time { 63 | return time.Tick(d) 64 | } 65 | 66 | func (st *systemTime) Wait4Scheduled(count int, timeout time.Duration) bool { 67 | panic("Not supported") 68 | } 69 | -------------------------------------------------------------------------------- /internal/holsterv4/clock/system_test.go: -------------------------------------------------------------------------------- 1 | package clock 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestSleep(t *testing.T) { 11 | start := Now() 12 | 13 | // When 14 | Sleep(100 * time.Millisecond) 15 | 16 | // Then 17 | if Now().Sub(start) < 100*time.Millisecond { 18 | assert.Fail(t, "Sleep did not last long enough") 19 | } 20 | } 21 | 22 | func TestAfter(t *testing.T) { 23 | start := Now() 24 | 25 | // When 26 | end := <-After(100 * time.Millisecond) 27 | 28 | // Then 29 | if end.Sub(start) < 100*time.Millisecond { 30 | assert.Fail(t, "Sleep did not last long enough") 31 | } 32 | } 33 | 34 | func TestAfterFunc(t *testing.T) { 35 | start := Now() 36 | endCh := make(chan time.Time, 1) 37 | 38 | // When 39 | AfterFunc(100*time.Millisecond, func() { endCh <- time.Now() }) 40 | 41 | // Then 42 | end := <-endCh 43 | if end.Sub(start) < 100*time.Millisecond { 44 | assert.Fail(t, "Sleep did not last long enough") 45 | } 46 | } 47 | 48 | func TestNewTimer(t *testing.T) { 49 | start := Now() 50 | 51 | // When 52 | timer := NewTimer(100 * time.Millisecond) 53 | 54 | // Then 55 | end := <-timer.C() 56 | if end.Sub(start) < 100*time.Millisecond { 57 | assert.Fail(t, "Sleep did not last long enough") 58 | } 59 | } 60 | 61 | func TestTimerStop(t *testing.T) { 62 | timer := NewTimer(50 * time.Millisecond) 63 | 64 | // When 65 | active := timer.Stop() 66 | 67 | // Then 68 | assert.Equal(t, true, active) 69 | time.Sleep(100) 70 | select { 71 | case <-timer.C(): 72 | assert.Fail(t, "Timer should not have fired") 73 | default: 74 | } 75 | } 76 | 77 | func TestTimerReset(t *testing.T) { 78 | t.Skip("fail on the CI for darwin") 79 | start := time.Now() 80 | timer := NewTimer(300 * time.Millisecond) 81 | 82 | // When 83 | timer.Reset(100 * time.Millisecond) 84 | 85 | // Then 86 | end := <-timer.C() 87 | if end.Sub(start) >= 150*time.Millisecond { 88 | assert.Fail(t, "Waited too long", end.Sub(start).String()) 89 | } 90 | } 91 | 92 | func TestNewTicker(t *testing.T) { 93 | start := Now() 94 | 95 | // When 96 | timer := NewTicker(100 * time.Millisecond) 97 | 98 | // Then 99 | end := <-timer.C() 100 | if end.Sub(start) < 100*time.Millisecond { 101 | assert.Fail(t, "Sleep did not last long enough") 102 | } 103 | end = <-timer.C() 104 | if end.Sub(start) < 200*time.Millisecond { 105 | assert.Fail(t, "Sleep did not last long enough") 106 | } 107 | 108 | timer.Stop() 109 | time.Sleep(150) 110 | select { 111 | case <-timer.C(): 112 | assert.Fail(t, "Ticker should not have fired") 113 | default: 114 | } 115 | } 116 | 117 | func TestTick(t *testing.T) { 118 | start := Now() 119 | 120 | // When 121 | ch := Tick(100 * time.Millisecond) 122 | 123 | // Then 124 | end := <-ch 125 | if end.Sub(start) < 100*time.Millisecond { 126 | assert.Fail(t, "Sleep did not last long enough") 127 | } 128 | end = <-ch 129 | if end.Sub(start) < 200*time.Millisecond { 130 | assert.Fail(t, "Sleep did not last long enough") 131 | } 132 | } 133 | 134 | func TestNewStoppedTimer(t *testing.T) { 135 | timer := NewStoppedTimer() 136 | 137 | // When/Then 138 | select { 139 | case <-timer.C(): 140 | assert.Fail(t, "Timer should not have fired") 141 | default: 142 | } 143 | assert.Equal(t, false, timer.Stop()) 144 | } 145 | -------------------------------------------------------------------------------- /internal/holsterv4/collections/README.md: -------------------------------------------------------------------------------- 1 | ## Priority Queue 2 | Provides a Priority Queue implementation as described [here](https://en.wikipedia.org/wiki/Priority_queue) 3 | 4 | ```go 5 | queue := collections.NewPriorityQueue() 6 | 7 | queue.Push(&collections.PQItem{ 8 | Value: "thing3", 9 | Priority: 3, 10 | }) 11 | 12 | queue.Push(&collections.PQItem{ 13 | Value: "thing1", 14 | Priority: 1, 15 | }) 16 | 17 | queue.Push(&collections.PQItem{ 18 | Value: "thing2", 19 | Priority: 2, 20 | }) 21 | 22 | // Pops item off the queue according to the priority instead of the Push() order 23 | item := queue.Pop() 24 | 25 | fmt.Printf("Item: %s", item.Value.(string)) 26 | 27 | // Output: Item: thing1 28 | ``` 29 | -------------------------------------------------------------------------------- /internal/holsterv4/collections/priority_queue.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Mailgun Technologies Inc 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | package collections 17 | 18 | import ( 19 | "container/heap" 20 | ) 21 | 22 | // An PQItem is something we manage in a priority queue. 23 | type PQItem struct { 24 | Value interface{} 25 | Priority int // The priority of the item in the queue. 26 | // The index is needed by update and is maintained by the heap.Interface methods. 27 | index int // The index of the item in the heap. 28 | } 29 | 30 | // Implements a PriorityQueue 31 | type PriorityQueue struct { 32 | impl *pqImpl 33 | } 34 | 35 | func NewPriorityQueue() *PriorityQueue { 36 | mh := &pqImpl{} 37 | heap.Init(mh) 38 | return &PriorityQueue{impl: mh} 39 | } 40 | 41 | func (p PriorityQueue) Len() int { return p.impl.Len() } 42 | 43 | func (p *PriorityQueue) Push(el *PQItem) { 44 | heap.Push(p.impl, el) 45 | } 46 | 47 | func (p *PriorityQueue) Pop() *PQItem { 48 | el := heap.Pop(p.impl) 49 | return el.(*PQItem) 50 | } 51 | 52 | func (p *PriorityQueue) Peek() *PQItem { 53 | return (*p.impl)[0] 54 | } 55 | 56 | // Modifies the priority and value of an Item in the queue. 57 | func (p *PriorityQueue) Update(el *PQItem, priority int) { 58 | heap.Remove(p.impl, el.index) 59 | el.Priority = priority 60 | heap.Push(p.impl, el) 61 | } 62 | 63 | func (p *PriorityQueue) Remove(el *PQItem) { 64 | heap.Remove(p.impl, el.index) 65 | } 66 | 67 | // Actual Implementation using heap.Interface 68 | type pqImpl []*PQItem 69 | 70 | func (mh pqImpl) Len() int { return len(mh) } 71 | 72 | func (mh pqImpl) Less(i, j int) bool { 73 | return mh[i].Priority < mh[j].Priority 74 | } 75 | 76 | func (mh pqImpl) Swap(i, j int) { 77 | mh[i], mh[j] = mh[j], mh[i] 78 | mh[i].index = i 79 | mh[j].index = j 80 | } 81 | 82 | func (mh *pqImpl) Push(x interface{}) { 83 | n := len(*mh) 84 | item := x.(*PQItem) 85 | item.index = n 86 | *mh = append(*mh, item) 87 | } 88 | 89 | func (mh *pqImpl) Pop() interface{} { 90 | old := *mh 91 | n := len(old) 92 | item := old[n-1] 93 | item.index = -1 // for safety 94 | *mh = old[0 : n-1] 95 | return item 96 | } 97 | -------------------------------------------------------------------------------- /internal/holsterv4/collections/priority_queue_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Mailgun Technologies Inc 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | package collections_test 17 | 18 | import ( 19 | "fmt" 20 | "testing" 21 | 22 | "github.com/stretchr/testify/assert" 23 | "github.com/vulcand/oxy/v2/internal/holsterv4/collections" 24 | ) 25 | 26 | func toPtr(i int) interface{} { 27 | return &i 28 | } 29 | 30 | func toInt(i interface{}) int { 31 | return *(i.(*int)) 32 | } 33 | 34 | func TestPeek(t *testing.T) { 35 | mh := collections.NewPriorityQueue() 36 | 37 | el := &collections.PQItem{ 38 | Value: toPtr(1), 39 | Priority: 5, 40 | } 41 | 42 | mh.Push(el) 43 | assert.Equal(t, 1, toInt(mh.Peek().Value)) 44 | assert.Equal(t, 1, mh.Len()) 45 | 46 | el = &collections.PQItem{ 47 | Value: toPtr(2), 48 | Priority: 1, 49 | } 50 | mh.Push(el) 51 | assert.Equal(t, 2, mh.Len()) 52 | assert.Equal(t, 2, toInt(mh.Peek().Value)) 53 | assert.Equal(t, 2, toInt(mh.Peek().Value)) 54 | assert.Equal(t, 2, mh.Len()) 55 | 56 | el = mh.Pop() 57 | 58 | assert.Equal(t, 2, toInt(el.Value)) 59 | assert.Equal(t, 1, mh.Len()) 60 | assert.Equal(t, 1, toInt(mh.Peek().Value)) 61 | 62 | mh.Pop() 63 | assert.Equal(t, 0, mh.Len()) 64 | } 65 | 66 | func TestUpdate(t *testing.T) { 67 | mh := collections.NewPriorityQueue() 68 | x := &collections.PQItem{ 69 | Value: toPtr(1), 70 | Priority: 4, 71 | } 72 | y := &collections.PQItem{ 73 | Value: toPtr(2), 74 | Priority: 3, 75 | } 76 | z := &collections.PQItem{ 77 | Value: toPtr(3), 78 | Priority: 8, 79 | } 80 | mh.Push(x) 81 | mh.Push(y) 82 | mh.Push(z) 83 | assert.Equal(t, 2, toInt(mh.Peek().Value)) 84 | 85 | mh.Update(z, 1) 86 | assert.Equal(t, 3, toInt(mh.Peek().Value)) 87 | 88 | mh.Update(x, 0) 89 | assert.Equal(t, 1, toInt(mh.Peek().Value)) 90 | } 91 | 92 | func ExampleNewPriorityQueue() { 93 | queue := collections.NewPriorityQueue() 94 | 95 | queue.Push(&collections.PQItem{ 96 | Value: "thing3", 97 | Priority: 3, 98 | }) 99 | 100 | queue.Push(&collections.PQItem{ 101 | Value: "thing1", 102 | Priority: 1, 103 | }) 104 | 105 | queue.Push(&collections.PQItem{ 106 | Value: "thing2", 107 | Priority: 2, 108 | }) 109 | 110 | // Pops item off the queue according to the priority instead of the Push() order 111 | item := queue.Pop() 112 | 113 | fmt.Printf("Item: %s", item.Value.(string)) 114 | 115 | // Output: Item: thing1 116 | } 117 | -------------------------------------------------------------------------------- /internal/holsterv4/collections/ttlmap.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 Mailgun Technologies Inc 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | package collections 17 | 18 | import ( 19 | "fmt" 20 | "sync" 21 | "time" 22 | 23 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 24 | ) 25 | 26 | type TTLMap struct { 27 | // Optionally specifies a callback function to be 28 | // executed when an entry has expired 29 | OnExpire func(key string, i interface{}) 30 | 31 | capacity int 32 | elements map[string]*mapElement 33 | expiryTimes *PriorityQueue 34 | mutex *sync.RWMutex 35 | } 36 | 37 | type mapElement struct { 38 | key string 39 | value interface{} 40 | heapEl *PQItem 41 | } 42 | 43 | func NewTTLMap(capacity int) *TTLMap { 44 | if capacity <= 0 { 45 | capacity = 0 46 | } 47 | 48 | return &TTLMap{ 49 | capacity: capacity, 50 | elements: make(map[string]*mapElement), 51 | expiryTimes: NewPriorityQueue(), 52 | mutex: &sync.RWMutex{}, 53 | } 54 | } 55 | 56 | func (m *TTLMap) Set(key string, value interface{}, ttlSeconds int) error { 57 | expiryTime, err := m.toEpochSeconds(ttlSeconds) 58 | if err != nil { 59 | return err 60 | } 61 | m.mutex.Lock() 62 | defer m.mutex.Unlock() 63 | return m.set(key, value, expiryTime) 64 | } 65 | 66 | func (m *TTLMap) Len() int { 67 | m.mutex.RLock() 68 | defer m.mutex.RUnlock() 69 | return len(m.elements) 70 | } 71 | 72 | func (m *TTLMap) Get(key string) (interface{}, bool) { 73 | value, mapEl, expired := m.lockNGet(key) 74 | if mapEl == nil { 75 | return nil, false 76 | } 77 | if expired { 78 | m.lockNDel(mapEl) 79 | return nil, false 80 | } 81 | return value, true 82 | } 83 | 84 | func (m *TTLMap) Increment(key string, value int, ttlSeconds int) (int, error) { 85 | expiryTime, err := m.toEpochSeconds(ttlSeconds) 86 | if err != nil { 87 | return 0, err 88 | } 89 | 90 | m.mutex.Lock() 91 | defer m.mutex.Unlock() 92 | 93 | mapEl, expired := m.get(key) 94 | if mapEl == nil || expired { 95 | m.set(key, value, expiryTime) 96 | return value, nil 97 | } 98 | 99 | currentValue, ok := mapEl.value.(int) 100 | if !ok { 101 | return 0, fmt.Errorf("Expected existing value to be integer, got %T", mapEl.value) 102 | } 103 | 104 | currentValue += value 105 | m.set(key, currentValue, expiryTime) 106 | return currentValue, nil 107 | } 108 | 109 | func (m *TTLMap) GetInt(key string) (int, bool, error) { 110 | valueI, exists := m.Get(key) 111 | if !exists { 112 | return 0, false, nil 113 | } 114 | value, ok := valueI.(int) 115 | if !ok { 116 | return 0, false, fmt.Errorf("Expected existing value to be integer, got %T", valueI) 117 | } 118 | return value, true, nil 119 | } 120 | 121 | func (m *TTLMap) set(key string, value interface{}, expiryTime int) error { 122 | if mapEl, ok := m.elements[key]; ok { 123 | mapEl.value = value 124 | m.expiryTimes.Update(mapEl.heapEl, expiryTime) 125 | return nil 126 | } 127 | 128 | if len(m.elements) >= m.capacity { 129 | m.freeSpace(1) 130 | } 131 | heapEl := &PQItem{ 132 | Priority: expiryTime, 133 | } 134 | mapEl := &mapElement{ 135 | key: key, 136 | value: value, 137 | heapEl: heapEl, 138 | } 139 | heapEl.Value = mapEl 140 | m.elements[key] = mapEl 141 | m.expiryTimes.Push(heapEl) 142 | return nil 143 | } 144 | 145 | func (m *TTLMap) lockNGet(key string) (value interface{}, mapEl *mapElement, expired bool) { 146 | m.mutex.RLock() 147 | defer m.mutex.RUnlock() 148 | 149 | mapEl, expired = m.get(key) 150 | value = nil 151 | if mapEl != nil { 152 | value = mapEl.value 153 | } 154 | return value, mapEl, expired 155 | } 156 | 157 | func (m *TTLMap) get(key string) (*mapElement, bool) { 158 | mapEl, ok := m.elements[key] 159 | if !ok { 160 | return nil, false 161 | } 162 | now := int(clock.Now().Unix()) 163 | expired := mapEl.heapEl.Priority <= now 164 | return mapEl, expired 165 | } 166 | 167 | func (m *TTLMap) lockNDel(mapEl *mapElement) { 168 | m.mutex.Lock() 169 | defer m.mutex.Unlock() 170 | 171 | // Map element could have been updated. Now that we have a lock 172 | // retrieve it again and check if it is still expired. 173 | var ok bool 174 | if mapEl, ok = m.elements[mapEl.key]; !ok { 175 | return 176 | } 177 | now := int(clock.Now().Unix()) 178 | if mapEl.heapEl.Priority > now { 179 | return 180 | } 181 | 182 | if m.OnExpire != nil { 183 | m.OnExpire(mapEl.key, mapEl.value) 184 | } 185 | 186 | delete(m.elements, mapEl.key) 187 | m.expiryTimes.Remove(mapEl.heapEl) 188 | } 189 | 190 | func (m *TTLMap) freeSpace(count int) { 191 | removed := m.RemoveExpired(count) 192 | if removed >= count { 193 | return 194 | } 195 | m.RemoveLastUsed(count - removed) 196 | } 197 | 198 | func (m *TTLMap) RemoveExpired(iterations int) int { 199 | removed := 0 200 | now := int(clock.Now().Unix()) 201 | for i := 0; i < iterations; i += 1 { 202 | if len(m.elements) == 0 { 203 | break 204 | } 205 | heapEl := m.expiryTimes.Peek() 206 | if heapEl.Priority > now { 207 | break 208 | } 209 | m.expiryTimes.Pop() 210 | mapEl := heapEl.Value.(*mapElement) 211 | delete(m.elements, mapEl.key) 212 | removed += 1 213 | } 214 | return removed 215 | } 216 | 217 | func (m *TTLMap) RemoveLastUsed(iterations int) { 218 | for i := 0; i < iterations; i += 1 { 219 | if len(m.elements) == 0 { 220 | return 221 | } 222 | heapEl := m.expiryTimes.Pop() 223 | mapEl := heapEl.Value.(*mapElement) 224 | delete(m.elements, mapEl.key) 225 | } 226 | } 227 | 228 | func (m *TTLMap) toEpochSeconds(ttlSeconds int) (int, error) { 229 | if ttlSeconds <= 0 { 230 | return 0, fmt.Errorf("ttlSeconds should be >= 0, got %d", ttlSeconds) 231 | } 232 | return int(clock.Now().Add(time.Second * time.Duration(ttlSeconds)).Unix()), nil 233 | } 234 | -------------------------------------------------------------------------------- /memmetrics/anomaly.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "math" 5 | "sort" 6 | "time" 7 | ) 8 | 9 | // SplitLatencies provides simple anomaly detection for requests latencies. 10 | // it splits values into good or bad category based on the threshold and the median value. 11 | // If all values are not far from the median, it will return all values in 'good' set. 12 | // Precision is the smallest value to consider, e.g. if set to millisecond, microseconds will be ignored. 13 | func SplitLatencies(values []time.Duration, precision time.Duration) (good map[time.Duration]bool, bad map[time.Duration]bool) { 14 | // Find the max latency M and then map each latency L to the ratio L/M and then call SplitFloat64 15 | v2r := map[float64]time.Duration{} 16 | ratios := make([]float64, len(values)) 17 | m := maxTime(values) 18 | for i, v := range values { 19 | ratio := float64(v/precision+1) / float64(m/precision+1) // +1 is to avoid division by 0 20 | v2r[ratio] = v 21 | ratios[i] = ratio 22 | } 23 | good, bad = make(map[time.Duration]bool), make(map[time.Duration]bool) 24 | // Note that multiplier makes this function way less sensitive than ratios detector, this is to avoid noise. 25 | vgood, vbad := SplitFloat64(2, 0, ratios) 26 | for r := range vgood { 27 | good[v2r[r]] = true 28 | } 29 | for r := range vbad { 30 | bad[v2r[r]] = true 31 | } 32 | return good, bad 33 | } 34 | 35 | // SplitRatios provides simple anomaly detection for ratio values, that are all in the range [0, 1] 36 | // it splits values into good or bad category based on the threshold and the median value. 37 | // If all values are not far from the median, it will return all values in 'good' set. 38 | func SplitRatios(values []float64) (good map[float64]bool, bad map[float64]bool) { 39 | return SplitFloat64(1.5, 0, values) 40 | } 41 | 42 | // SplitFloat64 provides simple anomaly detection for skewed data sets with no particular distribution. 43 | // In essence it applies the formula if(v > median(values) + threshold * medianAbsoluteDeviation) -> anomaly 44 | // There's a corner case where there are just 2 values, so by definition there's no value that exceeds the threshold. 45 | // This case is solved by introducing additional value that we know is good, e.g. 0. That helps to improve the detection results 46 | // on such data sets. 47 | func SplitFloat64(threshold, sentinel float64, values []float64) (good map[float64]bool, bad map[float64]bool) { 48 | good, bad = make(map[float64]bool), make(map[float64]bool) 49 | var newValues []float64 50 | if len(values)%2 == 0 { 51 | newValues = make([]float64, len(values)+1) 52 | copy(newValues, values) 53 | // Add a sentinel endpoint so we can distinguish outliers better 54 | newValues[len(newValues)-1] = sentinel 55 | } else { 56 | newValues = values 57 | } 58 | 59 | m := median(newValues) 60 | mAbs := medianAbsoluteDeviation(newValues) 61 | for _, v := range values { 62 | if v > (m+mAbs)*threshold { 63 | bad[v] = true 64 | } else { 65 | good[v] = true 66 | } 67 | } 68 | return good, bad 69 | } 70 | 71 | func median(values []float64) float64 { 72 | vals := make([]float64, len(values)) 73 | copy(vals, values) 74 | sort.Float64s(vals) 75 | l := len(vals) 76 | if l%2 != 0 { 77 | return vals[l/2] 78 | } 79 | return (vals[l/2-1] + vals[l/2]) / 2.0 80 | } 81 | 82 | func medianAbsoluteDeviation(values []float64) float64 { 83 | m := median(values) 84 | distances := make([]float64, len(values)) 85 | for i, v := range values { 86 | distances[i] = math.Abs(v - m) 87 | } 88 | return median(distances) 89 | } 90 | 91 | func maxTime(vals []time.Duration) time.Duration { 92 | val := vals[0] 93 | for _, v := range vals { 94 | if v > val { 95 | val = v 96 | } 97 | } 98 | return val 99 | } 100 | -------------------------------------------------------------------------------- /memmetrics/anomaly_test.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "strconv" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 10 | ) 11 | 12 | func TestMedian(t *testing.T) { 13 | testCases := []struct { 14 | desc string 15 | values []float64 16 | expected float64 17 | }{ 18 | { 19 | desc: "2 values", 20 | values: []float64{0.1, 0.2}, 21 | expected: (float64(0.1) + float64(0.2)) / 2.0, 22 | }, 23 | { 24 | desc: "3 values", 25 | values: []float64{0.3, 0.2, 0.5}, 26 | expected: 0.3, 27 | }, 28 | } 29 | 30 | for _, test := range testCases { 31 | t.Run(test.desc, func(t *testing.T) { 32 | t.Parallel() 33 | 34 | actual := median(test.values) 35 | assert.Equal(t, test.expected, actual) 36 | }) 37 | } 38 | } 39 | 40 | func TestSplitRatios(t *testing.T) { 41 | testCases := []struct { 42 | values []float64 43 | good []float64 44 | bad []float64 45 | }{ 46 | { 47 | values: []float64{0, 0}, 48 | good: []float64{0}, 49 | bad: []float64{}, 50 | }, 51 | 52 | { 53 | values: []float64{0, 1}, 54 | good: []float64{0}, 55 | bad: []float64{1}, 56 | }, 57 | { 58 | values: []float64{0.1, 0.1}, 59 | good: []float64{0.1}, 60 | bad: []float64{}, 61 | }, 62 | 63 | { 64 | values: []float64{0.15, 0.1}, 65 | good: []float64{0.15, 0.1}, 66 | bad: []float64{}, 67 | }, 68 | { 69 | values: []float64{0.01, 0.01}, 70 | good: []float64{0.01}, 71 | bad: []float64{}, 72 | }, 73 | { 74 | values: []float64{0.012, 0.01, 1}, 75 | good: []float64{0.012, 0.01}, 76 | bad: []float64{1}, 77 | }, 78 | { 79 | values: []float64{0, 0, 1, 1}, 80 | good: []float64{0}, 81 | bad: []float64{1}, 82 | }, 83 | { 84 | values: []float64{0, 0.1, 0.1, 0}, 85 | good: []float64{0}, 86 | bad: []float64{0.1}, 87 | }, 88 | { 89 | values: []float64{0, 0.01, 0.1, 0}, 90 | good: []float64{0}, 91 | bad: []float64{0.01, 0.1}, 92 | }, 93 | { 94 | values: []float64{0, 0.01, 0.02, 1}, 95 | good: []float64{0, 0.01, 0.02}, 96 | bad: []float64{1}, 97 | }, 98 | { 99 | values: []float64{0, 0, 0, 0, 0, 0.01, 0.02, 1}, 100 | good: []float64{0}, 101 | bad: []float64{0.01, 0.02, 1}, 102 | }, 103 | } 104 | 105 | for ind, test := range testCases { 106 | t.Run(strconv.Itoa(ind), func(t *testing.T) { 107 | t.Parallel() 108 | 109 | good, bad := SplitRatios(test.values) 110 | 111 | vgood := make(map[float64]bool, len(test.good)) 112 | for _, v := range test.good { 113 | vgood[v] = true 114 | } 115 | 116 | vbad := make(map[float64]bool, len(test.bad)) 117 | for _, v := range test.bad { 118 | vbad[v] = true 119 | } 120 | 121 | assert.Equal(t, vgood, good) 122 | assert.Equal(t, vbad, bad) 123 | }) 124 | } 125 | } 126 | 127 | func TestSplitLatencies(t *testing.T) { 128 | testCases := []struct { 129 | values []int 130 | good []int 131 | bad []int 132 | }{ 133 | { 134 | values: []int{0, 0}, 135 | good: []int{0}, 136 | bad: []int{}, 137 | }, 138 | { 139 | values: []int{1, 2}, 140 | good: []int{1, 2}, 141 | bad: []int{}, 142 | }, 143 | { 144 | values: []int{1, 2, 4}, 145 | good: []int{1, 2, 4}, 146 | bad: []int{}, 147 | }, 148 | { 149 | values: []int{8, 8, 18}, 150 | good: []int{8}, 151 | bad: []int{18}, 152 | }, 153 | { 154 | values: []int{32, 28, 11, 26, 19, 51, 25, 39, 28, 26, 8, 97}, 155 | good: []int{32, 28, 11, 26, 19, 51, 25, 39, 28, 26, 8}, 156 | bad: []int{97}, 157 | }, 158 | { 159 | values: []int{1, 2, 4, 40}, 160 | good: []int{1, 2, 4}, 161 | bad: []int{40}, 162 | }, 163 | { 164 | values: []int{40, 60, 1000}, 165 | good: []int{40, 60}, 166 | bad: []int{1000}, 167 | }, 168 | } 169 | 170 | for ind, test := range testCases { 171 | t.Run(strconv.Itoa(ind), func(t *testing.T) { 172 | t.Parallel() 173 | 174 | values := make([]time.Duration, len(test.values)) 175 | for i, d := range test.values { 176 | values[i] = clock.Millisecond * time.Duration(d) 177 | } 178 | 179 | good, bad := SplitLatencies(values, clock.Millisecond) 180 | 181 | vgood := make(map[time.Duration]bool, len(test.good)) 182 | for _, v := range test.good { 183 | vgood[time.Duration(v)*clock.Millisecond] = true 184 | } 185 | assert.Equal(t, vgood, good) 186 | 187 | vbad := make(map[time.Duration]bool, len(test.bad)) 188 | for _, v := range test.bad { 189 | vbad[time.Duration(v)*clock.Millisecond] = true 190 | } 191 | assert.Equal(t, vbad, bad) 192 | }) 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /memmetrics/counter.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | 7 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 8 | ) 9 | 10 | type rcOption func(*RollingCounter) error 11 | 12 | // RollingCounter Calculates in memory failure rate of an endpoint using rolling window of a predefined size. 13 | type RollingCounter struct { 14 | resolution time.Duration 15 | values []int 16 | countedBuckets int // how many samples in different buckets have we collected so far 17 | lastBucket int // last recorded bucket 18 | lastUpdated clock.Time 19 | } 20 | 21 | // NewCounter creates a counter with fixed amount of buckets that are rotated every resolution period. 22 | // E.g. 10 buckets with 1 second means that every new second the bucket is refreshed, so it maintains 10 seconds rolling window. 23 | // By default, creates a bucket with 10 buckets and 1 second resolution. 24 | func NewCounter(buckets int, resolution time.Duration, options ...rcOption) (*RollingCounter, error) { 25 | if buckets <= 0 { 26 | return nil, errors.New("buckets should be >= 0") 27 | } 28 | if resolution < clock.Second { 29 | return nil, errors.New("resolution should be larger than a second") 30 | } 31 | 32 | rc := &RollingCounter{ 33 | lastBucket: -1, 34 | resolution: resolution, 35 | 36 | values: make([]int, buckets), 37 | } 38 | 39 | for _, o := range options { 40 | if err := o(rc); err != nil { 41 | return nil, err 42 | } 43 | } 44 | 45 | return rc, nil 46 | } 47 | 48 | // Append appends a counter. 49 | func (c *RollingCounter) Append(o *RollingCounter) error { 50 | c.Inc(int(o.Count())) 51 | return nil 52 | } 53 | 54 | // Clone clones a counter. 55 | func (c *RollingCounter) Clone() *RollingCounter { 56 | c.cleanup() 57 | other := &RollingCounter{ 58 | resolution: c.resolution, 59 | values: make([]int, len(c.values)), 60 | lastBucket: c.lastBucket, 61 | lastUpdated: c.lastUpdated, 62 | } 63 | copy(other.values, c.values) 64 | return other 65 | } 66 | 67 | // Reset resets a counter. 68 | func (c *RollingCounter) Reset() { 69 | c.lastBucket = -1 70 | c.countedBuckets = 0 71 | c.lastUpdated = clock.Time{} 72 | for i := range c.values { 73 | c.values[i] = 0 74 | } 75 | } 76 | 77 | // CountedBuckets gets counted buckets. 78 | func (c *RollingCounter) CountedBuckets() int { 79 | return c.countedBuckets 80 | } 81 | 82 | // Count counts. 83 | func (c *RollingCounter) Count() int64 { 84 | c.cleanup() 85 | return c.sum() 86 | } 87 | 88 | // Resolution gets resolution. 89 | func (c *RollingCounter) Resolution() time.Duration { 90 | return c.resolution 91 | } 92 | 93 | // Buckets gets buckets. 94 | func (c *RollingCounter) Buckets() int { 95 | return len(c.values) 96 | } 97 | 98 | // WindowSize gets windows size. 99 | func (c *RollingCounter) WindowSize() time.Duration { 100 | return time.Duration(len(c.values)) * c.resolution 101 | } 102 | 103 | // Inc increments counter. 104 | func (c *RollingCounter) Inc(v int) { 105 | c.cleanup() 106 | c.incBucketValue(v) 107 | } 108 | 109 | func (c *RollingCounter) incBucketValue(v int) { 110 | now := clock.Now().UTC() 111 | bucket := c.getBucket(now) 112 | c.values[bucket] += v 113 | c.lastUpdated = now 114 | // Update usage stats if we haven't collected enough data 115 | if c.countedBuckets < len(c.values) { 116 | // Only update if we have advanced to the next bucket and not incremented the value 117 | // in the current bucket. 118 | if c.lastBucket != bucket { 119 | c.lastBucket = bucket 120 | c.countedBuckets++ 121 | } 122 | } 123 | } 124 | 125 | // Returns the number in the moving window bucket that this slot occupies. 126 | func (c *RollingCounter) getBucket(t time.Time) int { 127 | return int(t.Truncate(c.resolution).Unix() % int64(len(c.values))) 128 | } 129 | 130 | // Reset buckets that were not updated. 131 | func (c *RollingCounter) cleanup() { 132 | now := clock.Now().UTC() 133 | for i := 0; i < len(c.values); i++ { 134 | checkPoint := now.Add(time.Duration(-1*i) * c.resolution) 135 | if checkPoint.Truncate(c.resolution).After(c.lastUpdated.Truncate(c.resolution)) { 136 | c.values[c.getBucket(checkPoint)] = 0 137 | } else { 138 | break 139 | } 140 | } 141 | } 142 | 143 | func (c *RollingCounter) sum() int64 { 144 | out := int64(0) 145 | for _, v := range c.values { 146 | out += int64(v) 147 | } 148 | return out 149 | } 150 | -------------------------------------------------------------------------------- /memmetrics/counter_test.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 10 | "github.com/vulcand/oxy/v2/testutils" 11 | ) 12 | 13 | func TestRollingCounter_Clone_expired(t *testing.T) { 14 | testutils.FreezeTime(t) 15 | 16 | cnt, err := NewCounter(3, clock.Second) 17 | require.NoError(t, err) 18 | 19 | cnt.Inc(1) 20 | 21 | clock.Advance(clock.Second) 22 | cnt.Inc(1) 23 | 24 | clock.Advance(clock.Second) 25 | cnt.Inc(1) 26 | 27 | clock.Advance(clock.Second) 28 | out := cnt.Clone() 29 | 30 | assert.EqualValues(t, 2, out.Count()) 31 | } 32 | 33 | func TestRollingCounter_cleanup(t *testing.T) { 34 | testutils.FreezeTime(t) 35 | 36 | cnt, err := NewCounter(10, clock.Second) 37 | require.NoError(t, err) 38 | 39 | cnt.Inc(1) 40 | 41 | for i := range 9 { 42 | clock.Advance(clock.Second) 43 | cnt.Inc(int(math.Pow10(i + 1))) 44 | } 45 | 46 | assert.EqualValues(t, 1111111111, cnt.Count()) 47 | assert.Equal(t, []int{1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000, 1, 10, 100}, cnt.values) 48 | 49 | clock.Advance(9 * clock.Second) 50 | 51 | assert.EqualValues(t, 1000000000, cnt.Count()) 52 | assert.Equal(t, []int{0, 0, 0, 0, 0, 0, 1000000000, 0, 0, 0}, cnt.values) 53 | } 54 | -------------------------------------------------------------------------------- /memmetrics/histogram.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/HdrHistogram/hdrhistogram-go" 9 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 10 | ) 11 | 12 | // HDRHistogram is a tiny wrapper around github.com/HdrHistogram/hdrhistogram-go that provides convenience functions for measuring http latencies. 13 | type HDRHistogram struct { 14 | // lowest trackable value 15 | low int64 16 | // highest trackable value 17 | high int64 18 | // significant figures 19 | sigfigs int 20 | 21 | h *hdrhistogram.Histogram 22 | } 23 | 24 | // NewHDRHistogram creates a new HDRHistogram. 25 | func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error) { 26 | defer func() { 27 | if msg := recover(); msg != nil { 28 | err = fmt.Errorf("%s", msg) 29 | } 30 | }() 31 | return &HDRHistogram{ 32 | low: low, 33 | high: high, 34 | sigfigs: sigfigs, 35 | h: hdrhistogram.New(low, high, sigfigs), 36 | }, nil 37 | } 38 | 39 | // Export exports a HDRHistogram. 40 | func (h *HDRHistogram) Export() *HDRHistogram { 41 | var hist *hdrhistogram.Histogram 42 | if h.h != nil { 43 | snapshot := h.h.Export() 44 | hist = hdrhistogram.Import(snapshot) 45 | } 46 | return &HDRHistogram{low: h.low, high: h.high, sigfigs: h.sigfigs, h: hist} 47 | } 48 | 49 | // LatencyAtQuantile sets latency at quantile with microsecond precision. 50 | func (h *HDRHistogram) LatencyAtQuantile(q float64) time.Duration { 51 | return time.Duration(h.ValueAtQuantile(q)) * clock.Microsecond 52 | } 53 | 54 | // RecordLatencies Records latencies with microsecond precision. 55 | func (h *HDRHistogram) RecordLatencies(d time.Duration, n int64) error { 56 | return h.RecordValues(int64(d/clock.Microsecond), n) 57 | } 58 | 59 | // Reset resets a HDRHistogram. 60 | func (h *HDRHistogram) Reset() { 61 | h.h.Reset() 62 | } 63 | 64 | // ValueAtQuantile sets value at quantile. 65 | func (h *HDRHistogram) ValueAtQuantile(q float64) int64 { 66 | return h.h.ValueAtQuantile(q) 67 | } 68 | 69 | // RecordValues sets record values. 70 | func (h *HDRHistogram) RecordValues(v, n int64) error { 71 | return h.h.RecordValues(v, n) 72 | } 73 | 74 | // Merge merges a HDRHistogram. 75 | func (h *HDRHistogram) Merge(other *HDRHistogram) error { 76 | if other == nil { 77 | return errors.New("other is nil") 78 | } 79 | h.h.Merge(other.h) 80 | return nil 81 | } 82 | 83 | type rhOption func(r *RollingHDRHistogram) error 84 | 85 | // RollingHDRHistogram holds multiple histograms and rotates every period. 86 | // It provides resulting histogram as a result of a call of 'Merged' function. 87 | type RollingHDRHistogram struct { 88 | idx int 89 | lastRoll clock.Time 90 | period time.Duration 91 | bucketCount int 92 | low int64 93 | high int64 94 | sigfigs int 95 | buckets []*HDRHistogram 96 | } 97 | 98 | // NewRollingHDRHistogram created a new RollingHDRHistogram. 99 | func NewRollingHDRHistogram(low, high int64, sigfigs int, period time.Duration, bucketCount int, options ...rhOption) (*RollingHDRHistogram, error) { 100 | rh := &RollingHDRHistogram{ 101 | bucketCount: bucketCount, 102 | period: period, 103 | low: low, 104 | high: high, 105 | sigfigs: sigfigs, 106 | } 107 | 108 | for _, o := range options { 109 | if err := o(rh); err != nil { 110 | return nil, err 111 | } 112 | } 113 | 114 | buckets := make([]*HDRHistogram, rh.bucketCount) 115 | for i := range buckets { 116 | h, err := NewHDRHistogram(low, high, sigfigs) 117 | if err != nil { 118 | return nil, err 119 | } 120 | buckets[i] = h 121 | } 122 | rh.buckets = buckets 123 | return rh, nil 124 | } 125 | 126 | // Export exports a RollingHDRHistogram. 127 | func (r *RollingHDRHistogram) Export() *RollingHDRHistogram { 128 | export := &RollingHDRHistogram{} 129 | export.idx = r.idx 130 | export.lastRoll = r.lastRoll 131 | export.period = r.period 132 | export.bucketCount = r.bucketCount 133 | export.low = r.low 134 | export.high = r.high 135 | export.sigfigs = r.sigfigs 136 | 137 | exportBuckets := make([]*HDRHistogram, len(r.buckets)) 138 | for i, hist := range r.buckets { 139 | exportBuckets[i] = hist.Export() 140 | } 141 | export.buckets = exportBuckets 142 | 143 | return export 144 | } 145 | 146 | // Append appends a RollingHDRHistogram. 147 | func (r *RollingHDRHistogram) Append(o *RollingHDRHistogram) error { 148 | if r.bucketCount != o.bucketCount || r.period != o.period || r.low != o.low || r.high != o.high || r.sigfigs != o.sigfigs { 149 | return errors.New("can't merge") 150 | } 151 | 152 | for i := range r.buckets { 153 | if err := r.buckets[i].Merge(o.buckets[i]); err != nil { 154 | return err 155 | } 156 | } 157 | return nil 158 | } 159 | 160 | // Reset resets a RollingHDRHistogram. 161 | func (r *RollingHDRHistogram) Reset() { 162 | r.idx = 0 163 | r.lastRoll = clock.Now().UTC() 164 | for _, b := range r.buckets { 165 | b.Reset() 166 | } 167 | } 168 | 169 | func (r *RollingHDRHistogram) rotate() { 170 | r.idx = (r.idx + 1) % len(r.buckets) 171 | r.buckets[r.idx].Reset() 172 | } 173 | 174 | // Merged gets merged histogram. 175 | func (r *RollingHDRHistogram) Merged() (*HDRHistogram, error) { 176 | m, err := NewHDRHistogram(r.low, r.high, r.sigfigs) 177 | if err != nil { 178 | return m, err 179 | } 180 | for _, h := range r.buckets { 181 | if errMerge := m.Merge(h); errMerge != nil { 182 | return nil, errMerge 183 | } 184 | } 185 | return m, nil 186 | } 187 | 188 | func (r *RollingHDRHistogram) getHist() *HDRHistogram { 189 | if clock.Now().UTC().Sub(r.lastRoll) >= r.period { 190 | r.rotate() 191 | r.lastRoll = clock.Now().UTC() 192 | } 193 | return r.buckets[r.idx] 194 | } 195 | 196 | // RecordLatencies sets records latencies. 197 | func (r *RollingHDRHistogram) RecordLatencies(v time.Duration, n int64) error { 198 | return r.getHist().RecordLatencies(v, n) 199 | } 200 | 201 | // RecordValues sets record values. 202 | func (r *RollingHDRHistogram) RecordValues(v, n int64) error { 203 | return r.getHist().RecordValues(v, n) 204 | } 205 | -------------------------------------------------------------------------------- /memmetrics/histogram_test.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/HdrHistogram/hdrhistogram-go" 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 10 | "github.com/vulcand/oxy/v2/testutils" 11 | ) 12 | 13 | func TestHDRHistogram_Merge(t *testing.T) { 14 | a, err := NewHDRHistogram(1, 3600000, 2) 15 | require.NoError(t, err) 16 | 17 | require.NoError(t, a.RecordValues(1, 2)) 18 | 19 | b, err := NewHDRHistogram(1, 3600000, 2) 20 | require.NoError(t, err) 21 | 22 | require.NoError(t, b.RecordValues(2, 1)) 23 | 24 | err = a.Merge(b) 25 | require.NoError(t, err) 26 | 27 | assert.EqualValues(t, 1, a.ValueAtQuantile(50)) 28 | assert.EqualValues(t, 2, a.ValueAtQuantile(100)) 29 | } 30 | 31 | func TestHDRHistogram_Merge_nil(t *testing.T) { 32 | a, err := NewHDRHistogram(1, 3600000, 1) 33 | require.NoError(t, err) 34 | 35 | require.Error(t, a.Merge(nil)) 36 | } 37 | 38 | func TestHDRHistogram_rotation(t *testing.T) { 39 | testutils.FreezeTime(t) 40 | 41 | h, err := NewRollingHDRHistogram( 42 | 1, // min value 43 | 3600000, // max value 44 | 3, // significant figures 45 | clock.Second, 46 | 2, // 2 histograms in a window 47 | ) 48 | 49 | require.NoError(t, err) 50 | require.NotNil(t, h) 51 | 52 | err = h.RecordValues(5, 1) 53 | require.NoError(t, err) 54 | 55 | m, err := h.Merged() 56 | require.NoError(t, err) 57 | assert.EqualValues(t, 5, m.ValueAtQuantile(100)) 58 | 59 | clock.Advance(clock.Second) 60 | require.NoError(t, h.RecordValues(2, 1)) 61 | require.NoError(t, h.RecordValues(1, 1)) 62 | 63 | m, err = h.Merged() 64 | require.NoError(t, err) 65 | assert.EqualValues(t, 5, m.ValueAtQuantile(100)) 66 | 67 | // rotate, this means that the old value would evaporate 68 | clock.Advance(clock.Second) 69 | 70 | require.NoError(t, h.RecordValues(1, 1)) 71 | 72 | m, err = h.Merged() 73 | require.NoError(t, err) 74 | assert.EqualValues(t, 2, m.ValueAtQuantile(100)) 75 | } 76 | 77 | func TestHDRHistogram_Reset(t *testing.T) { 78 | testutils.FreezeTime(t) 79 | 80 | h, err := NewRollingHDRHistogram( 81 | 1, // min value 82 | 3600000, // max value 83 | 3, // significant figures 84 | clock.Second, 85 | 2, // 2 histograms in a window 86 | ) 87 | 88 | require.NoError(t, err) 89 | require.NotNil(t, h) 90 | 91 | require.NoError(t, h.RecordValues(5, 1)) 92 | 93 | m, err := h.Merged() 94 | require.NoError(t, err) 95 | assert.EqualValues(t, 5, m.ValueAtQuantile(100)) 96 | 97 | clock.Advance(clock.Second) 98 | require.NoError(t, h.RecordValues(2, 1)) 99 | require.NoError(t, h.RecordValues(1, 1)) 100 | 101 | m, err = h.Merged() 102 | require.NoError(t, err) 103 | assert.EqualValues(t, 5, m.ValueAtQuantile(100)) 104 | 105 | h.Reset() 106 | 107 | require.NoError(t, h.RecordValues(5, 1)) 108 | 109 | m, err = h.Merged() 110 | require.NoError(t, err) 111 | assert.EqualValues(t, 5, m.ValueAtQuantile(100)) 112 | 113 | clock.Advance(clock.Second) 114 | require.NoError(t, h.RecordValues(2, 1)) 115 | require.NoError(t, h.RecordValues(1, 1)) 116 | 117 | m, err = h.Merged() 118 | require.NoError(t, err) 119 | assert.EqualValues(t, 5, m.ValueAtQuantile(100)) 120 | } 121 | 122 | func TestHDRHistogram_Export_returnsNewCopy(t *testing.T) { 123 | // Create HDRHistogram instance 124 | a := HDRHistogram{ 125 | low: 1, 126 | high: 2, 127 | sigfigs: 3, 128 | h: hdrhistogram.New(0, 1, 2), 129 | } 130 | 131 | // Get a copy and modify the original 132 | b := a.Export() 133 | a.low = 11 134 | a.high = 12 135 | a.sigfigs = 4 136 | a.h = nil 137 | 138 | // Assert the copy has not been modified 139 | assert.EqualValues(t, 1, b.low) 140 | assert.EqualValues(t, 2, b.high) 141 | assert.Equal(t, 3, b.sigfigs) 142 | require.NotNil(t, b.h) 143 | } 144 | 145 | func TestRollingHDRHistogram_Export_returnsNewCopy(t *testing.T) { 146 | origTime := clock.Now() 147 | testutils.FreezeTime(t) 148 | 149 | a := RollingHDRHistogram{ 150 | idx: 1, 151 | lastRoll: origTime, 152 | period: 2 * clock.Second, 153 | bucketCount: 3, 154 | low: 4, 155 | high: 5, 156 | sigfigs: 1, 157 | buckets: []*HDRHistogram{}, 158 | } 159 | 160 | b := a.Export() 161 | a.idx = 11 162 | a.lastRoll = clock.Now().Add(1 * clock.Minute) 163 | a.period = 12 * clock.Second 164 | a.bucketCount = 13 165 | a.low = 14 166 | a.high = 15 167 | a.sigfigs = 1 168 | a.buckets = nil 169 | 170 | assert.Equal(t, 1, b.idx) 171 | assert.Equal(t, origTime, b.lastRoll) 172 | assert.Equal(t, 2*clock.Second, b.period) 173 | assert.Equal(t, 3, b.bucketCount) 174 | assert.Equal(t, int64(4), b.low) 175 | assert.EqualValues(t, 5, b.high) 176 | assert.NotNil(t, b.buckets) 177 | } 178 | -------------------------------------------------------------------------------- /memmetrics/options.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | // RTOption represents an option you can pass to NewRTMetrics. 4 | type RTOption func(r *RTMetrics) error 5 | 6 | // RTCounter set a builder function for Counter. 7 | func RTCounter(fn NewCounterFn) RTOption { 8 | return func(r *RTMetrics) error { 9 | r.newCounter = fn 10 | return nil 11 | } 12 | } 13 | 14 | // RTHistogram set a builder function for RollingHDRHistogram. 15 | func RTHistogram(fn NewRollingHistogramFn) RTOption { 16 | return func(r *RTMetrics) error { 17 | r.newHist = fn 18 | return nil 19 | } 20 | } 21 | 22 | // RatioOption represents an option you can pass to NewRatioCounter. 23 | type RatioOption func(r *RatioCounter) error 24 | -------------------------------------------------------------------------------- /memmetrics/ratio.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import "time" 4 | 5 | // RatioCounter calculates a ratio of a/a+b over a rolling window of predefined buckets. 6 | type RatioCounter struct { 7 | a *RollingCounter 8 | b *RollingCounter 9 | } 10 | 11 | // NewRatioCounter creates a new RatioCounter. 12 | func NewRatioCounter(buckets int, resolution time.Duration, options ...RatioOption) (*RatioCounter, error) { 13 | rc := &RatioCounter{} 14 | 15 | for _, o := range options { 16 | if err := o(rc); err != nil { 17 | return nil, err 18 | } 19 | } 20 | 21 | a, err := NewCounter(buckets, resolution) 22 | if err != nil { 23 | return nil, err 24 | } 25 | 26 | b, err := NewCounter(buckets, resolution) 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | rc.a = a 32 | rc.b = b 33 | return rc, nil 34 | } 35 | 36 | // Reset resets the counter. 37 | func (r *RatioCounter) Reset() { 38 | r.a.Reset() 39 | r.b.Reset() 40 | } 41 | 42 | // IsReady returns true if the counter is ready. 43 | func (r *RatioCounter) IsReady() bool { 44 | return r.a.countedBuckets+r.b.countedBuckets >= len(r.a.values) 45 | } 46 | 47 | // CountA gets count A. 48 | func (r *RatioCounter) CountA() int64 { 49 | return r.a.Count() 50 | } 51 | 52 | // CountB gets count B. 53 | func (r *RatioCounter) CountB() int64 { 54 | return r.b.Count() 55 | } 56 | 57 | // Resolution gets resolution. 58 | func (r *RatioCounter) Resolution() time.Duration { 59 | return r.a.Resolution() 60 | } 61 | 62 | // Buckets gets buckets. 63 | func (r *RatioCounter) Buckets() int { 64 | return r.a.Buckets() 65 | } 66 | 67 | // WindowSize gets windows size. 68 | func (r *RatioCounter) WindowSize() time.Duration { 69 | return r.a.WindowSize() 70 | } 71 | 72 | // ProcessedCount gets processed count. 73 | func (r *RatioCounter) ProcessedCount() int64 { 74 | return r.CountA() + r.CountB() 75 | } 76 | 77 | // Ratio gets ratio. 78 | func (r *RatioCounter) Ratio() float64 { 79 | a := r.a.Count() 80 | b := r.b.Count() 81 | // No data, return ok 82 | if a+b == 0 { 83 | return 0 84 | } 85 | return float64(a) / float64(a+b) 86 | } 87 | 88 | // IncA increments counter A. 89 | func (r *RatioCounter) IncA(v int) { 90 | r.a.Inc(v) 91 | } 92 | 93 | // IncB increments counter B. 94 | func (r *RatioCounter) IncB(v int) { 95 | r.b.Inc(v) 96 | } 97 | 98 | // TestMeter a test meter. 99 | type TestMeter struct { 100 | Rate float64 101 | NotReady bool 102 | WindowSize time.Duration 103 | } 104 | 105 | // GetWindowSize gets windows size. 106 | func (tm *TestMeter) GetWindowSize() time.Duration { 107 | return tm.WindowSize 108 | } 109 | 110 | // IsReady returns true if the meter is ready. 111 | func (tm *TestMeter) IsReady() bool { 112 | return !tm.NotReady 113 | } 114 | 115 | // GetRate gets rate. 116 | func (tm *TestMeter) GetRate() float64 { 117 | return tm.Rate 118 | } 119 | -------------------------------------------------------------------------------- /memmetrics/ratio_test.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/require" 8 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 9 | "github.com/vulcand/oxy/v2/testutils" 10 | ) 11 | 12 | func TestNewRatioCounter_invalidParams(t *testing.T) { 13 | testutils.FreezeTime(t) 14 | 15 | // Bad buckets count 16 | _, err := NewRatioCounter(0, clock.Second) 17 | require.Error(t, err) 18 | 19 | // Too precise resolution 20 | _, err = NewRatioCounter(10, clock.Millisecond) 21 | require.Error(t, err) 22 | } 23 | 24 | func TestNewRatioCounter_notReady(t *testing.T) { 25 | testutils.FreezeTime(t) 26 | 27 | // No data 28 | fr, err := NewRatioCounter(10, clock.Second) 29 | require.NoError(t, err) 30 | 31 | assert.False(t, fr.IsReady()) 32 | assert.Equal(t, 0.0, fr.Ratio()) 33 | 34 | // Not enough data 35 | fr, err = NewRatioCounter(10, clock.Second) 36 | require.NoError(t, err) 37 | 38 | fr.CountA() 39 | assert.False(t, fr.IsReady()) 40 | } 41 | 42 | func TestRatioCounter_noB(t *testing.T) { 43 | testutils.FreezeTime(t) 44 | 45 | fr, err := NewRatioCounter(1, clock.Second) 46 | require.NoError(t, err) 47 | 48 | fr.IncA(1) 49 | 50 | assert.True(t, fr.IsReady()) 51 | assert.Equal(t, 1.0, fr.Ratio()) 52 | } 53 | 54 | func TestRatioCounter_noA(t *testing.T) { 55 | testutils.FreezeTime(t) 56 | 57 | fr, err := NewRatioCounter(1, clock.Second) 58 | require.NoError(t, err) 59 | 60 | fr.IncB(1) 61 | 62 | assert.True(t, fr.IsReady()) 63 | assert.Equal(t, 0.0, fr.Ratio()) 64 | } 65 | 66 | // Make sure that data is properly calculated over several buckets. 67 | func TestRatioCounter_multipleBuckets(t *testing.T) { 68 | testutils.FreezeTime(t) 69 | 70 | fr, err := NewRatioCounter(3, clock.Second) 71 | require.NoError(t, err) 72 | 73 | fr.IncB(1) 74 | clock.Advance(clock.Second) 75 | fr.IncA(1) 76 | 77 | clock.Advance(clock.Second) 78 | fr.IncA(1) 79 | 80 | assert.True(t, fr.IsReady()) 81 | assert.Equal(t, float64(2)/float64(3), fr.Ratio()) 82 | } 83 | 84 | // Make sure that data is properly calculated over several buckets 85 | // When we overwrite old data when the window is rolling. 86 | func TestRatioCounter_overwriteBuckets(t *testing.T) { 87 | testutils.FreezeTime(t) 88 | 89 | fr, err := NewRatioCounter(3, clock.Second) 90 | require.NoError(t, err) 91 | 92 | fr.IncB(1) 93 | 94 | clock.Advance(clock.Second) 95 | fr.IncA(1) 96 | 97 | clock.Advance(clock.Second) 98 | fr.IncA(1) 99 | 100 | // This time we should overwrite the old data points 101 | clock.Advance(clock.Second) 102 | fr.IncA(1) 103 | fr.IncB(2) 104 | 105 | assert.True(t, fr.IsReady()) 106 | assert.Equal(t, float64(3)/float64(5), fr.Ratio()) 107 | } 108 | 109 | // Make sure we cleanup the data after periods of inactivity 110 | // So it does not mess up the stats. 111 | func TestRatioCounter_inactiveBuckets(t *testing.T) { 112 | testutils.FreezeTime(t) 113 | 114 | fr, err := NewRatioCounter(3, clock.Second) 115 | require.NoError(t, err) 116 | 117 | fr.IncB(1) 118 | 119 | clock.Advance(clock.Second) 120 | fr.IncA(1) 121 | 122 | clock.Advance(clock.Second) 123 | fr.IncA(1) 124 | 125 | // This time we should overwrite the old data points with new data 126 | clock.Advance(clock.Second) 127 | fr.IncA(1) 128 | fr.IncB(2) 129 | 130 | // Jump to the last bucket and change the data 131 | clock.Advance(clock.Second * 2) 132 | fr.IncB(1) 133 | 134 | assert.True(t, fr.IsReady()) 135 | assert.Equal(t, float64(1)/float64(4), fr.Ratio()) 136 | } 137 | 138 | func TestRatioCounter_longPeriodsOfInactivity(t *testing.T) { 139 | testutils.FreezeTime(t) 140 | 141 | fr, err := NewRatioCounter(2, clock.Second) 142 | require.NoError(t, err) 143 | 144 | fr.IncB(1) 145 | 146 | clock.Advance(clock.Second) 147 | fr.IncA(1) 148 | 149 | assert.True(t, fr.IsReady()) 150 | assert.Equal(t, 0.5, fr.Ratio()) 151 | 152 | // This time we should overwrite all data points 153 | clock.Advance(100 * clock.Second) 154 | fr.IncA(1) 155 | assert.Equal(t, 1.0, fr.Ratio()) 156 | } 157 | 158 | func TestRatioCounter_Reset(t *testing.T) { 159 | testutils.FreezeTime(t) 160 | 161 | fr, err := NewRatioCounter(1, clock.Second) 162 | require.NoError(t, err) 163 | 164 | fr.IncB(1) 165 | fr.IncA(1) 166 | 167 | assert.True(t, fr.IsReady()) 168 | assert.Equal(t, 0.5, fr.Ratio()) 169 | 170 | // Reset the counter 171 | fr.Reset() 172 | assert.False(t, fr.IsReady()) 173 | 174 | // Now add some stats 175 | fr.IncA(2) 176 | 177 | // We are game again! 178 | assert.True(t, fr.IsReady()) 179 | assert.Equal(t, 1.0, fr.Ratio()) 180 | } 181 | -------------------------------------------------------------------------------- /memmetrics/roundtrip_test.go: -------------------------------------------------------------------------------- 1 | package memmetrics 2 | 3 | import ( 4 | "runtime" 5 | "sync" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 12 | "github.com/vulcand/oxy/v2/testutils" 13 | ) 14 | 15 | func TestNewRTMetrics_defaults(t *testing.T) { 16 | testutils.FreezeTime(t) 17 | 18 | rr, err := NewRTMetrics() 19 | require.NoError(t, err) 20 | require.NotNil(t, rr) 21 | 22 | rr.Record(200, clock.Second) 23 | rr.Record(502, 2*clock.Second) 24 | rr.Record(200, clock.Second) 25 | rr.Record(200, clock.Second) 26 | 27 | assert.EqualValues(t, 1, rr.NetworkErrorCount()) 28 | assert.EqualValues(t, 4, rr.TotalCount()) 29 | assert.Equal(t, map[int]int64{502: 1, 200: 3}, rr.StatusCodesCounts()) 30 | assert.Equal(t, float64(1)/float64(4), rr.NetworkErrorRatio()) 31 | assert.Equal(t, 1.0/3.0, rr.ResponseCodeRatio(500, 503, 200, 300)) 32 | 33 | h, err := rr.LatencyHistogram() 34 | require.NoError(t, err) 35 | assert.Equal(t, 2, int(h.LatencyAtQuantile(100)/clock.Second)) 36 | 37 | rr.Reset() 38 | assert.EqualValues(t, 0, rr.NetworkErrorCount()) 39 | assert.EqualValues(t, 0, rr.TotalCount()) 40 | assert.Equal(t, map[int]int64{}, rr.StatusCodesCounts()) 41 | assert.Equal(t, float64(0), rr.NetworkErrorRatio()) 42 | assert.Equal(t, float64(0), rr.ResponseCodeRatio(500, 503, 200, 300)) 43 | 44 | h, err = rr.LatencyHistogram() 45 | require.NoError(t, err) 46 | assert.Equal(t, time.Duration(0), h.LatencyAtQuantile(100)) 47 | } 48 | 49 | func TestRTMetrics_Append(t *testing.T) { 50 | testutils.FreezeTime(t) 51 | 52 | rr, err := NewRTMetrics() 53 | require.NoError(t, err) 54 | require.NotNil(t, rr) 55 | 56 | rr.Record(200, clock.Second) 57 | rr.Record(502, 2*clock.Second) 58 | rr.Record(200, clock.Second) 59 | rr.Record(200, clock.Second) 60 | 61 | rr2, err := NewRTMetrics() 62 | require.NoError(t, err) 63 | require.NotNil(t, rr2) 64 | 65 | rr2.Record(200, 3*clock.Second) 66 | rr2.Record(501, 3*clock.Second) 67 | rr2.Record(200, 3*clock.Second) 68 | rr2.Record(200, 3*clock.Second) 69 | 70 | require.NoError(t, rr2.Append(rr)) 71 | assert.Equal(t, map[int]int64{501: 1, 502: 1, 200: 6}, rr2.StatusCodesCounts()) 72 | assert.EqualValues(t, 1, rr2.NetworkErrorCount()) 73 | 74 | h, err := rr2.LatencyHistogram() 75 | require.NoError(t, err) 76 | assert.EqualValues(t, 3, h.LatencyAtQuantile(100)/clock.Second) 77 | } 78 | 79 | func TestRTMetrics_concurrentRecords(t *testing.T) { 80 | // This test asserts a race condition which requires parallelism 81 | runtime.GOMAXPROCS(100) 82 | 83 | rr, err := NewRTMetrics() 84 | require.NoError(t, err) 85 | 86 | for code := range 100 { 87 | for range 10 { 88 | go func(statusCode int) { 89 | _ = rr.recordStatusCode(statusCode) 90 | }(code) 91 | } 92 | } 93 | } 94 | 95 | func TestRTMetric_Export_returnsNewCopy(t *testing.T) { 96 | a := RTMetrics{ 97 | statusCodes: map[int]*RollingCounter{}, 98 | statusCodesLock: sync.RWMutex{}, 99 | histogram: &RollingHDRHistogram{}, 100 | histogramLock: sync.RWMutex{}, 101 | } 102 | 103 | var err error 104 | a.total, err = NewCounter(1, clock.Second) 105 | require.NoError(t, err) 106 | 107 | a.netErrors, err = NewCounter(1, clock.Second) 108 | require.NoError(t, err) 109 | 110 | a.newCounter = func() (*RollingCounter, error) { 111 | return NewCounter(counterBuckets, counterResolution) 112 | } 113 | a.newHist = func() (*RollingHDRHistogram, error) { 114 | return NewRollingHDRHistogram(histMin, histMax, histSignificantFigures, histPeriod, histBuckets) 115 | } 116 | 117 | b := a.Export() 118 | a.total = nil 119 | a.netErrors = nil 120 | a.statusCodes = nil 121 | a.histogram = nil 122 | a.newCounter = nil 123 | a.newHist = nil 124 | 125 | assert.NotNil(t, b.total) 126 | assert.NotNil(t, b.netErrors) 127 | assert.NotNil(t, b.statusCodes) 128 | assert.NotNil(t, b.histogram) 129 | assert.NotNil(t, b.newCounter) 130 | assert.NotNil(t, b.newHist) 131 | 132 | // a and b should have different locks 133 | locksSucceed := make(chan bool) 134 | go func() { 135 | a.statusCodesLock.Lock() 136 | b.statusCodesLock.Lock() 137 | a.histogramLock.Lock() 138 | b.histogramLock.Lock() 139 | locksSucceed <- true 140 | }() 141 | 142 | for { 143 | select { 144 | case <-locksSucceed: 145 | return 146 | case <-clock.After(10 * clock.Second): 147 | t.FailNow() 148 | } 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /ratelimit/bucket.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 9 | ) 10 | 11 | // UndefinedDelay default delay. 12 | const UndefinedDelay = -1 13 | 14 | // rate defines token bucket parameters. 15 | type rate struct { 16 | period time.Duration 17 | average int64 18 | burst int64 19 | } 20 | 21 | func (r *rate) String() string { 22 | return fmt.Sprintf("rate(%v/%v, burst=%v)", r.average, r.period, r.burst) 23 | } 24 | 25 | // tokenBucket Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket) 26 | type tokenBucket struct { 27 | // The time period controlled by the bucket in nanoseconds. 28 | period time.Duration 29 | // The number of nanoseconds that takes to add one more token to the total 30 | // number of available tokens. It effectively caches the value that could 31 | // have been otherwise deduced from refillRate. 32 | timePerToken time.Duration 33 | // The maximum number of tokens that can be accumulate in the bucket. 34 | burst int64 35 | // The number of tokens available for consumption at the moment. It can 36 | // nether be larger then capacity. 37 | availableTokens int64 38 | // Tells when tokensAvailable was updated the last time. 39 | lastRefresh clock.Time 40 | // The number of tokens consumed the last time. 41 | lastConsumed int64 42 | } 43 | 44 | // newTokenBucket crates a `tokenBucket` instance for the specified `Rate`. 45 | func newTokenBucket(rate *rate) *tokenBucket { 46 | period := rate.period 47 | if period == 0 { 48 | period = clock.Nanosecond 49 | } 50 | 51 | return &tokenBucket{ 52 | period: period, 53 | timePerToken: time.Duration(int64(period) / rate.average), 54 | burst: rate.burst, 55 | lastRefresh: clock.Now().UTC(), 56 | availableTokens: rate.burst, 57 | } 58 | } 59 | 60 | // consume makes an attempt to consume the specified number of tokens from the 61 | // bucket. If there are enough tokens available then `0, nil` is returned; if 62 | // tokens to consume is larger than the burst size, then an error is returned 63 | // and the delay is not defined; otherwise returned a none zero delay that tells 64 | // how much time the caller needs to wait until the desired number of tokens 65 | // will become available for consumption. 66 | func (tb *tokenBucket) consume(tokens int64) (time.Duration, error) { 67 | tb.updateAvailableTokens() 68 | tb.lastConsumed = 0 69 | if tokens > tb.burst { 70 | return UndefinedDelay, errors.New("requested tokens larger than max tokens") 71 | } 72 | if tb.availableTokens < tokens { 73 | return tb.timeTillAvailable(tokens), nil 74 | } 75 | tb.availableTokens -= tokens 76 | tb.lastConsumed = tokens 77 | return 0, nil 78 | } 79 | 80 | // rollback reverts effect of the most recent consumption. If the most recent 81 | // `consume` resulted in an error or a burst overflow, and therefore did not 82 | // modify the number of available tokens, then `rollback` won't do that either. 83 | // It is safe to call this method multiple times, for the second and all 84 | // following calls have no effect. 85 | func (tb *tokenBucket) rollback() { 86 | tb.availableTokens += tb.lastConsumed 87 | tb.lastConsumed = 0 88 | } 89 | 90 | // update modifies `average` and `burst` fields of the token bucket according 91 | // to the provided `Rate`. 92 | func (tb *tokenBucket) update(rate *rate) error { 93 | if rate.period != tb.period { 94 | return fmt.Errorf("period mismatch: %v != %v", tb.period, rate.period) 95 | } 96 | tb.timePerToken = time.Duration(int64(tb.period) / rate.average) 97 | tb.burst = rate.burst 98 | if tb.availableTokens > rate.burst { 99 | tb.availableTokens = rate.burst 100 | } 101 | return nil 102 | } 103 | 104 | // timeTillAvailable returns the number of nanoseconds that we need to 105 | // wait until the specified number of tokens becomes available for consumption. 106 | func (tb *tokenBucket) timeTillAvailable(tokens int64) time.Duration { 107 | missingTokens := tokens - tb.availableTokens 108 | return time.Duration(missingTokens) * tb.timePerToken 109 | } 110 | 111 | // updateAvailableTokens updates the number of tokens available for consumption. 112 | // It is calculated based on the refill rate, the time passed since last refresh, 113 | // and is limited by the bucket capacity. 114 | func (tb *tokenBucket) updateAvailableTokens() { 115 | now := clock.Now().UTC() 116 | timePassed := now.Sub(tb.lastRefresh) 117 | 118 | if tb.timePerToken == 0 { 119 | return 120 | } 121 | 122 | tokens := tb.availableTokens + int64(timePassed/tb.timePerToken) 123 | // If we haven't added any tokens that means that not enough time has passed, 124 | // in this case do not adjust last refill checkpoint, otherwise it will be 125 | // always moving in time in case of frequent requests that exceed the rate 126 | if tokens != tb.availableTokens { 127 | tb.lastRefresh = now 128 | tb.availableTokens = tokens 129 | } 130 | if tb.availableTokens > tb.burst { 131 | tb.availableTokens = tb.burst 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /ratelimit/bucketset.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "strings" 7 | "time" 8 | ) 9 | 10 | // TokenBucketSet represents a set of TokenBucket covering different time periods. 11 | type TokenBucketSet struct { 12 | buckets map[time.Duration]*tokenBucket 13 | maxPeriod time.Duration 14 | } 15 | 16 | // NewTokenBucketSet creates a `TokenBucketSet` from the specified `rates`. 17 | func NewTokenBucketSet(rates *RateSet) *TokenBucketSet { 18 | tbs := new(TokenBucketSet) 19 | // In the majority of cases we will have only one bucket. 20 | tbs.buckets = make(map[time.Duration]*tokenBucket, len(rates.m)) 21 | for _, rate := range rates.m { 22 | newBucket := newTokenBucket(rate) 23 | tbs.buckets[rate.period] = newBucket 24 | tbs.maxPeriod = maxDuration(tbs.maxPeriod, rate.period) 25 | } 26 | return tbs 27 | } 28 | 29 | // Update brings the buckets in the set in accordance with the provided `rates`. 30 | func (tbs *TokenBucketSet) Update(rates *RateSet) { 31 | // Update existing buckets and delete those that have no corresponding spec. 32 | for _, bucket := range tbs.buckets { 33 | if rate, ok := rates.m[bucket.period]; ok { 34 | _ = bucket.update(rate) 35 | } else { 36 | delete(tbs.buckets, bucket.period) 37 | } 38 | } 39 | // Add missing buckets. 40 | for _, rate := range rates.m { 41 | if _, ok := tbs.buckets[rate.period]; !ok { 42 | newBucket := newTokenBucket(rate) 43 | tbs.buckets[rate.period] = newBucket 44 | } 45 | } 46 | // Identify the maximum period in the set 47 | tbs.maxPeriod = 0 48 | for _, bucket := range tbs.buckets { 49 | tbs.maxPeriod = maxDuration(tbs.maxPeriod, bucket.period) 50 | } 51 | } 52 | 53 | // Consume consume tokens. 54 | func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) { 55 | var maxDelay time.Duration = UndefinedDelay 56 | var firstErr error 57 | for _, tokenBucket := range tbs.buckets { 58 | // We keep calling `Consume` even after a error is returned for one of 59 | // buckets because that allows us to simplify the rollback procedure, 60 | // that is to just call `Rollback` for all buckets. 61 | delay, err := tokenBucket.consume(tokens) 62 | if firstErr == nil { 63 | if err != nil { 64 | firstErr = err 65 | } else { 66 | maxDelay = maxDuration(maxDelay, delay) 67 | } 68 | } 69 | } 70 | // If we could not make ALL buckets consume tokens for whatever reason, 71 | // then rollback consumption for all of them. 72 | if firstErr != nil || maxDelay > 0 { 73 | for _, tokenBucket := range tbs.buckets { 74 | tokenBucket.rollback() 75 | } 76 | } 77 | return maxDelay, firstErr 78 | } 79 | 80 | // GetMaxPeriod returns the max period. 81 | func (tbs *TokenBucketSet) GetMaxPeriod() time.Duration { 82 | return tbs.maxPeriod 83 | } 84 | 85 | // debugState returns string that reflects the current state of all buckets in 86 | // this set. It is intended to be used for debugging and testing only. 87 | func (tbs *TokenBucketSet) debugState() string { 88 | periods := make([]int64, 0, len(tbs.buckets)) 89 | for period := range tbs.buckets { 90 | periods = append(periods, int64(period)) 91 | } 92 | sort.Slice(periods, func(i, j int) bool { return periods[i] < periods[j] }) 93 | bucketRepr := make([]string, 0, len(tbs.buckets)) 94 | for _, period := range periods { 95 | bucket := tbs.buckets[time.Duration(period)] 96 | bucketRepr = append(bucketRepr, fmt.Sprintf("{%v: %v}", bucket.period, bucket.availableTokens)) 97 | } 98 | return strings.Join(bucketRepr, ", ") 99 | } 100 | 101 | func maxDuration(x, y time.Duration) time.Duration { 102 | if x > y { 103 | return x 104 | } 105 | return y 106 | } 107 | -------------------------------------------------------------------------------- /ratelimit/options.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/vulcand/oxy/v2/utils" 7 | ) 8 | 9 | // TokenLimiterOption token limiter option type. 10 | type TokenLimiterOption func(l *TokenLimiter) error 11 | 12 | // ErrorHandler sets error handler of the server. 13 | func ErrorHandler(h utils.ErrorHandler) TokenLimiterOption { 14 | return func(cl *TokenLimiter) error { 15 | cl.errHandler = h 16 | return nil 17 | } 18 | } 19 | 20 | // ExtractRates sets the rate extractor. 21 | func ExtractRates(e RateExtractor) TokenLimiterOption { 22 | return func(cl *TokenLimiter) error { 23 | cl.extractRates = e 24 | return nil 25 | } 26 | } 27 | 28 | // Capacity sets the capacity. 29 | func Capacity(capacity int) TokenLimiterOption { 30 | return func(cl *TokenLimiter) error { 31 | if capacity <= 0 { 32 | return fmt.Errorf("bad capacity: %v", capacity) 33 | } 34 | cl.capacity = capacity 35 | return nil 36 | } 37 | } 38 | 39 | // Logger defines the logger the TokenLimiter will use. 40 | func Logger(l utils.Logger) TokenLimiterOption { 41 | return func(tl *TokenLimiter) error { 42 | tl.log = l 43 | return nil 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /ratelimit/tokenlimiter.go: -------------------------------------------------------------------------------- 1 | // Package ratelimit Tokenbucket based request rate limiter 2 | package ratelimit 3 | 4 | import ( 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "sync" 9 | "time" 10 | 11 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 12 | "github.com/vulcand/oxy/v2/internal/holsterv4/collections" 13 | "github.com/vulcand/oxy/v2/utils" 14 | ) 15 | 16 | // DefaultCapacity default capacity. 17 | const DefaultCapacity = 65536 18 | 19 | // RateSet maintains a set of rates. It can contain only one rate per period at a time. 20 | type RateSet struct { 21 | m map[time.Duration]*rate 22 | } 23 | 24 | // NewRateSet crates an empty `RateSet` instance. 25 | func NewRateSet() *RateSet { 26 | rs := new(RateSet) 27 | rs.m = make(map[time.Duration]*rate) 28 | return rs 29 | } 30 | 31 | // Add adds a rate to the set. If there is a rate with the same period in the 32 | // set then the new rate overrides the old one. 33 | func (rs *RateSet) Add(period time.Duration, average int64, burst int64) error { 34 | if period <= 0 { 35 | return fmt.Errorf("invalid period: %v", period) 36 | } 37 | if average <= 0 { 38 | return fmt.Errorf("invalid average: %v", average) 39 | } 40 | if burst <= 0 { 41 | return fmt.Errorf("invalid burst: %v", burst) 42 | } 43 | rs.m[period] = &rate{period: period, average: average, burst: burst} 44 | return nil 45 | } 46 | 47 | func (rs *RateSet) String() string { 48 | return fmt.Sprint(rs.m) 49 | } 50 | 51 | // RateExtractor rate extractor. 52 | type RateExtractor interface { 53 | Extract(r *http.Request) (*RateSet, error) 54 | } 55 | 56 | // RateExtractorFunc rate extractor function type. 57 | type RateExtractorFunc func(r *http.Request) (*RateSet, error) 58 | 59 | // Extract extract from request. 60 | func (e RateExtractorFunc) Extract(r *http.Request) (*RateSet, error) { 61 | return e(r) 62 | } 63 | 64 | // TokenLimiter implements rate limiting middleware. 65 | type TokenLimiter struct { 66 | defaultRates *RateSet 67 | extract utils.SourceExtractor 68 | extractRates RateExtractor 69 | mutex sync.Mutex 70 | bucketSets *collections.TTLMap 71 | errHandler utils.ErrorHandler 72 | capacity int 73 | next http.Handler 74 | 75 | log utils.Logger 76 | } 77 | 78 | // New constructs a `TokenLimiter` middleware instance. 79 | func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet, opts ...TokenLimiterOption) (*TokenLimiter, error) { 80 | if defaultRates == nil || len(defaultRates.m) == 0 { 81 | return nil, errors.New("provide default rates") 82 | } 83 | if extract == nil { 84 | return nil, errors.New("provide extract function") 85 | } 86 | tl := &TokenLimiter{ 87 | next: next, 88 | defaultRates: defaultRates, 89 | extract: extract, 90 | 91 | log: &utils.NoopLogger{}, 92 | } 93 | 94 | for _, o := range opts { 95 | if err := o(tl); err != nil { 96 | return nil, err 97 | } 98 | } 99 | setDefaults(tl) 100 | tl.bucketSets = collections.NewTTLMap(tl.capacity) 101 | return tl, nil 102 | } 103 | 104 | // Wrap sets the next handler to be called by token limiter handler. 105 | func (tl *TokenLimiter) Wrap(next http.Handler) { 106 | tl.next = next 107 | } 108 | 109 | func (tl *TokenLimiter) ServeHTTP(w http.ResponseWriter, req *http.Request) { 110 | source, amount, err := tl.extract.Extract(req) 111 | if err != nil { 112 | tl.errHandler.ServeHTTP(w, req, err) 113 | return 114 | } 115 | 116 | if err := tl.consumeRates(req, source, amount); err != nil { 117 | tl.log.Warn("limiting request %v %v, limit: %v", req.Method, req.URL, err) 118 | tl.errHandler.ServeHTTP(w, req, err) 119 | return 120 | } 121 | 122 | tl.next.ServeHTTP(w, req) 123 | } 124 | 125 | func (tl *TokenLimiter) consumeRates(req *http.Request, source string, amount int64) error { 126 | tl.mutex.Lock() 127 | defer tl.mutex.Unlock() 128 | 129 | effectiveRates := tl.resolveRates(req) 130 | bucketSetI, exists := tl.bucketSets.Get(source) 131 | var bucketSet *TokenBucketSet 132 | 133 | if exists { 134 | bucketSet = bucketSetI.(*TokenBucketSet) 135 | bucketSet.Update(effectiveRates) 136 | } else { 137 | bucketSet = NewTokenBucketSet(effectiveRates) 138 | // We set ttl as 10 times rate period. E.g. if rate is 100 requests/second per client ip 139 | // the counters for this ip will expire after 10 seconds of inactivity 140 | err := tl.bucketSets.Set(source, bucketSet, int(bucketSet.maxPeriod/clock.Second)*10+1) 141 | if err != nil { 142 | return err 143 | } 144 | } 145 | delay, err := bucketSet.Consume(amount) 146 | if err != nil { 147 | return err 148 | } 149 | if delay > 0 { 150 | return &MaxRateError{Delay: delay} 151 | } 152 | return nil 153 | } 154 | 155 | // effectiveRates retrieves rates to be applied to the request. 156 | func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet { 157 | // If configuration mapper is not specified for this instance, then return 158 | // the default bucket specs. 159 | if tl.extractRates == nil { 160 | return tl.defaultRates 161 | } 162 | 163 | rates, err := tl.extractRates.Extract(req) 164 | if err != nil { 165 | tl.log.Error("Failed to retrieve rates: %v", err) 166 | return tl.defaultRates 167 | } 168 | 169 | // If the returned rate set is empty then used the default one. 170 | if len(rates.m) == 0 { 171 | return tl.defaultRates 172 | } 173 | 174 | return rates 175 | } 176 | 177 | // MaxRateError max rate error. 178 | type MaxRateError struct { 179 | Delay time.Duration 180 | } 181 | 182 | func (m *MaxRateError) Error() string { 183 | return fmt.Sprintf("max rate reached: retry-in %v", m.Delay) 184 | } 185 | 186 | // RateErrHandler error handler. 187 | type RateErrHandler struct{} 188 | 189 | func (e *RateErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { 190 | //nolint:errorlint // must be changed 191 | if rerr, ok := err.(*MaxRateError); ok { 192 | w.Header().Set("Retry-After", fmt.Sprintf("%.0f", rerr.Delay.Seconds())) 193 | w.Header().Set("X-Retry-In", rerr.Delay.String()) 194 | w.WriteHeader(http.StatusTooManyRequests) 195 | _, _ = w.Write([]byte(err.Error())) 196 | return 197 | } 198 | utils.DefaultHandler.ServeHTTP(w, req, err) 199 | } 200 | 201 | var defaultErrHandler = &RateErrHandler{} 202 | 203 | func setDefaults(tl *TokenLimiter) { 204 | if tl.capacity <= 0 { 205 | tl.capacity = DefaultCapacity 206 | } 207 | if tl.errHandler == nil { 208 | tl.errHandler = defaultErrHandler 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /roundrobin/RequestRewriteListener.go: -------------------------------------------------------------------------------- 1 | package roundrobin 2 | 3 | import "net/http" 4 | 5 | // RequestRewriteListener function to rewrite request. 6 | type RequestRewriteListener func(oldReq *http.Request, newReq *http.Request) 7 | -------------------------------------------------------------------------------- /roundrobin/options.go: -------------------------------------------------------------------------------- 1 | package roundrobin 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | 7 | "github.com/vulcand/oxy/v2/utils" 8 | ) 9 | 10 | // RebalancerOption represents an option you can pass to NewRebalancer. 11 | type RebalancerOption func(*Rebalancer) error 12 | 13 | // RebalancerBackoff sets a beck off duration. 14 | func RebalancerBackoff(d time.Duration) RebalancerOption { 15 | return func(r *Rebalancer) error { 16 | r.backoffDuration = d 17 | return nil 18 | } 19 | } 20 | 21 | // RebalancerMeter sets a Meter builder function. 22 | func RebalancerMeter(newMeter NewMeterFn) RebalancerOption { 23 | return func(r *Rebalancer) error { 24 | r.newMeter = newMeter 25 | return nil 26 | } 27 | } 28 | 29 | // RebalancerErrorHandler is a functional argument that sets error handler of the server. 30 | func RebalancerErrorHandler(h utils.ErrorHandler) RebalancerOption { 31 | return func(r *Rebalancer) error { 32 | r.errHandler = h 33 | return nil 34 | } 35 | } 36 | 37 | // RebalancerStickySession sets a sticky session. 38 | func RebalancerStickySession(stickySession *StickySession) RebalancerOption { 39 | return func(r *Rebalancer) error { 40 | r.stickySession = stickySession 41 | return nil 42 | } 43 | } 44 | 45 | // RebalancerRequestRewriteListener is a functional argument that sets error handler of the server. 46 | func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOption { 47 | return func(r *Rebalancer) error { 48 | r.requestRewriteListener = rrl 49 | return nil 50 | } 51 | } 52 | 53 | // RebalancerLogger defines the logger used by Rebalancer. 54 | func RebalancerLogger(l utils.Logger) RebalancerOption { 55 | return func(rb *Rebalancer) error { 56 | rb.log = l 57 | return nil 58 | } 59 | } 60 | 61 | // RebalancerDebug additional debug information. 62 | func RebalancerDebug(debug bool) RebalancerOption { 63 | return func(rb *Rebalancer) error { 64 | rb.debug = debug 65 | return nil 66 | } 67 | } 68 | 69 | // ServerOption provides various options for server, e.g. weight. 70 | type ServerOption func(*server) error 71 | 72 | // Weight is an optional functional argument that sets weight of the server. 73 | func Weight(w int) ServerOption { 74 | return func(s *server) error { 75 | if w < 0 { 76 | return errors.New("Weight should be >= 0") 77 | } 78 | s.weight = w 79 | return nil 80 | } 81 | } 82 | 83 | // LBOption provides options for load balancer. 84 | type LBOption func(*RoundRobin) error 85 | 86 | // ErrorHandler is a functional argument that sets error handler of the server. 87 | func ErrorHandler(h utils.ErrorHandler) LBOption { 88 | return func(s *RoundRobin) error { 89 | s.errHandler = h 90 | return nil 91 | } 92 | } 93 | 94 | // EnableStickySession enable sticky session. 95 | func EnableStickySession(stickySession *StickySession) LBOption { 96 | return func(s *RoundRobin) error { 97 | s.stickySession = stickySession 98 | return nil 99 | } 100 | } 101 | 102 | // RoundRobinRequestRewriteListener is a functional argument that sets error handler of the server. 103 | func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption { 104 | return func(s *RoundRobin) error { 105 | s.requestRewriteListener = rrl 106 | return nil 107 | } 108 | } 109 | 110 | // Logger defines the logger the RoundRobin will use. 111 | func Logger(l utils.Logger) LBOption { 112 | return func(r *RoundRobin) error { 113 | r.log = l 114 | return nil 115 | } 116 | } 117 | 118 | // Verbose additional debug information. 119 | func Verbose(verbose bool) LBOption { 120 | return func(r *RoundRobin) error { 121 | r.verbose = verbose 122 | return nil 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /roundrobin/rr_test.go: -------------------------------------------------------------------------------- 1 | package roundrobin 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | "github.com/vulcand/oxy/v2/forward" 11 | "github.com/vulcand/oxy/v2/testutils" 12 | "github.com/vulcand/oxy/v2/utils" 13 | ) 14 | 15 | func TestRoundRobin_noServers(t *testing.T) { 16 | fwd := forward.New(false) 17 | 18 | lb, err := New(fwd) 19 | require.NoError(t, err) 20 | 21 | proxy := httptest.NewServer(lb) 22 | t.Cleanup(proxy.Close) 23 | 24 | re, _, err := testutils.Get(proxy.URL) 25 | require.NoError(t, err) 26 | assert.Equal(t, http.StatusInternalServerError, re.StatusCode) 27 | } 28 | 29 | func TestRoundRobin_RemoveServer_badServer(t *testing.T) { 30 | lb, err := New(nil) 31 | require.NoError(t, err) 32 | 33 | require.Error(t, lb.RemoveServer(testutils.MustParseRequestURI("http://google.com"))) 34 | } 35 | 36 | func TestRoundRobin_customErrHandler(t *testing.T) { 37 | errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, _ *http.Request, _ error) { 38 | w.WriteHeader(http.StatusTeapot) 39 | _, _ = w.Write([]byte(http.StatusText(http.StatusTeapot))) 40 | }) 41 | 42 | fwd := forward.New(false) 43 | 44 | lb, err := New(fwd, ErrorHandler(errHandler)) 45 | require.NoError(t, err) 46 | 47 | proxy := httptest.NewServer(lb) 48 | t.Cleanup(proxy.Close) 49 | 50 | re, _, err := testutils.Get(proxy.URL) 51 | require.NoError(t, err) 52 | assert.Equal(t, http.StatusTeapot, re.StatusCode) 53 | } 54 | 55 | func TestRoundRobin_oneServer(t *testing.T) { 56 | a := testutils.NewResponder(t, "a") 57 | 58 | fwd := forward.New(false) 59 | 60 | lb, err := New(fwd) 61 | require.NoError(t, err) 62 | 63 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(a.URL))) 64 | 65 | proxy := httptest.NewServer(lb) 66 | t.Cleanup(proxy.Close) 67 | 68 | assert.Equal(t, []string{"a", "a", "a"}, seq(t, proxy.URL, 3)) 69 | } 70 | 71 | func TestRoundRobin__imple(t *testing.T) { 72 | a := testutils.NewResponder(t, "a") 73 | b := testutils.NewResponder(t, "b") 74 | 75 | fwd := forward.New(false) 76 | 77 | lb, err := New(fwd) 78 | require.NoError(t, err) 79 | 80 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(a.URL))) 81 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(b.URL))) 82 | 83 | proxy := httptest.NewServer(lb) 84 | t.Cleanup(proxy.Close) 85 | 86 | assert.Equal(t, []string{"a", "b", "a"}, seq(t, proxy.URL, 3)) 87 | } 88 | 89 | func TestRoundRobin_removeServer(t *testing.T) { 90 | a := testutils.NewResponder(t, "a") 91 | b := testutils.NewResponder(t, "b") 92 | 93 | fwd := forward.New(false) 94 | 95 | lb, err := New(fwd) 96 | require.NoError(t, err) 97 | 98 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(a.URL))) 99 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(b.URL))) 100 | 101 | proxy := httptest.NewServer(lb) 102 | t.Cleanup(proxy.Close) 103 | 104 | assert.Equal(t, []string{"a", "b", "a"}, seq(t, proxy.URL, 3)) 105 | 106 | err = lb.RemoveServer(testutils.MustParseRequestURI(a.URL)) 107 | require.NoError(t, err) 108 | 109 | assert.Equal(t, []string{"b", "b", "b"}, seq(t, proxy.URL, 3)) 110 | } 111 | 112 | func TestRoundRobin_upsertSame(t *testing.T) { 113 | a := testutils.NewResponder(t, "a") 114 | 115 | fwd := forward.New(false) 116 | 117 | lb, err := New(fwd) 118 | require.NoError(t, err) 119 | 120 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(a.URL))) 121 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(a.URL))) 122 | 123 | proxy := httptest.NewServer(lb) 124 | t.Cleanup(proxy.Close) 125 | 126 | assert.Equal(t, []string{"a", "a", "a"}, seq(t, proxy.URL, 3)) 127 | } 128 | 129 | func TestRoundRobin_upsertWeight(t *testing.T) { 130 | a := testutils.NewResponder(t, "a") 131 | b := testutils.NewResponder(t, "b") 132 | 133 | fwd := forward.New(false) 134 | 135 | lb, err := New(fwd) 136 | require.NoError(t, err) 137 | 138 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(a.URL))) 139 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(b.URL))) 140 | 141 | proxy := httptest.NewServer(lb) 142 | t.Cleanup(proxy.Close) 143 | 144 | assert.Equal(t, []string{"a", "b", "a"}, seq(t, proxy.URL, 3)) 145 | 146 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(b.URL), Weight(3))) 147 | 148 | assert.Equal(t, []string{"b", "b", "a", "b"}, seq(t, proxy.URL, 4)) 149 | } 150 | 151 | func TestRoundRobin_weighted(t *testing.T) { 152 | require.NoError(t, SetDefaultWeight(0)) 153 | defer func() { _ = SetDefaultWeight(1) }() 154 | 155 | a := testutils.NewResponder(t, "a") 156 | b := testutils.NewResponder(t, "b") 157 | z := testutils.NewResponder(t, "z") 158 | 159 | fwd := forward.New(false) 160 | 161 | lb, err := New(fwd) 162 | require.NoError(t, err) 163 | 164 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(a.URL), Weight(3))) 165 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(b.URL), Weight(2))) 166 | require.NoError(t, lb.UpsertServer(testutils.MustParseRequestURI(z.URL), Weight(0))) 167 | 168 | proxy := httptest.NewServer(lb) 169 | t.Cleanup(proxy.Close) 170 | 171 | assert.Equal(t, []string{"a", "a", "b", "a", "b", "a"}, seq(t, proxy.URL, 6)) 172 | 173 | w, ok := lb.ServerWeight(testutils.MustParseRequestURI(a.URL)) 174 | assert.Equal(t, 3, w) 175 | assert.True(t, ok) 176 | 177 | w, ok = lb.ServerWeight(testutils.MustParseRequestURI(b.URL)) 178 | assert.Equal(t, 2, w) 179 | assert.True(t, ok) 180 | 181 | w, ok = lb.ServerWeight(testutils.MustParseRequestURI(z.URL)) 182 | assert.Equal(t, 0, w) 183 | assert.True(t, ok) 184 | 185 | w, ok = lb.ServerWeight(testutils.MustParseRequestURI("http://caramba:4000")) 186 | assert.Equal(t, -1, w) 187 | assert.False(t, ok) 188 | } 189 | 190 | func TestRoundRobinRequestRewriteListener(t *testing.T) { 191 | testutils.NewResponder(t, "a") 192 | testutils.NewResponder(t, "b") 193 | 194 | fwd := forward.New(false) 195 | 196 | lb, err := New(fwd, 197 | RoundRobinRequestRewriteListener(func(_ *http.Request, _ *http.Request) {})) 198 | require.NoError(t, err) 199 | 200 | assert.NotNil(t, lb.requestRewriteListener) 201 | } 202 | 203 | func seq(t *testing.T, url string, repeat int) []string { 204 | t.Helper() 205 | 206 | var out []string 207 | for range repeat { 208 | _, body, err := testutils.Get(url) 209 | require.NoError(t, err) 210 | out = append(out, string(body)) 211 | } 212 | return out 213 | } 214 | -------------------------------------------------------------------------------- /roundrobin/stickycookie/aes_value.go: -------------------------------------------------------------------------------- 1 | package stickycookie 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | "crypto/rand" 7 | "encoding/base64" 8 | "encoding/binary" 9 | "errors" 10 | "fmt" 11 | "io" 12 | "net/url" 13 | "strconv" 14 | "strings" 15 | "time" 16 | 17 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 18 | ) 19 | 20 | // AESValue manages hashed sticky value. 21 | type AESValue struct { 22 | block cipher.AEAD 23 | ttl time.Duration 24 | } 25 | 26 | // NewAESValue takes a fixed-size key and returns an CookieValue or an error. 27 | // Key size must be exactly one of 16, 24, or 32 bytes to select AES-128, AES-192, or AES-256. 28 | func NewAESValue(key []byte, ttl time.Duration) (*AESValue, error) { 29 | block, err := aes.NewCipher(key) 30 | if err != nil { 31 | return nil, err 32 | } 33 | 34 | gcm, err := cipher.NewGCM(block) 35 | if err != nil { 36 | return nil, err 37 | } 38 | 39 | return &AESValue{block: gcm, ttl: ttl}, nil 40 | } 41 | 42 | // Get hashes the sticky value. 43 | func (v *AESValue) Get(raw *url.URL) string { 44 | base := raw.String() 45 | if v.ttl > 0 { 46 | base = fmt.Sprintf("%s|%d", base, clock.Now().UTC().Add(v.ttl).Unix()) 47 | } 48 | 49 | // Nonce is the 64bit nanosecond-resolution time, plus 32bits of crypto/rand, for 96bits (12Bytes). 50 | // Theoretically, if 2^32 calls were made in 1 nanoseconds, there might be a repeat. 51 | // Adds ~765ns, and 4B heap in 1 alloc 52 | nonce := make([]byte, 12) 53 | binary.PutVarint(nonce, clock.Now().UnixNano()) 54 | 55 | rpend := make([]byte, 4) 56 | if _, err := io.ReadFull(rand.Reader, rpend); err != nil { 57 | // This is a near-impossible error condition on Linux systems. 58 | // An error here means rand.Reader (and thus getrandom(2), and thus /dev/urandom) returned 59 | // less than 4 bytes of data. /dev/urandom is guaranteed to always return the number of 60 | // bytes requested up to 512 bytes on modern kernels. Behavior on non-Linux systems 61 | // varies, of course. 62 | panic(err) 63 | } 64 | 65 | for i := range 4 { 66 | nonce[i+8] = rpend[i] 67 | } 68 | 69 | obfuscated := v.block.Seal(nil, nonce, []byte(base), nil) 70 | // We append the 12byte nonce onto the end of the message 71 | obfuscated = append(obfuscated, nonce...) 72 | obfuscatedStr := base64.RawURLEncoding.EncodeToString(obfuscated) 73 | 74 | return obfuscatedStr 75 | } 76 | 77 | // FindURL gets url from array that match the value. 78 | func (v *AESValue) FindURL(raw string, urls []*url.URL) (*url.URL, error) { 79 | rawURL, err := v.fromValue(raw) 80 | if err != nil { 81 | return nil, err 82 | } 83 | 84 | for _, u := range urls { 85 | ok, err := areURLEqual(rawURL, u) 86 | if err != nil { 87 | return nil, err 88 | } 89 | 90 | if ok { 91 | return u, nil 92 | } 93 | } 94 | 95 | return nil, nil 96 | } 97 | 98 | func (v *AESValue) fromValue(obfuscatedStr string) (string, error) { 99 | obfuscated, err := base64.RawURLEncoding.DecodeString(obfuscatedStr) 100 | if err != nil { 101 | return "", err 102 | } 103 | 104 | // The first len-12 bytes is the ciphertext, the last 12 bytes is the nonce 105 | n := len(obfuscated) - 12 106 | if n <= 0 { 107 | // Protect against range errors causing panics 108 | return "", errors.New("post-base64-decoded string is too short") 109 | } 110 | 111 | nonce := obfuscated[n:] 112 | obfuscated = obfuscated[:n] 113 | 114 | raw, err := v.block.Open(nil, nonce, obfuscated, nil) 115 | if err != nil { 116 | return "", err 117 | } 118 | 119 | if v.ttl > 0 { 120 | rawParts := strings.Split(string(raw), "|") 121 | if len(rawParts) < 2 { 122 | return "", fmt.Errorf("TTL set but cookie doesn't contain an expiration: '%s'", raw) 123 | } 124 | 125 | // validate the ttl 126 | i, err := strconv.ParseInt(rawParts[1], 10, 64) 127 | if err != nil { 128 | return "", err 129 | } 130 | 131 | if clock.Now().UTC().After(clock.Unix(i, 0).UTC()) { 132 | strTime := clock.Unix(i, 0).UTC().String() 133 | return "", fmt.Errorf("TTL expired: '%s' (%s)", raw, strTime) 134 | } 135 | 136 | raw = []byte(rawParts[0]) 137 | } 138 | 139 | return string(raw), nil 140 | } 141 | -------------------------------------------------------------------------------- /roundrobin/stickycookie/cookie_value.go: -------------------------------------------------------------------------------- 1 | package stickycookie 2 | 3 | import "net/url" 4 | 5 | // CookieValue interface to manage the sticky cookie value format. 6 | // It will be used by the load balancer to generate the sticky cookie value and to retrieve the matching url. 7 | type CookieValue interface { 8 | // Get converts raw value to an expected sticky format. 9 | Get(raw *url.URL) string 10 | 11 | // FindURL gets url from array that match the value. 12 | FindURL(raw string, urls []*url.URL) (*url.URL, error) 13 | } 14 | 15 | // areURLEqual compare a string to a url and check if the string is the same as the url value. 16 | func areURLEqual(normalized string, u *url.URL) (bool, error) { 17 | u1, err := url.Parse(normalized) 18 | if err != nil { 19 | return false, err 20 | } 21 | 22 | return u1.Scheme == u.Scheme && u1.Host == u.Host && u1.Path == u.Path, nil 23 | } 24 | -------------------------------------------------------------------------------- /roundrobin/stickycookie/fallback_value.go: -------------------------------------------------------------------------------- 1 | package stickycookie 2 | 3 | import ( 4 | "errors" 5 | "net/url" 6 | ) 7 | 8 | // FallbackValue manages hashed sticky value. 9 | type FallbackValue struct { 10 | from CookieValue 11 | to CookieValue 12 | } 13 | 14 | // NewFallbackValue creates a new FallbackValue. 15 | func NewFallbackValue(from CookieValue, to CookieValue) (*FallbackValue, error) { 16 | if from == nil || to == nil { 17 | return nil, errors.New("from and to are mandatory") 18 | } 19 | 20 | return &FallbackValue{from: from, to: to}, nil 21 | } 22 | 23 | // Get hashes the sticky value. 24 | func (v *FallbackValue) Get(raw *url.URL) string { 25 | return v.to.Get(raw) 26 | } 27 | 28 | // FindURL gets url from array that match the value. 29 | // If it is a symmetric algorithm, it decodes the URL, otherwise it compares the ciphered values. 30 | func (v *FallbackValue) FindURL(raw string, urls []*url.URL) (*url.URL, error) { 31 | findURL, err := v.from.FindURL(raw, urls) 32 | if findURL != nil { 33 | return findURL, err 34 | } 35 | 36 | return v.to.FindURL(raw, urls) 37 | } 38 | -------------------------------------------------------------------------------- /roundrobin/stickycookie/hash_value.go: -------------------------------------------------------------------------------- 1 | package stickycookie 2 | 3 | import ( 4 | "net/url" 5 | "strconv" 6 | 7 | "github.com/segmentio/fasthash/fnv1a" 8 | ) 9 | 10 | // HashValue manages hashed sticky value. 11 | type HashValue struct { 12 | // Salt secret to anonymize the hashed cookie 13 | Salt string 14 | } 15 | 16 | // Get hashes the sticky value. 17 | func (v *HashValue) Get(raw *url.URL) string { 18 | return v.hash(raw.String()) 19 | } 20 | 21 | // FindURL gets url from array that match the value. 22 | func (v *HashValue) FindURL(raw string, urls []*url.URL) (*url.URL, error) { 23 | for _, u := range urls { 24 | if raw == v.hash(normalized(u)) { 25 | return u, nil 26 | } 27 | } 28 | 29 | return nil, nil 30 | } 31 | 32 | func (v *HashValue) hash(input string) string { 33 | return strconv.FormatUint(fnv1a.HashString64(v.Salt+input), 16) 34 | } 35 | 36 | func normalized(u *url.URL) string { 37 | normalized := url.URL{Scheme: u.Scheme, Host: u.Host, Path: u.Path} 38 | return normalized.String() 39 | } 40 | -------------------------------------------------------------------------------- /roundrobin/stickycookie/raw_value.go: -------------------------------------------------------------------------------- 1 | package stickycookie 2 | 3 | import ( 4 | "net/url" 5 | ) 6 | 7 | // RawValue is a no-op that returns the raw strings as-is. 8 | type RawValue struct{} 9 | 10 | // Get returns the raw value. 11 | func (v *RawValue) Get(raw *url.URL) string { 12 | return raw.String() 13 | } 14 | 15 | // FindURL gets url from array that match the value. 16 | func (v *RawValue) FindURL(raw string, urls []*url.URL) (*url.URL, error) { 17 | for _, u := range urls { 18 | ok, err := areURLEqual(raw, u) 19 | if err != nil { 20 | return nil, err 21 | } 22 | 23 | if ok { 24 | return u, nil 25 | } 26 | } 27 | 28 | return nil, nil 29 | } 30 | -------------------------------------------------------------------------------- /roundrobin/stickysessions.go: -------------------------------------------------------------------------------- 1 | package roundrobin 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/url" 7 | "time" 8 | 9 | "github.com/vulcand/oxy/v2/roundrobin/stickycookie" 10 | ) 11 | 12 | // CookieOptions has all the options one would like to set on the affinity cookie. 13 | type CookieOptions struct { 14 | HTTPOnly bool 15 | Secure bool 16 | 17 | Path string 18 | Domain string 19 | Expires time.Time 20 | 21 | MaxAge int 22 | SameSite http.SameSite 23 | } 24 | 25 | // StickySession is a mixin for load balancers that implements layer 7 (http cookie) session affinity. 26 | type StickySession struct { 27 | cookieName string 28 | cookieValue stickycookie.CookieValue 29 | options CookieOptions 30 | } 31 | 32 | // NewStickySession creates a new StickySession. 33 | func NewStickySession(cookieName string) *StickySession { 34 | return &StickySession{cookieName: cookieName, cookieValue: &stickycookie.RawValue{}} 35 | } 36 | 37 | // NewStickySessionWithOptions creates a new StickySession whilst allowing for options to 38 | // shape its affinity cookie such as "httpOnly" or "secure". 39 | func NewStickySessionWithOptions(cookieName string, options CookieOptions) *StickySession { 40 | return &StickySession{cookieName: cookieName, options: options, cookieValue: &stickycookie.RawValue{}} 41 | } 42 | 43 | // SetCookieValue set the CookieValue for the StickySession. 44 | func (s *StickySession) SetCookieValue(value stickycookie.CookieValue) *StickySession { 45 | s.cookieValue = value 46 | return s 47 | } 48 | 49 | // GetBackend returns the backend URL stored in the sticky cookie, iff the backend is still in the valid list of servers. 50 | func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url.URL, bool, error) { 51 | cookie, err := req.Cookie(s.cookieName) 52 | if err != nil { 53 | if errors.Is(err, http.ErrNoCookie) { 54 | return nil, false, nil 55 | } 56 | 57 | return nil, false, err 58 | } 59 | 60 | server, err := s.cookieValue.FindURL(cookie.Value, servers) 61 | 62 | return server, server != nil, err 63 | } 64 | 65 | // StickBackend creates and sets the cookie. 66 | func (s *StickySession) StickBackend(backend *url.URL, w http.ResponseWriter) { 67 | opt := s.options 68 | 69 | cp := "/" 70 | if opt.Path != "" { 71 | cp = opt.Path 72 | } 73 | 74 | cookie := &http.Cookie{ 75 | Name: s.cookieName, 76 | Value: s.cookieValue.Get(backend), 77 | Path: cp, 78 | Domain: opt.Domain, 79 | Expires: opt.Expires, 80 | MaxAge: opt.MaxAge, 81 | Secure: opt.Secure, 82 | HttpOnly: opt.HTTPOnly, 83 | SameSite: opt.SameSite, 84 | } 85 | http.SetCookie(w, cookie) 86 | } 87 | -------------------------------------------------------------------------------- /stream/options.go: -------------------------------------------------------------------------------- 1 | package stream 2 | 3 | import ( 4 | "github.com/vulcand/oxy/v2/utils" 5 | ) 6 | 7 | // Option represents an option you can pass to New. 8 | type Option func(s *Stream) error 9 | 10 | // Logger defines the logger used by Stream. 11 | func Logger(l utils.Logger) Option { 12 | return func(s *Stream) error { 13 | s.log = l 14 | return nil 15 | } 16 | } 17 | 18 | // Verbose additional debug information. 19 | func Verbose(verbose bool) Option { 20 | return func(s *Stream) error { 21 | s.verbose = verbose 22 | return nil 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /stream/stream.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package stream provides http.Handler middleware that passes-through the entire request 3 | 4 | Stream works around several limitations caused by buffering implementations, but 5 | also introduces certain risks. 6 | 7 | Workarounds for buffering limitations: 8 | 1. Streaming really large chunks of data (large file transfers, or streaming videos, 9 | etc.) 10 | 11 | 2. Streaming (chunking) sparse data. For example, an implementation might 12 | send a health check or a heart beat over a long-lived connection. This 13 | does not play well with buffering. 14 | 15 | Risks: 16 | 1. Connections could survive for very long periods of time. 17 | 18 | 2. There is no easy way to enforce limits on size/time of a connection. 19 | 20 | Examples of a streaming middleware: 21 | 22 | // sample HTTP handler. 23 | handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 24 | w.Write([]byte("hello")) 25 | }) 26 | 27 | // Stream will literally pass through to the next handler without ANY buffering 28 | // or validation of the data. 29 | stream.New(handler) 30 | */ 31 | package stream 32 | 33 | import ( 34 | "net/http" 35 | 36 | "github.com/vulcand/oxy/v2/utils" 37 | ) 38 | 39 | // DefaultMaxBodyBytes No limit by default. 40 | const DefaultMaxBodyBytes = -1 41 | 42 | // Stream is responsible for buffering requests and responses 43 | // It buffers large requests and responses to disk,. 44 | type Stream struct { 45 | maxRequestBodyBytes int64 46 | 47 | maxResponseBodyBytes int64 48 | 49 | next http.Handler 50 | 51 | verbose bool 52 | log utils.Logger 53 | } 54 | 55 | // New returns a new streamer middleware. New() function supports optional functional arguments. 56 | func New(next http.Handler, setters ...Option) (*Stream, error) { 57 | strm := &Stream{ 58 | next: next, 59 | 60 | maxRequestBodyBytes: DefaultMaxBodyBytes, 61 | 62 | maxResponseBodyBytes: DefaultMaxBodyBytes, 63 | 64 | log: &utils.NoopLogger{}, 65 | } 66 | for _, s := range setters { 67 | if err := s(strm); err != nil { 68 | return nil, err 69 | } 70 | } 71 | return strm, nil 72 | } 73 | 74 | // Wrap sets the next handler to be called by stream handler. 75 | func (s *Stream) Wrap(next http.Handler) error { 76 | s.next = next 77 | return nil 78 | } 79 | 80 | func (s *Stream) ServeHTTP(w http.ResponseWriter, req *http.Request) { 81 | if s.verbose { 82 | dump := utils.DumpHTTPRequest(req) 83 | s.log.Debug("vulcand/oxy/stream: begin ServeHttp on request: %s", dump) 84 | defer s.log.Debug("vulcand/oxy/stream: completed ServeHttp on request: %s", dump) 85 | } 86 | 87 | s.next.ServeHTTP(w, req) 88 | } 89 | -------------------------------------------------------------------------------- /stream/threshold.go: -------------------------------------------------------------------------------- 1 | package stream 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/vulcand/predicate" 8 | ) 9 | 10 | type hpredicate func(*context) bool 11 | 12 | // IsValidExpression check if it's a valid expression. 13 | func IsValidExpression(expr string) bool { 14 | _, err := parseExpression(expr) 15 | return err == nil 16 | } 17 | 18 | // Parses expression in the go language into Failover predicates. 19 | func parseExpression(in string) (hpredicate, error) { 20 | p, err := predicate.NewParser(predicate.Def{ 21 | Operators: predicate.Operators{ 22 | AND: and, 23 | OR: or, 24 | EQ: eq, 25 | NEQ: neq, 26 | LT: lt, 27 | GT: gt, 28 | LE: le, 29 | GE: ge, 30 | }, 31 | Functions: map[string]interface{}{ 32 | "RequestMethod": requestMethod, 33 | "IsNetworkError": isNetworkError, 34 | "Attempts": attempts, 35 | "ResponseCode": responseCode, 36 | }, 37 | }) 38 | if err != nil { 39 | return nil, err 40 | } 41 | out, err := p.Parse(in) 42 | if err != nil { 43 | return nil, err 44 | } 45 | pr, ok := out.(hpredicate) 46 | if !ok { 47 | return nil, fmt.Errorf("expected predicate, got %T", out) 48 | } 49 | return pr, nil 50 | } 51 | 52 | // IsNetworkError returns a predicate that returns true if last attempt ended with network error. 53 | func isNetworkError() hpredicate { 54 | return func(c *context) bool { 55 | return c.responseCode == http.StatusBadGateway || c.responseCode == http.StatusGatewayTimeout 56 | } 57 | } 58 | 59 | // and returns predicate by joining the passed predicates with logical 'and'. 60 | func and(fns ...hpredicate) hpredicate { 61 | return func(c *context) bool { 62 | for _, fn := range fns { 63 | if !fn(c) { 64 | return false 65 | } 66 | } 67 | return true 68 | } 69 | } 70 | 71 | // or returns predicate by joining the passed predicates with logical 'or'. 72 | func or(fns ...hpredicate) hpredicate { 73 | return func(c *context) bool { 74 | for _, fn := range fns { 75 | if fn(c) { 76 | return true 77 | } 78 | } 79 | return false 80 | } 81 | } 82 | 83 | // not creates negation of the passed predicate. 84 | func not(p hpredicate) hpredicate { 85 | return func(c *context) bool { 86 | return !p(c) 87 | } 88 | } 89 | 90 | // eq returns predicate that tests for equality of the value of the mapper and the constant. 91 | func eq(m interface{}, value interface{}) (hpredicate, error) { 92 | switch mapper := m.(type) { 93 | case toString: 94 | return stringEQ(mapper, value) 95 | case toInt: 96 | return intEQ(mapper, value) 97 | } 98 | return nil, fmt.Errorf("unsupported argument: %T", m) 99 | } 100 | 101 | // neq returns predicate that tests for inequality of the value of the mapper and the constant. 102 | func neq(m interface{}, value interface{}) (hpredicate, error) { 103 | p, err := eq(m, value) 104 | if err != nil { 105 | return nil, err 106 | } 107 | return not(p), nil 108 | } 109 | 110 | // lt returns predicate that tests that value of the mapper function is less than the constant. 111 | func lt(m interface{}, value interface{}) (hpredicate, error) { 112 | switch mapper := m.(type) { 113 | case toInt: 114 | return intLT(mapper, value) 115 | default: 116 | return nil, fmt.Errorf("unsupported argument: %T", m) 117 | } 118 | } 119 | 120 | // le returns predicate that tests that value of the mapper function is less or equal than the constant. 121 | func le(m interface{}, value interface{}) (hpredicate, error) { 122 | l, err := lt(m, value) 123 | if err != nil { 124 | return nil, err 125 | } 126 | e, err := eq(m, value) 127 | if err != nil { 128 | return nil, err 129 | } 130 | return func(c *context) bool { 131 | return l(c) || e(c) 132 | }, nil 133 | } 134 | 135 | // gt returns predicate that tests that value of the mapper function is greater than the constant. 136 | func gt(m interface{}, value interface{}) (hpredicate, error) { 137 | switch mapper := m.(type) { 138 | case toInt: 139 | return intGT(mapper, value) 140 | default: 141 | return nil, fmt.Errorf("unsupported argument: %T", m) 142 | } 143 | } 144 | 145 | // ge returns predicate that tests that value of the mapper function is less or equal than the constant. 146 | func ge(m interface{}, value interface{}) (hpredicate, error) { 147 | g, err := gt(m, value) 148 | if err != nil { 149 | return nil, err 150 | } 151 | e, err := eq(m, value) 152 | if err != nil { 153 | return nil, err 154 | } 155 | return func(c *context) bool { 156 | return g(c) || e(c) 157 | }, nil 158 | } 159 | 160 | func stringEQ(m toString, val interface{}) (hpredicate, error) { 161 | value, ok := val.(string) 162 | if !ok { 163 | return nil, fmt.Errorf("expected string, got %T", val) 164 | } 165 | return func(c *context) bool { 166 | return m(c) == value 167 | }, nil 168 | } 169 | 170 | func intEQ(m toInt, val interface{}) (hpredicate, error) { 171 | value, ok := val.(int) 172 | if !ok { 173 | return nil, fmt.Errorf("expected int, got %T", val) 174 | } 175 | return func(c *context) bool { 176 | return m(c) == value 177 | }, nil 178 | } 179 | 180 | func intLT(m toInt, val interface{}) (hpredicate, error) { 181 | value, ok := val.(int) 182 | if !ok { 183 | return nil, fmt.Errorf("expected int, got %T", val) 184 | } 185 | return func(c *context) bool { 186 | return m(c) < value 187 | }, nil 188 | } 189 | 190 | func intGT(m toInt, val interface{}) (hpredicate, error) { 191 | value, ok := val.(int) 192 | if !ok { 193 | return nil, fmt.Errorf("expected int, got %T", val) 194 | } 195 | return func(c *context) bool { 196 | return m(c) > value 197 | }, nil 198 | } 199 | 200 | type context struct { 201 | r *http.Request 202 | attempt int 203 | responseCode int 204 | } 205 | 206 | type toString func(c *context) string 207 | 208 | type toInt func(c *context) int 209 | 210 | // RequestMethod returns mapper of the request to its method e.g. POST. 211 | func requestMethod() toString { 212 | return func(c *context) string { 213 | return c.r.Method 214 | } 215 | } 216 | 217 | // Attempts returns mapper of the request to the number of proxy attempts. 218 | func attempts() toInt { 219 | return func(c *context) int { 220 | return c.attempt 221 | } 222 | } 223 | 224 | // ResponseCode returns mapper of the request to the last response code, returns 0 if there was no response code. 225 | func responseCode() toInt { 226 | return func(c *context) int { 227 | return c.responseCode 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /testutils/utils.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "net/url" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/vulcand/oxy/v2/internal/holsterv4/clock" 14 | "github.com/vulcand/oxy/v2/utils" 15 | ) 16 | 17 | // NewHandler creates a new Server. 18 | func NewHandler(handler http.HandlerFunc) *httptest.Server { 19 | return httptest.NewServer(handler) 20 | } 21 | 22 | // NewResponder creates a new Server with response. 23 | func NewResponder(t *testing.T, response string) *httptest.Server { 24 | t.Helper() 25 | 26 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 27 | _, _ = w.Write([]byte(response)) 28 | })) 29 | 30 | t.Cleanup(server.Close) 31 | return server 32 | } 33 | 34 | // MustParseRequestURI is the version of url.ParseRequestURI that panics if incorrect, helpful to shorten the tests. 35 | func MustParseRequestURI(uri string) *url.URL { 36 | out, err := url.ParseRequestURI(uri) 37 | if err != nil { 38 | panic(err) 39 | } 40 | return out 41 | } 42 | 43 | // ReqOpts request options. 44 | type ReqOpts struct { 45 | Host string 46 | Method string 47 | Body string 48 | Headers http.Header 49 | Auth *utils.BasicAuth 50 | } 51 | 52 | // ReqOption request option type. 53 | type ReqOption func(o *ReqOpts) error 54 | 55 | // Method sets request method. 56 | func Method(m string) ReqOption { 57 | return func(o *ReqOpts) error { 58 | o.Method = m 59 | return nil 60 | } 61 | } 62 | 63 | // Host sets request host. 64 | func Host(h string) ReqOption { 65 | return func(o *ReqOpts) error { 66 | o.Host = h 67 | return nil 68 | } 69 | } 70 | 71 | // Body sets request body. 72 | func Body(b string) ReqOption { 73 | return func(o *ReqOpts) error { 74 | o.Body = b 75 | return nil 76 | } 77 | } 78 | 79 | // Header sets request header. 80 | func Header(name, val string) ReqOption { 81 | return func(o *ReqOpts) error { 82 | if o.Headers == nil { 83 | o.Headers = make(http.Header) 84 | } 85 | o.Headers.Add(name, val) 86 | return nil 87 | } 88 | } 89 | 90 | // Headers sets request headers. 91 | func Headers(h http.Header) ReqOption { 92 | return func(o *ReqOpts) error { 93 | if o.Headers == nil { 94 | o.Headers = make(http.Header) 95 | } 96 | utils.CopyHeaders(o.Headers, h) 97 | return nil 98 | } 99 | } 100 | 101 | // BasicAuth sets request basic auth. 102 | func BasicAuth(username, password string) ReqOption { 103 | return func(o *ReqOpts) error { 104 | o.Auth = &utils.BasicAuth{ 105 | Username: username, 106 | Password: password, 107 | } 108 | return nil 109 | } 110 | } 111 | 112 | // MakeRequest create and do a request. 113 | func MakeRequest(uri string, opts ...ReqOption) (*http.Response, []byte, error) { 114 | o := &ReqOpts{} 115 | for _, s := range opts { 116 | if err := s(o); err != nil { 117 | return nil, nil, err 118 | } 119 | } 120 | 121 | if o.Method == "" { 122 | o.Method = http.MethodGet 123 | } 124 | 125 | request, err := http.NewRequest(o.Method, uri, strings.NewReader(o.Body)) 126 | if err != nil { 127 | return nil, nil, err 128 | } 129 | 130 | if o.Headers != nil { 131 | utils.CopyHeaders(request.Header, o.Headers) 132 | } 133 | 134 | if o.Auth != nil { 135 | request.Header.Set("Authorization", o.Auth.String()) 136 | } 137 | 138 | if o.Host != "" { 139 | request.Host = o.Host 140 | } 141 | 142 | var tr *http.Transport 143 | if strings.HasPrefix(uri, "https") { 144 | tr = &http.Transport{ 145 | DisableKeepAlives: true, 146 | TLSClientConfig: &tls.Config{ 147 | InsecureSkipVerify: true, 148 | ServerName: request.Host, 149 | }, 150 | } 151 | } else { 152 | tr = &http.Transport{ 153 | DisableKeepAlives: true, 154 | } 155 | } 156 | 157 | client := &http.Client{ 158 | Transport: tr, 159 | CheckRedirect: func(_ *http.Request, _ []*http.Request) error { 160 | return errors.New("no redirects") 161 | }, 162 | } 163 | response, err := client.Do(request) 164 | if err == nil { 165 | bodyBytes, errRead := io.ReadAll(response.Body) 166 | return response, bodyBytes, errRead 167 | } 168 | return response, nil, err 169 | } 170 | 171 | // Get do a GET request. 172 | func Get(uri string, opts ...ReqOption) (*http.Response, []byte, error) { 173 | opts = append(opts, Method(http.MethodGet)) 174 | return MakeRequest(uri, opts...) 175 | } 176 | 177 | // Post do a POST request. 178 | func Post(uri string, opts ...ReqOption) (*http.Response, []byte, error) { 179 | opts = append(opts, Method(http.MethodPost)) 180 | return MakeRequest(uri, opts...) 181 | } 182 | 183 | // FreezeTime to the predetermined time. Returns a function that should be 184 | // deferred to unfreeze time. Meant for testing. 185 | func FreezeTime(t *testing.T) { 186 | t.Helper() 187 | 188 | clock.Freeze(clock.Date(2012, 3, 4, 5, 6, 7, 0, clock.UTC)) 189 | 190 | t.Cleanup(clock.Unfreeze) 191 | } 192 | -------------------------------------------------------------------------------- /trace/options.go: -------------------------------------------------------------------------------- 1 | package trace 2 | 3 | import "github.com/vulcand/oxy/v2/utils" 4 | 5 | // Option is a functional option setter for Tracer. 6 | type Option func(*Tracer) error 7 | 8 | // ErrorHandler is a functional argument that sets error handler of the server. 9 | func ErrorHandler(h utils.ErrorHandler) Option { 10 | return func(t *Tracer) error { 11 | t.errHandler = h 12 | return nil 13 | } 14 | } 15 | 16 | // RequestHeaders adds request headers to capture. 17 | func RequestHeaders(headers ...string) Option { 18 | return func(t *Tracer) error { 19 | t.reqHeaders = append(t.reqHeaders, headers...) 20 | return nil 21 | } 22 | } 23 | 24 | // ResponseHeaders adds response headers to capture. 25 | func ResponseHeaders(headers ...string) Option { 26 | return func(t *Tracer) error { 27 | t.respHeaders = append(t.respHeaders, headers...) 28 | return nil 29 | } 30 | } 31 | 32 | // Logger defines the logger the tracer will use. 33 | func Logger(l utils.Logger) Option { 34 | return func(t *Tracer) error { 35 | t.log = l 36 | return nil 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /trace/trace_test.go: -------------------------------------------------------------------------------- 1 | package trace 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "crypto/tls" 7 | "encoding/json" 8 | "fmt" 9 | "net/http" 10 | "net/http/httptest" 11 | "net/url" 12 | "testing" 13 | 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/require" 16 | "github.com/vulcand/oxy/v2/testutils" 17 | "github.com/vulcand/oxy/v2/utils" 18 | ) 19 | 20 | func TestTracer_simple(t *testing.T) { 21 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 22 | w.Header().Set("Content-Length", "5") 23 | _, _ = w.Write([]byte("hello")) 24 | }) 25 | 26 | trace := &bytes.Buffer{} 27 | tr, err := New(handler, trace) 28 | require.NoError(t, err) 29 | 30 | srv := httptest.NewServer(tr) 31 | t.Cleanup(srv.Close) 32 | 33 | re, _, err := testutils.MakeRequest(srv.URL+"/hello", testutils.Method(http.MethodPost), testutils.Body("123456")) 34 | require.NoError(t, err) 35 | assert.Equal(t, http.StatusOK, re.StatusCode) 36 | 37 | var r *Record 38 | require.NoError(t, json.Unmarshal(trace.Bytes(), &r)) 39 | 40 | assert.Equal(t, http.MethodPost, r.Request.Method) 41 | assert.Equal(t, "/hello", r.Request.URL) 42 | assert.Equal(t, http.StatusOK, r.Response.Code) 43 | assert.EqualValues(t, 6, r.Request.BodyBytes) 44 | assert.NotEqual(t, float64(0), r.Response.Roundtrip) 45 | assert.EqualValues(t, 5, r.Response.BodyBytes) 46 | } 47 | 48 | func TestTracer_captureHeaders(t *testing.T) { 49 | respHeaders := http.Header{ 50 | "X-Re-1": []string{"6", "7"}, 51 | "X-Re-2": []string{"2", "3"}, 52 | } 53 | 54 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 55 | utils.CopyHeaders(w.Header(), respHeaders) 56 | _, _ = w.Write([]byte("hello")) 57 | }) 58 | 59 | trace := &bytes.Buffer{} 60 | tr, err := New(handler, trace, RequestHeaders("X-Req-B", "X-Req-A"), ResponseHeaders("X-Re-1", "X-Re-2")) 61 | require.NoError(t, err) 62 | 63 | srv := httptest.NewServer(tr) 64 | t.Cleanup(srv.Close) 65 | 66 | reqHeaders := http.Header{"X-Req-A": []string{"1", "2"}, "X-Req-B": []string{"3", "4"}} 67 | re, _, err := testutils.Get(srv.URL+"/hello", testutils.Headers(reqHeaders)) 68 | require.NoError(t, err) 69 | assert.Equal(t, http.StatusOK, re.StatusCode) 70 | 71 | var r *Record 72 | require.NoError(t, json.Unmarshal(trace.Bytes(), &r)) 73 | 74 | assert.Equal(t, reqHeaders, r.Request.Headers) 75 | assert.Equal(t, respHeaders, r.Response.Headers) 76 | } 77 | 78 | func TestTracer_TLS(t *testing.T) { 79 | handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 80 | _, _ = w.Write([]byte("hello")) 81 | }) 82 | 83 | trace := &bytes.Buffer{} 84 | tr, err := New(handler, trace) 85 | require.NoError(t, err) 86 | 87 | srv := httptest.NewUnstartedServer(tr) 88 | srv.StartTLS() 89 | t.Cleanup(srv.Close) 90 | 91 | config := &tls.Config{ 92 | InsecureSkipVerify: true, 93 | } 94 | 95 | u, err := url.Parse(srv.URL) 96 | require.NoError(t, err) 97 | 98 | conn, err := tls.Dial("tcp", u.Host, config) 99 | require.NoError(t, err) 100 | 101 | _, _ = fmt.Fprint(conn, "GET / HTTP/1.0\r\n\r\n") 102 | status, err := bufio.NewReader(conn).ReadString('\n') 103 | require.NoError(t, err) 104 | assert.Equal(t, "HTTP/1.0 200 OK\r\n", status) 105 | state := conn.ConnectionState() 106 | _ = conn.Close() 107 | 108 | var r *Record 109 | require.NoError(t, json.Unmarshal(trace.Bytes(), &r)) 110 | assert.Equal(t, versionToString(state.Version), r.Request.TLS.Version) 111 | } 112 | -------------------------------------------------------------------------------- /utils/auth.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "encoding/base64" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | // BasicAuth basic auth information. 10 | type BasicAuth struct { 11 | Username string 12 | Password string 13 | } 14 | 15 | func (ba *BasicAuth) String() string { 16 | encoded := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", ba.Username, ba.Password))) 17 | return fmt.Sprintf("Basic %s", encoded) 18 | } 19 | 20 | // ParseAuthHeader creates a new BasicAuth from header values. 21 | func ParseAuthHeader(header string) (*BasicAuth, error) { 22 | values := strings.Fields(header) 23 | if len(values) != 2 { 24 | return nil, fmt.Errorf("failed to parse header '%s'", header) 25 | } 26 | 27 | authType := strings.ToLower(values[0]) 28 | if authType != "basic" { 29 | return nil, fmt.Errorf("expected basic auth type, got '%s'", authType) 30 | } 31 | 32 | encodedString := values[1] 33 | decodedString, err := base64.StdEncoding.DecodeString(encodedString) 34 | if err != nil { 35 | return nil, fmt.Errorf("failed to parse header '%s', base64 failed: %w", header, err) 36 | } 37 | 38 | values = strings.SplitN(string(decodedString), ":", 2) 39 | if len(values) != 2 { 40 | return nil, fmt.Errorf("failed to parse header '%s', expected separator ':'", header) 41 | } 42 | return &BasicAuth{Username: values[0], Password: values[1]}, nil 43 | } 44 | -------------------------------------------------------------------------------- /utils/auth_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | // Just to make sure we don't panic, return err and not 11 | // username and pass and cover the function. 12 | func TestParseBadHeaders(t *testing.T) { 13 | headers := []string{ 14 | // just empty string 15 | "", 16 | // missing auth type 17 | "justplainstring", 18 | // unknown auth type 19 | "Whut justplainstring", 20 | // invalid base64 21 | "Basic Shmasic", 22 | // random encoded string 23 | "Basic YW55IGNhcm5hbCBwbGVhcw==", 24 | } 25 | for _, h := range headers { 26 | _, err := ParseAuthHeader(h) 27 | require.Error(t, err) 28 | } 29 | } 30 | 31 | // Just to make sure we don't panic, return err and not 32 | // username and pass and cover the function. 33 | func TestParseSuccess(t *testing.T) { 34 | headers := []struct { 35 | Header string 36 | Expected BasicAuth 37 | }{ 38 | { 39 | "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", 40 | BasicAuth{Username: "Aladdin", Password: "open sesame"}, 41 | }, 42 | // Make sure that String() produces valid header 43 | { 44 | (&BasicAuth{Username: "Alice", Password: "Here's bob"}).String(), 45 | BasicAuth{Username: "Alice", Password: "Here's bob"}, 46 | }, 47 | // empty pass 48 | { 49 | "Basic QWxhZGRpbjo=", 50 | BasicAuth{Username: "Aladdin", Password: ""}, 51 | }, 52 | } 53 | 54 | for _, h := range headers { 55 | request, err := ParseAuthHeader(h.Header) 56 | require.NoError(t, err) 57 | assert.Equal(t, h.Expected.Username, request.Username) 58 | assert.Equal(t, h.Expected.Password, request.Password) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /utils/dumpreq.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "crypto/tls" 5 | "encoding/json" 6 | "fmt" 7 | "mime/multipart" 8 | "net/http" 9 | "net/url" 10 | ) 11 | 12 | // SerializableHTTPRequest serializable HTTP request. 13 | type SerializableHTTPRequest struct { 14 | Method string 15 | URL *url.URL 16 | Proto string // "HTTP/1.0" 17 | ProtoMajor int // 1 18 | ProtoMinor int // 0 19 | Header http.Header 20 | ContentLength int64 21 | TransferEncoding []string 22 | Host string 23 | Form url.Values 24 | PostForm url.Values 25 | MultipartForm *multipart.Form 26 | Trailer http.Header 27 | RemoteAddr string 28 | RequestURI string 29 | TLS *tls.ConnectionState 30 | } 31 | 32 | // Clone clone a request. 33 | func Clone(r *http.Request) *SerializableHTTPRequest { 34 | if r == nil { 35 | return nil 36 | } 37 | 38 | rc := new(SerializableHTTPRequest) 39 | rc.Method = r.Method 40 | rc.URL = r.URL 41 | rc.Proto = r.Proto 42 | rc.ProtoMajor = r.ProtoMajor 43 | rc.ProtoMinor = r.ProtoMinor 44 | rc.Header = r.Header 45 | rc.ContentLength = r.ContentLength 46 | rc.Host = r.Host 47 | rc.RemoteAddr = r.RemoteAddr 48 | rc.RequestURI = r.RequestURI 49 | return rc 50 | } 51 | 52 | // ToJSON serializes to JSON. 53 | func (s *SerializableHTTPRequest) ToJSON() string { 54 | jsonVal, err := json.Marshal(s) 55 | if err != nil || jsonVal == nil { 56 | return fmt.Sprintf("error marshaling SerializableHTTPRequest to json: %s", err) 57 | } 58 | return string(jsonVal) 59 | } 60 | 61 | // DumpHTTPRequest dump a HTTP request to JSON. 62 | func DumpHTTPRequest(req *http.Request) string { 63 | return Clone(req).ToJSON() 64 | } 65 | -------------------------------------------------------------------------------- /utils/dumpreq_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type readCloserTestImpl struct{} 12 | 13 | func (r *readCloserTestImpl) Read(_ []byte) (n int, err error) { 14 | return 0, nil 15 | } 16 | 17 | func (r *readCloserTestImpl) Close() error { 18 | return nil 19 | } 20 | 21 | // Just to make sure we don't panic, return err and not 22 | // username and pass and cover the function. 23 | func TestHttpReqToString(t *testing.T) { 24 | req := &http.Request{ 25 | URL: &url.URL{Host: "localhost:2374", Path: "/unittest"}, 26 | Method: http.MethodDelete, 27 | Cancel: make(chan struct{}), 28 | Body: &readCloserTestImpl{}, 29 | } 30 | 31 | assert.NotEmpty(t, DumpHTTPRequest(req)) 32 | } 33 | -------------------------------------------------------------------------------- /utils/handler.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "net" 8 | "net/http" 9 | ) 10 | 11 | // StatusClientClosedRequest non-standard HTTP status code for client disconnection. 12 | const StatusClientClosedRequest = 499 13 | 14 | // StatusClientClosedRequestText non-standard HTTP status for client disconnection. 15 | const StatusClientClosedRequestText = "Client Closed Request" 16 | 17 | // ErrorHandler error handler. 18 | type ErrorHandler interface { 19 | ServeHTTP(w http.ResponseWriter, req *http.Request, err error) 20 | } 21 | 22 | // DefaultHandler default error handler. 23 | var DefaultHandler ErrorHandler = &StdHandler{log: &NoopLogger{}} 24 | 25 | // StdHandler Standard error handler. 26 | type StdHandler struct { 27 | log Logger 28 | } 29 | 30 | func (e *StdHandler) ServeHTTP(w http.ResponseWriter, _ *http.Request, err error) { 31 | statusCode := http.StatusInternalServerError 32 | 33 | //nolint:errorlint // must be changed 34 | if e, ok := err.(net.Error); ok { 35 | if e.Timeout() { 36 | statusCode = http.StatusGatewayTimeout 37 | } else { 38 | statusCode = http.StatusBadGateway 39 | } 40 | } else if errors.Is(err, io.EOF) { 41 | statusCode = http.StatusBadGateway 42 | } else if errors.Is(err, context.Canceled) { 43 | statusCode = StatusClientClosedRequest 44 | } 45 | 46 | w.WriteHeader(statusCode) 47 | _, _ = w.Write([]byte(statusText(statusCode))) 48 | 49 | e.log.Debug("'%d %s' caused by: %v", statusCode, statusText(statusCode), err) 50 | } 51 | 52 | func statusText(statusCode int) string { 53 | if statusCode == StatusClientClosedRequest { 54 | return StatusClientClosedRequestText 55 | } 56 | return http.StatusText(statusCode) 57 | } 58 | 59 | // ErrorHandlerFunc error handler function type. 60 | type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, error) 61 | 62 | // ServeHTTP calls f(w, r). 63 | func (f ErrorHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request, err error) { 64 | f(w, r, err) 65 | } 66 | -------------------------------------------------------------------------------- /utils/handler_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/httptest" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestDefaultHandlerErrors(t *testing.T) { 15 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 16 | h := w.(http.Hijacker) 17 | conn, _, _ := h.Hijack() 18 | conn.Close() 19 | })) 20 | t.Cleanup(srv.Close) 21 | 22 | request, err := http.NewRequest(http.MethodGet, srv.URL, strings.NewReader("")) 23 | require.NoError(t, err) 24 | 25 | _, err = http.DefaultTransport.RoundTrip(request) 26 | 27 | w := NewBufferWriter(NopWriteCloser(&bytes.Buffer{}), &NoopLogger{}) 28 | 29 | DefaultHandler.ServeHTTP(w, nil, err) 30 | 31 | assert.Equal(t, http.StatusBadGateway, w.Code) 32 | } 33 | -------------------------------------------------------------------------------- /utils/log.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | // Logger the logger interface. 4 | type Logger interface { 5 | Debug(msg string, args ...any) 6 | Info(msg string, args ...any) 7 | Warn(msg string, args ...any) 8 | Error(msg string, args ...any) 9 | } 10 | 11 | // NoopLogger a noop logger. 12 | type NoopLogger struct{} 13 | 14 | // Debug noop. 15 | func (*NoopLogger) Debug(string, ...interface{}) {} 16 | 17 | // Info noop. 18 | func (*NoopLogger) Info(string, ...interface{}) {} 19 | 20 | // Warn noop. 21 | func (*NoopLogger) Warn(string, ...interface{}) {} 22 | 23 | // Error noop. 24 | func (*NoopLogger) Error(string, ...interface{}) {} 25 | -------------------------------------------------------------------------------- /utils/netutils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "net" 8 | "net/http" 9 | "net/url" 10 | "reflect" 11 | ) 12 | 13 | // ProxyWriter calls recorder, used to debug logs. 14 | type ProxyWriter struct { 15 | w http.ResponseWriter 16 | code int 17 | length int64 18 | 19 | log Logger 20 | } 21 | 22 | // NewProxyWriter creates a new ProxyWriter. 23 | func NewProxyWriter(w http.ResponseWriter) *ProxyWriter { 24 | return NewProxyWriterWithLogger(w, &NoopLogger{}) 25 | } 26 | 27 | // NewProxyWriterWithLogger creates a new ProxyWriter. 28 | func NewProxyWriterWithLogger(w http.ResponseWriter, l Logger) *ProxyWriter { 29 | return &ProxyWriter{ 30 | w: w, 31 | log: l, 32 | } 33 | } 34 | 35 | // StatusCode gets status code. 36 | func (p *ProxyWriter) StatusCode() int { 37 | if p.code == 0 { 38 | // per contract standard lib will set this to http.StatusOK if not set 39 | // by user, here we avoid the confusion by mirroring this logic 40 | return http.StatusOK 41 | } 42 | return p.code 43 | } 44 | 45 | // GetLength gets content length. 46 | func (p *ProxyWriter) GetLength() int64 { 47 | return p.length 48 | } 49 | 50 | // Header gets response header. 51 | func (p *ProxyWriter) Header() http.Header { 52 | return p.w.Header() 53 | } 54 | 55 | func (p *ProxyWriter) Write(buf []byte) (int, error) { 56 | p.length += int64(len(buf)) 57 | return p.w.Write(buf) 58 | } 59 | 60 | // WriteHeader writes status code. 61 | func (p *ProxyWriter) WriteHeader(code int) { 62 | p.code = code 63 | p.w.WriteHeader(code) 64 | } 65 | 66 | // Flush flush the writer. 67 | func (p *ProxyWriter) Flush() { 68 | if f, ok := p.w.(http.Flusher); ok { 69 | f.Flush() 70 | } 71 | } 72 | 73 | // CloseNotify returns a channel that receives at most a single value (true) 74 | // when the client connection has gone away. 75 | func (p *ProxyWriter) CloseNotify() <-chan bool { 76 | if cn, ok := p.w.(http.CloseNotifier); ok { 77 | return cn.CloseNotify() 78 | } 79 | p.log.Debug("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(p.w)) 80 | return make(<-chan bool) 81 | } 82 | 83 | // Hijack lets the caller take over the connection. 84 | func (p *ProxyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 85 | if hi, ok := p.w.(http.Hijacker); ok { 86 | return hi.Hijack() 87 | } 88 | p.log.Debug("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(p.w)) 89 | return nil, nil, fmt.Errorf("the response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(p.w)) 90 | } 91 | 92 | // NewBufferWriter creates a new BufferWriter. 93 | func NewBufferWriter(w io.WriteCloser, l Logger) *BufferWriter { 94 | return &BufferWriter{ 95 | W: w, 96 | H: make(http.Header), 97 | log: l, 98 | } 99 | } 100 | 101 | // BufferWriter buffer writer. 102 | type BufferWriter struct { 103 | H http.Header 104 | Code int 105 | W io.WriteCloser 106 | log Logger 107 | } 108 | 109 | // Close close the writer. 110 | func (b *BufferWriter) Close() error { 111 | return b.W.Close() 112 | } 113 | 114 | // Header gets response header. 115 | func (b *BufferWriter) Header() http.Header { 116 | return b.H 117 | } 118 | 119 | func (b *BufferWriter) Write(buf []byte) (int, error) { 120 | return b.W.Write(buf) 121 | } 122 | 123 | // WriteHeader writes status code. 124 | func (b *BufferWriter) WriteHeader(code int) { 125 | b.Code = code 126 | } 127 | 128 | // CloseNotify returns a channel that receives at most a single value (true) 129 | // when the client connection has gone away. 130 | func (b *BufferWriter) CloseNotify() <-chan bool { 131 | if cn, ok := b.W.(http.CloseNotifier); ok { 132 | return cn.CloseNotify() 133 | } 134 | b.log.Warn("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(b.W)) 135 | 136 | return make(<-chan bool) 137 | } 138 | 139 | // Hijack lets the caller take over the connection. 140 | func (b *BufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 141 | if hi, ok := b.W.(http.Hijacker); ok { 142 | return hi.Hijack() 143 | } 144 | b.log.Debug("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.W)) 145 | 146 | return nil, nil, fmt.Errorf("the response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(b.W)) 147 | } 148 | 149 | type nopWriteCloser struct { 150 | io.Writer 151 | } 152 | 153 | func (*nopWriteCloser) Close() error { return nil } 154 | 155 | // NopWriteCloser returns a WriteCloser with a no-op Close method wrapping 156 | // the provided Writer w. 157 | func NopWriteCloser(w io.Writer) io.WriteCloser { 158 | return &nopWriteCloser{Writer: w} 159 | } 160 | 161 | // CopyURL provides update safe copy by avoiding shallow copying User field. 162 | func CopyURL(i *url.URL) *url.URL { 163 | out := *i 164 | if i.User != nil { 165 | u := *i.User 166 | out.User = &u 167 | } 168 | return &out 169 | } 170 | 171 | // CopyHeaders copies http headers from source to destination, it 172 | // does not override, but adds multiple headers. 173 | func CopyHeaders(dst http.Header, src http.Header) { 174 | for k, vv := range src { 175 | dst[k] = append(dst[k], vv...) 176 | } 177 | } 178 | 179 | // HasHeaders determines whether any of the header names is present in the http headers. 180 | func HasHeaders(names []string, headers http.Header) bool { 181 | for _, h := range names { 182 | if headers.Get(h) != "" { 183 | return true 184 | } 185 | } 186 | return false 187 | } 188 | 189 | // RemoveHeaders removes the header with the given names from the headers map. 190 | func RemoveHeaders(headers http.Header, names ...string) { 191 | for _, h := range names { 192 | headers.Del(h) 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /utils/netutils_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // Make sure copy does it right, so the copied url is safe to alter without modifying the other. 12 | func TestCopyUrl(t *testing.T) { 13 | userinfo := url.UserPassword("foo", "secret") 14 | 15 | urlA := &url.URL{ 16 | Scheme: "http", 17 | Host: "localhost:5000", 18 | Path: "/upstream", 19 | Opaque: "opaque", 20 | RawQuery: "a=1&b=2", 21 | Fragment: "#hello", 22 | User: userinfo, 23 | } 24 | 25 | urlB := CopyURL(urlA) 26 | assert.Equal(t, urlA, urlB) 27 | 28 | *userinfo = *url.User("bar") 29 | 30 | assert.Equal(t, urlA.User, userinfo) 31 | assert.NotEqual(t, urlA.User, urlB.User) 32 | 33 | urlB.Scheme = "https" 34 | 35 | assert.NotEqual(t, urlA, urlB) 36 | } 37 | 38 | // Make sure copy headers is not shallow and copies all headers. 39 | func TestCopyHeaders(t *testing.T) { 40 | source, destination := make(http.Header), make(http.Header) 41 | source.Add("a", "b") 42 | source.Add("c", "d") 43 | 44 | CopyHeaders(destination, source) 45 | 46 | assert.Equal(t, "b", destination.Get("a")) 47 | assert.Equal(t, "d", destination.Get("c")) 48 | 49 | // make sure that altering source does not affect the destination 50 | source.Del("a") 51 | 52 | assert.Empty(t, source.Get("a")) 53 | assert.Equal(t, "b", destination.Get("a")) 54 | } 55 | 56 | func TestHasHeaders(t *testing.T) { 57 | source := make(http.Header) 58 | source.Add("a", "b") 59 | source.Add("c", "d") 60 | 61 | assert.True(t, HasHeaders([]string{"a", "f"}, source)) 62 | assert.False(t, HasHeaders([]string{"i", "j"}, source)) 63 | } 64 | 65 | func TestRemoveHeaders(t *testing.T) { 66 | source := make(http.Header) 67 | source.Add("a", "b") 68 | source.Add("a", "m") 69 | source.Add("c", "d") 70 | 71 | RemoveHeaders(source, "a") 72 | 73 | assert.Empty(t, source.Get("a")) 74 | assert.Equal(t, "d", source.Get("c")) 75 | } 76 | 77 | //nolint:intrange // benchmarks 78 | func BenchmarkCopyHeaders(b *testing.B) { 79 | dstHeaders := make([]http.Header, 0, b.N) 80 | sourceHeaders := make([]http.Header, 0, b.N) 81 | for n := 0; n < b.N; n++ { 82 | // example from a reverse proxy merging headers 83 | d := http.Header{} 84 | d.Add("Request-Id", "1bd36bcc-a0d1-4fc7-aedc-20bbdefa27c5") 85 | dstHeaders = append(dstHeaders, d) 86 | 87 | s := http.Header{} 88 | s.Add("Content-Length", "374") 89 | s.Add("Context-Type", "text/html; charset=utf-8") 90 | s.Add("Etag", `"op14g6ae"`) 91 | s.Add("Last-Modified", "Wed, 26 Apr 2017 18:24:06 GMT") 92 | s.Add("Server", "Caddy") 93 | s.Add("Date", "Fri, 28 Apr 2017 15:54:01 GMT") 94 | s.Add("Accept-Ranges", "bytes") 95 | sourceHeaders = append(sourceHeaders, s) 96 | } 97 | b.ResetTimer() 98 | 99 | for n := 0; n < b.N; n++ { 100 | CopyHeaders(dstHeaders[n], sourceHeaders[n]) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /utils/source.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strings" 7 | ) 8 | 9 | // SourceExtractor extracts the source from the request, e.g. that may be client ip, or particular header that 10 | // identifies the source. amount stands for amount of connections the source consumes, usually 1 for connection limiters 11 | // error should be returned when source can not be identified. 12 | type SourceExtractor interface { 13 | Extract(req *http.Request) (token string, amount int64, err error) 14 | } 15 | 16 | // ExtractorFunc extractor function type. 17 | type ExtractorFunc func(req *http.Request) (token string, amount int64, err error) 18 | 19 | // Extract extract from request. 20 | func (f ExtractorFunc) Extract(req *http.Request) (string, int64, error) { 21 | return f(req) 22 | } 23 | 24 | // ExtractSource extract source function type. 25 | type ExtractSource func(req *http.Request) 26 | 27 | // NewExtractor creates a new SourceExtractor. 28 | func NewExtractor(variable string) (SourceExtractor, error) { 29 | if variable == "client.ip" { 30 | return ExtractorFunc(extractClientIP), nil 31 | } 32 | if variable == "request.host" { 33 | return ExtractorFunc(extractHost), nil 34 | } 35 | if strings.HasPrefix(variable, "request.header.") { 36 | header := strings.TrimPrefix(variable, "request.header.") 37 | if header == "" { 38 | return nil, fmt.Errorf("wrong header: %s", header) 39 | } 40 | return makeHeaderExtractor(header), nil 41 | } 42 | return nil, fmt.Errorf("unsupported limiting variable: '%s'", variable) 43 | } 44 | 45 | func extractClientIP(req *http.Request) (string, int64, error) { 46 | vals := strings.SplitN(req.RemoteAddr, ":", 2) 47 | if vals[0] == "" { 48 | return "", 0, fmt.Errorf("failed to parse client IP: %v", req.RemoteAddr) 49 | } 50 | return vals[0], 1, nil 51 | } 52 | 53 | func extractHost(req *http.Request) (string, int64, error) { 54 | return req.Host, 1, nil 55 | } 56 | 57 | func makeHeaderExtractor(header string) SourceExtractor { 58 | return ExtractorFunc(func(req *http.Request) (string, int64, error) { 59 | return req.Header.Get(header), 1, nil 60 | }) 61 | } 62 | --------------------------------------------------------------------------------