├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .ruby-version ├── CHANGELOG.md ├── CONTRIBUTING.md ├── MIT-LICENSE ├── Makefile ├── README.md ├── Rakefile ├── cmd └── thrust │ └── main.go ├── exe └── thrust ├── go.mod ├── go.sum ├── internal ├── cache_handler.go ├── cache_handler_test.go ├── cacheable_response.go ├── cacheable_response_test.go ├── config.go ├── config_test.go ├── fixtures │ ├── image.jpg │ └── loremipsum.txt ├── handler.go ├── handler_test.go ├── logging_middleware.go ├── logging_middleware_test.go ├── memory_cache.go ├── memory_cache_test.go ├── proxy_handler.go ├── sendfile_handler.go ├── sendfile_handler_test.go ├── server.go ├── service.go ├── testing.go ├── upstream_process.go ├── upstream_process_test.go ├── variant.go └── variant_test.go ├── lib ├── thruster.rb └── thruster │ └── version.rb ├── rakelib └── package.rake └── thruster.gemspec /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | 15 | - name: Set up Go 16 | uses: actions/setup-go@v4 17 | with: 18 | go-version: '1.24.x' 19 | 20 | - name: Install dependencies 21 | run: go mod download 22 | 23 | - name: Build 24 | run: go build -v ./... 25 | 26 | - name: Test 27 | run: go test -v ./... 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore Go build artifacts 2 | /bin 3 | /dist 4 | 5 | # Ignore Gem packaging artifacts 6 | /pkg 7 | /exe/ 8 | !/exe/thrust 9 | -------------------------------------------------------------------------------- /.ruby-version: -------------------------------------------------------------------------------- 1 | 3.3.0 2 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## v0.1.13 / 2025-04-21 2 | 3 | * Update deps to address CVEs (#74) 4 | * Build with Go 1.24.2 5 | 6 | ## v0.1.12 / 2025-03-10 7 | 8 | * Build with Go 1.23.7 (#69) 9 | 10 | ## v0.1.11 / 2025-02-11 11 | 12 | * Build with Go 1.23.6 (#67) 13 | * Allow disabling compression with env var (#56) 14 | 15 | ## v0.1.10 / 2025-01-06 16 | 17 | * Avoid runtime glibc dependency in dist builds 18 | 19 | ## v0.1.9 / 2024-11-13 20 | 21 | * Build with Go 1.23.3 22 | 23 | ## v0.1.8 / 2024-08-06 24 | 25 | * Only forward X-Forwarded-* by default when not using TLS 26 | 27 | ## v0.1.7 / 2024-07-11 28 | 29 | * Preserve existing X-Forwarded-* headers when present 30 | 31 | ## v0.1.6 / 2024-07-10 32 | 33 | * Properly handle an empty TLS_DOMAIN value 34 | 35 | ## v0.1.5 / 2024-07-09 36 | 37 | * Fix bug where replacing existing cache items could lead to a crash during 38 | eviction 39 | * Accept comma-separated `TLS_DOMAIN` to support multiple domains (#28) 40 | * Populate `X-Forwarded-For`, `X-Forwarded-Host` and `X-Forwarded-Proto` 41 | headers (#29) 42 | 43 | ## v0.1.4 / 2024-04-26 44 | 45 | * [BREAKING] Rename the `SSL_DOMAIN` env to `TLS_DOMAIN` (#13) 46 | * Set `stdin` in upstream process (#18) 47 | 48 | ## v0.1.3 / 2024-03-21 49 | 50 | * Disable transparent proxy compression (#11) 51 | 52 | ## v0.1.2 / 2024-03-19 53 | 54 | * Don't cache `Range` requests 55 | 56 | ## v0.1.1 / 2024-03-18 57 | 58 | * Ensure `Content-Length` set correctly in `X-Sendfile` responses (#10) 59 | 60 | ## v0.1.0 / 2024-03-07 61 | 62 | * Build with Go 1.22.1 63 | * Use stdlib `MaxBytesHandler` for request size limiting 64 | 65 | ## v0.0.3 / 2024-03-06 66 | 67 | * Support additional ACME providers 68 | * Respond with `413`, not `400` when blocking oversized requests 69 | * Allow prefixing env vars with `THRUSTER_` to avoid naming clashes 70 | * Additional debug-level logging 71 | 72 | ## v0.0.2 / 2024-02-28 73 | 74 | * Support `Vary` header in HTTP caching 75 | * Return `X-Cache` `bypass` when not caching request 76 | 77 | ## v0.0.1 / 2024-02-14 78 | 79 | * Initial version 80 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Thruster 2 | 3 | Thruster is a Go application. It is packaged as a Ruby gem to make it easy to use with Rails applications. 4 | 5 | 6 | ## Running the tests 7 | 8 | You can run the test suite using the Makefile: 9 | 10 | make test 11 | 12 | You can also run individual tests using Go's test runner. For example: 13 | 14 | go test -v -run ^TestVariantMatches_multiple_headers$ ./... 15 | 16 | 17 | ## Running & building the application 18 | 19 | You can run the application using `go run`: 20 | 21 | go run ./cmd/thrust 22 | 23 | You can also build for the current environment using the Makefile: 24 | 25 | make build 26 | 27 | This will create a binary in the `bin/` directory. 28 | 29 | To build binaries for all supported architectures and operating systems, use: 30 | 31 | make dist 32 | 33 | This will create a `dist/` directory with binaries for each platform. 34 | 35 | 36 | ## Publishing a release 37 | 38 | In order to ship the platform-specific binary inside a gem, we actually build 39 | multiple gems, one for each platform. The `rake release` task will build all the 40 | necessary gems. 41 | 42 | The comlete steps for releasing a new version are: 43 | 44 | - Update the version & changelog: 45 | - [ ] update `lib/thruster/version.rb` 46 | - [ ] update `CHANGELOG.md` 47 | - [ ] commit and create a git tag 48 | 49 | - Build the native gems: 50 | - [ ] `rake clobber` (to clean up any old packages) 51 | - [ ] `rake package` 52 | 53 | - Push gems: 54 | - [ ] `for g in pkg/*.gem ; do gem push $g ; done` 55 | - [ ] `git push` 56 | 57 | -------------------------------------------------------------------------------- /MIT-LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 37signals, LLC 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build dist test bench clean 2 | 3 | PLATFORMS = linux darwin 4 | ARCHITECTURES = amd64 arm64 5 | 6 | build: 7 | go build -o bin/ ./cmd/... 8 | 9 | dist: 10 | @for platform in $(PLATFORMS); do \ 11 | for arch in $(ARCHITECTURES); do \ 12 | GOOS=$$platform GOARCH=$$arch CGO_ENABLED=0 go build -trimpath -o dist/thrust-$$platform-$$arch ./cmd/...; \ 13 | done \ 14 | done 15 | 16 | test: 17 | go test ./... 18 | 19 | bench: 20 | go test -bench=. -benchmem -run=^# ./... 21 | 22 | clean: 23 | rm -rf bin dist 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Thruster 2 | 3 | Thruster is an HTTP/2 proxy for simple production-ready deployments of Rails 4 | applications. It runs alongside the Puma webserver to provide a few additional 5 | features that help your app run efficiently and safely on the open Internet: 6 | 7 | - HTTP/2 support 8 | - Automatic TLS certificate management with Let's Encrypt 9 | - Basic HTTP caching of public assets 10 | - X-Sendfile support and compression, to efficiently serve static files 11 | 12 | Thruster aims to be as zero-config as possible. It has no configuration file, 13 | and most features are automatically enabled with sensible defaults. The goal is 14 | that simply running your Puma server with Thruster should be enough to get a 15 | production-ready setup. 16 | 17 | The only exception to this is TLS provisioning: in order for Thruster to 18 | provision TLS certificates, it needs to know which domain those certificates 19 | should be for. So to use TLS, you need to set the `TLS_DOMAIN` environment 20 | variable. If you don't set this variable, Thruster will run in HTTP-only mode. 21 | 22 | Thruster also wraps the Puma process so that you can use it without managing 23 | multiple processes yourself. This is particularly useful when running in a 24 | containerized environment, where you typically won't have a process manager 25 | available to coordinate the processes. Instead you can use Thruster as your 26 | `CMD`, and it will manage Puma for you. 27 | 28 | Thruster was originally created for the [ONCE](https://once.com) project, where 29 | we wanted a no-fuss way to serve a Rails application from a single container, 30 | directly on the open Internet. We've since found it useful for simple 31 | deployments of other Rails applications. 32 | 33 | 34 | ## Installation 35 | 36 | Thruster is distributed as a Ruby gem. Because Thruster is written in Go, we 37 | provide several pre-built platform-specific binaries. Installing the gem will 38 | automatically fetch the appropriate binary for your platform. 39 | 40 | To install it, add it to your application's Gemfile: 41 | 42 | ```ruby 43 | gem 'thruster' 44 | ``` 45 | 46 | Or install it globally: 47 | 48 | ```sh 49 | $ gem install thruster 50 | ``` 51 | 52 | 53 | ## Usage 54 | 55 | To run your Puma application inside Thruster, prefix your usual command string 56 | with `thrust`. For example: 57 | 58 | ```sh 59 | $ thrust bin/rails server 60 | ``` 61 | 62 | Or with automatic TLS: 63 | 64 | ```sh 65 | $ TLS_DOMAIN=myapp.example.com thrust bin/rails server 66 | ``` 67 | 68 | 69 | ## Custom configuration 70 | 71 | In most cases, Thruster should work out of the box with no additional 72 | configuration. But if you need to customize its behavior, there are a few 73 | environment variables that you can set. 74 | 75 | | Variable Name | Description | Default Value | 76 | |-----------------------------|---------------------------------------------------------|---------------| 77 | | `TLS_DOMAIN` | Comma-separated list of domain names to use for TLS provisioning. If not set, TLS will be disabled. | None | 78 | | `TARGET_PORT` | The port that your Puma server should run on. Thruster will set `PORT` to this value when starting your server. | 3000 | 79 | | `CACHE_SIZE` | The size of the HTTP cache in bytes. | 64MB | 80 | | `MAX_CACHE_ITEM_SIZE` | The maximum size of a single item in the HTTP cache in bytes. | 1MB | 81 | | `GZIP_COMPRESSION_ENABLED` | Whether to enable gzip compression for static assets. Set to `0` or `false` to disable. | Enabled | 82 | | `X_SENDFILE_ENABLED` | Whether to enable X-Sendfile support. Set to `0` or `false` to disable. | Enabled | 83 | | `MAX_REQUEST_BODY` | The maximum size of a request body in bytes. Requests larger than this size will be refused; `0` means no maximum size is enforced. | `0` | 84 | | `STORAGE_PATH` | The path to store Thruster's internal state. Provisioned TLS certificates will be stored here, so that they will not need to be requested every time your application is started. | `./storage/thruster` | 85 | | `BAD_GATEWAY_PAGE` | Path to an HTML file to serve when the backend server returns a 502 Bad Gateway error. If there is no file at the specific path, Thruster will serve an empty 502 response instead. Because Thruster boots very quickly, a custom page can be a useful way to show that your application is starting up. | `./public/502.html` | 86 | | `HTTP_PORT` | The port to listen on for HTTP traffic. | 80 | 87 | | `HTTPS_PORT` | The port to listen on for HTTPS traffic. | 443 | 88 | | `HTTP_IDLE_TIMEOUT` | The maximum time in seconds that a client can be idle before the connection is closed. | 60 | 89 | | `HTTP_READ_TIMEOUT` | The maximum time in seconds that a client can take to send the request headers and body. | 30 | 90 | | `HTTP_WRITE_TIMEOUT` | The maximum time in seconds during which the client must read the response. | 30 | 91 | | `ACME_DIRECTORY` | The URL of the ACME directory to use for TLS certificate provisioning. | `https://acme-v02.api.letsencrypt.org/directory` (Let's Encrypt production) | 92 | | `EAB_KID` | The EAB key identifier to use when provisioning TLS certificates, if required. | None | 93 | | `EAB_HMAC_KEY` | The Base64-encoded EAB HMAC key to use when provisioning TLS certificates, if required. | None | 94 | | `FORWARD_HEADERS` | Whether to forward X-Forwarded-* headers from the client. | Disabled when running with TLS; enabled otherwise | 95 | | `DEBUG` | Set to `1` or `true` to enable debug logging. | Disabled | 96 | 97 | To prevent naming clashes with your application's own environment variables, 98 | Thruster's environment variables can optionally be prefixed with `THRUSTER_`. 99 | For example, `TLS_DOMAIN` can also be written as `THRUSTER_TLS_DOMAIN`. Whenever 100 | a prefixed variable is set, it will take precedence over the unprefixed version. 101 | -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | require "bundler/setup" 2 | require "bundler/gem_tasks" 3 | -------------------------------------------------------------------------------- /cmd/thrust/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log/slog" 6 | "os" 7 | 8 | "github.com/basecamp/thruster/internal" 9 | ) 10 | 11 | func setLogger(level slog.Level) { 12 | slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: level}))) 13 | } 14 | 15 | func main() { 16 | config, err := internal.NewConfig() 17 | if err != nil { 18 | fmt.Printf("ERROR: %s\n", err) 19 | os.Exit(1) 20 | } 21 | 22 | setLogger(config.LogLevel) 23 | 24 | service := internal.NewService(config) 25 | os.Exit(service.Run()) 26 | } 27 | -------------------------------------------------------------------------------- /exe/thrust: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env ruby 2 | 3 | PLATFORM = [ :cpu, :os ].map { |m| Gem::Platform.local.send(m) }.join("-") 4 | EXECUTABLE = File.expand_path(File.join(__dir__, PLATFORM, "thrust")) 5 | 6 | if File.exist?(EXECUTABLE) 7 | exec(EXECUTABLE, *ARGV) 8 | else 9 | STDERR.puts("ERROR: Unsupported platform: #{PLATFORM}") 10 | exit 1 11 | end 12 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/basecamp/thruster 2 | 3 | go 1.24.2 4 | 5 | require ( 6 | github.com/klauspost/compress v1.17.4 7 | github.com/stretchr/testify v1.8.4 8 | golang.org/x/crypto v0.37.0 9 | ) 10 | 11 | require ( 12 | github.com/davecgh/go-spew v1.1.1 // indirect 13 | github.com/kr/text v0.2.0 // indirect 14 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect 15 | github.com/pmezard/go-difflib v1.0.0 // indirect 16 | golang.org/x/net v0.39.0 // indirect 17 | golang.org/x/text v0.24.0 // indirect 18 | gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect 19 | gopkg.in/yaml.v3 v3.0.1 // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= 5 | github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= 6 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 7 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 8 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 9 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 10 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= 11 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= 12 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 13 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 14 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 15 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 16 | golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= 17 | golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= 18 | golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= 19 | golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= 20 | golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= 21 | golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= 22 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 23 | gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= 24 | gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 25 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 26 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 27 | -------------------------------------------------------------------------------- /internal/cache_handler.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "log/slog" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | type CacheKey uint64 10 | 11 | type Cache interface { 12 | Get(key CacheKey) ([]byte, bool) 13 | Set(key CacheKey, value []byte, expiresAt time.Time) 14 | } 15 | 16 | type CacheHandler struct { 17 | cache Cache 18 | next http.Handler 19 | maxBodySize int 20 | } 21 | 22 | func NewCacheHandler(cache Cache, maxBodySize int, next http.Handler) *CacheHandler { 23 | return &CacheHandler{ 24 | cache: cache, 25 | next: next, 26 | maxBodySize: maxBodySize, 27 | } 28 | } 29 | 30 | func (h *CacheHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 31 | variant := NewVariant(r) 32 | response, key, found := h.fetchFromCache(r, variant) 33 | 34 | if found { 35 | variant.SetResponseHeader(response.HttpHeader) 36 | if !variant.Matches(response.VariantHeader) { 37 | response, key, found = h.fetchFromCache(r, variant) 38 | } 39 | } 40 | 41 | if found { 42 | response.WriteCachedResponse(w, r) 43 | return 44 | } 45 | 46 | if !h.shouldCacheRequest(r) { 47 | slog.Debug("Bypassing cache for request", "path", r.URL.Path, "method", r.Method) 48 | w.Header().Set("X-Cache", "bypass") 49 | h.next.ServeHTTP(w, r) 50 | return 51 | } 52 | 53 | cr := NewCacheableResponse(w, h.maxBodySize) 54 | h.next.ServeHTTP(cr, r) 55 | 56 | cacheable, expires := cr.CacheStatus() 57 | if cacheable { 58 | variant.SetResponseHeader(cr.HttpHeader) 59 | cr.VariantHeader = variant.VariantHeader() 60 | 61 | encoded, err := cr.ToBuffer() 62 | if err != nil { 63 | slog.Error("Failed to encode response for caching", "path", r.URL.Path, "error", err) 64 | } else { 65 | h.cache.Set(key, encoded, expires) 66 | slog.Debug("Added response to cache", "path", r.URL.Path, "key", key, "expires", expires, "size", len(encoded)) 67 | } 68 | } 69 | } 70 | 71 | // Private 72 | 73 | func (h *CacheHandler) fetchFromCache(r *http.Request, variant *Variant) (CacheableResponse, CacheKey, bool) { 74 | key := variant.CacheKey() 75 | cached, found := h.cache.Get(key) 76 | 77 | if found { 78 | response, err := CacheableResponseFromBuffer(cached) 79 | if err != nil { 80 | slog.Error("Failed to decode cached response", "path", r.URL.Path, "error", err) 81 | return CacheableResponse{}, key, false 82 | } 83 | 84 | return response, key, true 85 | } 86 | 87 | return CacheableResponse{}, key, false 88 | } 89 | 90 | func (h *CacheHandler) shouldCacheRequest(r *http.Request) bool { 91 | allowedMethod := r.Method == http.MethodGet || r.Method == http.MethodHead 92 | isUpgrade := r.Header.Get("Connection") == "Upgrade" || r.Header.Get("Upgrade") == "websocket" 93 | isRange := r.Header.Get("Range") != "" 94 | 95 | return allowedMethod && !isUpgrade && !isRange 96 | } 97 | -------------------------------------------------------------------------------- /internal/cache_handler_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestCacheHandler_caching(t *testing.T) { 14 | tests := map[string]struct { 15 | req *http.Request 16 | cacheControl string 17 | expectedResponses []string 18 | expectedHits []string 19 | expectedCacheLength int 20 | }{ 21 | "cacheable": { 22 | httptest.NewRequest("GET", "http://example.com", nil), 23 | "public, max-age=60", 24 | []string{"Hello 1", "Hello 1", "Hello 1"}, 25 | []string{"miss", "hit", "hit"}, 26 | 1, 27 | }, 28 | "cacheable with s-max-age": { 29 | httptest.NewRequest("GET", "http://example.com", nil), 30 | "public, s-max-age=60", 31 | []string{"Hello 1", "Hello 1", "Hello 1"}, 32 | []string{"miss", "hit", "hit"}, 33 | 1, 34 | }, 35 | "uncacheable response": { 36 | httptest.NewRequest("GET", "http://example.com", nil), 37 | "private", 38 | []string{"Hello 1", "Hello 2", "Hello 3"}, 39 | []string{"miss", "miss", "miss"}, 40 | 0, 41 | }, 42 | "uncacheable request": { 43 | httptest.NewRequest("POST", "http://example.com", nil), 44 | "public, max-age=60", 45 | []string{"Hello 1", "Hello 2", "Hello 3"}, 46 | []string{"bypass", "bypass", "bypass"}, 47 | 0, 48 | }, 49 | } 50 | 51 | for name, tc := range tests { 52 | t.Run(name, func(t *testing.T) { 53 | cache := newTestCache() 54 | counter := 0 55 | hits := []string{} 56 | 57 | handler := NewCacheHandler(cache, 1024, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 58 | counter++ 59 | w.Header().Set("Cache-Control", tc.cacheControl) 60 | fmt.Fprintf(w, "Hello %d", counter) 61 | })) 62 | 63 | for _, expectedResponse := range tc.expectedResponses { 64 | w := httptest.NewRecorder() 65 | handler.ServeHTTP(w, tc.req) 66 | hits = append(hits, w.Result().Header.Get("X-Cache")) 67 | 68 | assert.Equal(t, expectedResponse, w.Body.String()) 69 | } 70 | 71 | assert.Equal(t, tc.expectedHits, hits) 72 | assert.Equal(t, tc.expectedCacheLength, len(cache.items)) 73 | }) 74 | } 75 | } 76 | 77 | func TestCacheHandler_keying(t *testing.T) { 78 | tests := map[string]struct { 79 | paths []string 80 | methods []string 81 | expectedHits []string 82 | }{ 83 | "path": { 84 | []string{"http://example.com/one", "http://example.com/two", "http://example.com/three", "http://example.com/three"}, 85 | []string{http.MethodGet, http.MethodGet, http.MethodGet, http.MethodGet}, 86 | []string{"miss", "miss", "miss", "hit"}, 87 | }, 88 | "query string": { 89 | []string{"http://example.com?name=kevin", "http://example.com?name=kevin", "http://example.com?name=bob"}, 90 | []string{http.MethodGet, http.MethodGet, http.MethodGet}, 91 | []string{"miss", "hit", "miss"}, 92 | }, 93 | "query string ordering": { 94 | []string{"http://example.com?a=1&b=2", "http://example.com?a=1&b=2", "http://example.com?b=2&a=1"}, 95 | []string{http.MethodGet, http.MethodGet, http.MethodGet}, 96 | []string{"miss", "hit", "hit"}, 97 | }, 98 | "method": { 99 | []string{"http://example.com/one", "http://example.com/one", "http://example.com/one"}, 100 | []string{http.MethodGet, http.MethodHead, http.MethodPost}, 101 | []string{"miss", "miss", "bypass"}, 102 | }, 103 | } 104 | 105 | for name, tc := range tests { 106 | t.Run(name, func(t *testing.T) { 107 | cache := newTestCache() 108 | handler := NewCacheHandler(cache, 1024, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 109 | w.Header().Set("Cache-Control", "public, max-age=60") 110 | w.Write([]byte("Hello")) 111 | })) 112 | 113 | hits := []string{} 114 | 115 | for i, url := range tc.paths { 116 | w := httptest.NewRecorder() 117 | r := httptest.NewRequest(tc.methods[i], url, nil) 118 | handler.ServeHTTP(w, r) 119 | 120 | hits = append(hits, w.Result().Header.Get("X-Cache")) 121 | } 122 | 123 | assert.Equal(t, tc.expectedHits, hits) 124 | }) 125 | } 126 | } 127 | 128 | func TestCacheHandler_vary_header(t *testing.T) { 129 | cache := newTestCache() 130 | handler := NewCacheHandler(cache, 1024, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 131 | contentType := r.Header.Get("Accept") 132 | w.Header().Set("Vary", "Accept") 133 | w.Header().Set("Cache-Control", "public, max-age=600") 134 | w.Header().Set("Content-Type", contentType) 135 | w.Write([]byte(contentType)) 136 | })) 137 | 138 | doReq := func(accept string, other string) *httptest.ResponseRecorder { 139 | w := httptest.NewRecorder() 140 | r := httptest.NewRequest("GET", "http://example.com", nil) 141 | r.Header.Set("Accept", accept) 142 | r.Header.Set("Other", other) 143 | handler.ServeHTTP(w, r) 144 | return w 145 | } 146 | 147 | resp := doReq("application/json", "a") 148 | assert.Equal(t, "application/json", resp.Body.String()) 149 | assert.Equal(t, "miss", resp.Header().Get("X-Cache")) 150 | 151 | resp = doReq("application/json", "b") 152 | assert.Equal(t, "application/json", resp.Body.String()) 153 | assert.Equal(t, "hit", resp.Header().Get("X-Cache")) 154 | 155 | resp = doReq("text/plain", "a") 156 | assert.Equal(t, "text/plain", resp.Body.String()) 157 | assert.Equal(t, "miss", resp.Header().Get("X-Cache")) 158 | 159 | resp = doReq("text/plain", "a") 160 | assert.Equal(t, "text/plain", resp.Body.String()) 161 | assert.Equal(t, "hit", resp.Header().Get("X-Cache")) 162 | 163 | resp = doReq("application/json", "b") 164 | assert.Equal(t, "application/json", resp.Body.String()) 165 | assert.Equal(t, "hit", resp.Header().Get("X-Cache")) 166 | } 167 | 168 | func TestCacheHandler_range_requests_are_not_cached(t *testing.T) { 169 | cache := newTestCache() 170 | 171 | handler := NewCacheHandler(cache, 1024, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 172 | w.Header().Set("Cache-Control", "public, max-age=60") 173 | http.ServeFile(w, r, fixturePath("image.jpg")) 174 | })) 175 | 176 | w := httptest.NewRecorder() 177 | r := httptest.NewRequest("GET", "/", nil) 178 | r.Header.Set("Range", "bytes=0-1") 179 | handler.ServeHTTP(w, r) 180 | 181 | assert.Equal(t, http.StatusPartialContent, w.Code) 182 | assert.Equal(t, "2", w.Header().Get("Content-Length")) 183 | assert.Equal(t, fixtureContent("image.jpg")[:2], w.Body.Bytes()) 184 | assert.Equal(t, "bypass", w.Header().Get("X-Cache")) 185 | 186 | w = httptest.NewRecorder() 187 | r = httptest.NewRequest("GET", "/", nil) 188 | r.Header.Set("Range", "bytes=2-5") 189 | handler.ServeHTTP(w, r) 190 | 191 | assert.Equal(t, http.StatusPartialContent, w.Code) 192 | assert.Equal(t, "4", w.Header().Get("Content-Length")) 193 | assert.Equal(t, fixtureContent("image.jpg")[2:6], w.Body.Bytes()) 194 | assert.Equal(t, "bypass", w.Header().Get("X-Cache")) 195 | } 196 | 197 | func BenchmarkCacheHandler_retrieving(b *testing.B) { 198 | cache := NewMemoryCache(1*MB, 1*MB) 199 | 200 | handler := NewCacheHandler(cache, 1024, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 201 | w.Header().Set("Cache-Control", "public, max-age=600") 202 | w.Write([]byte("Hello")) 203 | })) 204 | 205 | for i := 0; i < b.N; i++ { 206 | w := httptest.NewRecorder() 207 | r := httptest.NewRequest("GET", "/", nil) 208 | handler.ServeHTTP(w, r) 209 | } 210 | } 211 | 212 | // Mocks 213 | 214 | type testCache struct { 215 | items map[CacheKey][]byte 216 | } 217 | 218 | func newTestCache() *testCache { 219 | return &testCache{items: make(map[CacheKey][]byte)} 220 | } 221 | 222 | func (t *testCache) Get(key CacheKey) ([]byte, bool) { 223 | item, found := t.items[key] 224 | return item, found 225 | } 226 | 227 | func (t *testCache) Set(key CacheKey, value []byte, expiresAt time.Time) { 228 | t.items[key] = value 229 | } 230 | -------------------------------------------------------------------------------- /internal/cacheable_response.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "bytes" 5 | "encoding/gob" 6 | "io" 7 | "net/http" 8 | "regexp" 9 | "strconv" 10 | "strings" 11 | "time" 12 | ) 13 | 14 | var ( 15 | publicExp = regexp.MustCompile(`\bpublic\b`) 16 | noCacheExpt = regexp.MustCompile(`\bno-cache\b`) 17 | sMaxAgeExp = regexp.MustCompile(`\bs-max-age=(\d+)\b`) 18 | maxAgeExp = regexp.MustCompile(`\bmax-age=(\d+)\b`) 19 | ) 20 | 21 | type CacheableResponse struct { 22 | StatusCode int 23 | HttpHeader http.Header 24 | Body []byte 25 | VariantHeader http.Header 26 | 27 | responseWriter http.ResponseWriter 28 | stasher *stashingWriter 29 | headersWritten bool 30 | } 31 | 32 | func NewCacheableResponse(w http.ResponseWriter, maxBodyLength int) *CacheableResponse { 33 | return &CacheableResponse{ 34 | StatusCode: http.StatusOK, 35 | HttpHeader: http.Header{}, 36 | 37 | responseWriter: w, 38 | stasher: NewStashingWriter(maxBodyLength, w), 39 | } 40 | } 41 | 42 | func CacheableResponseFromBuffer(b []byte) (CacheableResponse, error) { 43 | var cr CacheableResponse 44 | decoder := gob.NewDecoder(bytes.NewReader(b)) 45 | err := decoder.Decode(&cr) 46 | 47 | return cr, err 48 | } 49 | 50 | func (c *CacheableResponse) ToBuffer() ([]byte, error) { 51 | c.Body = c.stasher.Body() 52 | 53 | var b bytes.Buffer 54 | encoder := gob.NewEncoder(&b) 55 | err := encoder.Encode(c) 56 | 57 | return b.Bytes(), err 58 | } 59 | 60 | func (c *CacheableResponse) Header() http.Header { 61 | return c.HttpHeader 62 | } 63 | 64 | func (c *CacheableResponse) Write(bytes []byte) (int, error) { 65 | if !c.headersWritten { 66 | c.WriteHeader(http.StatusOK) 67 | } 68 | return c.stasher.Write(bytes) 69 | } 70 | 71 | func (c *CacheableResponse) WriteHeader(statusCode int) { 72 | c.StatusCode = statusCode 73 | c.scrubHeaders() 74 | c.copyHeaders(c.responseWriter, false, c.StatusCode) 75 | c.headersWritten = true 76 | } 77 | 78 | func (c *CacheableResponse) CacheStatus() (bool, time.Time) { 79 | if c.stasher.Overflowed() { 80 | return false, time.Time{} 81 | } 82 | 83 | if c.StatusCode < 200 || c.StatusCode > 399 || c.StatusCode == http.StatusNotModified { 84 | return false, time.Time{} 85 | } 86 | 87 | if strings.Contains(c.HttpHeader.Get("Vary"), "*") { 88 | return false, time.Time{} 89 | } 90 | 91 | cc := c.HttpHeader.Get("Cache-Control") 92 | 93 | if !publicExp.MatchString(cc) || noCacheExpt.MatchString(cc) { 94 | return false, time.Time{} 95 | } 96 | 97 | matches := sMaxAgeExp.FindStringSubmatch(cc) 98 | if len(matches) != 2 { 99 | matches = maxAgeExp.FindStringSubmatch(cc) 100 | } 101 | if len(matches) != 2 { 102 | return false, time.Time{} 103 | } 104 | 105 | maxAge, err := strconv.Atoi(matches[1]) 106 | if err != nil || maxAge <= 0 { 107 | return false, time.Time{} 108 | } 109 | 110 | return true, time.Now().Add(time.Duration(maxAge) * time.Second) 111 | } 112 | 113 | func (c *CacheableResponse) WriteCachedResponse(w http.ResponseWriter, r *http.Request) { 114 | if c.wasNotModified(r) { 115 | c.copyHeaders(w, true, http.StatusNotModified) 116 | } else { 117 | c.copyHeaders(w, true, c.StatusCode) 118 | io.Copy(w, bytes.NewReader(c.Body)) 119 | } 120 | } 121 | 122 | // Private 123 | 124 | func (c *CacheableResponse) wasNotModified(r *http.Request) bool { 125 | requestEtag := c.HttpHeader.Get("Etag") 126 | if requestEtag == "" { 127 | return false 128 | } 129 | 130 | ifNoneMatch := strings.Split(r.Header.Get("If-None-Match"), ",") 131 | for _, etag := range ifNoneMatch { 132 | if strings.TrimSpace(etag) == requestEtag { 133 | return true 134 | } 135 | } 136 | 137 | return false 138 | } 139 | 140 | func (c *CacheableResponse) copyHeaders(w http.ResponseWriter, wasHit bool, statusCode int) { 141 | for k, v := range c.HttpHeader { 142 | w.Header()[k] = v 143 | } 144 | 145 | if wasHit { 146 | w.Header().Set("X-Cache", "hit") 147 | } else { 148 | w.Header().Set("X-Cache", "miss") 149 | } 150 | 151 | w.WriteHeader(statusCode) 152 | } 153 | 154 | func (c *CacheableResponse) scrubHeaders() { 155 | cacheable, _ := c.CacheStatus() 156 | 157 | if cacheable { 158 | c.HttpHeader.Del("Set-Cookie") 159 | } 160 | } 161 | 162 | type stashingWriter struct { 163 | limit int 164 | dest io.Writer 165 | buffer bytes.Buffer 166 | overflowed bool 167 | } 168 | 169 | func NewStashingWriter(limit int, dest io.Writer) *stashingWriter { 170 | return &stashingWriter{ 171 | limit: limit, 172 | dest: dest, 173 | } 174 | } 175 | 176 | func (w *stashingWriter) Write(p []byte) (int, error) { 177 | if w.buffer.Len()+len(p) > w.limit { 178 | w.overflowed = true 179 | } else { 180 | w.buffer.Write(p) 181 | } 182 | 183 | return w.dest.Write(p) 184 | } 185 | 186 | func (w *stashingWriter) Body() []byte { 187 | if w.overflowed { 188 | return nil 189 | } 190 | return w.buffer.Bytes() 191 | } 192 | 193 | func (w *stashingWriter) Overflowed() bool { 194 | return w.overflowed 195 | } 196 | -------------------------------------------------------------------------------- /internal/cacheable_response_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestCacheableResponse_cache_headers(t *testing.T) { 14 | tests := map[string]struct { 15 | cacheControl string 16 | cacheable bool 17 | }{ 18 | "public, with max-age": { 19 | cacheControl: "public, max-age=60", 20 | cacheable: true, 21 | }, 22 | 23 | "public, with s-max-age": { 24 | cacheControl: "public, s-max-age=60", 25 | cacheable: true, 26 | }, 27 | 28 | "public, with max-age of zero": { 29 | cacheControl: "public, max-age=0", 30 | cacheable: false, 31 | }, 32 | 33 | "public, with no max-age": { 34 | cacheControl: "public", 35 | cacheable: false, 36 | }, 37 | 38 | "private, with max-age": { 39 | cacheControl: "private, max-age=60", 40 | cacheable: false, 41 | }, 42 | 43 | "max-age, but no public specified": { 44 | cacheControl: "max-age=60", 45 | cacheable: false, 46 | }, 47 | 48 | "public, with max-age, but also no-cache": { 49 | cacheControl: "public, max-age=60, no-cache", 50 | cacheable: false, 51 | }, 52 | } 53 | 54 | for name, test := range tests { 55 | t.Run(name, func(t *testing.T) { 56 | rec := httptest.NewRecorder() 57 | cr := NewCacheableResponse(rec, 1024) 58 | cr.Header().Set("Cache-Control", test.cacheControl) 59 | 60 | cacheable, _ := cr.CacheStatus() 61 | assert.Equal(t, test.cacheable, cacheable) 62 | }) 63 | } 64 | } 65 | 66 | func TestCacheableResponse_does_not_cache_items_with_wildcard_vary_header(t *testing.T) { 67 | rec := httptest.NewRecorder() 68 | cr := NewCacheableResponse(rec, 1024) 69 | cr.Header().Set("Cache-Control", "public, max-age=60") 70 | cr.Header().Set("Vary", "*") 71 | 72 | cacheable, _ := cr.CacheStatus() 73 | assert.False(t, cacheable) 74 | } 75 | 76 | func TestCacheableResponse_does_not_cache_items_where_body_too_large(t *testing.T) { 77 | rec := httptest.NewRecorder() 78 | cr := NewCacheableResponse(rec, 10) 79 | cr.Header().Set("Cache-Control", "public, max-age=60") 80 | cr.Write([]byte("12345678901234567890")) 81 | 82 | cacheable, _ := cr.CacheStatus() 83 | assert.False(t, cacheable) 84 | } 85 | 86 | func TestCacheableResponse_does_not_cache_304_responses(t *testing.T) { 87 | rec := httptest.NewRecorder() 88 | cr := NewCacheableResponse(rec, 1024) 89 | cr.Header().Set("Cache-Control", "public, max-age=60") 90 | cr.WriteHeader(http.StatusNotModified) 91 | 92 | cacheable, _ := cr.CacheStatus() 93 | assert.False(t, cacheable) 94 | } 95 | 96 | func TestCacheableResponse_writes_response_to_writer(t *testing.T) { 97 | w := httptest.NewRecorder() 98 | cr := NewCacheableResponse(w, 1024) 99 | cr.Header().Set("Cache-Control", "public, max-age=60") 100 | cr.WriteHeader(http.StatusCreated) 101 | cr.Write([]byte("Hello World")) 102 | 103 | assert.Equal(t, http.StatusCreated, w.Code) 104 | assert.Equal(t, "Hello World", w.Body.String()) 105 | assert.Equal(t, "public, max-age=60", w.Header().Get("Cache-Control")) 106 | assert.Equal(t, "miss", w.Header().Get("X-Cache")) 107 | } 108 | 109 | func TestCacheableResponse_writes_response_to_writer_even_when_too_large_to_cache(t *testing.T) { 110 | w := httptest.NewRecorder() 111 | cr := NewCacheableResponse(w, 10) 112 | cr.Header().Set("Cache-Control", "public, max-age=60") 113 | cr.WriteHeader(http.StatusCreated) 114 | cr.Write([]byte("12345678901234567890")) 115 | 116 | assert.Equal(t, http.StatusCreated, w.Code) 117 | assert.Equal(t, "12345678901234567890", w.Body.String()) 118 | assert.Equal(t, "public, max-age=60", w.Header().Get("Cache-Control")) 119 | assert.Equal(t, "miss", w.Header().Get("X-Cache")) 120 | } 121 | 122 | func TestCacheableResponse_write_cached_response(t *testing.T) { 123 | rec := httptest.NewRecorder() 124 | cr := NewCacheableResponse(rec, 1024) 125 | cr.Header().Set("Cache-Control", "public, max-age=60") 126 | cr.WriteHeader(http.StatusCreated) 127 | cr.Write([]byte("Hello World")) 128 | 129 | cr.ToBuffer() // Ensure the body is saved 130 | 131 | w := httptest.NewRecorder() 132 | r := httptest.NewRequest(http.MethodGet, "/", nil) 133 | cr.WriteCachedResponse(w, r) 134 | 135 | assert.Equal(t, http.StatusCreated, w.Code) 136 | assert.Equal(t, "Hello World", w.Body.String()) 137 | assert.Equal(t, "public, max-age=60", w.Header().Get("Cache-Control")) 138 | assert.Equal(t, "hit", w.Header().Get("X-Cache")) 139 | } 140 | 141 | func TestCacheableResponse_conditional_response(t *testing.T) { 142 | etag := `"deadbeef"` 143 | 144 | rec := httptest.NewRecorder() 145 | cr := NewCacheableResponse(rec, 1024) 146 | cr.Header().Set("Etag", etag) 147 | cr.WriteHeader(http.StatusOK) 148 | cr.Write([]byte("Hello World")) 149 | 150 | cr.ToBuffer() // Ensure the body is saved 151 | 152 | w := httptest.NewRecorder() 153 | r := httptest.NewRequest(http.MethodGet, "/", nil) 154 | r.Header.Set("If-None-Match", etag) 155 | cr.WriteCachedResponse(w, r) 156 | 157 | assert.Equal(t, http.StatusNotModified, w.Code) 158 | assert.Equal(t, "", w.Body.String()) 159 | assert.Equal(t, etag, w.Header().Get("Etag")) 160 | assert.Equal(t, "hit", w.Header().Get("X-Cache")) 161 | 162 | w = httptest.NewRecorder() 163 | r = httptest.NewRequest(http.MethodGet, "/", nil) 164 | r.Header.Set("If-None-Match", "\"another\", \"deadbeef\"") 165 | cr.WriteCachedResponse(w, r) 166 | 167 | assert.Equal(t, http.StatusNotModified, w.Code) 168 | assert.Equal(t, "", w.Body.String()) 169 | assert.Equal(t, etag, w.Header().Get("Etag")) 170 | assert.Equal(t, "hit", w.Header().Get("X-Cache")) 171 | } 172 | 173 | func TestCacheableResponse_conditional_response_none_match(t *testing.T) { 174 | rec := httptest.NewRecorder() 175 | cr := NewCacheableResponse(rec, 1024) 176 | cr.Header().Set("Etag", "ffffffff") 177 | cr.WriteHeader(http.StatusOK) 178 | cr.Write([]byte("Hello World")) 179 | 180 | cr.ToBuffer() // Ensure the body is saved 181 | 182 | w := httptest.NewRecorder() 183 | r := httptest.NewRequest(http.MethodGet, "/", nil) 184 | r.Header.Set("If-None-Match", "deadbeef") 185 | cr.WriteCachedResponse(w, r) 186 | 187 | assert.Equal(t, http.StatusOK, w.Code) 188 | assert.Equal(t, "Hello World", w.Body.String()) 189 | assert.Equal(t, "ffffffff", w.Header().Get("Etag")) 190 | assert.Equal(t, "hit", w.Header().Get("X-Cache")) 191 | } 192 | 193 | func TestCacheableResponse_conditional_response_no_etag_in_request(t *testing.T) { 194 | rec := httptest.NewRecorder() 195 | cr := NewCacheableResponse(rec, 1024) 196 | cr.Header().Set("Etag", "ffffffff") 197 | cr.WriteHeader(http.StatusOK) 198 | cr.Write([]byte("Hello World")) 199 | 200 | cr.ToBuffer() // Ensure the body is saved 201 | 202 | w := httptest.NewRecorder() 203 | r := httptest.NewRequest(http.MethodGet, "/", nil) 204 | cr.WriteCachedResponse(w, r) 205 | 206 | assert.Equal(t, http.StatusOK, w.Code) 207 | assert.Equal(t, "Hello World", w.Body.String()) 208 | assert.Equal(t, "ffffffff", w.Header().Get("Etag")) 209 | assert.Equal(t, "hit", w.Header().Get("X-Cache")) 210 | } 211 | 212 | func TestCacheableResponse_conditional_response_no_etag_in_response(t *testing.T) { 213 | rec := httptest.NewRecorder() 214 | cr := NewCacheableResponse(rec, 1024) 215 | cr.WriteHeader(http.StatusOK) 216 | cr.Write([]byte("Hello World")) 217 | 218 | cr.ToBuffer() // Ensure the body is saved 219 | 220 | w := httptest.NewRecorder() 221 | r := httptest.NewRequest(http.MethodGet, "/", nil) 222 | r.Header.Set("If-None-Match", "deadbeef") 223 | cr.WriteCachedResponse(w, r) 224 | 225 | assert.Equal(t, http.StatusOK, w.Code) 226 | assert.Equal(t, "Hello World", w.Body.String()) 227 | assert.Empty(t, w.Header().Get("Etag")) 228 | assert.Equal(t, "hit", w.Header().Get("X-Cache")) 229 | } 230 | 231 | func TestCacheableResponse_scrubs_cookies_from_cacheable_responses(t *testing.T) { 232 | rec := httptest.NewRecorder() 233 | cr := NewCacheableResponse(rec, 1024) 234 | cr.Header().Set("Cache-Control", "public, max-age=60") 235 | cr.Header().Set("Set-Cookie", "user=1234; Path=/; HttpOnly") 236 | cr.WriteHeader(http.StatusOK) 237 | 238 | w := httptest.NewRecorder() 239 | r := httptest.NewRequest(http.MethodGet, "/", nil) 240 | 241 | cr.WriteCachedResponse(w, r) 242 | 243 | assert.Equal(t, http.StatusOK, w.Code) 244 | assert.Empty(t, w.Header().Get("Set-Cookie")) 245 | } 246 | 247 | func TestCacheableResponse_does_not_scrub_cookies_from_non_cacheable_responses(t *testing.T) { 248 | rec := httptest.NewRecorder() 249 | cr := NewCacheableResponse(rec, 1024) 250 | cr.Header().Set("Set-Cookie", "user=1234; Path=/; HttpOnly") 251 | cr.WriteHeader(http.StatusOK) 252 | 253 | w := httptest.NewRecorder() 254 | r := httptest.NewRequest(http.MethodGet, "/", nil) 255 | 256 | cr.WriteCachedResponse(w, r) 257 | 258 | assert.Equal(t, http.StatusOK, w.Code) 259 | assert.Equal(t, "user=1234; Path=/; HttpOnly", w.Header().Get("Set-Cookie")) 260 | } 261 | 262 | func TestCacheableResponse_serialization(t *testing.T) { 263 | rec := httptest.NewRecorder() 264 | cr := NewCacheableResponse(rec, 1024) 265 | cr.Header().Set("Cache-Control", "public, max-age=60") 266 | cr.WriteHeader(http.StatusCreated) 267 | cr.Write([]byte("Hello World")) 268 | 269 | saved, err := cr.ToBuffer() 270 | assert.NoError(t, err) 271 | 272 | restored, err := CacheableResponseFromBuffer(saved) 273 | assert.NoError(t, err) 274 | 275 | assert.Equal(t, cr.StatusCode, restored.StatusCode) 276 | assert.Equal(t, cr.Header(), restored.Header()) 277 | assert.Equal(t, cr.Body, restored.Body) 278 | } 279 | 280 | func TestStashingWriter_writing_within_limit(t *testing.T) { 281 | writer := &bytes.Buffer{} 282 | sw := NewStashingWriter(10, writer) 283 | 284 | written, err := sw.Write([]byte("12345")) 285 | require.NoError(t, err) 286 | assert.Equal(t, 5, written) 287 | 288 | assert.Equal(t, "12345", writer.String()) 289 | assert.Equal(t, []byte("12345"), sw.Body()) 290 | assert.False(t, sw.Overflowed()) 291 | } 292 | 293 | func TestStashingWriter_writing_over_limit(t *testing.T) { 294 | writer := &bytes.Buffer{} 295 | sw := NewStashingWriter(10, writer) 296 | 297 | written, err := sw.Write([]byte("12345678901234567890")) 298 | require.NoError(t, err) 299 | assert.Equal(t, 20, written) 300 | 301 | assert.Equal(t, "12345678901234567890", writer.String()) 302 | assert.Nil(t, sw.Body()) 303 | assert.True(t, sw.Overflowed()) 304 | } 305 | 306 | func TestStashingWriter_writing_over_limit_in_small_pieces(t *testing.T) { 307 | writer := &bytes.Buffer{} 308 | sw := NewStashingWriter(10, writer) 309 | 310 | for i := 0; i < 10; i++ { 311 | written, err := sw.Write([]byte("12")) 312 | require.NoError(t, err) 313 | assert.Equal(t, 2, written) 314 | } 315 | 316 | assert.Equal(t, "12121212121212121212", writer.String()) 317 | assert.Nil(t, sw.Body()) 318 | assert.True(t, sw.Overflowed()) 319 | } 320 | -------------------------------------------------------------------------------- /internal/config.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "errors" 5 | "log/slog" 6 | "os" 7 | "strconv" 8 | "strings" 9 | "time" 10 | 11 | "golang.org/x/crypto/acme" 12 | ) 13 | 14 | const ( 15 | KB = 1024 16 | MB = 1024 * KB 17 | 18 | ENV_PREFIX = "THRUSTER_" 19 | 20 | defaultTargetPort = 3000 21 | 22 | defaultCacheSize = 64 * MB 23 | defaultMaxCacheItemSizeBytes = 1 * MB 24 | defaultMaxRequestBody = 0 25 | 26 | defaultACMEDirectoryURL = acme.LetsEncryptURL 27 | defaultStoragePath = "./storage/thruster" 28 | defaultBadGatewayPage = "./public/502.html" 29 | 30 | defaultHttpPort = 80 31 | defaultHttpsPort = 443 32 | defaultHttpIdleTimeout = 60 * time.Second 33 | defaultHttpReadTimeout = 30 * time.Second 34 | defaultHttpWriteTimeout = 30 * time.Second 35 | 36 | defaultLogLevel = slog.LevelInfo 37 | ) 38 | 39 | type Config struct { 40 | TargetPort int 41 | UpstreamCommand string 42 | UpstreamArgs []string 43 | 44 | CacheSizeBytes int 45 | MaxCacheItemSizeBytes int 46 | XSendfileEnabled bool 47 | GzipCompressionEnabled bool 48 | MaxRequestBody int 49 | 50 | TLSDomains []string 51 | ACMEDirectoryURL string 52 | EAB_KID string 53 | EAB_HMACKey string 54 | StoragePath string 55 | BadGatewayPage string 56 | 57 | HttpPort int 58 | HttpsPort int 59 | HttpIdleTimeout time.Duration 60 | HttpReadTimeout time.Duration 61 | HttpWriteTimeout time.Duration 62 | 63 | ForwardHeaders bool 64 | 65 | LogLevel slog.Level 66 | } 67 | 68 | func NewConfig() (*Config, error) { 69 | if len(os.Args) < 2 { 70 | return nil, errors.New("missing upstream command") 71 | } 72 | 73 | logLevel := defaultLogLevel 74 | if getEnvBool("DEBUG", false) { 75 | logLevel = slog.LevelDebug 76 | } 77 | 78 | config := &Config{ 79 | TargetPort: getEnvInt("TARGET_PORT", defaultTargetPort), 80 | UpstreamCommand: os.Args[1], 81 | UpstreamArgs: os.Args[2:], 82 | 83 | CacheSizeBytes: getEnvInt("CACHE_SIZE", defaultCacheSize), 84 | MaxCacheItemSizeBytes: getEnvInt("MAX_CACHE_ITEM_SIZE", defaultMaxCacheItemSizeBytes), 85 | XSendfileEnabled: getEnvBool("X_SENDFILE_ENABLED", true), 86 | GzipCompressionEnabled: getEnvBool("GZIP_COMPRESSION_ENABLED", true), 87 | MaxRequestBody: getEnvInt("MAX_REQUEST_BODY", defaultMaxRequestBody), 88 | 89 | TLSDomains: getEnvStrings("TLS_DOMAIN", []string{}), 90 | ACMEDirectoryURL: getEnvString("ACME_DIRECTORY", defaultACMEDirectoryURL), 91 | EAB_KID: getEnvString("EAB_KID", ""), 92 | EAB_HMACKey: getEnvString("EAB_HMAC_KEY", ""), 93 | StoragePath: getEnvString("STORAGE_PATH", defaultStoragePath), 94 | BadGatewayPage: getEnvString("BAD_GATEWAY_PAGE", defaultBadGatewayPage), 95 | 96 | HttpPort: getEnvInt("HTTP_PORT", defaultHttpPort), 97 | HttpsPort: getEnvInt("HTTPS_PORT", defaultHttpsPort), 98 | HttpIdleTimeout: getEnvDuration("HTTP_IDLE_TIMEOUT", defaultHttpIdleTimeout), 99 | HttpReadTimeout: getEnvDuration("HTTP_READ_TIMEOUT", defaultHttpReadTimeout), 100 | HttpWriteTimeout: getEnvDuration("HTTP_WRITE_TIMEOUT", defaultHttpWriteTimeout), 101 | 102 | LogLevel: logLevel, 103 | } 104 | 105 | config.ForwardHeaders = getEnvBool("FORWARD_HEADERS", !config.HasTLS()) 106 | 107 | return config, nil 108 | } 109 | 110 | func (c *Config) HasTLS() bool { 111 | return len(c.TLSDomains) > 0 112 | } 113 | 114 | func findEnv(key string) (string, bool) { 115 | value, ok := os.LookupEnv(ENV_PREFIX + key) 116 | if ok { 117 | return value, true 118 | } 119 | 120 | value, ok = os.LookupEnv(key) 121 | if ok { 122 | return value, true 123 | } 124 | 125 | return "", false 126 | } 127 | 128 | func getEnvString(key, defaultValue string) string { 129 | value, ok := findEnv(key) 130 | if ok { 131 | return value 132 | } 133 | 134 | return defaultValue 135 | } 136 | 137 | func getEnvStrings(key string, defaultValue []string) []string { 138 | value, ok := findEnv(key) 139 | if ok { 140 | items := strings.Split(value, ",") 141 | result := []string{} 142 | 143 | for _, item := range items { 144 | item = strings.TrimSpace(item) 145 | if item != "" { 146 | result = append(result, item) 147 | } 148 | } 149 | 150 | return result 151 | } 152 | 153 | return defaultValue 154 | } 155 | 156 | func getEnvInt(key string, defaultValue int) int { 157 | value, ok := findEnv(key) 158 | if !ok { 159 | return defaultValue 160 | } 161 | 162 | intValue, err := strconv.Atoi(value) 163 | if err != nil { 164 | return defaultValue 165 | } 166 | 167 | return intValue 168 | } 169 | 170 | func getEnvDuration(key string, defaultValue time.Duration) time.Duration { 171 | value, ok := findEnv(key) 172 | if !ok { 173 | return defaultValue 174 | } 175 | 176 | intValue, err := strconv.Atoi(value) 177 | if err != nil { 178 | return defaultValue 179 | } 180 | 181 | return time.Duration(intValue) * time.Second 182 | } 183 | 184 | func getEnvBool(key string, defaultValue bool) bool { 185 | value, ok := findEnv(key) 186 | if !ok { 187 | return defaultValue 188 | } 189 | 190 | boolValue, err := strconv.ParseBool(value) 191 | if err != nil { 192 | return defaultValue 193 | } 194 | 195 | return boolValue 196 | } 197 | -------------------------------------------------------------------------------- /internal/config_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "log/slog" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestConfig_tls(t *testing.T) { 13 | t.Run("with no ENV", func(t *testing.T) { 14 | usingProgramArgs(t, "thruster", "echo", "hello") 15 | 16 | c, err := NewConfig() 17 | require.NoError(t, err) 18 | 19 | assert.Equal(t, []string{}, c.TLSDomains) 20 | assert.False(t, c.HasTLS()) 21 | assert.True(t, c.ForwardHeaders) 22 | }) 23 | 24 | t.Run("with an empty TLS_DOMAIN", func(t *testing.T) { 25 | usingProgramArgs(t, "thruster", "echo", "hello") 26 | usingEnvVar(t, "TLS_DOMAIN", "") 27 | 28 | c, err := NewConfig() 29 | require.NoError(t, err) 30 | 31 | assert.Equal(t, []string{}, c.TLSDomains) 32 | assert.False(t, c.HasTLS()) 33 | assert.True(t, c.ForwardHeaders) 34 | }) 35 | 36 | t.Run("with single TLS_DOMAIN", func(t *testing.T) { 37 | usingProgramArgs(t, "thruster", "echo", "hello") 38 | usingEnvVar(t, "TLS_DOMAIN", "example.com") 39 | 40 | c, err := NewConfig() 41 | require.NoError(t, err) 42 | 43 | assert.Equal(t, []string{"example.com"}, c.TLSDomains) 44 | assert.True(t, c.HasTLS()) 45 | assert.False(t, c.ForwardHeaders) 46 | }) 47 | 48 | t.Run("with multiple TLS_DOMAIN", func(t *testing.T) { 49 | usingProgramArgs(t, "thruster", "echo", "hello") 50 | usingEnvVar(t, "TLS_DOMAIN", "example.com, example.io") 51 | 52 | c, err := NewConfig() 53 | require.NoError(t, err) 54 | 55 | assert.Equal(t, []string{"example.com", "example.io"}, c.TLSDomains) 56 | assert.True(t, c.HasTLS()) 57 | assert.False(t, c.ForwardHeaders) 58 | }) 59 | 60 | t.Run("with TLS_DOMAIN containing whitespace", func(t *testing.T) { 61 | usingProgramArgs(t, "thruster", "echo", "hello") 62 | usingEnvVar(t, "TLS_DOMAIN", " , example.com, example.io,") 63 | 64 | c, err := NewConfig() 65 | require.NoError(t, err) 66 | 67 | assert.Equal(t, []string{"example.com", "example.io"}, c.TLSDomains) 68 | assert.True(t, c.HasTLS()) 69 | assert.False(t, c.ForwardHeaders) 70 | }) 71 | 72 | t.Run("overriding with FORWARD_HEADERS when using TLS", func(t *testing.T) { 73 | usingProgramArgs(t, "thruster", "echo", "hello") 74 | usingEnvVar(t, "TLS_DOMAIN", "example.com") 75 | usingEnvVar(t, "FORWARD_HEADERS", "true") 76 | 77 | c, err := NewConfig() 78 | require.NoError(t, err) 79 | 80 | assert.Equal(t, []string{"example.com"}, c.TLSDomains) 81 | assert.True(t, c.HasTLS()) 82 | assert.True(t, c.ForwardHeaders) 83 | }) 84 | 85 | t.Run("overriding with FORWARD_HEADERS when not using TLS", func(t *testing.T) { 86 | usingProgramArgs(t, "thruster", "echo", "hello") 87 | usingEnvVar(t, "FORWARD_HEADERS", "false") 88 | 89 | c, err := NewConfig() 90 | require.NoError(t, err) 91 | 92 | assert.Empty(t, c.TLSDomains) 93 | assert.False(t, c.HasTLS()) 94 | assert.False(t, c.ForwardHeaders) 95 | }) 96 | } 97 | 98 | func TestConfig_defaults(t *testing.T) { 99 | usingProgramArgs(t, "thruster", "echo", "hello") 100 | 101 | c, err := NewConfig() 102 | require.NoError(t, err) 103 | 104 | assert.Equal(t, 3000, c.TargetPort) 105 | assert.Equal(t, "echo", c.UpstreamCommand) 106 | assert.Equal(t, defaultCacheSize, c.CacheSizeBytes) 107 | assert.Equal(t, slog.LevelInfo, c.LogLevel) 108 | } 109 | 110 | func TestConfig_override_defaults_with_env_vars(t *testing.T) { 111 | usingProgramArgs(t, "thruster", "echo", "hello") 112 | usingEnvVar(t, "TARGET_PORT", "4000") 113 | usingEnvVar(t, "CACHE_SIZE", "256") 114 | usingEnvVar(t, "HTTP_READ_TIMEOUT", "5") 115 | usingEnvVar(t, "X_SENDFILE_ENABLED", "0") 116 | usingEnvVar(t, "GZIP_COMPRESSION_ENABLED", "0") 117 | usingEnvVar(t, "DEBUG", "1") 118 | usingEnvVar(t, "ACME_DIRECTORY", "https://acme-staging-v02.api.letsencrypt.org/directory") 119 | 120 | c, err := NewConfig() 121 | require.NoError(t, err) 122 | 123 | assert.Equal(t, 4000, c.TargetPort) 124 | assert.Equal(t, 256, c.CacheSizeBytes) 125 | assert.Equal(t, 5*time.Second, c.HttpReadTimeout) 126 | assert.Equal(t, false, c.XSendfileEnabled) 127 | assert.Equal(t, false, c.GzipCompressionEnabled) 128 | assert.Equal(t, slog.LevelDebug, c.LogLevel) 129 | assert.Equal(t, "https://acme-staging-v02.api.letsencrypt.org/directory", c.ACMEDirectoryURL) 130 | } 131 | 132 | func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) { 133 | usingProgramArgs(t, "thruster", "echo", "hello") 134 | usingEnvVar(t, "THRUSTER_TARGET_PORT", "4000") 135 | usingEnvVar(t, "THRUSTER_CACHE_SIZE", "256") 136 | usingEnvVar(t, "THRUSTER_HTTP_READ_TIMEOUT", "5") 137 | usingEnvVar(t, "THRUSTER_X_SENDFILE_ENABLED", "0") 138 | usingEnvVar(t, "THRUSTER_DEBUG", "1") 139 | 140 | c, err := NewConfig() 141 | require.NoError(t, err) 142 | 143 | assert.Equal(t, 4000, c.TargetPort) 144 | assert.Equal(t, 256, c.CacheSizeBytes) 145 | assert.Equal(t, 5*time.Second, c.HttpReadTimeout) 146 | assert.Equal(t, false, c.XSendfileEnabled) 147 | assert.Equal(t, slog.LevelDebug, c.LogLevel) 148 | } 149 | 150 | func TestConfig_prefixed_variables_take_precedence_over_non_prefixed(t *testing.T) { 151 | usingProgramArgs(t, "thruster", "echo", "hello") 152 | usingEnvVar(t, "TARGET_PORT", "3000") 153 | usingEnvVar(t, "THRUSTER_TARGET_PORT", "4000") 154 | 155 | c, err := NewConfig() 156 | require.NoError(t, err) 157 | 158 | assert.Equal(t, 4000, c.TargetPort) 159 | } 160 | 161 | func TestConfig_return_error_when_no_upstream_command(t *testing.T) { 162 | usingProgramArgs(t, "thruster") 163 | 164 | _, err := NewConfig() 165 | require.Error(t, err) 166 | } 167 | -------------------------------------------------------------------------------- /internal/fixtures/image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basecamp/thruster/3d8aec355ae6f40caf3d7ceb25fc5731f1150195/internal/fixtures/image.jpg -------------------------------------------------------------------------------- /internal/fixtures/loremipsum.txt: -------------------------------------------------------------------------------- 1 | Lorem ipsum dolor sit amet, consectetur adipiscing elit. Etiam sed mi quis 2 | mauris sodales pulvinar. Aenean fringilla fermentum nulla, et porttitor orci 3 | rutrum vitae. Donec faucibus mollis placerat. Aenean eu nibh tincidunt, 4 | sollicitudin libero in, vulputate urna. Fusce rutrum, leo eget blandit 5 | elementum, turpis purus ullamcorper est, at ultricies mi magna eget dui. Proin 6 | at ante eget elit rhoncus blandit eu eget eros. Nulla facilisi. Pellentesque a 7 | lorem a risus porttitor auctor non eget urna. Phasellus a lobortis sem. In 8 | pretium pretium mauris, ut egestas magna vulputate nec. Vestibulum ante ipsum 9 | primis in faucibus orci luctus et ultrices posuere cubilia curae; Etiam lobortis 10 | tortor lorem, eget feugiat lorem varius vitae. Pellentesque habitant morbi 11 | tristique senectus et netus et malesuada fames ac turpis egestas. Suspendisse 12 | gravida, erat nec porttitor tempus, lectus lectus cursus quam, ac pellentesque 13 | ante ex eu ante. Ut venenatis et ante et accumsan. Sed lobortis eros non sem 14 | eleifend suscipit. 15 | 16 | Quisque tempor nibh turpis, id tempor purus ornare non. Vivamus eget varius 17 | tellus. Fusce ac viverra velit. Donec cursus lectus quis nunc ornare, quis 18 | semper metus feugiat. Ut eget porttitor mi, sed semper nulla. Quisque non massa 19 | sit amet risus condimentum aliquam. Nam eu faucibus est. Sed ut massa orci. 20 | Curabitur vitae congue risus. Ut vestibulum finibus purus, vitae placerat ipsum 21 | commodo suscipit. Praesent quis dapibus lacus, non tincidunt ex. 22 | 23 | Donec ac porttitor nunc. Maecenas ut nulla eget felis interdum pharetra vel quis 24 | augue. In nunc urna, rhoncus at dolor id, pellentesque lobortis dolor. Sed et 25 | orci turpis. Nunc pellentesque mi id felis facilisis, vitae porttitor erat 26 | efficitur. Curabitur vel pretium tortor. Vestibulum volutpat lectus nec mauris 27 | auctor, non lacinia sapien vulputate. Aliquam erat volutpat. Duis luctus ornare 28 | diam accumsan varius. Quisque id luctus lacus. Nulla scelerisque, quam in 29 | fermentum mattis, elit libero viverra neque, ac ullamcorper diam orci id erat. 30 | Nulla ligula neque, rutrum sed dapibus sed, scelerisque non elit. Suspendisse 31 | ullamcorper elit tellus, lobortis gravida sem aliquam in. Aliquam consectetur 32 | viverra tortor nec ullamcorper. Orci varius natoque penatibus et magnis dis 33 | parturient montes, nascetur ridiculus mus. 34 | 35 | Nulla et purus et sem blandit gravida ac et purus. Nullam placerat turpis 36 | lectus, at pharetra turpis pulvinar vel. Vivamus porttitor metus malesuada 37 | sapien faucibus, at efficitur justo bibendum. Proin tincidunt molestie posuere. 38 | Donec elit turpis, interdum id fringilla vel, tempor eu orci. Etiam pellentesque 39 | lacus et dui ullamcorper, et vulputate nunc suscipit. Pellentesque scelerisque 40 | gravida pharetra. Suspendisse tristique nisl vitae pharetra ultricies. Donec 41 | placerat enim magna, tempor gravida diam sodales eget. Morbi elit nisi, 42 | scelerisque condimentum mattis in, porta sit amet elit. 43 | 44 | Duis risus urna, eleifend suscipit arcu sit amet, vehicula elementum augue. 45 | Nulla hendrerit elit mauris. Sed vel ipsum a mauris vestibulum pellentesque sit 46 | amet pellentesque elit. Maecenas venenatis ligula at felis eleifend, at 47 | ultricies odio fringilla. Nulla tempus ullamcorper maximus. Aenean hendrerit 48 | bibendum pharetra. Nullam rhoncus, dui sed venenatis maximus, dolor sapien 49 | aliquet diam, venenatis egestas metus neque sed velit. Cras eros mi, pharetra ut 50 | pellentesque a, rutrum quis ligula. Vestibulum velit lectus, sagittis non 51 | egestas in, tempus ac dolor. Maecenas vel tincidunt lectus. Aliquam porttitor 52 | vel urna non condimentum. Aliquam non diam et dui ultrices fringilla. Nulla at 53 | eros molestie, fermentum diam vitae, volutpat eros. 54 | -------------------------------------------------------------------------------- /internal/handler.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "log/slog" 5 | "net/http" 6 | "net/url" 7 | 8 | "github.com/klauspost/compress/gzhttp" 9 | ) 10 | 11 | type HandlerOptions struct { 12 | badGatewayPage string 13 | cache Cache 14 | maxCacheableResponseBody int 15 | maxRequestBody int 16 | targetUrl *url.URL 17 | xSendfileEnabled bool 18 | gzipCompressionEnabled bool 19 | forwardHeaders bool 20 | } 21 | 22 | func NewHandler(options HandlerOptions) http.Handler { 23 | handler := NewProxyHandler(options.targetUrl, options.badGatewayPage, options.forwardHeaders) 24 | handler = NewCacheHandler(options.cache, options.maxCacheableResponseBody, handler) 25 | handler = NewSendfileHandler(options.xSendfileEnabled, handler) 26 | if options.gzipCompressionEnabled { 27 | handler = gzhttp.GzipHandler(handler) 28 | } 29 | 30 | if options.maxRequestBody > 0 { 31 | handler = http.MaxBytesHandler(handler, int64(options.maxRequestBody)) 32 | } 33 | 34 | handler = NewLoggingMiddleware(slog.Default(), handler) 35 | 36 | return handler 37 | } 38 | -------------------------------------------------------------------------------- /internal/handler_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/httptest" 7 | "net/url" 8 | "strconv" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestHandlerGzipCompression_when_proxying(t *testing.T) { 15 | upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 16 | w.Header().Set("Content-Length", strconv.FormatInt(fixtureLength("loremipsum.txt"), 10)) 17 | w.Write(fixtureContent("loremipsum.txt")) 18 | })) 19 | defer upstream.Close() 20 | 21 | h := NewHandler(handlerOptions(upstream.URL)) 22 | 23 | w := httptest.NewRecorder() 24 | r := httptest.NewRequest("GET", "/", nil) 25 | r.Header.Set("Accept-Encoding", "gzip") 26 | h.ServeHTTP(w, r) 27 | 28 | assert.Equal(t, http.StatusOK, w.Code) 29 | assert.Contains(t, w.Header().Get("Content-Type"), "text/plain") 30 | assert.Equal(t, "gzip", w.Header().Get("Content-Encoding")) 31 | 32 | transferredSize, _ := strconv.ParseInt(w.Header().Get("Content-Length"), 10, 64) 33 | assert.Less(t, transferredSize, fixtureLength("loremipsum.txt")) 34 | } 35 | 36 | func TestNotHandlerGzipCompression_when_disabled(t *testing.T) { 37 | fixtureLength := strconv.FormatInt(fixtureLength("loremipsum.txt"), 10) 38 | 39 | upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 40 | w.Header().Set("Content-Length", fixtureLength) 41 | w.Write(fixtureContent("loremipsum.txt")) 42 | })) 43 | defer upstream.Close() 44 | 45 | options := handlerOptions(upstream.URL) 46 | options.gzipCompressionEnabled = false 47 | h := NewHandler(options) 48 | 49 | w := httptest.NewRecorder() 50 | r := httptest.NewRequest("GET", "/", nil) 51 | r.Header.Set("Accept-Encoding", "gzip") 52 | h.ServeHTTP(w, r) 53 | 54 | assert.Equal(t, http.StatusOK, w.Code) 55 | assert.Contains(t, w.Header().Get("Content-Type"), "text/plain") 56 | assert.Empty(t, w.Header().Get("Content-Encoding")) 57 | assert.Equal(t, fixtureLength, w.Header().Get("Content-Length")) 58 | } 59 | 60 | func TestHandlerGzipCompression_is_not_applied_when_not_requested(t *testing.T) { 61 | fixtureLength := strconv.FormatInt(fixtureLength("loremipsum.txt"), 10) 62 | 63 | upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 64 | w.Header().Set("Content-Length", fixtureLength) 65 | w.Write(fixtureContent("loremipsum.txt")) 66 | })) 67 | defer upstream.Close() 68 | 69 | h := NewHandler(handlerOptions(upstream.URL)) 70 | 71 | w := httptest.NewRecorder() 72 | r := httptest.NewRequest("GET", "/", nil) 73 | h.ServeHTTP(w, r) 74 | 75 | assert.Equal(t, http.StatusOK, w.Code) 76 | assert.Contains(t, w.Header().Get("Content-Type"), "text/plain") 77 | assert.Empty(t, w.Header().Get("Content-Encoding")) 78 | assert.Equal(t, fixtureLength, w.Header().Get("Content-Length")) 79 | } 80 | 81 | func TestHandlerGzipCompression_does_not_compress_images(t *testing.T) { 82 | fixtureLength := strconv.FormatInt(fixtureLength("image.jpg"), 10) 83 | 84 | upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 85 | w.Header().Set("Content-Type", "image/jpg") 86 | w.Header().Set("Content-Length", fixtureLength) 87 | w.Write(fixtureContent("image.jpg")) 88 | })) 89 | defer upstream.Close() 90 | 91 | h := NewHandler(handlerOptions(upstream.URL)) 92 | 93 | w := httptest.NewRecorder() 94 | r := httptest.NewRequest("GET", "/", nil) 95 | r.Header.Set("Accept-Encoding", "gzip") 96 | h.ServeHTTP(w, r) 97 | 98 | assert.Equal(t, http.StatusOK, w.Code) 99 | assert.Contains(t, w.Header().Get("Content-Type"), "image/jpg") 100 | assert.NotEqual(t, "gzip", w.Header().Get("Content-Encoding")) 101 | assert.Equal(t, fixtureLength, w.Header().Get("Content-Length")) 102 | } 103 | 104 | func TestHandlerGzipCompression_when_sendfile(t *testing.T) { 105 | upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 106 | assert.Equal(t, "X-Sendfile", r.Header.Get("X-Sendfile-Type")) 107 | 108 | w.Header().Set("X-Sendfile", fixturePath("loremipsum.txt")) 109 | })) 110 | defer upstream.Close() 111 | 112 | h := NewHandler(handlerOptions(upstream.URL)) 113 | 114 | w := httptest.NewRecorder() 115 | r := httptest.NewRequest("GET", "/", nil) 116 | r.Header.Set("Accept-Encoding", "gzip") 117 | h.ServeHTTP(w, r) 118 | 119 | assert.Equal(t, http.StatusOK, w.Code) 120 | assert.Contains(t, w.Header().Get("Content-Type"), "text/plain") 121 | assert.Equal(t, "gzip", w.Header().Get("Content-Encoding")) 122 | 123 | transferredSize, _ := strconv.ParseInt(w.Header().Get("Content-Length"), 10, 64) 124 | assert.Less(t, transferredSize, fixtureLength("loremipsum.txt")) 125 | } 126 | 127 | func TestHandler_do_not_request_compressed_responses_from_upstream_unless_client_does(t *testing.T) { 128 | upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 129 | acceptsGzip := r.Header.Get("Accept-Encoding") == "gzip" 130 | shouldAcceptGzip := r.URL.Path == "/compressed" 131 | 132 | assert.Equal(t, shouldAcceptGzip, acceptsGzip) 133 | if acceptsGzip { 134 | w.Header().Set("Content-Encoding", "gzip") 135 | } 136 | })) 137 | defer upstream.Close() 138 | 139 | h := NewHandler(handlerOptions(upstream.URL)) 140 | 141 | w := httptest.NewRecorder() 142 | r := httptest.NewRequest("GET", "/plain", nil) 143 | h.ServeHTTP(w, r) 144 | assert.Equal(t, http.StatusOK, w.Code) 145 | assert.Empty(t, w.Header().Get("Content-Encoding")) 146 | 147 | w = httptest.NewRecorder() 148 | r = httptest.NewRequest("GET", "/compressed", nil) 149 | r.Header.Set("Accept-Encoding", "gzip") 150 | h.ServeHTTP(w, r) 151 | assert.Equal(t, http.StatusOK, w.Code) 152 | assert.Equal(t, "gzip", w.Header().Get("Content-Encoding")) 153 | } 154 | 155 | func TestHandlerMaxRequestBody(t *testing.T) { 156 | upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 157 | defer upstream.Close() 158 | 159 | options := handlerOptions(upstream.URL) 160 | options.maxRequestBody = 10 161 | h := NewHandler(options) 162 | 163 | w := httptest.NewRecorder() 164 | r := httptest.NewRequest("POST", "/", bytes.NewReader([]byte("Hello"))) 165 | h.ServeHTTP(w, r) 166 | assert.Equal(t, http.StatusOK, w.Code) 167 | 168 | w = httptest.NewRecorder() 169 | r = httptest.NewRequest("POST", "/", bytes.NewReader([]byte("This one is too long"))) 170 | h.ServeHTTP(w, r) 171 | assert.Equal(t, http.StatusRequestEntityTooLarge, w.Code) 172 | 173 | options.maxRequestBody = 0 174 | h = NewHandler(options) 175 | 176 | w = httptest.NewRecorder() 177 | r = httptest.NewRequest("POST", "/", bytes.NewReader([]byte("This one is still long"))) 178 | h.ServeHTTP(w, r) 179 | assert.Equal(t, http.StatusOK, w.Code) 180 | } 181 | 182 | func TestHandlerPreserveInboundHostHeaderWhenProxying(t *testing.T) { 183 | upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 184 | assert.Equal(t, "example.org", r.Host) 185 | })) 186 | defer upstream.Close() 187 | 188 | h := NewHandler(handlerOptions(upstream.URL)) 189 | 190 | w := httptest.NewRecorder() 191 | r := httptest.NewRequest("GET", "http://example.org", nil) 192 | h.ServeHTTP(w, r) 193 | } 194 | 195 | func TestHandlerAppendInboundXFFHeaderWhenProxying(t *testing.T) { 196 | upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 197 | assert.Equal(t, "0.0.0.0, 0.0.0.1", r.Header.Get("X-Forwarded-For")) 198 | })) 199 | defer upstream.Close() 200 | 201 | h := NewHandler(handlerOptions(upstream.URL)) 202 | 203 | w := httptest.NewRecorder() 204 | r := httptest.NewRequest("GET", "http://example.org", nil) 205 | r.RemoteAddr = "0.0.0.1:1234" 206 | r.Header.Set("X-Forwarded-For", "0.0.0.0") 207 | h.ServeHTTP(w, r) 208 | } 209 | 210 | func TestHandlerXForwardedHeadersWhenProxying(t *testing.T) { 211 | upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 212 | assert.Equal(t, "1.2.3.4", r.Header.Get("X-Forwarded-For")) 213 | assert.Equal(t, "example.org", r.Header.Get("X-Forwarded-Host")) 214 | assert.Equal(t, "https", r.Header.Get("X-Forwarded-Proto")) 215 | })) 216 | defer upstream.Close() 217 | 218 | h := NewHandler(handlerOptions(upstream.URL)) 219 | 220 | w := httptest.NewRecorder() 221 | r := httptest.NewRequest("GET", "https://example.org", nil) 222 | r.RemoteAddr = "1.2.3.4:1234" 223 | h.ServeHTTP(w, r) 224 | } 225 | 226 | func TestHandlerXForwardedHeadersForwardsExistingHeadersWhenForwardingEnabled(t *testing.T) { 227 | upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 228 | assert.Equal(t, "4.3.2.1, 1.2.3.4", r.Header.Get("X-Forwarded-For")) 229 | assert.Equal(t, "other.example.com", r.Header.Get("X-Forwarded-Host")) 230 | assert.Equal(t, "https", r.Header.Get("X-Forwarded-Proto")) 231 | })) 232 | defer upstream.Close() 233 | 234 | h := NewHandler(handlerOptions(upstream.URL)) 235 | 236 | w := httptest.NewRecorder() 237 | r := httptest.NewRequest("GET", "http://example.org", nil) 238 | r.Header.Set("X-Forwarded-For", "4.3.2.1") 239 | r.Header.Set("X-Forwarded-Proto", "https") 240 | r.Header.Set("X-Forwarded-Host", "other.example.com") 241 | r.RemoteAddr = "1.2.3.4:1234" 242 | h.ServeHTTP(w, r) 243 | } 244 | 245 | func TestHandlerXForwardedHeadersDropsExistingHeadersWhenForwardingNotEnabled(t *testing.T) { 246 | upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 247 | assert.Equal(t, "1.2.3.4", r.Header.Get("X-Forwarded-For")) 248 | assert.Equal(t, "example.org", r.Header.Get("X-Forwarded-Host")) 249 | assert.Equal(t, "http", r.Header.Get("X-Forwarded-Proto")) 250 | })) 251 | defer upstream.Close() 252 | 253 | options := handlerOptions(upstream.URL) 254 | options.forwardHeaders = false 255 | h := NewHandler(options) 256 | 257 | w := httptest.NewRecorder() 258 | r := httptest.NewRequest("GET", "http://example.org", nil) 259 | r.Header.Set("X-Forwarded-For", "4.3.2.1") 260 | r.Header.Set("X-Forwarded-Proto", "https") 261 | r.Header.Set("X-Forwarded-Host", "other.example.com") 262 | r.RemoteAddr = "1.2.3.4:1234" 263 | h.ServeHTTP(w, r) 264 | } 265 | 266 | // Helpers 267 | 268 | func handlerOptions(targetUrl string) HandlerOptions { 269 | url, _ := url.Parse(targetUrl) 270 | 271 | return HandlerOptions{ 272 | cache: NewMemoryCache(defaultCacheSize, defaultMaxCacheItemSizeBytes), 273 | targetUrl: url, 274 | xSendfileEnabled: true, 275 | gzipCompressionEnabled: true, 276 | maxCacheableResponseBody: 1024, 277 | badGatewayPage: "", 278 | forwardHeaders: true, 279 | } 280 | } 281 | -------------------------------------------------------------------------------- /internal/logging_middleware.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "log/slog" 7 | "net" 8 | "net/http" 9 | "time" 10 | ) 11 | 12 | type LoggingMiddleware struct { 13 | logger *slog.Logger 14 | next http.Handler 15 | } 16 | 17 | func NewLoggingMiddleware(logger *slog.Logger, next http.Handler) *LoggingMiddleware { 18 | return &LoggingMiddleware{ 19 | logger: logger, 20 | next: next, 21 | } 22 | } 23 | 24 | func (h *LoggingMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { 25 | writer := newResponseWriter(w) 26 | 27 | started := time.Now() 28 | h.next.ServeHTTP(writer, r) 29 | elapsed := time.Since(started) 30 | 31 | userAgent := r.Header.Get("User-Agent") 32 | reqContent := r.Header.Get("Content-Type") 33 | respContent := writer.Header().Get("Content-Type") 34 | cache := writer.Header().Get("X-Cache") 35 | remoteAddr := r.Header.Get("X-Forwarded-For") 36 | if remoteAddr == "" { 37 | remoteAddr = r.RemoteAddr 38 | } 39 | 40 | h.logger.Info("Request", 41 | "path", r.URL.Path, 42 | "status", writer.statusCode, 43 | "dur", elapsed.Milliseconds(), 44 | "method", r.Method, 45 | "req_content_length", r.ContentLength, 46 | "req_content_type", reqContent, 47 | "resp_content_length", writer.bytesWritten, 48 | "resp_content_type", respContent, 49 | "remote_addr", remoteAddr, 50 | "user_agent", userAgent, 51 | "cache", cache, 52 | "query", r.URL.RawQuery) 53 | } 54 | 55 | type responseWriter struct { 56 | http.ResponseWriter 57 | statusCode int 58 | bytesWritten int64 59 | } 60 | 61 | func newResponseWriter(w http.ResponseWriter) *responseWriter { 62 | return &responseWriter{w, http.StatusOK, 0} 63 | } 64 | 65 | // WriteHeader is used to capture the status code 66 | func (r *responseWriter) WriteHeader(statusCode int) { 67 | r.statusCode = statusCode 68 | r.ResponseWriter.WriteHeader(statusCode) 69 | } 70 | 71 | // Write is used to capture the amount of data written 72 | func (r *responseWriter) Write(b []byte) (int, error) { 73 | bytesWritten, err := r.ResponseWriter.Write(b) 74 | r.bytesWritten += int64(bytesWritten) 75 | return bytesWritten, err 76 | } 77 | 78 | func (r *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 79 | hijacker, ok := r.ResponseWriter.(http.Hijacker) 80 | if !ok { 81 | return nil, nil, errors.New("ResponseWriter does not implement http.Hijacker") 82 | } 83 | 84 | con, rw, err := hijacker.Hijack() 85 | if err == nil { 86 | r.statusCode = http.StatusSwitchingProtocols 87 | } 88 | return con, rw, err 89 | } 90 | -------------------------------------------------------------------------------- /internal/logging_middleware_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "log/slog" 8 | "net/http" 9 | "net/http/httptest" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | func TestMiddleware_LoggingMiddleware(t *testing.T) { 18 | out := &strings.Builder{} 19 | logger := slog.New(slog.NewJSONHandler(out, nil)) 20 | middleware := NewLoggingMiddleware(logger, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 21 | w.Header().Set("X-Cache", "miss") 22 | w.Header().Set("Content-Type", "text/html") 23 | w.WriteHeader(http.StatusCreated) 24 | fmt.Fprintln(w, "goodbye") 25 | })) 26 | 27 | req := httptest.NewRequest("POST", "/somepath?q=ok", bytes.NewReader([]byte("hello"))) 28 | req.Header.Set("X-Forwarded-For", "192.168.1.1") 29 | req.Header.Set("User-Agent", "Robot/1") 30 | req.Header.Set("Content-Type", "application/json") 31 | 32 | middleware.ServeHTTP(httptest.NewRecorder(), req) 33 | 34 | logline := struct { 35 | Path string `json:"path"` 36 | Method string `json:"method"` 37 | Status int `json:"status"` 38 | RemoteAddr string `json:"remote_addr"` 39 | UserAgent string `json:"user_agent"` 40 | ReqContentLength int64 `json:"req_content_length"` 41 | ReqContentType string `json:"req_content_type"` 42 | RespContentLength int64 `json:"resp_content_length"` 43 | RespContentType string `json:"resp_content_type"` 44 | Query string `json:"query"` 45 | Cache string `json:"cache"` 46 | }{} 47 | 48 | err := json.NewDecoder(strings.NewReader(out.String())).Decode(&logline) 49 | require.NoError(t, err) 50 | 51 | assert.Equal(t, "/somepath", logline.Path) 52 | assert.Equal(t, "POST", logline.Method) 53 | assert.Equal(t, http.StatusCreated, logline.Status) 54 | assert.Equal(t, "192.168.1.1", logline.RemoteAddr) 55 | assert.Equal(t, "Robot/1", logline.UserAgent) 56 | assert.Equal(t, "application/json", logline.ReqContentType) 57 | assert.Equal(t, "text/html", logline.RespContentType) 58 | assert.Equal(t, "q=ok", logline.Query) 59 | assert.Equal(t, int64(5), logline.ReqContentLength) 60 | assert.Equal(t, int64(8), logline.RespContentLength) 61 | assert.Equal(t, "miss", logline.Cache) 62 | } 63 | -------------------------------------------------------------------------------- /internal/memory_cache.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "log/slog" 5 | "math/rand" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | type GetCurrentTime func() time.Time 11 | 12 | type MemoryCacheEntry struct { 13 | lastAccessedAt time.Time 14 | expiresAt time.Time 15 | value []byte 16 | } 17 | 18 | type MemoryCacheEntryMap map[CacheKey]*MemoryCacheEntry 19 | type MemoryCacheKeyList []CacheKey 20 | 21 | type MemoryCache struct { 22 | sync.Mutex 23 | capacity int 24 | maxItemSize int 25 | size int 26 | keys MemoryCacheKeyList 27 | items MemoryCacheEntryMap 28 | getCurrentTime GetCurrentTime 29 | } 30 | 31 | func NewMemoryCache(capacity, maxItemSize int) *MemoryCache { 32 | return &MemoryCache{ 33 | capacity: capacity, 34 | maxItemSize: maxItemSize, 35 | size: 0, 36 | keys: MemoryCacheKeyList{}, 37 | items: MemoryCacheEntryMap{}, 38 | getCurrentTime: time.Now, 39 | } 40 | } 41 | 42 | func (c *MemoryCache) Set(key CacheKey, value []byte, expiresAt time.Time) { 43 | c.Lock() 44 | defer c.Unlock() 45 | 46 | itemSize := len(value) 47 | if itemSize > c.maxItemSize || itemSize > c.capacity { 48 | slog.Debug("Cache: item is too large to store", "len", itemSize) 49 | return 50 | } 51 | 52 | limit := c.capacity - itemSize 53 | for c.size > limit { 54 | slog.Debug("Cache: evicting item to make space", "current_size", c.size, "need_size", limit) 55 | c.evictOldestItem() 56 | } 57 | 58 | existingItem, ok := c.items[key] 59 | if ok { 60 | c.size -= len(existingItem.value) 61 | } else { 62 | c.keys = append(c.keys, key) 63 | } 64 | 65 | c.items[key] = &MemoryCacheEntry{ 66 | lastAccessedAt: c.getCurrentTime(), 67 | expiresAt: expiresAt, 68 | value: value, 69 | } 70 | 71 | c.size += itemSize 72 | 73 | slog.Debug("Cache: added item", "key", key, "size", itemSize, "expires_at", expiresAt) 74 | } 75 | 76 | func (c *MemoryCache) Get(key CacheKey) ([]byte, bool) { 77 | c.Lock() 78 | defer c.Unlock() 79 | 80 | now := c.getCurrentTime() 81 | 82 | item, ok := c.items[key] 83 | if !ok || item.expiresAt.Before(now) { 84 | return nil, false 85 | } 86 | 87 | item.lastAccessedAt = now 88 | return item.value, true 89 | } 90 | 91 | func (c *MemoryCache) evictOldestItem() { 92 | var oldestKey CacheKey 93 | var oldestIndex int 94 | var oldest time.Time 95 | 96 | now := c.getCurrentTime() 97 | 98 | // Pick 5 random items and evict the oldest one, On average we'll evict items 99 | // in the oldest 20%, which is good enough and is much faster than scanning 100 | // through them all. 101 | // 102 | // If we find an expired item while looking, that's a better choice to evict, 103 | // so we can choose it immediately. 104 | for i := 0; i < 5; i++ { 105 | index := rand.Intn(len(c.keys)) 106 | key := c.keys[index] 107 | v := c.items[key] 108 | 109 | if v.expiresAt.Before(now) { 110 | oldestKey = key 111 | oldestIndex = index 112 | break 113 | } 114 | 115 | if v.lastAccessedAt.Before(oldest) || oldest.IsZero() { 116 | oldest = v.lastAccessedAt 117 | oldestKey = key 118 | oldestIndex = index 119 | } 120 | } 121 | 122 | c.keys[oldestIndex] = c.keys[len(c.keys)-1] 123 | c.keys = c.keys[:len(c.keys)-1] 124 | 125 | c.size -= len(c.items[oldestKey].value) 126 | delete(c.items, oldestKey) 127 | } 128 | -------------------------------------------------------------------------------- /internal/memory_cache_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestMemoryCache_store_and_retrieve(t *testing.T) { 12 | c := NewMemoryCache(32*MB, 1*MB) 13 | c.Set(1, []byte("hello world"), time.Now().Add(30*time.Second)) 14 | 15 | read, ok := c.Get(1) 16 | assert.True(t, ok) 17 | assert.Equal(t, []byte("hello world"), read) 18 | } 19 | 20 | func TestMemoryCache_storing_updates_existing_value(t *testing.T) { 21 | c := NewMemoryCache(32*MB, 1*MB) 22 | c.Set(1, []byte("first"), time.Now().Add(30*time.Second)) 23 | c.Set(1, []byte("second"), time.Now().Add(30*time.Second)) 24 | 25 | read, ok := c.Get(1) 26 | assert.True(t, ok) 27 | assert.Equal(t, []byte("second"), read) 28 | } 29 | 30 | func TestMemoryCache_storing_existing_value_keeps_keys_and_size_correct(t *testing.T) { 31 | c := NewMemoryCache(32*MB, 1*MB) 32 | c.Set(1, []byte("first"), time.Now().Add(30*time.Second)) 33 | c.Set(1, []byte("second"), time.Now().Add(30*time.Second)) 34 | 35 | assert.Equal(t, 1, len(c.keys)) 36 | assert.Equal(t, 6, c.size) 37 | } 38 | 39 | func TestMemoryCache_expiry(t *testing.T) { 40 | c := NewMemoryCache(32*MB, 1*MB) 41 | now := time.Date(2023, 1, 22, 17, 30, 0, 0, time.UTC) 42 | 43 | c.getCurrentTime = func() time.Time { return now } 44 | c.Set(1, []byte("hello world"), now.Add(1*time.Second)) 45 | 46 | read, ok := c.Get(1) 47 | assert.True(t, ok) 48 | assert.Equal(t, []byte("hello world"), read) 49 | 50 | c.getCurrentTime = func() time.Time { return now.Add(2 * time.Second) } 51 | 52 | _, ok = c.Get(1) 53 | assert.False(t, ok) 54 | } 55 | 56 | func TestMemoryCache_does_not_store_items_over_cache_limit(t *testing.T) { 57 | c := NewMemoryCache(3*KB, 50*KB) 58 | 59 | payload := make([]byte, 10*KB) 60 | c.Set(1, payload, time.Now().Add(1*time.Hour)) 61 | 62 | _, ok := c.Get(1) 63 | assert.False(t, ok) 64 | } 65 | 66 | func TestMemoryCache_of_size_zero_does_not_store_items(t *testing.T) { 67 | c := NewMemoryCache(0, 1*KB) 68 | 69 | c.Set(1, []byte("There's nowhere to store this"), time.Now().Add(1*time.Hour)) 70 | 71 | _, ok := c.Get(1) 72 | assert.False(t, ok) 73 | } 74 | 75 | func TestMemoryCache_items_are_evicted_to_make_space(t *testing.T) { 76 | maxCacheSize := 10 * KB 77 | c := NewMemoryCache(maxCacheSize, 1*KB) 78 | 79 | for i := CacheKey(0); i < 20; i++ { 80 | payload := bytes.Repeat([]byte{byte(i)}, 1*KB) 81 | c.Set(i, payload, time.Now().Add(1*time.Hour)) 82 | 83 | retrieved, ok := c.Get(i) 84 | assert.True(t, ok) 85 | assert.Equal(t, payload, retrieved) 86 | } 87 | 88 | assert.Equal(t, maxCacheSize, c.size) 89 | } 90 | 91 | func TestMemoryCache_does_not_store_items_over_item_limit(t *testing.T) { 92 | c := NewMemoryCache(50*KB, 3*KB) 93 | 94 | payload := make([]byte, 10*KB) 95 | c.Set(1, payload, time.Now().Add(1*time.Hour)) 96 | 97 | _, ok := c.Get(1) 98 | assert.False(t, ok) 99 | } 100 | 101 | func BenchmarkCache_populating_small_objects(b *testing.B) { 102 | c := NewMemoryCache(32*MB, 1*MB) 103 | payload := make([]byte, KB) 104 | expires := time.Now().Add(1 * time.Hour) 105 | 106 | for i := CacheKey(0); i < CacheKey(b.N); i++ { 107 | c.Set(i, payload, expires) 108 | c.Get(i) 109 | } 110 | } 111 | 112 | func BenchmarkCache_populating_large_objects(b *testing.B) { 113 | c := NewMemoryCache(32*MB, 1*MB) 114 | payload := make([]byte, 512*KB) 115 | expires := time.Now().Add(1 * time.Hour) 116 | 117 | for i := CacheKey(0); i < CacheKey(b.N); i++ { 118 | c.Set(i, payload, expires) 119 | c.Get(i) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /internal/proxy_handler.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "errors" 5 | "log/slog" 6 | "net/http" 7 | "net/http/httputil" 8 | "net/url" 9 | "os" 10 | ) 11 | 12 | func NewProxyHandler(targetUrl *url.URL, badGatewayPage string, forwardHeaders bool) http.Handler { 13 | return &httputil.ReverseProxy{ 14 | Rewrite: func(r *httputil.ProxyRequest) { 15 | r.SetURL(targetUrl) 16 | r.Out.Host = r.In.Host 17 | setXForwarded(r, forwardHeaders) 18 | }, 19 | ErrorHandler: ProxyErrorHandler(badGatewayPage), 20 | Transport: createProxyTransport(), 21 | } 22 | } 23 | 24 | func ProxyErrorHandler(badGatewayPage string) func(w http.ResponseWriter, r *http.Request, err error) { 25 | content, err := os.ReadFile(badGatewayPage) 26 | if err != nil { 27 | slog.Debug("No custom 502 page found", "path", badGatewayPage) 28 | content = nil 29 | } 30 | 31 | return func(w http.ResponseWriter, r *http.Request, err error) { 32 | slog.Info("Unable to proxy request", "path", r.URL.Path, "error", err) 33 | 34 | if isRequestEntityTooLarge(err) { 35 | w.WriteHeader(http.StatusRequestEntityTooLarge) 36 | return 37 | } 38 | 39 | if content != nil { 40 | w.Header().Set("Content-Type", "text/html") 41 | w.WriteHeader(http.StatusBadGateway) 42 | w.Write(content) 43 | } else { 44 | w.WriteHeader(http.StatusBadGateway) 45 | } 46 | } 47 | } 48 | 49 | func setXForwarded(r *httputil.ProxyRequest, forwardHeaders bool) { 50 | if forwardHeaders { 51 | r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] 52 | } 53 | 54 | r.SetXForwarded() 55 | 56 | if forwardHeaders { 57 | // Preserve original headers if we had them 58 | if r.In.Header.Get("X-Forwarded-Host") != "" { 59 | r.Out.Header.Set("X-Forwarded-Host", r.In.Header.Get("X-Forwarded-Host")) 60 | } 61 | if r.In.Header.Get("X-Forwarded-Proto") != "" { 62 | r.Out.Header.Set("X-Forwarded-Proto", r.In.Header.Get("X-Forwarded-Proto")) 63 | } 64 | } 65 | } 66 | 67 | func isRequestEntityTooLarge(err error) bool { 68 | var maxBytesError *http.MaxBytesError 69 | return errors.As(err, &maxBytesError) 70 | } 71 | 72 | func createProxyTransport() *http.Transport { 73 | // The default transport requests compressed responses even if the client 74 | // didn't. If it receives a compressed response but the client wants 75 | // uncompressed, the transport decompresses the response transparently. 76 | // 77 | // Although that seems helpful, it doesn't play well with X-Sendfile 78 | // responses, as it may result in us being handed a reference to a file on 79 | // disk that is already compressed, and we'd have to similarly decompress it 80 | // before serving it to the client. This is wasteful, especially since there 81 | // was probably an uncompressed version of it on disk already. It's also a bit 82 | // fiddly to do on the fly without the ability to seek around in the 83 | // uncompressed content. 84 | // 85 | // Compression between us and the upstream server is likely to be of limited 86 | // use anyway, since we're only proxying from localhost. Given that fact -- 87 | // and the fact that most clients *will* request compressed responses anyway, 88 | // which makes all of this moot -- our best option is to disable this 89 | // on-the-fly compression. 90 | 91 | transport := http.DefaultTransport.(*http.Transport).Clone() 92 | transport.DisableCompression = true 93 | 94 | return transport 95 | } 96 | -------------------------------------------------------------------------------- /internal/sendfile_handler.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "log/slog" 7 | "net" 8 | "net/http" 9 | "os" 10 | "strconv" 11 | ) 12 | 13 | type SendfileHandler struct { 14 | enabled bool 15 | next http.Handler 16 | } 17 | 18 | func NewSendfileHandler(enabled bool, next http.Handler) *SendfileHandler { 19 | return &SendfileHandler{ 20 | enabled: enabled, 21 | next: next, 22 | } 23 | } 24 | 25 | func (h *SendfileHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 26 | if h.enabled { 27 | r.Header.Set("X-Sendfile-Type", "X-Sendfile") 28 | w = &sendfileWriter{w, r, false, false} 29 | } else { 30 | r.Header.Del("X-Sendfile-Type") 31 | } 32 | 33 | h.next.ServeHTTP(w, r) 34 | } 35 | 36 | type sendfileWriter struct { 37 | w http.ResponseWriter 38 | r *http.Request 39 | headerWritten bool 40 | sendingFile bool 41 | } 42 | 43 | func (w *sendfileWriter) Header() http.Header { 44 | return w.w.Header() 45 | } 46 | 47 | func (w *sendfileWriter) Write(b []byte) (int, error) { 48 | if !w.headerWritten { 49 | w.WriteHeader(http.StatusOK) 50 | } 51 | 52 | if w.sendingFile { 53 | return 0, http.ErrBodyNotAllowed 54 | } 55 | 56 | return w.w.Write(b) 57 | } 58 | 59 | func (w *sendfileWriter) WriteHeader(statusCode int) { 60 | filename := w.sendingFilename() 61 | w.w.Header().Del("X-Sendfile") 62 | 63 | w.sendingFile = filename != "" 64 | w.headerWritten = true 65 | 66 | if w.sendingFile { 67 | w.serveFile(filename) 68 | } else { 69 | w.w.WriteHeader(statusCode) 70 | } 71 | } 72 | 73 | func (w *sendfileWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 74 | hijacker, ok := w.w.(http.Hijacker) 75 | if !ok { 76 | return nil, nil, errors.New("ResponseWriter does not implement http.Hijacker") 77 | } 78 | 79 | return hijacker.Hijack() 80 | } 81 | 82 | func (w *sendfileWriter) sendingFilename() string { 83 | return w.w.Header().Get("X-Sendfile") 84 | } 85 | 86 | func (w *sendfileWriter) serveFile(filename string) { 87 | slog.Debug("X-Sendfile sending file", "path", filename) 88 | 89 | w.setContentLength(filename) 90 | http.ServeFile(w.w, w.r, filename) 91 | } 92 | 93 | func (w *sendfileWriter) setContentLength(filename string) { 94 | // In most cases, `http.ServeFile` will set this for us. However, it will not 95 | // set it if the response also has a `Content-Encoding`. 96 | // (https://github.com/golang/go/commit/fdc21f3eafe94490e55e0bf018490b3aa9ba2383) 97 | // 98 | // If we don't set (or at least clear) the header in that case, we'll pass 99 | // through the `Content-Length` of the upstream's response, which can lead to 100 | // us serving an incomplete response. 101 | // 102 | // In particular, this happens when Rails is serving a gzipped asset via 103 | // `X-Sendfile`, which it does using `Content-Encoding: gzip` and 104 | // `Content-Length: 0`. 105 | 106 | fi, err := os.Stat(filename) 107 | if err != nil { 108 | w.w.Header().Del("Content-Length") 109 | } else { 110 | w.w.Header().Set("Content-Length", strconv.FormatInt(fi.Size(), 10)) 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /internal/sendfile_handler_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strconv" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestSendfileHandler(t *testing.T) { 13 | upstream := func(w http.ResponseWriter, r *http.Request) { 14 | assert.Equal(t, "X-Sendfile", r.Header.Get("X-Sendfile-Type")) 15 | 16 | w.Header().Set("X-Sendfile", fixturePath("image.jpg")) 17 | w.Write([]byte("This body should not be seen")) 18 | } 19 | 20 | h := NewSendfileHandler(true, http.HandlerFunc(upstream)) 21 | 22 | w := httptest.NewRecorder() 23 | r := httptest.NewRequest("GET", "/", nil) 24 | h.ServeHTTP(w, r) 25 | 26 | assert.Equal(t, http.StatusOK, w.Code) 27 | assert.Equal(t, "image/jpeg", w.Header().Get("Content-Type")) 28 | assert.Equal(t, strconv.FormatInt(fixtureLength("image.jpg"), 10), w.Header().Get("Content-Length")) 29 | assert.Equal(t, fixtureContent("image.jpg"), w.Body.Bytes()) 30 | } 31 | 32 | func TestSendfileHandler_sends_correct_content_length_when_content_encoding_present(t *testing.T) { 33 | upstream := func(w http.ResponseWriter, r *http.Request) { 34 | assert.Equal(t, "X-Sendfile", r.Header.Get("X-Sendfile-Type")) 35 | 36 | w.Header().Set("Content-Encoding", "gzip") 37 | w.Header().Set("Content-Length", "0") 38 | w.Header().Set("X-Sendfile", fixturePath("image.jpg")) 39 | w.WriteHeader(http.StatusOK) 40 | } 41 | 42 | h := NewSendfileHandler(true, http.HandlerFunc(upstream)) 43 | 44 | w := httptest.NewRecorder() 45 | r := httptest.NewRequest("GET", "/", nil) 46 | h.ServeHTTP(w, r) 47 | 48 | assert.Equal(t, http.StatusOK, w.Code) 49 | assert.Equal(t, "image/jpeg", w.Header().Get("Content-Type")) 50 | assert.Equal(t, fixtureContent("image.jpg"), w.Body.Bytes()) 51 | assert.Equal(t, strconv.FormatInt(fixtureLength("image.jpg"), 10), w.Header().Get("Content-Length")) 52 | } 53 | 54 | func TestSendFileHandler_when_no_x_sendfile_present(t *testing.T) { 55 | upstream := func(w http.ResponseWriter, r *http.Request) { 56 | assert.Equal(t, "X-Sendfile", r.Header.Get("X-Sendfile-Type")) 57 | 58 | w.Header().Set("Content-Type", "application/custom") 59 | w.WriteHeader(http.StatusTeapot) 60 | w.Write([]byte("This body should be seen")) 61 | } 62 | 63 | h := NewSendfileHandler(true, http.HandlerFunc(upstream)) 64 | 65 | w := httptest.NewRecorder() 66 | r := httptest.NewRequest("GET", "/", nil) 67 | h.ServeHTTP(w, r) 68 | 69 | assert.Equal(t, http.StatusTeapot, w.Code) 70 | assert.Equal(t, "application/custom", w.Header().Get("Content-Type")) 71 | assert.Equal(t, "This body should be seen", w.Body.String()) 72 | } 73 | 74 | func TestSendFileHandler_when_not_enabled(t *testing.T) { 75 | upstream := func(w http.ResponseWriter, r *http.Request) { 76 | assert.Equal(t, "", r.Header.Get("X-Sendfile-Type")) 77 | 78 | w.Header().Set("Content-Type", "application/custom") 79 | w.Header().Set("X-Sendfile", fixturePath("image.jpg")) 80 | w.WriteHeader(http.StatusTeapot) 81 | w.Write([]byte("This body should be seen")) 82 | } 83 | 84 | h := NewSendfileHandler(false, http.HandlerFunc(upstream)) 85 | 86 | w := httptest.NewRecorder() 87 | r := httptest.NewRequest("GET", "/", nil) 88 | h.ServeHTTP(w, r) 89 | 90 | assert.Equal(t, http.StatusTeapot, w.Code) 91 | assert.Equal(t, "application/custom", w.Header().Get("Content-Type")) 92 | assert.Equal(t, "This body should be seen", w.Body.String()) 93 | } 94 | -------------------------------------------------------------------------------- /internal/server.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "context" 5 | "encoding/base64" 6 | "fmt" 7 | "log/slog" 8 | "net" 9 | "net/http" 10 | "time" 11 | 12 | "golang.org/x/crypto/acme" 13 | "golang.org/x/crypto/acme/autocert" 14 | ) 15 | 16 | type Server struct { 17 | config *Config 18 | handler http.Handler 19 | httpServer *http.Server 20 | httpsServer *http.Server 21 | } 22 | 23 | func NewServer(config *Config, handler http.Handler) *Server { 24 | return &Server{ 25 | handler: handler, 26 | config: config, 27 | } 28 | } 29 | 30 | func (s *Server) Start() { 31 | httpAddress := fmt.Sprintf(":%d", s.config.HttpPort) 32 | httpsAddress := fmt.Sprintf(":%d", s.config.HttpsPort) 33 | 34 | if s.config.HasTLS() { 35 | manager := s.certManager() 36 | 37 | s.httpServer = s.defaultHttpServer(httpAddress) 38 | s.httpServer.Handler = manager.HTTPHandler(http.HandlerFunc(httpRedirectHandler)) 39 | 40 | s.httpsServer = s.defaultHttpServer(httpsAddress) 41 | s.httpsServer.TLSConfig = manager.TLSConfig() 42 | s.httpsServer.Handler = s.handler 43 | 44 | go s.httpServer.ListenAndServe() 45 | go s.httpsServer.ListenAndServeTLS("", "") 46 | 47 | slog.Info("Server started", "http", httpAddress, "https", httpsAddress, "tls_domain", s.config.TLSDomains) 48 | } else { 49 | s.httpsServer = nil 50 | s.httpServer = s.defaultHttpServer(httpAddress) 51 | s.httpServer.Handler = s.handler 52 | 53 | go s.httpServer.ListenAndServe() 54 | 55 | slog.Info("Server started", "http", httpAddress) 56 | } 57 | } 58 | 59 | func (s *Server) Stop() { 60 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 61 | defer cancel() 62 | defer slog.Info("Server stopped") 63 | 64 | slog.Info("Server stopping") 65 | 66 | s.httpServer.Shutdown(ctx) 67 | if s.httpsServer != nil { 68 | s.httpsServer.Shutdown(ctx) 69 | } 70 | } 71 | 72 | func (s *Server) certManager() *autocert.Manager { 73 | client := &acme.Client{DirectoryURL: s.config.ACMEDirectoryURL} 74 | binding := s.externalAccountBinding() 75 | 76 | slog.Debug("TLS: initializing", "directory", client.DirectoryURL, "using_eab", binding != nil) 77 | 78 | return &autocert.Manager{ 79 | Cache: autocert.DirCache(s.config.StoragePath), 80 | Client: client, 81 | ExternalAccountBinding: binding, 82 | HostPolicy: autocert.HostWhitelist(s.config.TLSDomains...), 83 | Prompt: autocert.AcceptTOS, 84 | } 85 | } 86 | 87 | func (s *Server) externalAccountBinding() *acme.ExternalAccountBinding { 88 | if s.config.EAB_KID == "" || s.config.EAB_HMACKey == "" { 89 | return nil 90 | } 91 | 92 | key, err := base64.RawURLEncoding.DecodeString(s.config.EAB_HMACKey) 93 | if err != nil { 94 | slog.Error("Error decoding EAB_HMACKey", "error", err) 95 | return nil 96 | } 97 | 98 | return &acme.ExternalAccountBinding{ 99 | KID: s.config.EAB_KID, 100 | Key: key, 101 | } 102 | } 103 | 104 | func (s *Server) defaultHttpServer(addr string) *http.Server { 105 | return &http.Server{ 106 | Addr: addr, 107 | IdleTimeout: s.config.HttpIdleTimeout, 108 | ReadTimeout: s.config.HttpReadTimeout, 109 | WriteTimeout: s.config.HttpWriteTimeout, 110 | } 111 | } 112 | 113 | func httpRedirectHandler(w http.ResponseWriter, r *http.Request) { 114 | w.Header().Set("Connection", "close") 115 | 116 | host, _, err := net.SplitHostPort(r.Host) 117 | if err != nil { 118 | host = r.Host 119 | } 120 | 121 | url := "https://" + host + r.URL.RequestURI() 122 | http.Redirect(w, r, url, http.StatusMovedPermanently) 123 | } 124 | -------------------------------------------------------------------------------- /internal/service.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "fmt" 5 | "log/slog" 6 | "net/url" 7 | "os" 8 | ) 9 | 10 | type Service struct { 11 | config *Config 12 | } 13 | 14 | func NewService(config *Config) *Service { 15 | return &Service{ 16 | config: config, 17 | } 18 | } 19 | 20 | func (s *Service) Run() int { 21 | handlerOptions := HandlerOptions{ 22 | cache: s.cache(), 23 | targetUrl: s.targetUrl(), 24 | xSendfileEnabled: s.config.XSendfileEnabled, 25 | gzipCompressionEnabled: s.config.GzipCompressionEnabled, 26 | maxCacheableResponseBody: s.config.MaxCacheItemSizeBytes, 27 | maxRequestBody: s.config.MaxRequestBody, 28 | badGatewayPage: s.config.BadGatewayPage, 29 | forwardHeaders: s.config.ForwardHeaders, 30 | } 31 | 32 | handler := NewHandler(handlerOptions) 33 | server := NewServer(s.config, handler) 34 | upstream := NewUpstreamProcess(s.config.UpstreamCommand, s.config.UpstreamArgs...) 35 | 36 | server.Start() 37 | defer server.Stop() 38 | 39 | s.setEnvironment() 40 | 41 | exitCode, err := upstream.Run() 42 | if err != nil { 43 | slog.Error("Failed to start wrapped process", "command", s.config.UpstreamCommand, "args", s.config.UpstreamArgs, "error", err) 44 | return 1 45 | } 46 | 47 | return exitCode 48 | } 49 | 50 | // Private 51 | 52 | func (s *Service) cache() Cache { 53 | return NewMemoryCache(s.config.CacheSizeBytes, s.config.MaxCacheItemSizeBytes) 54 | } 55 | 56 | func (s *Service) targetUrl() *url.URL { 57 | url, _ := url.Parse(fmt.Sprintf("http://localhost:%d", s.config.TargetPort)) 58 | return url 59 | } 60 | 61 | func (s *Service) setEnvironment() { 62 | // Set PORT to be inherited by the upstream process. 63 | os.Setenv("PORT", fmt.Sprintf("%d", s.config.TargetPort)) 64 | } 65 | -------------------------------------------------------------------------------- /internal/testing.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "os" 5 | "path" 6 | "testing" 7 | ) 8 | 9 | func fixturePath(name string) string { 10 | return path.Join("fixtures", name) 11 | } 12 | 13 | func fixtureContent(name string) []byte { 14 | result, _ := os.ReadFile(fixturePath(name)) 15 | return result 16 | } 17 | 18 | func fixtureLength(name string) int64 { 19 | info, _ := os.Stat(fixturePath(name)) 20 | return info.Size() 21 | } 22 | 23 | func usingEnvVar(t *testing.T, key, value string) { 24 | old, found := os.LookupEnv(key) 25 | os.Setenv(key, value) 26 | 27 | t.Cleanup(func() { 28 | if found { 29 | os.Setenv(key, old) 30 | } else { 31 | os.Unsetenv(key) 32 | } 33 | }) 34 | } 35 | 36 | func usingProgramArgs(t *testing.T, args ...string) { 37 | old := os.Args 38 | os.Args = args 39 | 40 | t.Cleanup(func() { 41 | os.Args = old 42 | }) 43 | } 44 | -------------------------------------------------------------------------------- /internal/upstream_process.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "errors" 5 | "log/slog" 6 | "os" 7 | "os/exec" 8 | "os/signal" 9 | "syscall" 10 | ) 11 | 12 | type UpstreamProcess struct { 13 | Started chan struct{} 14 | cmd *exec.Cmd 15 | } 16 | 17 | func NewUpstreamProcess(name string, arg ...string) *UpstreamProcess { 18 | return &UpstreamProcess{ 19 | Started: make(chan struct{}, 1), 20 | cmd: exec.Command(name, arg...), 21 | } 22 | } 23 | 24 | func (p *UpstreamProcess) Run() (int, error) { 25 | p.cmd.Stdin = os.Stdin 26 | p.cmd.Stdout = os.Stdout 27 | p.cmd.Stderr = os.Stderr 28 | 29 | err := p.cmd.Start() 30 | if err != nil { 31 | return 0, err 32 | } 33 | 34 | p.Started <- struct{}{} 35 | 36 | go p.handleSignals() 37 | err = p.cmd.Wait() 38 | 39 | var exitErr *exec.ExitError 40 | if errors.As(err, &exitErr) { 41 | return exitErr.ExitCode(), nil 42 | } 43 | 44 | return 0, err 45 | } 46 | 47 | func (p *UpstreamProcess) Signal(sig os.Signal) error { 48 | return p.cmd.Process.Signal(sig) 49 | } 50 | 51 | func (p *UpstreamProcess) handleSignals() { 52 | ch := make(chan os.Signal, 1) 53 | signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) 54 | 55 | sig := <-ch 56 | slog.Info("Relaying signal to upstream process", "signal", sig.String()) 57 | p.Signal(sig) 58 | } 59 | -------------------------------------------------------------------------------- /internal/upstream_process_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "syscall" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestUpstreamProcess(t *testing.T) { 11 | t.Run("return exit code on exit", func(t *testing.T) { 12 | p := NewUpstreamProcess("false") 13 | exitCode, err := p.Run() 14 | 15 | assert.NoError(t, err) 16 | assert.Equal(t, 1, exitCode) 17 | }) 18 | 19 | t.Run("signal a process to stop it", func(t *testing.T) { 20 | var exitCode int 21 | var err error 22 | 23 | p := NewUpstreamProcess("sleep", "10") 24 | 25 | go func() { 26 | exitCode, err = p.Run() 27 | }() 28 | 29 | <-p.Started 30 | p.Signal(syscall.SIGTERM) 31 | 32 | assert.NoError(t, err) 33 | assert.Equal(t, 0, exitCode) 34 | }) 35 | } 36 | -------------------------------------------------------------------------------- /internal/variant.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "hash/fnv" 5 | "net/http" 6 | "slices" 7 | "strings" 8 | ) 9 | 10 | type Variant struct { 11 | r *http.Request 12 | headerNames []string 13 | } 14 | 15 | func NewVariant(r *http.Request) *Variant { 16 | return &Variant{r: r} 17 | } 18 | 19 | func (v *Variant) SetResponseHeader(header http.Header) { 20 | v.headerNames = v.parseVaryHeader(header) 21 | } 22 | 23 | func (v *Variant) CacheKey() CacheKey { 24 | hash := fnv.New64() 25 | hash.Write([]byte(v.r.Method)) 26 | hash.Write([]byte(v.r.URL.Path)) 27 | hash.Write([]byte(v.r.URL.Query().Encode())) 28 | 29 | for _, name := range v.headerNames { 30 | hash.Write([]byte(name + "=" + v.r.Header.Get(name))) 31 | } 32 | 33 | return CacheKey(hash.Sum64()) 34 | } 35 | 36 | func (v *Variant) Matches(responseHeader http.Header) bool { 37 | for _, name := range v.headerNames { 38 | if responseHeader.Get(name) != v.r.Header.Get(name) { 39 | return false 40 | } 41 | } 42 | return true 43 | } 44 | 45 | func (v *Variant) VariantHeader() http.Header { 46 | requestHeader := http.Header{} 47 | for _, name := range v.headerNames { 48 | requestHeader.Set(name, v.r.Header.Get(name)) 49 | } 50 | return requestHeader 51 | } 52 | 53 | // Private 54 | 55 | func (v *Variant) parseVaryHeader(responseHeader http.Header) []string { 56 | list := responseHeader.Get("Vary") 57 | if list == "" { 58 | return []string{} 59 | } 60 | 61 | names := strings.Split(list, ",") 62 | for i, name := range names { 63 | names[i] = http.CanonicalHeaderKey(strings.TrimSpace(name)) 64 | } 65 | slices.Sort(names) 66 | 67 | return names 68 | } 69 | -------------------------------------------------------------------------------- /internal/variant_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestVariantCacheKey(t *testing.T) { 12 | key1 := NewVariant(httptest.NewRequest("GET", "/home", nil)).CacheKey() 13 | key2 := NewVariant(httptest.NewRequest("GET", "/home", nil)).CacheKey() 14 | key3 := NewVariant(httptest.NewRequest("GET", "/home?a=b", nil)).CacheKey() 15 | key4 := NewVariant(httptest.NewRequest("POST", "/home?a=b", nil)).CacheKey() 16 | 17 | assert.Equal(t, key1, key2) 18 | assert.NotEqual(t, key1, key3) 19 | assert.NotEqual(t, key3, key4) 20 | } 21 | 22 | func TestVariantCacheKey_includes_variant_header_fields(t *testing.T) { 23 | r1 := httptest.NewRequest("GET", "/home", nil) 24 | r2 := httptest.NewRequest("GET", "/home", nil) 25 | r2.Header.Set("Accept-Encoding", "gzip") 26 | 27 | v1 := NewVariant(r1) 28 | v2 := NewVariant(r2) 29 | 30 | assert.Equal(t, v1.CacheKey(), v2.CacheKey()) 31 | 32 | v1.SetResponseHeader(http.Header{"Vary": []string{"Accept-Encoding"}}) 33 | v2.SetResponseHeader(http.Header{"Vary": []string{"Accept-Encoding"}}) 34 | 35 | assert.NotEqual(t, v1.CacheKey(), v2.CacheKey()) 36 | } 37 | 38 | func TestVariantMatches(t *testing.T) { 39 | r := httptest.NewRequest("GET", "/home", nil) 40 | r.Header.Set("Accept-Encoding", "gzip") 41 | 42 | v := NewVariant(r) 43 | v.SetResponseHeader(http.Header{"Vary": []string{"Accept-Encoding"}}) 44 | 45 | assert.True(t, v.Matches(http.Header{"Accept-Encoding": []string{"gzip"}, "Accept": []string{"text/plain"}})) 46 | assert.False(t, v.Matches(http.Header{"Accept-Encoding": []string{"deflate"}})) 47 | } 48 | 49 | func TestVariantMatches_multiple_headers(t *testing.T) { 50 | r := httptest.NewRequest("GET", "/home", nil) 51 | r.Header.Set("Accept-Encoding", "gzip") 52 | r.Header.Set("Accept", "text/plain") 53 | 54 | v := NewVariant(r) 55 | v.SetResponseHeader(http.Header{"Vary": []string{"Accept-Encoding, Accept"}}) 56 | 57 | assert.True(t, v.Matches(http.Header{"Accept-Encoding": []string{"gzip"}, "Accept": []string{"text/plain"}})) 58 | assert.False(t, v.Matches(http.Header{"Accept-Encoding": []string{"gzip"}, "Accept": []string{"text/html"}})) 59 | } 60 | 61 | func TestVariantMatches_missing_headers(t *testing.T) { 62 | r := httptest.NewRequest("GET", "/home", nil) 63 | r.Header.Set("Accept-Encoding", "gzip") 64 | 65 | v := NewVariant(r) 66 | v.SetResponseHeader(http.Header{"Vary": []string{"Accept-Encoding, Accept"}}) 67 | 68 | assert.True(t, v.Matches(http.Header{"Accept-Encoding": []string{"gzip"}})) 69 | assert.False(t, v.Matches(http.Header{"Accept-Encoding": []string{"gzip"}, "Accept": []string{"text/html"}})) 70 | } 71 | -------------------------------------------------------------------------------- /lib/thruster.rb: -------------------------------------------------------------------------------- 1 | module Thruster 2 | end 3 | 4 | require_relative "thruster/version" 5 | -------------------------------------------------------------------------------- /lib/thruster/version.rb: -------------------------------------------------------------------------------- 1 | module Thruster 2 | VERSION = "0.1.13" 3 | end 4 | -------------------------------------------------------------------------------- /rakelib/package.rake: -------------------------------------------------------------------------------- 1 | require "rubygems/package_task" 2 | 3 | NATIVE_PLATFORMS = { 4 | "arm64-darwin" => "dist/thrust-darwin-arm64", 5 | "x86_64-darwin" => "dist/thrust-darwin-amd64", 6 | "x86_64-linux" => "dist/thrust-linux-amd64", 7 | "aarch64-linux" => "dist/thrust-linux-arm64", 8 | } 9 | 10 | BASE_GEMSPEC = Bundler.load_gemspec("thruster.gemspec") 11 | 12 | gem_path = Gem::PackageTask.new(BASE_GEMSPEC).define 13 | desc "Build the ruby gem" 14 | task "gem:ruby" => [ gem_path ] 15 | 16 | desc "Build native executables" 17 | namespace :build do 18 | task :native do 19 | system("make dist") 20 | end 21 | end 22 | task :gem => "build:native" 23 | 24 | NATIVE_PLATFORMS.each do |platform, executable| 25 | BASE_GEMSPEC.dup.tap do |gemspec| 26 | exedir = File.join(gemspec.bindir, platform) 27 | exepath = File.join(exedir, "thrust") 28 | 29 | gemspec.platform = platform 30 | gemspec.files << exepath 31 | 32 | gem_path = Gem::PackageTask.new(gemspec).define 33 | desc "Build the #{platform} gem" 34 | task "gem:#{platform}" => [gem_path] 35 | 36 | directory exedir 37 | file exepath => [ exedir ] do 38 | FileUtils.cp executable, exepath 39 | FileUtils.chmod(0755, exepath ) 40 | end 41 | 42 | CLOBBER.add(exedir) 43 | end 44 | end 45 | 46 | CLOBBER.add("dist") 47 | -------------------------------------------------------------------------------- /thruster.gemspec: -------------------------------------------------------------------------------- 1 | require_relative "lib/thruster/version" 2 | 3 | Gem::Specification.new do |s| 4 | s.name = "thruster" 5 | s.version = Thruster::VERSION 6 | s.summary = "Zero-config HTTP/2 proxy" 7 | s.description = "A zero-config HTTP/2 proxy for lightweight production deployments" 8 | s.authors = [ "Kevin McConnell" ] 9 | s.email = "kevin@37signals.com" 10 | s.homepage = "https://github.com/basecamp/thruster" 11 | s.license = "MIT" 12 | 13 | s.metadata = { 14 | "homepage_uri" => s.homepage, 15 | "rubygems_mfa_required" => "true" 16 | } 17 | 18 | s.files = Dir[ "{lib}/**/*", "MIT-LICENSE", "README.md" ] 19 | s.bindir = "exe" 20 | s.executables << "thrust" 21 | end 22 | --------------------------------------------------------------------------------