├── .github ├── CODEOWNERS ├── workflows │ ├── generate-readme.yml │ ├── renovate.yml │ ├── test.yml │ ├── generate-code.yml │ └── tag-semver.yml ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── bug_report.yml │ └── feature_request.yml ├── PULL_REQUEST_TEMPLATE.md ├── SUPPORT.md ├── SECURITY.md ├── CONTRIBUTING.md └── CODE_OF_CONDUCT.md ├── .gitignore ├── generate.go ├── Makefile ├── TODO.md ├── codegen ├── realip_cloudflare.tmpl └── main.go ├── logical.go ├── headers.go ├── context.go ├── .editorconfig ├── debug_test.go ├── LICENSE ├── debug.go ├── headers_test.go ├── logical_test.go ├── realip_cloudflare.go ├── render.go ├── go.mod ├── recoverer_test.go ├── recoverer.go ├── prometheus.go ├── api.go ├── render_test.go ├── logger.go ├── redirect.go ├── api_test.go ├── bind.go ├── security.go ├── server.go ├── goembed.go ├── errors.go ├── README.md ├── realip.go ├── .golangci.yaml ├── realip_test.go ├── go.sum └── auth.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # THIS FILE IS GENERATED! DO NOT EDIT! Maintained by Terraform. 2 | .github/* @lrstanley 3 | LICENSE @lrstanley 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | ~* 3 | *.tmp 4 | tmp 5 | *.txt 6 | dist 7 | *.test 8 | *.prof 9 | *.conf 10 | !.golangci.yml 11 | *.yml 12 | *.yaml 13 | !/.github/** 14 | vendor/* 15 | bin 16 | **/.env 17 | -------------------------------------------------------------------------------- /generate.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | //go:generate go run ./codegen/main.go 8 | -------------------------------------------------------------------------------- /.github/workflows/generate-readme.yml: -------------------------------------------------------------------------------- 1 | name: generate-readme 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | tags: [v*] 7 | schedule: 8 | - cron: "0 13 * * *" 9 | 10 | jobs: 11 | generate: 12 | uses: lrstanley/.github/.github/workflows/generate-readme.yml@master 13 | secrets: inherit 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := generate 2 | 3 | license: 4 | curl -sL https://liam.sh/-/gh/g/license-header.sh | bash -s 5 | 6 | up: 7 | go get -u ./... && go mod tidy 8 | go get -u -t ./... && go mod tidy 9 | 10 | generate: license 11 | go generate -x ./... 12 | go test -v ./... 13 | 14 | test: 15 | gofmt -e -s -w . 16 | go vet . 17 | go test -v ./... 18 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | ## TODO 2 | 3 | - use https://go-chi.io/#/pages/testing 4 | - https://go-chi.io/#/pages/middleware?id=cors 5 | - https://go-chi.io/#/pages/middleware?id=jwt-authentication 6 | - https://go-chi.io/#/pages/middleware?id=http-rate-limiting-middleware 7 | - Add additional recommended middleware to the readme, as a helpful starter 8 | template? 9 | - register a bunch of default middleware via a helper method? 10 | -------------------------------------------------------------------------------- /codegen/realip_cloudflare.tmpl: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | // 5 | // THIS FILE IS AUTO-GENERATED. DO NOT EDIT. 6 | 7 | package chix 8 | 9 | import ( 10 | "net" 11 | 12 | "github.com/lrstanley/go-bogon" 13 | ) 14 | 15 | // cloudflareRanges returns the list of Cloudflare IP ranges. 16 | func cloudflareRanges() []*net.IPNet { 17 | return []*net.IPNet{ 18 | {{- range . }} 19 | bogon.MustCIDR("{{ . }}"), 20 | {{- end }} 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /.github/workflows/renovate.yml: -------------------------------------------------------------------------------- 1 | name: renovate 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | force-run: 7 | description: >- 8 | Force a run regardless of the schedule configuration. 9 | required: false 10 | default: false 11 | type: boolean 12 | push: 13 | branches: [master] 14 | schedule: 15 | - cron: "0 5 1,15 * *" 16 | 17 | jobs: 18 | renovate: 19 | uses: lrstanley/.github/.github/workflows/renovate.yml@master 20 | secrets: inherit 21 | with: 22 | force-run: ${{ inputs.force-run == true || github.event_name == 'schedule' }} 23 | -------------------------------------------------------------------------------- /logical.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import "net/http" 8 | 9 | // UseIf is a conditional middleware that only uses the provided middleware if 10 | // the condition is true, otherwise continues as normal. 11 | func UseIf(cond bool, handler func(http.Handler) http.Handler) func(http.Handler) http.Handler { 12 | return func(next http.Handler) http.Handler { 13 | if !cond { 14 | return next 15 | } 16 | 17 | return handler(next) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | pull_request: 5 | branches: [master] 6 | paths-ignore: [".gitignore", "**/*.md", ".github/ISSUE_TEMPLATE/**"] 7 | types: [opened, edited, reopened, synchronize, unlocked] 8 | push: 9 | branches: [master] 10 | paths-ignore: [".gitignore", "**/*.md", ".github/ISSUE_TEMPLATE/**"] 11 | 12 | jobs: 13 | go-test: 14 | uses: lrstanley/.github/.github/workflows/lang-go-test-matrix.yml@master 15 | secrets: inherit 16 | with: { num-minor: 1, num-patch: 2 } 17 | go-lint: 18 | uses: lrstanley/.github/.github/workflows/lang-go-lint.yml@master 19 | secrets: inherit 20 | -------------------------------------------------------------------------------- /.github/workflows/generate-code.yml: -------------------------------------------------------------------------------- 1 | name: generate-code 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | schedule: 7 | - cron: "0 10 * * *" 8 | 9 | jobs: 10 | generate: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: install-go 14 | uses: actions/setup-go@v5 15 | with: 16 | go-version: latest 17 | - uses: actions/checkout@v4 18 | - run: make generate 19 | - uses: EndBug/add-and-commit@a604fba70a846a0ea59e6040ef8a4a4f95015772 20 | with: 21 | message: "chore(gen): codegen changes" 22 | commit: "--signoff" 23 | pathspec_error_handling: "ignore" 24 | add: "*.go" 25 | -------------------------------------------------------------------------------- /headers.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import "net/http" 8 | 9 | // UseHeaders is a convenience handler to set multiple response header key/value 10 | // pairs. Similar to go-chi's SetHeader, but allows for multiple headers to be set 11 | // at once. 12 | func UseHeaders(headers map[string]string) func(next http.Handler) http.Handler { 13 | return func(next http.Handler) http.Handler { 14 | fn := func(w http.ResponseWriter, r *http.Request) { 15 | for k, v := range headers { 16 | w.Header().Set(k, v) 17 | } 18 | 19 | next.ServeHTTP(w, r) 20 | } 21 | return http.HandlerFunc(fn) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | // contextKey is a type that prevent key collisions in contexts. When the type 8 | // is different, even if the key name is the same, it will never overlap with 9 | // another package. 10 | type contextKey string 11 | 12 | const ( 13 | contextDebug contextKey = "debug" 14 | contextAuth contextKey = "auth" 15 | contextAuthID contextKey = "auth_id" 16 | contextAuthRoles contextKey = "auth_roles" 17 | contextNextURL contextKey = "next_url" 18 | contextSkipNextURL contextKey = "skip_next_url" 19 | contextIP contextKey = "ip" 20 | 21 | authSessionKey = "_auth" 22 | nextSessionKey = "_next" 23 | ) 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | # THIS FILE IS GENERATED! DO NOT EDIT! Maintained by Terraform. 2 | blank_issues_enabled: false 3 | contact_links: 4 | - name: "🙋‍♂️ Ask the community a question!" 5 | about: Have a question, that might not be a bug? Wondering how to solve a problem? Ask away! 6 | url: "https://github.com/lrstanley/chix/discussions/new?category=q-a" 7 | - name: "🎉 Show us what you've made!" 8 | about: Have you built something using chix, and want to show others? Post here! 9 | url: "https://github.com/lrstanley/chix/discussions/new?category=show-and-tell" 10 | - name: "✋ Additional support information" 11 | about: Looking for something else? Check here. 12 | url: "https://github.com/lrstanley/chix/blob/master/.github/SUPPORT.md" 13 | - name: "💬 Discord chat" 14 | about: On-topic and off-topic discussions. 15 | url: "https://liam.sh/chat" 16 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # THIS FILE IS GENERATED! DO NOT EDIT! Maintained by Terraform. 2 | # 3 | # editorconfig: https://editorconfig.org/ 4 | # actual source: https://github.com/lrstanley/.github/blob/master/terraform/github-common-files/templates/.editorconfig 5 | # 6 | 7 | root = true 8 | 9 | [*] 10 | charset = utf-8 11 | end_of_line = lf 12 | indent_size = 4 13 | indent_style = space 14 | insert_final_newline = true 15 | trim_trailing_whitespace = true 16 | max_line_length = 100 17 | 18 | [*.tf] 19 | indent_size = 2 20 | 21 | [*.go] 22 | indent_style = tab 23 | indent_size = 4 24 | 25 | [*.md] 26 | trim_trailing_whitespace = false 27 | 28 | [*.{md,py,sh,yml,yaml,cjs,js,ts,vue,css}] 29 | max_line_length = 105 30 | 31 | [*.{yml,yaml,toml}] 32 | indent_size = 2 33 | 34 | [*.json] 35 | indent_size = 2 36 | 37 | [*.html] 38 | max_line_length = 140 39 | indent_size = 2 40 | 41 | [*.{cjs,js,ts,vue,css}] 42 | indent_size = 2 43 | 44 | [Makefile] 45 | indent_style = tab 46 | 47 | [**.min.js] 48 | indent_style = ignore 49 | 50 | [*.bat] 51 | indent_style = tab 52 | -------------------------------------------------------------------------------- /debug_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | ) 12 | 13 | func TestUseDebug(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | debug bool 17 | }{ 18 | {name: "debug", debug: true}, 19 | {name: "not debug", debug: false}, 20 | } 21 | for _, tt := range tests { 22 | t.Run(tt.name, func(t *testing.T) { 23 | req := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody) 24 | 25 | handler := UseDebug(tt.debug)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 26 | if IsDebug(r) != tt.debug { 27 | t.Errorf("IsDebug() = %v, want %v", IsDebug(r), tt.debug) 28 | } 29 | 30 | if IsDebugCtx(r.Context()) != tt.debug { 31 | t.Errorf("IsDebugCtx() = %v, want %v", IsDebugCtx(r.Context()), tt.debug) 32 | } 33 | })) 34 | 35 | handler.ServeHTTP(httptest.NewRecorder(), req) 36 | }) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /.github/workflows/tag-semver.yml: -------------------------------------------------------------------------------- 1 | name: tag-semver 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | method: 7 | description: "Tagging method to use" 8 | required: true 9 | type: choice 10 | options: [major, minor, patch, alpha, rc, custom] 11 | custom: 12 | description: "Custom tag, if the default doesn't suffice. Must also use method 'custom'." 13 | required: false 14 | type: string 15 | ref: 16 | description: "Git ref to apply tag to (will use default branch if unspecified)." 17 | required: false 18 | type: string 19 | annotation: 20 | description: "Optional annotation to add to the commit." 21 | required: false 22 | type: string 23 | 24 | jobs: 25 | tag-semver: 26 | uses: lrstanley/.github/.github/workflows/tag-semver.yml@master 27 | secrets: inherit 28 | with: 29 | method: ${{ github.event.inputs.method }} 30 | ref: ${{ github.event.inputs.ref }} 31 | custom: ${{ github.event.inputs.custom }} 32 | annotation: ${{ github.event.inputs.annotation }} 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Liam Stanley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /debug.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "context" 9 | "net/http" 10 | ) 11 | 12 | // UseDebug is a middleware that allows passing if debugging is enabled for the 13 | // http server. Use IsDebug to check if debugging is enabled. 14 | func UseDebug(debug bool) func(next http.Handler) http.Handler { 15 | return func(next http.Handler) http.Handler { 16 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 17 | next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextDebug, debug))) 18 | }) 19 | } 20 | } 21 | 22 | // IsDebug returns true if debugging for the server is enabled (gets the 23 | // context from the request). 24 | func IsDebug(r *http.Request) bool { 25 | return IsDebugCtx(r.Context()) 26 | } 27 | 28 | // IsDebugCtx returns true if debugging for the server is enabled. 29 | func IsDebugCtx(ctx context.Context) bool { 30 | // If it's not there, return false anyway. 31 | debug, _ := ctx.Value(contextDebug).(bool) 32 | return debug 33 | } 34 | -------------------------------------------------------------------------------- /headers_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | ) 12 | 13 | func TestUseHeaders(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | headers map[string]string 17 | }{ 18 | {name: "empty", headers: map[string]string{}}, 19 | {name: "single", headers: map[string]string{"Some-Header": "foo"}}, 20 | {name: "multiple", headers: map[string]string{"Some-Header": "foo", "Another-Header": "bar"}}, 21 | } 22 | for _, tt := range tests { 23 | t.Run(tt.name, func(t *testing.T) { 24 | req := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody) 25 | 26 | handler := UseHeaders(tt.headers)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 27 | for k, v := range tt.headers { 28 | if w.Header().Get(k) != v { 29 | t.Errorf("UseHeaders() = %v, want %v", w.Header().Get(k), v) 30 | } 31 | } 32 | })) 33 | 34 | handler.ServeHTTP(httptest.NewRecorder(), req) 35 | }) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /logical_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | ) 12 | 13 | func TestUseIf(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | cond bool 17 | }{ 18 | {name: "true", cond: true}, 19 | {name: "false", cond: false}, 20 | } 21 | for _, tt := range tests { 22 | t.Run(tt.name, func(t *testing.T) { 23 | req := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody) 24 | 25 | called := false 26 | setCalledHandler := func(next http.Handler) http.Handler { 27 | called = true 28 | 29 | return next 30 | } 31 | 32 | handler := UseIf(tt.cond, setCalledHandler)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 33 | handler.ServeHTTP(httptest.NewRecorder(), req) 34 | 35 | if tt.cond && !called { 36 | t.Errorf("UseIf() = %v, want %v", called, tt.cond) 37 | } 38 | 39 | if !tt.cond && called { 40 | t.Errorf("UseIf() = %v, want %v", called, tt.cond) 41 | } 42 | }) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /realip_cloudflare.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | // 5 | // THIS FILE IS AUTO-GENERATED. DO NOT EDIT. 6 | 7 | package chix 8 | 9 | import ( 10 | "net" 11 | 12 | "github.com/lrstanley/go-bogon" 13 | ) 14 | 15 | // cloudflareRanges returns the list of Cloudflare IP ranges. 16 | func cloudflareRanges() []*net.IPNet { 17 | return []*net.IPNet{ 18 | bogon.MustCIDR("173.245.48.0/20"), 19 | bogon.MustCIDR("103.21.244.0/22"), 20 | bogon.MustCIDR("103.22.200.0/22"), 21 | bogon.MustCIDR("103.31.4.0/22"), 22 | bogon.MustCIDR("141.101.64.0/18"), 23 | bogon.MustCIDR("108.162.192.0/18"), 24 | bogon.MustCIDR("190.93.240.0/20"), 25 | bogon.MustCIDR("188.114.96.0/20"), 26 | bogon.MustCIDR("197.234.240.0/22"), 27 | bogon.MustCIDR("198.41.128.0/17"), 28 | bogon.MustCIDR("162.158.0.0/15"), 29 | bogon.MustCIDR("104.16.0.0/13"), 30 | bogon.MustCIDR("104.24.0.0/14"), 31 | bogon.MustCIDR("172.64.0.0/13"), 32 | bogon.MustCIDR("131.0.72.0/22"), 33 | bogon.MustCIDR("2400:cb00::/32"), 34 | bogon.MustCIDR("2606:4700::/32"), 35 | bogon.MustCIDR("2803:f800::/32"), 36 | bogon.MustCIDR("2405:b500::/32"), 37 | bogon.MustCIDR("2405:8100::/32"), 38 | bogon.MustCIDR("2a06:98c0::/29"), 39 | bogon.MustCIDR("2c0f:f248::/32"), 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /render.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "bytes" 9 | "encoding/json" 10 | "net/http" 11 | "strconv" 12 | 13 | "github.com/apex/log" 14 | ) 15 | 16 | // M is a convenience alias for quickly building a map structure that is going 17 | // out to a responder. Just a short-hand. 18 | type M map[string]any 19 | 20 | // Fields satisfies the log.Fielder interface. 21 | func (m M) Fields() (f log.Fields) { 22 | if m == nil { 23 | return nil 24 | } 25 | 26 | f = make(log.Fields) 27 | for k, v := range m { 28 | f[k] = v 29 | } 30 | 31 | return f 32 | } 33 | 34 | // JSON marshals 'v' to JSON, and setting the Content-Type as application/json. 35 | // Note that this does NOT auto-escape HTML. 36 | // 37 | // JSON also supports prettification when the origin request has "?pretty=true" 38 | // or similar. 39 | func JSON(w http.ResponseWriter, r *http.Request, status int, v any) { 40 | buf := &bytes.Buffer{} 41 | enc := json.NewEncoder(buf) 42 | 43 | if pretty, _ := strconv.ParseBool(r.FormValue("pretty")); pretty { 44 | enc.SetIndent("", " ") 45 | } 46 | 47 | if err := enc.Encode(v); err != nil { 48 | panic(err) 49 | } 50 | 51 | w.Header().Set("Content-Type", "application/json") 52 | w.WriteHeader(status) 53 | _, _ = w.Write(buf.Bytes()) 54 | } 55 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/lrstanley/chix 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.2 6 | 7 | require ( 8 | github.com/apex/log v1.9.0 9 | github.com/go-chi/chi/v5 v5.2.2 10 | github.com/go-playground/form/v4 v4.2.1 11 | github.com/go-playground/validator/v10 v10.26.0 12 | github.com/gorilla/sessions v1.4.0 13 | github.com/lrstanley/go-bogon v1.0.0 14 | github.com/markbates/goth v1.81.0 15 | github.com/prometheus/client_golang v1.22.0 16 | golang.org/x/sync v0.15.0 17 | ) 18 | 19 | require ( 20 | github.com/beorn7/perks v1.0.1 // indirect 21 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 22 | github.com/gabriel-vasile/mimetype v1.4.9 // indirect 23 | github.com/go-playground/locales v0.14.1 // indirect 24 | github.com/go-playground/universal-translator v0.18.1 // indirect 25 | github.com/gorilla/mux v1.8.1 // indirect 26 | github.com/gorilla/securecookie v1.1.2 // indirect 27 | github.com/leodido/go-urn v1.4.0 // indirect 28 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 29 | github.com/pkg/errors v0.9.1 // indirect 30 | github.com/prometheus/client_model v0.6.2 // indirect 31 | github.com/prometheus/common v0.65.0 // indirect 32 | github.com/prometheus/procfs v0.16.1 // indirect 33 | golang.org/x/crypto v0.39.0 // indirect 34 | golang.org/x/net v0.41.0 // indirect 35 | golang.org/x/oauth2 v0.30.0 // indirect 36 | golang.org/x/sys v0.33.0 // indirect 37 | golang.org/x/text v0.26.0 // indirect 38 | google.golang.org/protobuf v1.36.6 // indirect 39 | ) 40 | -------------------------------------------------------------------------------- /recoverer_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "errors" 9 | "net/http" 10 | "net/http/httptest" 11 | "testing" 12 | ) 13 | 14 | func TestUseRecoverer(t *testing.T) { 15 | req := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody) 16 | 17 | handler := UseRecoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 18 | panic("testing panic") 19 | })) 20 | 21 | rec := httptest.NewRecorder() 22 | handler.ServeHTTP(rec, req) 23 | 24 | if rec.Result().StatusCode != http.StatusInternalServerError { 25 | t.Fatalf("expected status code %d, got %d", http.StatusInternalServerError, rec.Result().StatusCode) 26 | } 27 | } 28 | 29 | func TestUseRecovererAbort(t *testing.T) { 30 | defer func() { 31 | if rvr := recover(); rvr != nil { 32 | if e, ok := rvr.(error); ok && errors.Is(e, http.ErrAbortHandler) { 33 | return 34 | } 35 | t.Fatalf("expected panic of type http.ErrAbortHandler, got %v", rvr) 36 | } else { 37 | t.Fatalf("expected panic of type http.ErrAbortHandler, got nil") 38 | } 39 | }() 40 | 41 | req := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody) 42 | 43 | handler := UseRecoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 44 | panic(http.ErrAbortHandler) 45 | })) 46 | 47 | handler.ServeHTTP(httptest.NewRecorder(), req) 48 | } 49 | -------------------------------------------------------------------------------- /recoverer.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "net/http" 11 | "runtime/debug" 12 | 13 | "github.com/go-chi/chi/v5/middleware" 14 | ) 15 | 16 | // Deprecated: Recoverer is deprecated, and will be removed in a future release. 17 | // Please use UseRecoverer instead. 18 | func Recoverer(next http.Handler) http.Handler { 19 | return UseRecoverer(next) 20 | } 21 | 22 | // UseRecoverer is a middleware that recovers from panics, and returns a chix.Error 23 | // with HTTP 500 status (Internal Server Error) if possible. If debug is enabled, 24 | // through UseDebug(), a stack trace will be printed to stderr, otherwise to 25 | // standard structured logging. 26 | // 27 | // NOTE: This middleware should be loaded after logging/request-id/use-debug, etc 28 | // middleware, but before the handlers that may panic. 29 | func UseRecoverer(next http.Handler) http.Handler { 30 | fn := func(w http.ResponseWriter, r *http.Request) { 31 | defer func() { 32 | if rvr := recover(); rvr != nil { 33 | if e, ok := rvr.(error); ok && errors.Is(e, http.ErrAbortHandler) { 34 | panic(rvr) 35 | } 36 | 37 | err := fmt.Errorf("panic recovered: %v", rvr) 38 | if IsDebug(r) { 39 | middleware.PrintPrettyStack(rvr) 40 | } else { 41 | Log(r).WithError(err).Error(string(debug.Stack())) 42 | } 43 | 44 | ErrorCode(w, r, http.StatusInternalServerError, err) 45 | } 46 | }() 47 | 48 | next.ServeHTTP(w, r) 49 | } 50 | 51 | return http.HandlerFunc(fn) 52 | } 53 | -------------------------------------------------------------------------------- /prometheus.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "net/http" 9 | "strconv" 10 | "time" 11 | 12 | "github.com/go-chi/chi/v5" 13 | "github.com/go-chi/chi/v5/middleware" 14 | "github.com/prometheus/client_golang/prometheus" 15 | "github.com/prometheus/client_golang/prometheus/promauto" 16 | ) 17 | 18 | var ( 19 | metricHTTPDuration = promauto.NewHistogramVec( 20 | prometheus.HistogramOpts{ 21 | Name: "http_duration_seconds", 22 | Help: "HTTP request latencies in seconds.", 23 | }, 24 | []string{"method", "path", "status"}, 25 | ) 26 | metricHTTPCount = promauto.NewCounterVec( 27 | prometheus.CounterOpts{ 28 | Name: "http_requests_total", 29 | Help: "Total number of HTTP requests made.", 30 | }, 31 | []string{"method", "path", "status"}, 32 | ) 33 | metricHTTPBytes = promauto.NewCounterVec( 34 | prometheus.CounterOpts{ 35 | Name: "http_response_bytes_total", 36 | Help: "Total number of bytes sent in response to HTTP requests.", 37 | }, 38 | []string{"method", "path", "status"}, 39 | ) 40 | ) 41 | 42 | func UsePrometheus(next http.Handler) http.Handler { 43 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 44 | wrappedWriter := middleware.NewWrapResponseWriter(w, r.ProtoMajor) 45 | 46 | start := time.Now() 47 | next.ServeHTTP(wrappedWriter, r) 48 | elapsed := time.Since(start) 49 | 50 | rctx := chi.RouteContext(r.Context()) 51 | 52 | labels := prometheus.Labels{ 53 | "method": r.Method, 54 | "path": rctx.RoutePattern(), 55 | "status": strconv.Itoa(wrappedWriter.Status()), 56 | } 57 | 58 | metricHTTPDuration.With(labels).Observe(elapsed.Seconds()) 59 | metricHTTPCount.With(labels).Inc() 60 | metricHTTPBytes.With(labels).Add(float64(wrappedWriter.BytesWritten())) 61 | }) 62 | } 63 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 7 | 8 | ## 🚀 Changes proposed by this PR 9 | 10 | 14 | 15 | 16 | ### 🔗 Related bug reports/feature requests 17 | 18 | 22 | - fixes #(issue) 23 | - closes #(issue) 24 | - relates to #(issue) 25 | - implements #(feature) 26 | 27 | ### 🧰 Type of change 28 | 29 | 30 | - [ ] Bug fix (non-breaking change which fixes an issue). 31 | - [ ] New feature (non-breaking change which adds functionality). 32 | - [ ] Breaking change (fix or feature that causes existing functionality to not work as expected). 33 | - [ ] This change requires (or is) a documentation update. 34 | 35 | ### 📝 Notes to reviewer 36 | 37 | 41 | 42 | ### 🤝 Requirements 43 | 44 | - [ ] ✍ I have read and agree to this projects [Code of Conduct](../../blob/master/.github/CODE_OF_CONDUCT.md). 45 | - [ ] ✍ I have read and agree to this projects [Contribution Guidelines](../../blob/master/.github/CONTRIBUTING.md). 46 | - [ ] ✍ I have read and agree to the [Developer Certificate of Origin](https://developercertificate.org/). 47 | - [ ] 🔎 I have performed a self-review of my own changes. 48 | - [ ] 🎨 My changes follow the style guidelines of this project. 49 | 50 | - [ ] 💬 My changes as properly commented, primarily for hard-to-understand areas. 51 | - [ ] 📝 I have made corresponding changes to the documentation. 52 | - [ ] 🧪 I have included tests (if necessary) for this change. 53 | -------------------------------------------------------------------------------- /.github/SUPPORT.md: -------------------------------------------------------------------------------- 1 | # :raising_hand_man: Support 2 | 3 | This document explains where and how to get help with most of my projects. 4 | Please ensure you read through it thoroughly. 5 | 6 | > :point_right: **Note**: before participating in the community, please read our 7 | > [Code of Conduct][coc]. 8 | > By interacting with this repository, organization, or community you agree to 9 | > abide by its terms. 10 | 11 | ## :grey_question: Asking quality questions 12 | 13 | Questions can go to [Github Discussions][discussions] or feel free to join 14 | the Discord [here][chat]. 15 | 16 | Help me help you! Spend time framing questions and add links and resources. 17 | Spending the extra time up front can help save everyone time in the long run. 18 | Here are some tips: 19 | 20 | * Don't fall for the [XY problem][xy]. 21 | * Search to find out if a similar question has been asked or if a similar 22 | issue/bug has been reported. 23 | * Try to define what you need help with: 24 | * Is there something in particular you want to do? 25 | * What problem are you encountering and what steps have you taken to try 26 | and fix it? 27 | * Is there a concept you don't understand? 28 | * Provide sample code, such as a [CodeSandbox][cs] or a simple snippet, if 29 | possible. 30 | * Screenshots can help, but if there's important text such as code or error 31 | messages in them, please also provide those. 32 | * The more time you put into asking your question, the better I and others 33 | can help you. 34 | 35 | ## :old_key: Security 36 | 37 | For any security or vulnerability related disclosure, please follow the 38 | guidelines outlined in our [security policy][security]. 39 | 40 | ## :handshake: Contributions 41 | 42 | See [`CONTRIBUTING.md`][contributing] on how to contribute. 43 | 44 | 45 | [coc]: https://github.com/lrstanley/chix/blob/master/.github/CODE_OF_CONDUCT.md 46 | [contributing]: https://github.com/lrstanley/chix/blob/master/.github/CONTRIBUTING.md 47 | [discussions]: https://github.com/lrstanley/chix/discussions/categories/q-a 48 | [issues]: https://github.com/lrstanley/chix/issues/new/choose 49 | [license]: https://github.com/lrstanley/chix/blob/master/LICENSE 50 | [pull-requests]: https://github.com/lrstanley/chix/issues/new/choose 51 | [security]: https://github.com/lrstanley/chix/security/policy 52 | [support]: https://github.com/lrstanley/chix/blob/master/.github/SUPPORT.md 53 | 54 | [xy]: https://meta.stackexchange.com/questions/66377/what-is-the-xy-problem/66378#66378 55 | [chat]: https://liam.sh/chat 56 | [cs]: https://codesandbox.io 57 | -------------------------------------------------------------------------------- /api.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "net/http" 9 | ) 10 | 11 | // Deprecated: APIVersionMatch is deprecated, and will be removed in a future release. 12 | // Please use UseAPIVersionMatch instead. 13 | func APIVersionMatch(version string) func(next http.Handler) http.Handler { 14 | return UseAPIVersionMatch(version) 15 | } 16 | 17 | // UseAPIVersionMatch is a middleware that checks if the request has the correct 18 | // API version provided in the DefaultAPIVersionHeader. 19 | func UseAPIVersionMatch(version string) func(next http.Handler) http.Handler { 20 | return func(next http.Handler) http.Handler { 21 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 22 | clientVersion := r.Header.Get(DefaultAPIVersionHeader) 23 | if clientVersion == "" { 24 | _ = Error(w, r, WrapError(ErrAPIVersionMissing, http.StatusPreconditionFailed)) 25 | return 26 | } 27 | 28 | if clientVersion != version { 29 | _ = Error(w, r, WrapError(ErrAPIVersionMismatch, http.StatusPreconditionFailed)) 30 | return 31 | } 32 | 33 | next.ServeHTTP(w, r) 34 | }) 35 | } 36 | } 37 | 38 | var ( 39 | // DefaultAPIVersionHeader is the default header name for the API version. 40 | DefaultAPIVersionHeader = "X-Api-Version" 41 | 42 | // DefaultErrorHandler is the default header where we should look for the 43 | // API key. 44 | DefaultAPIKeyHeader = "X-Api-Key" //nolint:gosec 45 | 46 | // DefaultAPIPrefix is the default prefix for your API. Set to an empty 47 | // string to disable checks that change depending on if the request has 48 | // the provided prefix. 49 | DefaultAPIPrefix = "/api/" 50 | ) 51 | 52 | // UseAPIKeyRequired is a middleware that checks if the request has the correct 53 | // API keys provided in the DefaultAPIKeyHeader header. Panics if no keys 54 | // are provided. Returns http.StatusUnauthorized if an invalid key is provided, 55 | // and http.StatusPreconditionFailed if no key header is provided. 56 | func UseAPIKeyRequired(keys []string) func(next http.Handler) http.Handler { 57 | if len(keys) == 0 { 58 | panic(ErrNoAPIKeys) 59 | } 60 | 61 | return func(next http.Handler) http.Handler { 62 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 63 | providedKey := r.Header.Get(DefaultAPIKeyHeader) 64 | 65 | for _, key := range keys { 66 | if providedKey == key { 67 | next.ServeHTTP(w, r) 68 | return 69 | } 70 | } 71 | 72 | if providedKey == "" { 73 | _ = Error(w, r, WrapError(ErrAPIKeyMissing, http.StatusPreconditionFailed)) 74 | return 75 | } 76 | 77 | Error(w, r, WrapError(ErrAPIKeyInvalid, http.StatusUnauthorized)) 78 | }) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /codegen/main.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "bufio" 9 | "bytes" 10 | "context" 11 | _ "embed" 12 | "errors" 13 | "fmt" 14 | "go/format" 15 | "net" 16 | "net/http" 17 | "os" 18 | "slices" 19 | "text/template" 20 | "time" 21 | ) 22 | 23 | const cloudflareURI = `https://www.cloudflare.com/ips-%s` 24 | 25 | //go:embed realip_cloudflare.tmpl 26 | var cloudflareTmpl string 27 | 28 | func writeTemplatedGoFile(path string, tmpl *template.Template, data any) error { 29 | f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644) 30 | if err != nil { 31 | return errors.New("error opening file: " + err.Error()) 32 | } 33 | defer f.Close() 34 | 35 | buf := &bytes.Buffer{} 36 | 37 | err = tmpl.Execute(buf, data) 38 | if err != nil { 39 | return errors.New("error executing template: " + err.Error()) 40 | } 41 | 42 | fmtd, err := format.Source(buf.Bytes()) 43 | if err != nil { 44 | return errors.New("error formatting source: " + err.Error()) 45 | } 46 | 47 | _, err = f.Write(fmtd) 48 | return err 49 | } 50 | 51 | func main() { 52 | var cidrs []*net.IPNet 53 | 54 | ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) 55 | defer cancelFn() 56 | 57 | for _, version := range []string{"v4", "v6"} { 58 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf(cloudflareURI, version), http.NoBody) 59 | if err != nil { 60 | panic(err) 61 | } 62 | 63 | resp, err := http.DefaultClient.Do(req) 64 | if err != nil { 65 | panic(err) 66 | } 67 | 68 | if resp.StatusCode != http.StatusOK { 69 | resp.Body.Close() 70 | panic(fmt.Errorf("unexpected status code: %d", resp.StatusCode)) 71 | } 72 | 73 | scan := bufio.NewScanner(resp.Body) 74 | 75 | var cidr *net.IPNet 76 | 77 | for scan.Scan() { 78 | _, cidr, err = net.ParseCIDR(scan.Text()) 79 | if err != nil { 80 | resp.Body.Close() 81 | panic(err) 82 | } 83 | 84 | // Make sure it doesn't already exist in the list. 85 | if slices.ContainsFunc(cidrs, func(other *net.IPNet) bool { 86 | return cidr.String() == other.String() 87 | }) { 88 | continue 89 | } 90 | 91 | cidrs = append(cidrs, cidr) 92 | fmt.Printf("found CIDR: %s\n", cidr) //nolint:forbidigo 93 | } 94 | _ = resp.Body.Close() 95 | } 96 | 97 | if len(cidrs) < 10 { 98 | panic(fmt.Errorf("found %d CIDRs, but expected at least 10", len(cidrs))) 99 | } 100 | 101 | fmt.Printf("found %d CIDRs\n", len(cidrs)) //nolint:forbidigo 102 | 103 | err := writeTemplatedGoFile( 104 | "realip_cloudflare.go", 105 | template.Must(template.New(".").Parse(cloudflareTmpl)), 106 | cidrs, 107 | ) 108 | if err != nil { 109 | panic(err) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /render_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "encoding/json" 9 | "io" 10 | "net/http" 11 | "net/http/httptest" 12 | "reflect" 13 | "testing" 14 | ) 15 | 16 | // testJSONMarshalEqual is a helper function to test if the resulting http body equals 17 | // the input data when marshalled and unmarshalled. 18 | func testJSONMarshalEqual[T any](t *testing.T, input T, body io.ReadCloser, shouldEqual bool) (ok bool) { 19 | var in, out any 20 | 21 | inBytes, err := json.Marshal(input) 22 | if err != nil { 23 | t.Errorf("error marshaling input data: %v", err) 24 | return false 25 | } 26 | 27 | if err = json.Unmarshal(inBytes, &in); err != nil { 28 | t.Errorf("error unmarshaling input data: %v", err) 29 | return false 30 | } 31 | 32 | dec := json.NewDecoder(body) 33 | if err = dec.Decode(&out); err != nil { 34 | t.Errorf("error decoding response body: %v", err) 35 | return false 36 | } 37 | 38 | if reflect.DeepEqual(in, out) && !shouldEqual { 39 | t.Errorf("expected %#v to not equal %#v", in, out) 40 | return false 41 | } else if !reflect.DeepEqual(in, out) && shouldEqual { 42 | t.Errorf("expected %#v to equal %#v", in, out) 43 | return false 44 | } 45 | 46 | return true 47 | } 48 | 49 | func TestJSON(t *testing.T) { 50 | tests := []struct { 51 | name string 52 | data any 53 | headers map[string]string 54 | statusCode int 55 | }{ 56 | { 57 | name: "empty", 58 | data: M{}, 59 | headers: map[string]string{"Content-Type": "application/json"}, 60 | statusCode: http.StatusOK, 61 | }, 62 | { 63 | name: "base-object", 64 | data: M{"foo": "bar"}, 65 | headers: map[string]string{"Content-Type": "application/json"}, 66 | statusCode: http.StatusOK, 67 | }, 68 | { 69 | name: "base-array", 70 | data: []string{"foo", "bar"}, 71 | headers: map[string]string{"Content-Type": "application/json"}, 72 | statusCode: http.StatusOK, 73 | }, 74 | } 75 | 76 | for _, tt := range tests { 77 | t.Run(tt.name, func(t *testing.T) { 78 | req := httptest.NewRequest(http.MethodGet, "http://example.com/?pretty=true", http.NoBody) 79 | 80 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 81 | JSON(w, r, tt.statusCode, tt.data) 82 | }) 83 | 84 | rec := httptest.NewRecorder() 85 | handler.ServeHTTP(rec, req) 86 | resp := rec.Result() 87 | 88 | if resp.StatusCode != tt.statusCode { 89 | t.Errorf("expected status code %d, got %d", tt.statusCode, resp.StatusCode) 90 | } 91 | 92 | for k, v := range tt.headers { 93 | if resp.Header.Get(k) != v { 94 | t.Errorf("expected header %s to be %s, got %s", k, v, resp.Header.Get(k)) 95 | } 96 | } 97 | 98 | _ = testJSONMarshalEqual(t, tt.data, resp.Body, true) 99 | }) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /.github/SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | # :old_key: Security Policy 3 | 4 | ## :heavy_check_mark: Supported Versions 5 | 6 | The following restrictions apply for versions that are still supported in terms of security and bug fixes: 7 | 8 | * :grey_question: Must be using the latest major/minor version. 9 | * :grey_question: Must be using a supported platform for the repository (e.g. OS, browser, etc), and that platform must 10 | be within its supported versions (for example: don't use a legacy or unsupported version of Ubuntu or 11 | Google Chrome). 12 | * :grey_question: Repository must not be archived (unless the vulnerability is critical, and the repository moderately 13 | popular). 14 | * :heavy_check_mark: 15 | 16 | If one of the above doesn't apply to you, feel free to submit an issue and we can discuss the 17 | issue/vulnerability further. 18 | 19 | 20 | ## :lady_beetle: Reporting a Vulnerability 21 | 22 | Best method of contact: [GPG :key:](https://github.com/lrstanley.gpg) 23 | 24 | * :speech_balloon: [Discord][chat]: message `lrstanley` (`/home/liam#0000`). 25 | * :email: Email: `security@liam.sh` 26 | 27 | Backup contacts (if I am unresponsive after **48h**): [GPG :key:](https://github.com/FM1337.gpg) 28 | * :speech_balloon: [Discord][chat]: message `Allen#7440`. 29 | * :email: Email: `security@allenlydiard.ca` 30 | 31 | If you feel that this disclosure doesn't include a critical vulnerability and there is no sensitive 32 | information in the disclosure, you don't have to use the GPG key. For all other situations, please 33 | use it. 34 | 35 | ### :stopwatch: Vulnerability disclosure expectations 36 | 37 | * :no_bell: We expect you to not share this information with others, unless: 38 | * The maximum timeline for initial response has been exceeded (shown below). 39 | * The maximum resolution time has been exceeded (shown below). 40 | * :mag_right: We expect you to responsibly investigate this vulnerability -- please do not utilize the 41 | vulnerability beyond the initial findings. 42 | * :stopwatch: Initial response within 48h, however, if the primary contact shown above is unavailable, please 43 | use the backup contacts provided. The maximum timeline for an initial response should be within 44 | 7 days. 45 | * :stopwatch: Depending on the severity of the disclosure, resolution time may be anywhere from 24h to 2 46 | weeks after initial response, though in most cases it will likely be closer to the former. 47 | * If the vulnerability is very low/low in terms of risk, the above timelines **will not apply**. 48 | * :toolbox: Before the release of resolved versions, a [GitHub Security Advisory][advisory-docs]. 49 | will be released on the respective repository. [Browser all advisories here][advisory]. 50 | 51 | 52 | [chat]: https://liam.sh/chat 53 | [advisory]: https://github.com/advisories?query=type%3Areviewed+ecosystem%3Ago 54 | [advisory-docs]: https://docs.github.com/en/code-security/repository-security-advisories/creating-a-repository-security-advisory 55 | -------------------------------------------------------------------------------- /logger.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "net/http" 9 | "sync/atomic" 10 | "time" 11 | 12 | "github.com/apex/log" 13 | "github.com/go-chi/chi/v5/middleware" 14 | ) 15 | 16 | // LogHandler is a function type that can be used to add any additional 17 | // custom fields to a request log entry. 18 | type LogHandler func(r *http.Request) M 19 | 20 | var logHandlers atomic.Value // []LogHandler 21 | 22 | // AddLogHandler can be used to inject additional metadata/fields into the 23 | // log context. Use this to add things like authentication information, or 24 | // similar, to the log entry. 25 | // 26 | // NOTE: the request context will only include entries that were registered 27 | // in the request context prior to the structured logger being loaded. 28 | func AddLogHandler(h LogHandler) { 29 | handlers, ok := logHandlers.Load().([]LogHandler) 30 | if !ok { 31 | handlers = []LogHandler{} 32 | } 33 | 34 | handlers = append(handlers, h) 35 | logHandlers.Store(handlers) 36 | } 37 | 38 | // UseStructuredLogger wraps each request and writes a log entry with 39 | // extra info. UseStructuredLogger also injects a logger into the request 40 | // context that can be used by children middleware business logic. 41 | func UseStructuredLogger(logger log.Interface) func(next http.Handler) http.Handler { 42 | return func(next http.Handler) http.Handler { 43 | fn := func(w http.ResponseWriter, r *http.Request) { 44 | bfields := log.Fields{} 45 | bfields["src"] = "http" 46 | 47 | // RequestID middleware must be loaded before this is loaded into 48 | // the chain. 49 | if id := middleware.GetReqID(r.Context()); id != "" { 50 | bfields["rid"] = id 51 | } 52 | 53 | if ray := r.Header.Get("Cf-Ray"); ray != "" { 54 | bfields["ray_id"] = ray 55 | } 56 | 57 | if country := r.Header.Get("Cf-Ipcountry"); country != "" { 58 | bfields["country"] = country 59 | } 60 | 61 | wrappedWriter := middleware.NewWrapResponseWriter(w, r.ProtoMajor) 62 | 63 | bfields["ip"] = r.RemoteAddr 64 | bfields["host"] = r.Host 65 | bfields["proto"] = r.Proto 66 | bfields["method"] = r.Method 67 | bfields["ua"] = r.Header.Get("User-Agent") 68 | bfields["bytes_in"] = r.Header.Get("Content-Length") 69 | 70 | logEntry := logger.WithFields(bfields) 71 | start := time.Now() 72 | defer func() { 73 | finish := time.Since(start) 74 | 75 | // If log handlers were provided, and they returned a map, 76 | // then we'll use that to add additional fields to the log 77 | // context. 78 | if handlers, ok := logHandlers.Load().([]LogHandler); ok { 79 | var fields M 80 | for _, fn := range handlers { 81 | if fields = fn(r); fields != nil { 82 | logEntry = logEntry.WithFields(fields) 83 | } 84 | } 85 | } 86 | 87 | logEntry.WithFields(log.Fields{ 88 | "code": wrappedWriter.Status(), 89 | "duration_ms": finish.Milliseconds(), 90 | "bytes_out": wrappedWriter.BytesWritten(), 91 | }).Info(r.URL.RequestURI()) 92 | }() 93 | 94 | next.ServeHTTP(wrappedWriter, r.WithContext(log.NewContext(r.Context(), logEntry))) 95 | } 96 | 97 | return http.HandlerFunc(fn) 98 | } 99 | } 100 | 101 | // Log is a helper for obtaining the structured logger from the request 102 | // context. 103 | func Log(r *http.Request) log.Interface { 104 | return log.FromContext(r.Context()) 105 | } 106 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | # THIS FILE IS GENERATED! DO NOT EDIT! Maintained by Terraform. 2 | name: "🐞 Submit a bug report" 3 | description: Create a report to help us improve! 4 | title: "bug: [REPLACE ME]" 5 | labels: 6 | - bug 7 | body: 8 | - type: markdown 9 | attributes: 10 | value: | 11 | ### Thanks for submitting a bug report to **chix**! 📋 12 | 13 | - 💬 Make sure to check out the [**discussions**](../discussions) section. If your issue isn't a bug (or you're not sure), and you're looking for help to solve it, please [start a discussion here](../discussions/new?category=q-a) first. 14 | - 🔎 Please [**search**](../labels/bug) to see if someone else has submitted a similar bug report, before making a new report. 15 | 16 | ---------------------------------------- 17 | - type: textarea 18 | id: description 19 | attributes: 20 | label: "🌧 Describe the problem" 21 | description: A clear and concise description of what the problem is. 22 | placeholder: 'Example: "When I attempted to do X, I got X error"' 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: expected 27 | attributes: 28 | label: "⛅ Expected behavior" 29 | description: A clear and concise description of what you expected to happen. 30 | placeholder: 'Example: "I expected X to let me add Y component, and be successful"' 31 | validations: 32 | required: true 33 | - type: textarea 34 | id: reproduce 35 | attributes: 36 | label: "🔄 Minimal reproduction" 37 | description: >- 38 | Steps to reproduce the behavior (including code examples and/or 39 | configuration files if necessary) 40 | placeholder: >- 41 | Example: "1. Click on '....' | 2. Run command with flags --foo --bar, 42 | etc | 3. See error" 43 | - type: input 44 | id: version 45 | attributes: 46 | label: "💠 Version: chix" 47 | description: What version of chix is being used? 48 | placeholder: 'Examples: "v1.2.3, master branch, commit 1a2b3c"' 49 | validations: 50 | required: true 51 | - type: dropdown 52 | id: os 53 | attributes: 54 | label: "🖥 Version: Operating system" 55 | description: >- 56 | What operating system did this issue occur on (if other, specify in 57 | "Additional context" section)? 58 | options: 59 | - linux/ubuntu 60 | - linux/debian 61 | - linux/centos 62 | - linux/alpine 63 | - linux/other 64 | - windows/10 65 | - windows/11 66 | - windows/other 67 | - macos 68 | - other 69 | validations: 70 | required: true 71 | - type: textarea 72 | id: context 73 | attributes: 74 | label: "⚙ Additional context" 75 | description: >- 76 | Add any other context about the problem here. This includes things 77 | like logs, screenshots, code examples, what was the state when the 78 | bug occurred? 79 | placeholder: > 80 | Examples: "logs, code snippets, screenshots, os/browser version info, 81 | etc" 82 | - type: checkboxes 83 | id: requirements 84 | attributes: 85 | label: "🤝 Requirements" 86 | description: "Please confirm the following:" 87 | options: 88 | - label: >- 89 | I believe the problem I'm facing is a bug, and is not intended 90 | behavior. [Post here if you're not sure](../discussions/new?category=q-a). 91 | required: true 92 | - label: >- 93 | I have confirmed that someone else has not 94 | [submitted a similar bug report](../labels/bug). 95 | required: true 96 | -------------------------------------------------------------------------------- /redirect.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "context" 9 | "net/http" 10 | "net/url" 11 | "strings" 12 | ) 13 | 14 | const nextURLExpiration = 86400 // 1 day. 15 | 16 | // UseNextURL is a middleware that will store the current URL provided via 17 | // the "next" query parameter, as a cookie in the response, for use with 18 | // multi-step authentication flows. This allows the user to be redirected 19 | // back to the original destination after authentication. Must use 20 | // chix.SecureRedirect to redirect the user, which will pick up the url from 21 | // the cookie. 22 | func UseNextURL(next http.Handler) http.Handler { 23 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 24 | if n := r.URL.Query().Get("next"); n != "" { 25 | host := r.Host 26 | if i := strings.Index(host, ":"); i > -1 { 27 | host = host[:i] 28 | } 29 | 30 | http.SetCookie(w, &http.Cookie{ 31 | Name: nextSessionKey, 32 | Value: n, 33 | Path: "/", 34 | MaxAge: nextURLExpiration, 35 | Domain: host, 36 | HttpOnly: true, 37 | }) 38 | } 39 | next.ServeHTTP(w, r) 40 | }) 41 | } 42 | 43 | // SecureRedirect supports validating that redirect requests fulfill the the 44 | // following conditions: 45 | // - Target URL must match one of: 46 | // - Absolute/relative to the same host 47 | // - http or https, with a host matching the requested host (no cross-domain, no port matching). 48 | // - Target URL can be parsed by url.Parse(). 49 | // 50 | // Additionally, if using chix.UseNextURL middleware, and the current session 51 | // has a "next" URL stored, the redirect will be to that URL. This allows 52 | // a multi-step authentication flow to be completed, then redirected to the 53 | // original destination. 54 | func SecureRedirect(w http.ResponseWriter, r *http.Request, status int, fallback string) { 55 | target := fallback 56 | 57 | if skip := r.Context().Value(contextSkipNextURL); skip == nil { 58 | n, err := r.Cookie(nextSessionKey) 59 | if err == nil && n.Value != "" { 60 | target = n.Value 61 | } 62 | } 63 | 64 | next, err := url.Parse(target) 65 | if err != nil { 66 | http.Redirect(w, r, "/", http.StatusTemporaryRedirect) 67 | return 68 | } 69 | 70 | if next.Scheme == "" && next.Host == "" { 71 | http.Redirect(w, r, next.String(), status) 72 | return 73 | } 74 | 75 | if next.Scheme != "http" && next.Scheme != "https" { 76 | http.Redirect(w, r, "/", http.StatusTemporaryRedirect) 77 | return 78 | } 79 | 80 | reqHost := r.Host 81 | if i := strings.Index(reqHost, ":"); i > -1 { 82 | reqHost = reqHost[:i] 83 | } 84 | 85 | nextHost := next.Host 86 | if i := strings.Index(nextHost, ":"); i > -1 { 87 | nextHost = nextHost[:i] 88 | } 89 | 90 | if !strings.EqualFold(reqHost, nextHost) { 91 | http.Redirect(w, r, "/", http.StatusTemporaryRedirect) 92 | return 93 | } 94 | 95 | http.SetCookie(w, &http.Cookie{ 96 | Name: nextSessionKey, 97 | Path: "/", 98 | MaxAge: -1, 99 | Domain: reqHost, 100 | HttpOnly: true, 101 | }) 102 | 103 | http.Redirect(w, r, next.String(), status) 104 | } 105 | 106 | // SkipNextURL is a middleware that will prevent the next URL (if any), that 107 | // is loaded by chix.UseNextURL() from being used during a redirect. This is 108 | // useful when you have to redirect to another source first. 109 | func SkipNextURL(r *http.Request) *http.Request { 110 | return r.WithContext(context.WithValue(r.Context(), contextSkipNextURL, true)) 111 | } 112 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | # THIS FILE IS GENERATED! DO NOT EDIT! Maintained by Terraform. 2 | name: "💡 Submit a feature request" 3 | description: Suggest an awesome feature for this project! 4 | title: "feature: [REPLACE ME]" 5 | labels: 6 | - enhancement 7 | body: 8 | - type: markdown 9 | attributes: 10 | value: | 11 | ### Thanks for submitting a feature request! 📋 12 | 13 | - 💬 Make sure to check out the [**discussions**](../discussions) section of this repository. Do you have an idea for an improvement, but want to brainstorm it with others first? [Start a discussion here](../discussions/new?category=ideas) first. 14 | - 🔎 Please [**search**](../labels/enhancement) to see if someone else has submitted a similar feature request, before making a new request. 15 | 16 | --------------------------------------------- 17 | - type: textarea 18 | id: describe 19 | attributes: 20 | label: "✨ Describe the feature you'd like" 21 | description: >- 22 | A clear and concise description of what you want to happen, or what 23 | feature you'd like added. 24 | placeholder: 'Example: "It would be cool if X had support for Y"' 25 | validations: 26 | required: true 27 | - type: textarea 28 | id: related 29 | attributes: 30 | label: "🌧 Is your feature request related to a problem?" 31 | description: >- 32 | A clear and concise description of what the problem is. 33 | placeholder: >- 34 | Example: "I'd like to see X feature added, as I frequently have to do Y, 35 | and I think Z would solve that problem" 36 | - type: textarea 37 | id: alternatives 38 | attributes: 39 | label: "🔎 Describe alternatives you've considered" 40 | description: >- 41 | A clear and concise description of any alternative solutions or features 42 | you've considered. 43 | placeholder: >- 44 | Example: "I've considered X and Y, however the potential problems with 45 | those solutions would be [...]" 46 | validations: 47 | required: true 48 | - type: dropdown 49 | id: breaking 50 | attributes: 51 | label: "⚠ If implemented, do you think this feature will be a breaking change to users?" 52 | description: >- 53 | To the best of your ability, do you think implementing this change 54 | would impact users in a way during an upgrade process? 55 | options: 56 | - "Yes" 57 | - "No" 58 | - "Not sure" 59 | validations: 60 | required: true 61 | - type: textarea 62 | id: context 63 | attributes: 64 | label: "⚙ Additional context" 65 | description: >- 66 | Add any other context or screenshots about the feature request here 67 | (attach if necessary). 68 | placeholder: "Examples: logs, screenshots, etc" 69 | - type: checkboxes 70 | id: requirements 71 | attributes: 72 | label: "🤝 Requirements" 73 | description: "Please confirm the following:" 74 | options: 75 | - label: >- 76 | I have confirmed that someone else has not 77 | [submitted a similar feature request](../labels/enhancement). 78 | required: true 79 | - label: >- 80 | If implemented, I believe this feature will help others, in 81 | addition to solving my problems. 82 | required: true 83 | - label: I have looked into alternative solutions to the best of my ability. 84 | required: true 85 | - label: >- 86 | (optional) I would be willing to contribute to testing this 87 | feature if implemented, or making a PR to implement this 88 | functionality. 89 | required: false 90 | -------------------------------------------------------------------------------- /api_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | ) 12 | 13 | func TestUseAPIVersionMatch(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | version string 17 | headers map[string]string 18 | ok bool 19 | statusCode int 20 | }{ 21 | { 22 | name: "empty", 23 | version: "v1", 24 | headers: map[string]string{}, 25 | ok: false, 26 | statusCode: http.StatusPreconditionFailed, 27 | }, 28 | { 29 | name: "mismatch", 30 | version: "v1", 31 | headers: map[string]string{ 32 | DefaultAPIVersionHeader: "v2", 33 | }, 34 | ok: false, 35 | statusCode: http.StatusPreconditionFailed, 36 | }, 37 | { 38 | name: "match", 39 | version: "v1", 40 | headers: map[string]string{ 41 | DefaultAPIVersionHeader: "v1", 42 | }, 43 | ok: true, 44 | statusCode: http.StatusOK, 45 | }, 46 | } 47 | for _, tt := range tests { 48 | t.Run(tt.name, func(t *testing.T) { 49 | req := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody) 50 | 51 | for k, v := range tt.headers { 52 | req.Header.Set(k, v) 53 | } 54 | 55 | handler := UseAPIVersionMatch(tt.version)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 56 | if !tt.ok { 57 | t.Error("expected handler to not be invoked, but did") 58 | return 59 | } 60 | 61 | w.WriteHeader(http.StatusOK) 62 | })) 63 | 64 | rec := httptest.NewRecorder() 65 | handler.ServeHTTP(rec, req) 66 | resp := rec.Result() 67 | 68 | if resp.StatusCode != tt.statusCode { 69 | t.Errorf("expected status code %d, got %d", tt.statusCode, resp.StatusCode) 70 | } 71 | }) 72 | } 73 | } 74 | 75 | func TestUseAPIKeyRequired(t *testing.T) { 76 | tests := []struct { 77 | name string 78 | keys []string 79 | headers map[string]string 80 | ok bool 81 | statusCode int 82 | }{ 83 | { 84 | name: "empty", 85 | keys: []string{"R5HKAjpQFKNW4KUHF2M4", "O1YQbFh8x5tTpbg4uVhb"}, 86 | headers: map[string]string{}, 87 | ok: false, 88 | statusCode: http.StatusPreconditionFailed, 89 | }, 90 | { 91 | name: "mismatch", 92 | keys: []string{"R5HKAjpQFKNW4KUHF2M4", "O1YQbFh8x5tTpbg4uVhb"}, 93 | headers: map[string]string{DefaultAPIKeyHeader: "qx7zkX6EiONmslV3uIWH"}, 94 | ok: false, 95 | statusCode: http.StatusUnauthorized, 96 | }, 97 | { 98 | name: "match", 99 | keys: []string{"R5HKAjpQFKNW4KUHF2M4", "O1YQbFh8x5tTpbg4uVhb"}, 100 | headers: map[string]string{DefaultAPIKeyHeader: "R5HKAjpQFKNW4KUHF2M4"}, 101 | ok: true, 102 | statusCode: http.StatusOK, 103 | }, 104 | } 105 | for _, tt := range tests { 106 | t.Run(tt.name, func(t *testing.T) { 107 | req := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody) 108 | 109 | for k, v := range tt.headers { 110 | req.Header.Set(k, v) 111 | } 112 | 113 | handler := UseAPIKeyRequired(tt.keys)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 114 | if !tt.ok { 115 | t.Error("expected handler to not be invoked, but did") 116 | return 117 | } 118 | 119 | w.WriteHeader(http.StatusOK) 120 | })) 121 | 122 | rec := httptest.NewRecorder() 123 | handler.ServeHTTP(rec, req) 124 | resp := rec.Result() 125 | 126 | if resp.StatusCode != tt.statusCode { 127 | t.Errorf("expected status code %d, got %d", tt.statusCode, resp.StatusCode) 128 | } 129 | }) 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /bind.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "encoding/json" 9 | "errors" 10 | "fmt" 11 | "net/http" 12 | "strings" 13 | 14 | "github.com/go-playground/form/v4" 15 | "github.com/go-playground/validator/v10" 16 | ) 17 | 18 | var ( 19 | // DefaultDecoder is the default decoder used by Bind. You can either override 20 | // this, or provide your own. Make sure it is set before Bind is called. 21 | DefaultDecoder = form.NewDecoder() 22 | 23 | // DefaultDecodeMaxMemory is the maximum amount of memory in bytes that will be 24 | // used for decoding multipart/form-data requests. 25 | DefaultDecodeMaxMemory int64 = 8 << 20 26 | 27 | // DefaultValidator is the default validator used by Bind, when the provided 28 | // struct to the Bind() call doesn't implement Validatable. Set this to nil 29 | // to disable validation using go-playground/validator. 30 | DefaultValidator = validator.New() 31 | ) 32 | 33 | // Validatable is an interface that can be implemented by structs to 34 | // provide custom validation logic, on top of the default go-playground/form 35 | // validation. 36 | type Validatable interface { 37 | Validate() error 38 | } 39 | 40 | // Bind decodes the request body to the given struct. Take a look at 41 | // DefaultDecoder to add additional customizations to the default decoder. 42 | // You can add additional customizations by using the Validatable interface, 43 | // with a custom implementation of the Validate() method on v. Alternatively, 44 | // chix also supports the go-playground/validator package, which allows various 45 | // validation methods via struct tags. 46 | // 47 | // At this time the only supported content-types are application/json, 48 | // application/x-www-form-urlencoded, as well as GET parameters. 49 | // 50 | // If validation fails, an error that is wrapped with the necessary status code 51 | // will be returned (can just pass to chix.Error() and it will know the appropriate 52 | // HTTP code to return, and if it should be a JSON body or not). 53 | func Bind(r *http.Request, v any) (err error) { 54 | var rerr error 55 | 56 | if err = r.ParseForm(); err != nil { 57 | rerr = fmt.Errorf("error parsing %s parameters, invalid request", r.Method) 58 | goto handle 59 | } 60 | 61 | switch r.Method { 62 | case http.MethodGet, http.MethodHead: 63 | err = DefaultDecoder.Decode(v, r.Form) 64 | case http.MethodPost, http.MethodPut, http.MethodPatch: 65 | switch { 66 | case strings.HasPrefix(r.Header.Get("Content-Type"), "application/json"): 67 | dec := json.NewDecoder(r.Body) 68 | defer r.Body.Close() 69 | err = dec.Decode(v) 70 | case strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data"): 71 | err = r.ParseMultipartForm(DefaultDecodeMaxMemory) 72 | if err == nil { 73 | err = DefaultDecoder.Decode(v, r.MultipartForm.Value) 74 | } 75 | default: 76 | err = DefaultDecoder.Decode(v, r.PostForm) 77 | } 78 | default: 79 | return WrapError(fmt.Errorf("unsupported method %s", r.Method), http.StatusBadRequest) 80 | } 81 | if err != nil { 82 | rerr = fmt.Errorf("error decoding %s request into required format (%T): validate request parameters", r.Method, v) 83 | } 84 | 85 | handle: 86 | if err != nil { 87 | return WrapError(rerr, http.StatusBadRequest) 88 | } 89 | 90 | if v, ok := v.(Validatable); ok { 91 | if err = v.Validate(); err != nil { 92 | return WrapError(err, http.StatusBadRequest) 93 | } 94 | 95 | return nil 96 | } 97 | 98 | if DefaultValidator != nil { 99 | err = DefaultValidator.Struct(v) 100 | if err != nil { 101 | invalidValidationError := &validator.InvalidValidationError{} 102 | if errors.As(err, &invalidValidationError) { 103 | panic(fmt.Errorf("invalid validation error: %w", err)) 104 | } 105 | 106 | // for _, err := range err.(validator.ValidationErrors) {} 107 | return WrapError(err, http.StatusBadRequest) 108 | } 109 | 110 | return nil 111 | } 112 | 113 | return nil 114 | } 115 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | 2 | # :handshake: Contributing 3 | 4 | This document outlines some of the guidelines that we try and adhere to while 5 | working on this project. 6 | 7 | > :point_right: **Note**: before participating in the community, please read our 8 | > [Code of Conduct][coc]. 9 | > By interacting with this repository, organization, or community you agree to 10 | > abide by our Code of Conduct. 11 | > 12 | > Additionally, if you contribute **any source code** to this repository, you 13 | > agree to the terms of the [Developer Certificate of Origin][dco]. This helps 14 | > ensure that contributions aren't in violation of 3rd party license terms. 15 | 16 | ## :lady_beetle: Issue submission 17 | 18 | When [submitting an issue][issues] or bug report, 19 | please follow these guidelines: 20 | 21 | * Provide as much information as possible (logs, metrics, screenshots, 22 | runtime environment, etc). 23 | * Ensure that you are running on the latest stable version (tagged), or 24 | when using `master`, provide the specific commit being used. 25 | * Provide the minimum needed viable source to replicate the problem. 26 | 27 | ## :bulb: Feature requests 28 | 29 | When [submitting a feature request][issues], please 30 | follow these guidelines: 31 | 32 | * Does this feature benefit others? or just your usecase? If the latter, 33 | it will likely be declined, unless it has a more broad benefit to others. 34 | * Please include the pros and cons of the feature. 35 | * If possible, describe how the feature would work, and any diagrams/mock 36 | examples of what the feature would look like. 37 | 38 | ## :rocket: Pull requests 39 | 40 | To review what is currently being worked on, or looked into, feel free to head 41 | over to the [open pull requests][pull-requests] or [issues list][issues]. 42 | 43 | ## :raised_back_of_hand: Assistance with discussions 44 | 45 | * Take a look at the [open discussions][discussions], and if you feel like 46 | you'd like to help out other members of the community, it would be much 47 | appreciated! 48 | 49 | ## :pushpin: Guidelines 50 | 51 | ### :test_tube: Language agnostic 52 | 53 | Below are a few guidelines if you would like to contribute: 54 | 55 | * If the feature is large or the bugfix has potential breaking changes, 56 | please open an issue first to ensure the changes go down the best path. 57 | * If possible, break the changes into smaller PRs. Pull requests should be 58 | focused on a specific feature/fix. 59 | * Pull requests will only be accepted with sufficient documentation 60 | describing the new functionality/fixes. 61 | * Keep the code simple where possible. Code that is smaller/more compact 62 | does not mean better. Don't do magic behind the scenes. 63 | * Use the same formatting/styling/structure as existing code. 64 | * Follow idioms and community-best-practices of the related language, 65 | unless the previous above guidelines override what the community 66 | recommends. 67 | * Always test your changes, both the features/fixes being implemented, but 68 | also in the standard way that a user would use the project (not just 69 | your configuration that fixes your issue). 70 | * Only use 3rd party libraries when necessary. If only a small portion of 71 | the library is needed, simply rewrite it within the library to prevent 72 | useless imports. 73 | 74 | ### :hamster: Golang 75 | 76 | * See [golang/go/wiki/CodeReviewComments](https://github.com/golang/go/wiki/CodeReviewComments) 77 | * This project uses [golangci-lint](https://golangci-lint.run/) for 78 | Go-related files. This should be available for any editor that supports 79 | `gopls`, however you can also run it locally with `golangci-lint run` 80 | after installing it. 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | ## :clipboard: References 91 | 92 | * [Open Source: How to Contribute](https://opensource.guide/how-to-contribute/) 93 | * [About pull requests](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests) 94 | * [GitHub Docs](https://docs.github.com/) 95 | 96 | ## :speech_balloon: What to do next? 97 | 98 | * :old_key: Find a vulnerability? Check out our [Security and Disclosure][security] policy. 99 | * :link: Repository [License][license]. 100 | * [Support][support] 101 | * [Code of Conduct][coc]. 102 | 103 | 104 | [coc]: https://github.com/lrstanley/chix/blob/master/.github/CODE_OF_CONDUCT.md 105 | [dco]: https://developercertificate.org/ 106 | [discussions]: https://github.com/lrstanley/chix/discussions 107 | [issues]: https://github.com/lrstanley/chix/issues/new/choose 108 | [license]: https://github.com/lrstanley/chix/blob/master/LICENSE 109 | [pull-requests]: https://github.com/lrstanley/chix/pulls?q=is%3Aopen+is%3Apr 110 | [security]: https://github.com/lrstanley/chix/security/policy 111 | [support]: https://github.com/lrstanley/chix/blob/master/.github/SUPPORT.md 112 | -------------------------------------------------------------------------------- /security.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "fmt" 9 | "net/http" 10 | "strings" 11 | "time" 12 | ) 13 | 14 | var ( 15 | securityExpires = time.Now() 16 | robotsTxt = "User-agent: *\nDisallow: %s\nAllow: /\n" 17 | ) 18 | 19 | // UseRobotsTxt returns a handler that serves a robots.txt file. When custom 20 | // is empty, the default robots.txt is served (disallow *, allow /). 21 | // 22 | // You can also use go:embed to embed the robots.txt file into your binary. 23 | // Example: 24 | // 25 | // //go:embed your/robots.txt 26 | // var robotsTxt string 27 | // [...] 28 | // chix.UseRobotsTxt(router, robotsTxt) 29 | func UseRobotsTxt(custom string) func(next http.Handler) http.Handler { 30 | return func(next http.Handler) http.Handler { 31 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 32 | if !strings.HasSuffix(r.URL.Path, "/robots.txt") || (r.Method != http.MethodGet && r.Method != http.MethodHead) { 33 | next.ServeHTTP(w, r) 34 | return 35 | } 36 | 37 | if r.Method == http.MethodHead { 38 | w.WriteHeader(http.StatusOK) 39 | return 40 | } 41 | 42 | w.WriteHeader(http.StatusOK) 43 | w.Header().Set("Content-Type", "text/plain; charset=utf-8") 44 | 45 | if custom == "" { 46 | fmt.Fprintf(w, robotsTxt, DefaultAPIPrefix) 47 | return 48 | } 49 | 50 | _, _ = w.Write([]byte(custom)) 51 | }) 52 | } 53 | } 54 | 55 | // UseSecurityTxt returns a handler that serves a security.txt file at the 56 | // standardized path(s). Only the provided fields will be included in the 57 | // response. 58 | func UseSecurityTxt(config *SecurityConfig) func(next http.Handler) http.Handler { 59 | if config == nil { 60 | panic("SecurityConfig is nil") 61 | } 62 | 63 | return func(next http.Handler) http.Handler { 64 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 65 | if !strings.HasSuffix(r.URL.Path, "/security.txt") || (r.Method != http.MethodGet && r.Method != http.MethodHead) { 66 | next.ServeHTTP(w, r) 67 | return 68 | } 69 | 70 | if r.Method == http.MethodHead { 71 | w.WriteHeader(http.StatusOK) 72 | return 73 | } 74 | 75 | config.ServeHTTP(w, r) 76 | }) 77 | } 78 | } 79 | 80 | // SecurityConfig configures the security.txt middleware. 81 | type SecurityConfig struct { 82 | // Expires is the time when the content of the security.txt file should 83 | // be considered stale (so security researchers should then not trust it). 84 | // Make sure you update this value periodically and keep your file under 85 | // review. 86 | Expires time.Time 87 | 88 | // ExpiresIn is similar to Expires, but uses a given timeframe from when 89 | // the http server was started. 90 | ExpiresIn time.Duration 91 | 92 | // Contacts contains links or e-mail addresses for people to contact you 93 | // about security issues. Remember to include "https://" for URLs, and 94 | // "mailto:" for e-mails (this will be auto-included if it contains an @ 95 | // character). 96 | Contacts []string 97 | 98 | // KeyLinks contains links to keys which security researchers should use 99 | // to securely talk to you. Remember to include "https://". 100 | KeyLinks []string 101 | 102 | // Languages is a list of language codes that your security team speaks. 103 | Languages []string 104 | 105 | // Acknowledgements contains links to webpages where you say thank you 106 | // to security researchers who have helped you. Remember to include 107 | // "https://". 108 | Acknowledgements []string 109 | 110 | // Policies contains links to policies detailing what security researchers 111 | // should do when searching for or reporting security issues. Remember 112 | // to include "https://". 113 | Policies []string 114 | 115 | // Canonical contains the URLs for accessing your security.txt file. It 116 | // is important to include this if you are digitally signing the 117 | // security.txt file, so that the location of the security.txt file can 118 | // be digitally signed too. 119 | Canonical []string 120 | } 121 | 122 | func (h *SecurityConfig) ServeHTTP(w http.ResponseWriter, _ *http.Request) { 123 | w.WriteHeader(http.StatusOK) 124 | w.Header().Set("Content-Type", "text/plain; charset=utf-8") 125 | 126 | for _, entry := range h.Contacts { 127 | if strings.Contains(entry, "@") && !strings.Contains(entry, "mailto:") { 128 | entry = "mailto:" + entry 129 | } 130 | 131 | _, _ = w.Write([]byte("Contact: " + entry + "\n")) 132 | } 133 | 134 | for _, entry := range h.KeyLinks { 135 | _, _ = w.Write([]byte("Encryption: " + entry + "\n")) 136 | } 137 | 138 | if len(h.Languages) > 0 { 139 | _, _ = w.Write([]byte("Preferred-Languages: " + strings.Join(h.Languages, ", ") + "\n")) 140 | } 141 | 142 | for _, entry := range h.Acknowledgements { 143 | _, _ = w.Write([]byte("Acknowledgements: " + entry + "\n")) 144 | } 145 | 146 | for _, entry := range h.Policies { 147 | _, _ = w.Write([]byte("Policy: " + entry + "\n")) 148 | } 149 | 150 | for _, entry := range h.Canonical { 151 | _, _ = w.Write([]byte("Canonical: " + entry + "\n")) 152 | } 153 | 154 | if !h.Expires.IsZero() { 155 | _, _ = w.Write([]byte("Expires: " + h.Expires.Format(time.RFC3339) + "\n")) 156 | } else if h.ExpiresIn > 0 { 157 | _, _ = w.Write([]byte("Expires: " + securityExpires.Add(h.ExpiresIn).Format(time.RFC3339) + "\n")) 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | 2 | # Code of Conduct 3 | 4 | ## Our Pledge :purple_heart: 5 | 6 | We as members, contributors, and leaders pledge to make participation in our 7 | community a harassment-free experience for everyone, regardless of age, body 8 | size, visible or invisible disability, ethnicity, sex characteristics, gender 9 | identity and expression, level of experience, education, socio-economic status, 10 | nationality, personal appearance, race, caste, color, religion, or sexual 11 | identity and orientation. 12 | 13 | We pledge to act and interact in ways that contribute to an open, welcoming, 14 | diverse, inclusive, and healthy community. 15 | 16 | ## Our Standards 17 | 18 | Examples of behavior that contributes to a positive environment for our 19 | community include: 20 | 21 | * Demonstrating empathy and kindness toward other people 22 | * Being respectful of differing opinions, viewpoints, and experiences 23 | * Giving and gracefully accepting constructive feedback 24 | * Accepting responsibility and apologizing to those affected by our mistakes, 25 | and learning from the experience 26 | * Focusing on what is best not just for us as individuals, but for the overall 27 | community 28 | 29 | Examples of unacceptable behavior include: 30 | 31 | * The use of sexualized language or imagery, and sexual attention or advances of 32 | any kind 33 | * Trolling, insulting or derogatory comments, and personal or political attacks 34 | * Public or private harassment 35 | * Publishing others' private information, such as a physical or email address, 36 | without their explicit permission 37 | * Other conduct which could reasonably be considered inappropriate in a 38 | professional setting 39 | 40 | ## Enforcement Responsibilities 41 | 42 | Community leaders are responsible for clarifying and enforcing our standards of 43 | acceptable behavior and will take appropriate and fair corrective action in 44 | response to any behavior that they deem inappropriate, threatening, offensive, 45 | or harmful. 46 | 47 | Community leaders have the right and responsibility to remove, edit, or reject 48 | comments, commits, code, wiki edits, issues, and other contributions that are 49 | not aligned to this Code of Conduct, and will communicate reasons for moderation 50 | decisions when appropriate. 51 | 52 | ## Scope 53 | 54 | This Code of Conduct applies within all community spaces, and also applies when 55 | an individual is officially representing the community in public spaces. 56 | Examples of representing our community include using an official e-mail address, 57 | posting via an official social media account, or acting as an appointed 58 | representative at an online or offline event. 59 | 60 | ## Enforcement 61 | 62 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 63 | reported to the community leaders responsible for enforcement at 64 | disclosure@liam.sh. All complaints will be reviewed and investigated 65 | promptly and fairly. 66 | 67 | All community leaders are obligated to respect the privacy and security of the 68 | reporter of any incident. 69 | 70 | ## Enforcement Guidelines 71 | 72 | Community leaders will follow these Community Impact Guidelines in determining 73 | the consequences for any action they deem in violation of this Code of Conduct: 74 | 75 | ### 1. Correction 76 | 77 | **Community Impact**: Use of inappropriate language or other behavior deemed 78 | unprofessional or unwelcome in the community. 79 | 80 | **Consequence**: A private, written warning from community leaders, providing 81 | clarity around the nature of the violation and an explanation of why the 82 | behavior was inappropriate. A public apology may be requested. 83 | 84 | ### 2. Warning 85 | 86 | **Community Impact**: A violation through a single incident or series of 87 | actions. 88 | 89 | **Consequence**: A warning with consequences for continued behavior. No 90 | interaction with the people involved, including unsolicited interaction with 91 | those enforcing the Code of Conduct, for a specified period of time. This 92 | includes avoiding interactions in community spaces as well as external channels 93 | like social media. Violating these terms may lead to a temporary or permanent 94 | ban. 95 | 96 | ### 3. Temporary Ban 97 | 98 | **Community Impact**: A serious violation of community standards, including 99 | sustained inappropriate behavior. 100 | 101 | **Consequence**: A temporary ban from any sort of interaction or public 102 | communication with the community for a specified period of time. No public or 103 | private interaction with the people involved, including unsolicited interaction 104 | with those enforcing the Code of Conduct, is allowed during this period. 105 | Violating these terms may lead to a permanent ban. 106 | 107 | ### 4. Permanent Ban 108 | 109 | **Community Impact**: Demonstrating a pattern of violation of community 110 | standards, including sustained inappropriate behavior, harassment of an 111 | individual, or aggression toward or disparagement of classes of individuals. 112 | 113 | **Consequence**: A permanent ban from any sort of public interaction within the 114 | community. 115 | 116 | ## Attribution 117 | 118 | This Code of Conduct is adapted from the Contributor Covenant, 119 | version 2.1, available [here](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html). 120 | 121 | For answers to common questions about this code of conduct, see the [FAQ](https://www.contributor-covenant.org/faq). 122 | Translations are available at [translations](https://www.contributor-covenant.org/translations). 123 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "context" 9 | "errors" 10 | "fmt" 11 | "net/http" 12 | "os" 13 | "os/signal" 14 | "syscall" 15 | "time" 16 | 17 | "github.com/apex/log" 18 | "golang.org/x/sync/errgroup" 19 | ) 20 | 21 | const ( 22 | srvDefaultReadTimeout = 15 * time.Second 23 | srvDefaultWriteTimeout = 15 * time.Second 24 | srvDefaultMaxHeaderBytes = 1 << 20 25 | srvCancelTimeout = 10 * time.Second 26 | ) 27 | 28 | type Runner func(ctx context.Context) error 29 | 30 | func (r Runner) Invoke(ctx context.Context) func() error { 31 | fn := func() error { 32 | return r(ctx) 33 | } 34 | return fn 35 | } 36 | 37 | func RunnerInterval(name string, r Runner, frequency time.Duration, runImmediately, exitOnError bool) Runner { 38 | return func(ctx context.Context) error { 39 | logEntry := log.FromContext(ctx).WithField("runner", name) 40 | ctx = log.NewContext(ctx, logEntry) 41 | 42 | var lastRun time.Time 43 | 44 | if runImmediately { 45 | lastRun = time.Now() 46 | logEntry.Info("invoking runner") 47 | if err := r(ctx); err != nil { 48 | logEntry.WithError(err).WithDuration(time.Since(lastRun)).Error("invocation failed") 49 | return err 50 | } 51 | logEntry.WithDuration(time.Since(lastRun)).Info("invocation complete") 52 | } 53 | 54 | ticker := time.NewTicker(frequency) 55 | defer ticker.Stop() 56 | 57 | for { 58 | select { 59 | case <-ctx.Done(): 60 | return nil 61 | case <-ticker.C: 62 | lastRun = time.Now() 63 | logEntry.Info("invoking runner") 64 | if err := r(ctx); err != nil { 65 | logEntry.WithError(err).WithDuration(time.Since(lastRun)).Error("invocation failed") 66 | 67 | if exitOnError { 68 | return err 69 | } 70 | 71 | logEntry.WithDuration(time.Since(lastRun)).Info("invocation complete") 72 | continue 73 | } 74 | } 75 | } 76 | } 77 | } 78 | 79 | // Run runs the provided http server, and listens for any termination signals 80 | // (SIGINT, SIGTERM, SIGQUIT, etc). If runners are provided, those will run 81 | // concurrently. 82 | // 83 | // If the http server, or any runners return an error, all runners will 84 | // terminate (assuming they listen to the provided context), and the first 85 | // known error will be returned. The http server will be gracefully shut down, 86 | // with a timeout of 10 seconds. 87 | func Run(srv *http.Server, runners ...Runner) error { 88 | return RunCtx(context.Background(), srv, runners...) 89 | } 90 | 91 | // RunTLS is the same as Run, but allows for TLS to be used. 92 | func RunTLS(srv *http.Server, certFile, keyFile string, runners ...Runner) error { 93 | return RunTLSContext(context.Background(), srv, certFile, keyFile, runners...) 94 | } 95 | 96 | // Deprecated: Use [RunContext] instead. 97 | func RunCtx(ctx context.Context, srv *http.Server, runners ...Runner) error { 98 | return RunContext(ctx, srv, runners...) 99 | } 100 | 101 | // RunContext is the same as Run, but with the provided context that can be used 102 | // to externally cancel all runners and the http server. 103 | func RunContext(ctx context.Context, srv *http.Server, runners ...Runner) error { 104 | serverSetDefaults(srv) 105 | 106 | var g *errgroup.Group 107 | g, ctx = errgroup.WithContext(ctx) 108 | 109 | g.Go(func() error { 110 | return signalListener(ctx) 111 | }) 112 | 113 | g.Go(func() error { 114 | return httpServer(ctx, srv, "", "") 115 | }) 116 | 117 | for _, runner := range runners { 118 | g.Go(runner.Invoke(ctx)) 119 | } 120 | 121 | return g.Wait() 122 | } 123 | 124 | // RunTLSContext is the same as Run, but with the provided context that can be used 125 | // to externally cancel all runners and the http server, and also allows for TLS 126 | // to be used. 127 | func RunTLSContext(ctx context.Context, srv *http.Server, certFile, keyFile string, runners ...Runner) error { 128 | serverSetDefaults(srv) 129 | 130 | var g *errgroup.Group 131 | g, ctx = errgroup.WithContext(ctx) 132 | 133 | g.Go(func() error { 134 | return signalListener(ctx) 135 | }) 136 | 137 | g.Go(func() error { 138 | return httpServer(ctx, srv, certFile, keyFile) 139 | }) 140 | 141 | for _, runner := range runners { 142 | g.Go(runner.Invoke(ctx)) 143 | } 144 | 145 | return g.Wait() 146 | } 147 | 148 | var SetServerDefaults = true 149 | 150 | func serverSetDefaults(srv *http.Server) { 151 | if !SetServerDefaults { 152 | return 153 | } 154 | 155 | if srv.ReadTimeout == 0 { 156 | srv.ReadTimeout = srvDefaultReadTimeout 157 | } 158 | 159 | if srv.WriteTimeout == 0 { 160 | srv.WriteTimeout = srvDefaultWriteTimeout 161 | } 162 | 163 | if srv.MaxHeaderBytes == 0 { 164 | srv.MaxHeaderBytes = srvDefaultMaxHeaderBytes 165 | } 166 | } 167 | 168 | func signalListener(ctx context.Context) error { 169 | quit := make(chan os.Signal, 1) 170 | signal.Notify(quit, os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT) 171 | 172 | select { 173 | case sig := <-quit: 174 | log.FromContext(ctx).WithField("signal", sig).Warn("received signal, starting graceful termination") 175 | return fmt.Errorf("received signal: %v", sig) 176 | case <-ctx.Done(): 177 | return nil 178 | } 179 | } 180 | 181 | func httpServer(ctx context.Context, srv *http.Server, certFile, keyFile string) error { 182 | ch := make(chan error) 183 | go func() { 184 | var err error 185 | 186 | if certFile != "" && keyFile != "" { 187 | err = srv.ListenAndServeTLS(certFile, keyFile) 188 | } else { 189 | err = srv.ListenAndServe() 190 | } 191 | 192 | if err != nil && !errors.Is(err, http.ErrServerClosed) { 193 | ch <- err 194 | } 195 | close(ch) 196 | }() 197 | 198 | select { 199 | case <-ctx.Done(): 200 | case err := <-ch: 201 | return err 202 | } 203 | 204 | ctxTimeout, cancel := context.WithTimeout(context.Background(), srvCancelTimeout) 205 | defer cancel() 206 | 207 | return srv.Shutdown(ctxTimeout) 208 | } 209 | -------------------------------------------------------------------------------- /goembed.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "context" 9 | "errors" 10 | "fmt" 11 | "io/fs" 12 | "net/http" 13 | "os" 14 | "path" 15 | "path/filepath" 16 | "runtime" 17 | "strings" 18 | 19 | "github.com/apex/log" 20 | ) 21 | 22 | // UseStatic returns a handler that serves static files from the provided embedded 23 | // filesystem, with support for using the direct filesystem when debugging is 24 | // enabled. 25 | // 26 | // Example usage: 27 | // 28 | // //go:embed all:public/dist 29 | // var publicDist embed.FS 30 | // [...] 31 | // router.Mount("/static", chix.UseStatic(&chix.Static{ 32 | // FS: publicDist, 33 | // Prefix: "/static", 34 | // AllowLocal: true, 35 | // Path: "public/dist" 36 | // })) 37 | func UseStatic(ctx context.Context, config *Static) http.Handler { 38 | logger := log.FromContext(ctx) 39 | 40 | var err error 41 | 42 | if config == nil { 43 | panic("config is nil") 44 | } 45 | 46 | if config.FS == nil { 47 | panic("config.FS is nil") 48 | } 49 | 50 | if config.LocalPath == "" { 51 | config.LocalPath = config.Path 52 | } 53 | 54 | config.Path = strings.Trim(config.Path, "/") 55 | config.LocalPath = strings.Trim(config.LocalPath, "/") 56 | 57 | if config.AllowLocal && config.LocalPath == "" { 58 | panic("config.AllowLocal is true, but config.LocalPath and config.Path is empty") 59 | } 60 | 61 | if config.Path != "" { 62 | config.FS, err = fs.Sub(config.FS, config.Path) 63 | if err != nil { 64 | panic(fmt.Errorf("failed to use subdirectory of filesystem: %w", err)) 65 | } 66 | } 67 | 68 | _, srcPath, _, _ := runtime.Caller(1) 69 | srcPath = path.Join(filepath.Dir(srcPath), config.LocalPath) 70 | 71 | exePath, err := os.Executable() 72 | if err != nil { 73 | panic(fmt.Errorf("failed to get executable path: %w", err)) 74 | } 75 | exePath = path.Join(filepath.Dir(exePath), config.LocalPath) 76 | 77 | cwdLocal, _ := os.Stat(config.LocalPath) // Path to the current working directory. 78 | srcLocal, _ := os.Stat(srcPath) // Path to source file, if it's still on the filesystem. 79 | exeLocal, _ := os.Stat(exePath) // Path to the current executable. 80 | 81 | logger.WithFields(log.Fields{ 82 | "allow_local": config.AllowLocal, 83 | "path": config.Path, 84 | "local_path": config.LocalPath, 85 | "src_path": srcPath, 86 | "exe_path": exePath, 87 | }).Debug("static asset search paths") 88 | 89 | switch { 90 | case config.AllowLocal && cwdLocal != nil && cwdLocal.IsDir(): 91 | config.httpFS = http.Dir(config.LocalPath) 92 | logger.WithField("path", config.LocalPath).Debug("registering static assets in current working directory") 93 | case config.AllowLocal && srcLocal != nil && srcLocal.IsDir(): 94 | config.LocalPath = srcPath 95 | config.httpFS = http.Dir(config.LocalPath) 96 | logger.WithField("path", config.LocalPath).Debug("registering static assets in source file directory") 97 | case config.AllowLocal && exeLocal != nil && exeLocal.IsDir(): 98 | config.LocalPath = exePath 99 | config.httpFS = http.Dir(config.LocalPath) 100 | logger.WithField("path", config.LocalPath).Debug("registering static assets in executable directory") 101 | default: 102 | logger.WithField("path", config.Path).Debug("registering embedded static assets") 103 | config.httpFS = http.FS(config.FS) 104 | 105 | _ = fs.WalkDir(config.FS, ".", func(path string, info fs.DirEntry, err error) error { 106 | if err != nil || info.IsDir() { 107 | return nil //nolint:nilerr 108 | } 109 | 110 | logger.Debugf("registering embedded asset: %v", path) 111 | return nil 112 | }) 113 | } 114 | 115 | config.handler = http.FileServer(config.httpFS) 116 | 117 | if config.Prefix != "" { 118 | // Don't wrap the internal handler, as any logic we do, we want the prefix 119 | // to be stripped first. 120 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 121 | r.URL.Path = strings.TrimPrefix(r.URL.Path, config.Prefix) 122 | r.URL.RawPath = strings.TrimPrefix(r.URL.RawPath, config.Prefix) 123 | config.ServeHTTP(w, r) 124 | }) 125 | } 126 | return config 127 | } 128 | 129 | // Static is an http.Handler that serves static files from an embedded filesystem. 130 | // See chix.UseStatic() for more information. 131 | type Static struct { 132 | // fs is the filesystem to serve. 133 | FS fs.FS 134 | 135 | // Prefix is the prefix where the filesystem is mounted on your http router. 136 | Prefix string 137 | 138 | // CatchAll is a boolean that determines if chix.Static is being used as a 139 | // catch-all for not-found routes. If so, it will do extra validations for 140 | // using chix.Error when the route is related to an API endpoint (see 141 | // chix.DefaultAPIPrefix), as well as enforce specific methods. 142 | CatchAll bool 143 | 144 | // AllowLocal is a boolean that, if true, and chix.LocalPath exists, it will 145 | // bypass the provided filesystem and instead use the actual filesystem. 146 | AllowLocal bool 147 | 148 | // LocalPath is the subpath to use when AllowLocal is enabled. If empty, it 149 | // will default to Static.Path. It will check for this sub-directory in either 150 | // the current working directory, or the executable directory. 151 | LocalPath string 152 | 153 | // Path of the embedded filesystem, instead of the entire filesystem. go:embed 154 | // will include the target that gets embedded, as a prefix to the path. 155 | // 156 | // For example, given "go:embed all:public/dist", mounted at "/static", you 157 | // would normally have to access using "/static/public/dist/". Providing 158 | // path, where path is "public/dist", you can access the same files 159 | // via "/static/". 160 | Path string 161 | 162 | // SPA is a boolean that, if true, will serve a single page application, i.e. 163 | // redirecting all files not found, to the index.html file. 164 | SPA bool 165 | 166 | // Headers is a map of headers to set on the response (e.g. cache headers). 167 | // Example: 168 | // &chix.Static{ 169 | // [...] 170 | // Headers: map[string]string{ 171 | // "Vary": "Accept-Encoding", 172 | // "Cache-Control": "public, max-age=7776000", 173 | // }, 174 | // } 175 | Headers map[string]string 176 | 177 | httpFS http.FileSystem 178 | handler http.Handler 179 | } 180 | 181 | func (s *Static) ServeHTTP(w http.ResponseWriter, r *http.Request) { 182 | if s.CatchAll { 183 | if strings.HasPrefix(r.URL.Path, DefaultAPIPrefix) { 184 | Error(w, r, WrapCode(http.StatusNotFound)) 185 | return 186 | } 187 | 188 | if r.Method != http.MethodGet { 189 | Error(w, r, WrapCode(http.StatusMethodNotAllowed)) 190 | return 191 | } 192 | } 193 | 194 | // Handle custom headers, if any. 195 | if s.Headers != nil { 196 | for k, v := range s.Headers { 197 | w.Header().Set(k, v) 198 | } 199 | } 200 | 201 | // Handle SPA, if enabled. 202 | if s.SPA { 203 | if !strings.HasPrefix(r.URL.Path, "/") { 204 | r.URL.Path = "/" + r.URL.Path 205 | } 206 | 207 | f, err := s.httpFS.Open(path.Clean(r.URL.Path)) 208 | if err != nil && errors.Is(err, fs.ErrNotExist) { 209 | r.URL.Path = "/" 210 | } 211 | if f != nil { 212 | _ = f.Close() 213 | } 214 | } 215 | 216 | s.handler.ServeHTTP(w, r) 217 | } 218 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "net/http" 11 | "strings" 12 | "sync/atomic" 13 | "time" 14 | 15 | "github.com/go-chi/chi/v5/middleware" 16 | ) 17 | 18 | //nolint:lll 19 | var ( 20 | // DefaultMaskError is a flag that can be used to mask errors in the 21 | // default error handler. This only impacts errors from 500 onwards. 22 | // If debug is enabled via the UseDebug middleware, this flag will be 23 | // ignored. 24 | DefaultMaskError = true 25 | 26 | ErrAccessDenied = errors.New("access denied") 27 | ErrAPIKeyInvalid = errors.New("invalid api key provided") 28 | ErrAPIKeyMissing = errors.New("api key not specified") 29 | ErrNoAPIKeys = errors.New("no api keys provided in initialization") 30 | ErrAPIVersionMissing = errors.New("api version not specified") 31 | ErrAPIVersionMismatch = errors.New("server and client version mismatch") 32 | ErrRealIPNoOpts = errors.New("realip: no options specified") 33 | ErrRealIPNoSource = errors.New("realip: no real IP source specified (OptUseXForwardedFor, OptUseXRealIP, or OptUseTrueClientIP, OptUseCFConnectingIP)") 34 | ErrRealIPNoTrusted = errors.New("realip: no trusted proxies or bogon IPs specified") 35 | ErrAuthNotFound = errors.New("auth: no authentiation found") 36 | ErrAuthMissingRole = errors.New("auth: missing necessary role") 37 | ) 38 | 39 | type ErrRealIPInvalidIP struct { 40 | Err error 41 | } 42 | 43 | func (e ErrRealIPInvalidIP) Error() string { 44 | return fmt.Sprintf("realip: invalid IP or range specified: %s", e.Err.Error()) 45 | } 46 | 47 | func (e ErrRealIPInvalidIP) Unwrap() error { 48 | return e.Err 49 | } 50 | 51 | // ErrorResolver is a function that converts an error to a status code. If 52 | // 0 is returned, the originally provided status code will be used. Resolvers 53 | // are useful in situations where you want to return a different error when 54 | // the error contains a database-related error (like duplicate key already 55 | // exists, returning a 400 by default), of when you can check if the error 56 | // is due to user input. 57 | type ErrorResolver func(err error) (status int) 58 | 59 | var errorResolvers atomic.Value // []ErrorResolver 60 | 61 | // AddErrorResolver can be used to add additional error resolvers to the 62 | // default error handler. These will not be used if a custom error handler 63 | // is used. 64 | func AddErrorResolver(r ErrorResolver) { 65 | resolvers, ok := errorResolvers.Load().([]ErrorResolver) 66 | if !ok { 67 | resolvers = []ErrorResolver{} 68 | } 69 | 70 | resolvers = append(resolvers, r) 71 | errorResolvers.Store(resolvers) 72 | } 73 | 74 | // ErrorHandler is a function that, depending on the input, will either 75 | // respond to a request with a given response structured based off an error 76 | // or do nothing, if there isn't actually an error. 77 | type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error) (ok bool) 78 | 79 | // Error handles the error (if any). Handler WILL respond to the request 80 | // with a header and a response if there is an error. The return boolean tells 81 | // the caller if the handler has responded to the request or not. If the 82 | // request includes /api/ as the prefix (see DefaultAPIPrefix), the response 83 | // will be JSON. 84 | // 85 | // If you'd like a specific status code to be returned, there are four options: 86 | // 1. Use AddErrorResolver() to add a custom resolver for err -> status code. 87 | // 2. Use WrapError() to wrap the error with a given status code. 88 | // 3. Use WrapCode() to make an error from a given status code (if you don't 89 | // have an error that you can provide). 90 | // 4. If none of the above apply, http.StatusInternalServerError will be returned. 91 | // 92 | // NOTE: if you override this function, you must call chix.UnwrapError() on the 93 | // error to get the original error, and the status code, if any of the above are 94 | // used. 95 | var Error = defaultErrorHandler 96 | 97 | // defaultErrorHandler is the default ErrorHandler implementation. 98 | func defaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) bool { 99 | if err == nil { 100 | return false 101 | } 102 | 103 | var statusCode int 104 | err, statusCode = UnwrapError(err) 105 | statusText := http.StatusText(statusCode) 106 | 107 | id := middleware.GetReqID(r.Context()) 108 | if id == "" { 109 | id = "-" 110 | } 111 | 112 | if statusCode >= http.StatusInternalServerError { 113 | Log(r).WithError(err) 114 | 115 | if !IsDebug(r) && DefaultMaskError { 116 | err = errors.New("internal server error") 117 | } 118 | } 119 | 120 | if DefaultAPIPrefix != "" && strings.HasPrefix(r.URL.Path, DefaultAPIPrefix) { 121 | JSON(w, r, statusCode, M{ 122 | "error": err.Error(), 123 | "type": statusText, 124 | "code": statusCode, 125 | "request_id": id, 126 | "timestamp": time.Now().UTC().Format(time.RFC3339), 127 | }) 128 | } else { 129 | http.Error(w, fmt.Sprintf( 130 | "%s: %s (id: %s)", statusText, err.Error(), id, 131 | ), statusCode) 132 | } 133 | 134 | return true 135 | } 136 | 137 | // ErrorCode is a helper function for Error() that includes a status code in the 138 | // response. See also chix.WrapError() and chix.WrapCode(). 139 | func ErrorCode(w http.ResponseWriter, r *http.Request, statusCode int, err error) bool { 140 | return Error(w, r, WrapError(err, statusCode)) 141 | } 142 | 143 | // ErrWithStatusCode is an error wrapper that bundles a given status code, that 144 | // can be used by chix.Error() as the response code. See chix.WrapError() and 145 | // chix.WrapCode(). 146 | type ErrWithStatusCode struct { 147 | Err error 148 | Code int 149 | } 150 | 151 | func (e ErrWithStatusCode) Error() string { 152 | return fmt.Sprintf("%s (status: %d)", e.Err.Error(), e.Code) 153 | } 154 | 155 | func (e ErrWithStatusCode) Unwrap() error { 156 | return e.Err 157 | } 158 | 159 | // UnwrapError is a helper function for retrieving the underlying error and status 160 | // code from an error that has been wrapped. 161 | func UnwrapError(err error) (resultErr error, statusCode int) { //nolint:revive 162 | if err == nil { 163 | return nil, 0 164 | } 165 | 166 | statusCode = http.StatusInternalServerError 167 | 168 | // If the user has wrapped the error, this will override any other code 169 | // we have. 170 | var codeErr *ErrWithStatusCode 171 | if errors.As(err, &codeErr) { 172 | statusCode = codeErr.Code 173 | err = codeErr.Unwrap() 174 | return err, statusCode 175 | } 176 | 177 | // First try with resolvers. 178 | if resolvers, ok := errorResolvers.Load().([]ErrorResolver); ok { 179 | for _, fn := range resolvers { 180 | var code int 181 | if code = fn(err); code != 0 { 182 | statusCode = code 183 | break 184 | } 185 | } 186 | } 187 | return err, statusCode 188 | } 189 | 190 | // WrapError wraps an error with an http status code, which chix.Error() can use 191 | // as the response code. Example usage: 192 | // 193 | // if chix.Error(w, r, chix.WrapError(err, http.StatusBadRequest)) { 194 | // return 195 | // } 196 | // 197 | // if chix.Error(w, r, chix.WrapError(err, 500)) { 198 | // return 199 | // } 200 | func WrapError(err error, code int) error { 201 | return &ErrWithStatusCode{Err: err, Code: code} 202 | } 203 | 204 | // WrapCode is a helper function that returns an error using the status text of the 205 | // given http status code. This is useful if you don't have an explicit error to 206 | // respond with. Example usage: 207 | // 208 | // chix.Error(w, r, chix.WrapCode(http.StatusBadRequest)) 209 | // return 210 | // 211 | // chix.Error(w, r, chix.WrapCode(500)) 212 | // return 213 | func WrapCode(code int) error { 214 | out := http.StatusText(code) 215 | if out == "" { 216 | out = fmt.Sprintf("unknown error (%d)", code) 217 | } 218 | 219 | return &ErrWithStatusCode{Err: errors.New(out), Code: code} 220 | } 221 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 6 | ![logo](https://liam.sh/-/gh/svg/lrstanley/chix?icon=logos%3Ago&icon.height=65&layout=left&font=1.1&icon.color=rgba%280%2C+0%2C+0%2C+1%29) 7 | 8 | 9 | 10 | 11 |

12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 |

39 |

40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 |

54 | 55 | 56 | 57 | 58 | ## :link: Table of Contents 59 | 60 | - [Usage](#gear-usage) 61 | - [Features](#sparkles-features) 62 | - [Related Libraries](#zap-related-libraries) 63 | - [Example Projects](#bulb-example-projects) 64 | - [Support & Assistance](#raising_hand_man-support--assistance) 65 | - [Contributing](#handshake-contributing) 66 | - [License](#balance_scale-license) 67 | 68 | 69 | ## :gear: Usage 70 | 71 | 72 | 73 | ```console 74 | go get -u github.com/lrstanley/chix@latest 75 | ``` 76 | 77 | 78 | ## :sparkles: Features 79 | 80 | - `http.Server` wrapper that easily allows starting, and gracefully shutting 81 | down your http server, and other background services, using `errgroup`. 82 | - RealIP middleware (supports whitelisting specific proxies, rather than allowing 83 | any source). 84 | - private IP middleware, restricting endpoints to be internal only. 85 | - Rendering helpers: 86 | - `JSON` (with `?pretty=true` support). 87 | - Auth middleware: 88 | - Uses [markbates/goth](https://github.com/markbates/goth) to support many 89 | different providers. 90 | - Encrypts session cookies, which removes the need for local session storage. 91 | - Uses Go 1.18's generics functionality to provide a custom ID and auth object 92 | resolver. 93 | - No longer have to type assert to your local models! 94 | - Optionally requiring authentication. 95 | - Optionally requiring specific roles. 96 | - Optionally adding authentication info to context for use by children handlers. 97 | - API key validation. 98 | - API version validation. 99 | - Struct/type binding, from get/post data, with support for [go-playground/validator](https://github.com/go-playground/validator). 100 | - Structured logging using [apex/log](https://github.com/apex/log) (same API 101 | as logrus). 102 | - Allows injecting additional metadata into logs. 103 | - Injects logger into context for use by children handlers. 104 | - Debug middleware: 105 | - Easily let children handlers know if global debug flags are enabled. 106 | - Allows masking errors, unless debugging is enabled. 107 | - Error handler, that automatically handles api-vs-static content responses. 108 | - Supports `ErrorResolver`'s, providing the ability to override status codes 109 | for specific types of errors. 110 | - `go:embed` helpers for mounting an embedded filesystem seamlessly as an http 111 | endpoint. 112 | - Useful for projects that bundle their frontend assets in their binary. 113 | - Supports local filesystem reading, when debugging is enabled (TODO). 114 | - Middleware for robots.txt and security.txt responding. 115 | 116 | ## :zap: Related Libraries 117 | 118 | - [lrstanley/clix](https://github.com/lrstanley/clix) -- go-flags wrapper, that 119 | handles parsing and decoding, with additional helpers. 120 | - [lrstanley/go-query-parser](https://github.com/lrstanley/go-queryparser) -- similar 121 | to that of Google/Github/etc search, a query string parser that allows filters 122 | and tags to be dynamically configured by the end user. 123 | 124 | ## :bulb: Example Projects 125 | 126 | Use these as a reference point for how you might use some of the functionality within 127 | this library, or how you might want to structure your applications. 128 | 129 | - [lrstanley/geoip](https://github.com/lrstanley/geoip) 130 | - [lrstanley/liam.sh](https://github.com/lrstanley/liam.sh) 131 | - [lrstanley/spectrograph](https://github.com/lrstanley/spectrograph) 132 | 133 | 134 | 135 | ## :raising_hand_man: Support & Assistance 136 | 137 | * :heart: Please review the [Code of Conduct](.github/CODE_OF_CONDUCT.md) for 138 | guidelines on ensuring everyone has the best experience interacting with 139 | the community. 140 | * :raising_hand_man: Take a look at the [support](.github/SUPPORT.md) document on 141 | guidelines for tips on how to ask the right questions. 142 | * :lady_beetle: For all features/bugs/issues/questions/etc, [head over here](https://github.com/lrstanley/chix/issues/new/choose). 143 | 144 | 145 | 146 | 147 | ## :handshake: Contributing 148 | 149 | * :heart: Please review the [Code of Conduct](.github/CODE_OF_CONDUCT.md) for guidelines 150 | on ensuring everyone has the best experience interacting with the 151 | community. 152 | * :clipboard: Please review the [contributing](.github/CONTRIBUTING.md) doc for submitting 153 | issues/a guide on submitting pull requests and helping out. 154 | * :old_key: For anything security related, please review this repositories [security policy](https://github.com/lrstanley/chix/security/policy). 155 | 156 | 157 | 158 | 159 | ## :balance_scale: License 160 | 161 | ``` 162 | MIT License 163 | 164 | Copyright (c) 2022 Liam Stanley 165 | 166 | Permission is hereby granted, free of charge, to any person obtaining a copy 167 | of this software and associated documentation files (the "Software"), to deal 168 | in the Software without restriction, including without limitation the rights 169 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 170 | copies of the Software, and to permit persons to whom the Software is 171 | furnished to do so, subject to the following conditions: 172 | 173 | The above copyright notice and this permission notice shall be included in all 174 | copies or substantial portions of the Software. 175 | 176 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 177 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 178 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 179 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 180 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 181 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 182 | SOFTWARE. 183 | ``` 184 | 185 | _Also located [here](LICENSE)_ 186 | 187 | -------------------------------------------------------------------------------- /realip.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "net" 11 | "net/http" 12 | "strings" 13 | 14 | "github.com/lrstanley/go-bogon" 15 | ) 16 | 17 | const ( 18 | OptTrustBogon RealIPOptions = 1 << iota // Trust bogon IP ranges (private IP ranges). 19 | OptTrustAny // Trust any proxy (DON'T USE THIS!). 20 | OptTrustCloudflare // Trust Cloudflare's origin IPs. 21 | OptUseXForwardedFor // Allow using the X-Forwarded-For header. 22 | OptUseXRealIP // Allow using the X-Real-IP header. 23 | OptUseTrueClientIP // Allow using the True-Client-IP header. 24 | OptUseCFConnectingIP // Allow using the CF-Connecting-IP header. 25 | 26 | OptDefaultTrust = OptTrustBogon | OptUseXForwardedFor // Default trust options. 27 | 28 | xForwardedFor = "X-Forwarded-For" 29 | xRealIP = "X-Real-IP" 30 | trueClientIP = "True-Client-IP" 31 | ) 32 | 33 | // RealIPOptions is a bitmask of options that can be passed to RealIP. 34 | type RealIPOptions int 35 | 36 | // UseRealIPDefault is a convenience function that wraps RealIP with the default 37 | // options (OptTrustBogon and OptUseXForwardedFor). 38 | func UseRealIPDefault(next http.Handler) http.Handler { 39 | return UseRealIP(nil, OptDefaultTrust)(next) 40 | } 41 | 42 | // UseRealIPCLIOpts is a convenience function that wraps RealIP, with support for 43 | // configuring the middleware via CLI flags. You can pass in an array that contains 44 | // a mix of different supported headers, "cloudflare", "*" (or "any", "all") to 45 | // trust anything, "local" (or "localhost", "bogon", "internal") for bogon IPs, 46 | // and anything else gets passed in as allowed CIDRs. 47 | // 48 | // If no options are passed in, the default will use the same as chix.UseRealIPDefault 49 | // (OptTrustBogon and OptUseXForwardedFor). 50 | func UseRealIPCLIOpts(options []string) func(next http.Handler) http.Handler { 51 | if len(options) == 0 { 52 | return UseRealIPDefault 53 | } 54 | 55 | var flags RealIPOptions 56 | var proxies []string 57 | 58 | for _, option := range options { 59 | switch strings.ToLower(option) { 60 | case "cloudflare", "cf-connecting-ip": 61 | flags |= OptTrustCloudflare | OptUseCFConnectingIP 62 | case "x-forwarded-for": 63 | flags |= OptUseXForwardedFor 64 | case "x-real-ip": 65 | flags |= OptUseXRealIP 66 | case "true-client-ip": 67 | flags |= OptUseTrueClientIP 68 | case "*", "any", "all": 69 | flags |= OptTrustAny 70 | case "local", "localhost", "bogon", "internal": 71 | flags |= OptTrustBogon 72 | default: 73 | proxies = append(proxies, option) 74 | } 75 | } 76 | 77 | if flags == 0 { 78 | flags = OptDefaultTrust 79 | } 80 | 81 | return UseRealIP(proxies, flags) 82 | } 83 | 84 | // UseRealIP is a middleware that allows passing the real IP address of the client 85 | // only if the request headers that include an override, come from a trusted 86 | // proxy. Pass an optional list of trusted proxies to trust, as well as 87 | // any additional options to control the behavior of the middleware. See the 88 | // related Opt* constants for more information. Will panic if invalid IP's or 89 | // ranges are specified. 90 | // 91 | // NOTE: if multiple headers are configured to be trusted, the lookup order is: 92 | // - CF-Connecting-IP 93 | // - X-Real-IP 94 | // - True-Client-IP 95 | // - X-Forwarded-For 96 | // 97 | // Examples: 98 | // 99 | // router.Use(chix.UseRealIP([]string{"1.2.3.4", "10.0.0.0/24"}, chix.OptUseXForwardedFor)) 100 | // router.Use(nil, chix.OptTrustBogon|chix.OptUseXForwardedFor)) 101 | func UseRealIP(trusted []string, flags RealIPOptions) func(next http.Handler) http.Handler { 102 | if flags == 0 { 103 | panic(ErrRealIPNoOpts) 104 | } 105 | 106 | // Must provide at least one proxy header type. 107 | if flags&(OptUseXForwardedFor|OptUseXRealIP|OptUseTrueClientIP|OptUseCFConnectingIP) == 0 { 108 | panic(ErrRealIPNoSource) 109 | } 110 | 111 | // ¯\_(ツ)_/¯. 112 | if flags&OptTrustAny != 0 { 113 | trusted = append(trusted, "0.0.0.0/0") 114 | } 115 | 116 | rip := &realIP{ 117 | trusted: []*net.IPNet{}, 118 | } 119 | 120 | // Add all known bogon IP ranges. 121 | if flags&OptTrustBogon != 0 { 122 | rip.trusted = append(rip.trusted, bogon.DefaultRanges()...) 123 | } 124 | 125 | if flags&OptTrustCloudflare != 0 { 126 | rip.trusted = append(rip.trusted, cloudflareRanges()...) 127 | } 128 | 129 | // Start parsing user-provided CIDR's and/or IP's. 130 | for _, proxy := range trusted { 131 | if !strings.Contains(proxy, "/") { 132 | ip := parseIP(proxy) 133 | if ip == nil { 134 | panic(&ErrRealIPInvalidIP{Err: &net.ParseError{Type: "IP address", Text: proxy}}) 135 | } 136 | 137 | switch len(ip) { 138 | case net.IPv4len: 139 | proxy += "/32" 140 | case net.IPv6len: 141 | proxy += "/128" 142 | } 143 | } 144 | 145 | _, cidr, err := net.ParseCIDR(proxy) 146 | if err != nil { 147 | panic(fmt.Errorf("chix: realip: invalid CIDR %w", err)) 148 | } 149 | 150 | rip.trusted = append(rip.trusted, cidr) 151 | } 152 | 153 | if len(rip.trusted) == 0 { 154 | panic(ErrRealIPNoTrusted) 155 | } 156 | 157 | return func(next http.Handler) http.Handler { 158 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 159 | ip, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)) 160 | if err != nil { 161 | goto nexthandler // Fallback and don't modify. 162 | } 163 | 164 | if trusted := rip.isTrustedProxy(net.ParseIP(ip)); !trusted { 165 | goto nexthandler // Fallback and don't modify. 166 | } 167 | 168 | // Parse enabled headers by most specific (and common) to least. 169 | if flags&OptUseCFConnectingIP != 0 { 170 | if value := parseIP(r.Header.Get("Cf-Connecting-Ip")); value != nil { 171 | r.RemoteAddr = value.String() 172 | goto nexthandler 173 | } 174 | } 175 | 176 | if flags&OptUseXRealIP != 0 { 177 | if value := parseIP(r.Header.Get(xRealIP)); value != nil { 178 | r.RemoteAddr = value.String() 179 | goto nexthandler 180 | } 181 | } 182 | 183 | if flags&OptUseTrueClientIP != 0 { 184 | if value := parseIP(r.Header.Get(trueClientIP)); value != nil { 185 | r.RemoteAddr = value.String() 186 | goto nexthandler 187 | } 188 | } 189 | 190 | if flags&OptUseXForwardedFor != 0 { 191 | if value, valid := rip.parseForwardedFor(r.Header.Get(xForwardedFor)); valid && value != "" { 192 | r.RemoteAddr = value 193 | goto nexthandler 194 | } 195 | } 196 | 197 | nexthandler: 198 | next.ServeHTTP(w, r) 199 | }) 200 | } 201 | } 202 | 203 | type realIP struct { 204 | trusted []*net.IPNet 205 | } 206 | 207 | // isTrustedProxy will check whether the IP address is included in the trusted 208 | // list according to realIP.trusted. 209 | func (rip *realIP) isTrustedProxy(ip net.IP) bool { 210 | if ip == nil || rip.trusted == nil { 211 | return false 212 | } 213 | 214 | for _, cidr := range rip.trusted { 215 | if cidr.Contains(ip) { 216 | return true 217 | } 218 | } 219 | 220 | return false 221 | } 222 | 223 | // parseForwardedFor will parse the X-Forwarded-For header in the proper 224 | // direction (reversed). 225 | func (rip *realIP) parseForwardedFor(value string) (clientIP string, valid bool) { 226 | if value == "" { 227 | return "", false 228 | } 229 | 230 | items := strings.Split(value, ",") 231 | 232 | // X-Forwarded-For is appended by each proxy. Check IPs in reverse order 233 | // and stop when find untrusted proxy. 234 | for i := len(items) - 1; i >= 0; i-- { 235 | raw := strings.TrimSpace(items[i]) 236 | 237 | ip := net.ParseIP(raw) 238 | if ip == nil { 239 | break 240 | } 241 | 242 | if (i == 0) || (!rip.isTrustedProxy(ip)) { 243 | return raw, true 244 | } 245 | } 246 | 247 | return "", false 248 | } 249 | 250 | // parseIP parse a string representation of an IP and returns a net.IP with 251 | // the appropriate byte representation or nil, if the input is invalid. 252 | func parseIP(ip string) net.IP { 253 | parsedIP := net.ParseIP(strings.TrimSpace(ip)) 254 | 255 | if parsedIP != nil { 256 | if v4 := parsedIP.To4(); v4 != nil { 257 | return v4 258 | } 259 | } 260 | 261 | return parsedIP 262 | } 263 | 264 | // UsePrivateIP can be used to allow only private IP's to access specific 265 | // routes. Make sure to register this middleware after UseRealIP, otherwise 266 | // the IP checking may be incorrect. 267 | func UsePrivateIP(next http.Handler) http.Handler { 268 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 269 | if ok, _ := bogon.Is(sanitizeIP(r.RemoteAddr)); ok { 270 | next.ServeHTTP(w, r) 271 | return 272 | } 273 | 274 | _ = Error(w, r, WrapError(ErrAccessDenied, http.StatusForbidden)) 275 | }) 276 | } 277 | 278 | // UseContextIP can be used to add the requests IP to the context. This is beneficial 279 | // for passing the request context to a request-unaware function/method/service, that 280 | // does not have access to the original request. Ensure that this middleware is 281 | // registered after UseRealIP, otherwise the stored IP may be incorrect. 282 | func UseContextIP(next http.Handler) http.Handler { 283 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 284 | next.ServeHTTP(w, r.WithContext( 285 | context.WithValue( 286 | r.Context(), 287 | contextIP, 288 | parseIP(sanitizeIP(r.RemoteAddr)), 289 | ), 290 | )) 291 | }) 292 | } 293 | 294 | // GetContextIP can be used to retrieve the IP from the context, that was previously 295 | // set by UseContextIP. If no IP was set, nil is returned. 296 | func GetContextIP(ctx context.Context) net.IP { 297 | if ip, ok := ctx.Value(contextIP).(net.IP); ok { 298 | return ip 299 | } 300 | 301 | return nil 302 | } 303 | 304 | func sanitizeIP(input string) (ip string) { 305 | ip, _, err := net.SplitHostPort(strings.TrimSpace(input)) 306 | if err != nil || ip == "" { 307 | ip = input 308 | } 309 | return ip 310 | } 311 | -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | # THIS FILE IS GENERATED! DO NOT EDIT! Maintained by Terraform. 2 | # 3 | # golangci-lint: https://golangci-lint.run/ 4 | # false-positives: https://golangci-lint.run/usage/false-positives/ 5 | # actual source: https://github.com/lrstanley/.github/blob/master/terraform/github-common-files/templates/.golangci.yml 6 | # modified variant of: https://gist.github.com/maratori/47a4d00457a92aa426dbd48a18776322 7 | 8 | version: "2" 9 | 10 | formatters: 11 | enable: [gofumpt] 12 | 13 | issues: 14 | max-issues-per-linter: 0 15 | max-same-issues: 50 16 | 17 | severity: 18 | default: error 19 | rules: 20 | - linters: 21 | - errcheck 22 | - gocritic 23 | severity: warning 24 | 25 | linters: 26 | default: none 27 | enable: 28 | - asasalint # checks for pass []any as any in variadic func(...any) 29 | - asciicheck # checks that your code does not contain non-ASCII identifiers 30 | - bidichk # checks for dangerous unicode character sequences 31 | - bodyclose # checks whether HTTP response body is closed successfully 32 | - canonicalheader # checks whether net/http.Header uses canonical header 33 | - copyloopvar # detects places where loop variables are copied (Go 1.22+) 34 | - depguard # checks if package imports are in a list of acceptable packages 35 | - dupl # tool for code clone detection 36 | - durationcheck # checks for two durations multiplied together 37 | - errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases 38 | - errname # checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error 39 | - errorlint # finds code that will cause problems with the error wrapping scheme introduced in Go 1.13 40 | - exhaustive # checks exhaustiveness of enum switch statements 41 | - exptostd # detects functions from golang.org/x/exp/ that can be replaced by std functions 42 | - fatcontext # detects nested contexts in loops 43 | - forbidigo # forbids identifiers 44 | - funlen # tool for detection of long functions 45 | - gocheckcompilerdirectives # validates go compiler directive comments (//go:) 46 | - gochecknoinits # checks that no init functions are present in Go code 47 | - gochecksumtype # checks exhaustiveness on Go "sum types" 48 | - gocognit # computes and checks the cognitive complexity of functions 49 | - goconst # finds repeated strings that could be replaced by a constant 50 | - gocritic # provides diagnostics that check for bugs, performance and style issues 51 | - godot # checks if comments end in a period 52 | - gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod 53 | - goprintffuncname # checks that printf-like functions are named with f at the end 54 | - gosec # inspects source code for security problems 55 | - govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string 56 | - iface # checks the incorrect use of interfaces, helping developers avoid interface pollution 57 | - ineffassign # detects when assignments to existing variables are not used 58 | - intrange # finds places where for loops could make use of an integer range 59 | - loggercheck # checks key value pairs for common logger libraries (kitlog,klog,logr,zap) 60 | - makezero # finds slice declarations with non-zero initial length 61 | - mirror # reports wrong mirror patterns of bytes/strings usage 62 | - misspell # [useless] finds commonly misspelled English words in comments 63 | - musttag # enforces field tags in (un)marshaled structs 64 | - nakedret # finds naked returns in functions greater than a specified function length 65 | - nestif # reports deeply nested if statements 66 | - nilerr # finds the code that returns nil even if it checks that the error is not nil 67 | - nilnesserr # reports that it checks for err != nil, but it returns a different nil value error (powered by nilness and nilerr) 68 | - nilnil # checks that there is no simultaneous return of nil error and an invalid value 69 | - noctx # finds sending http request without context.Context 70 | - nosprintfhostport # checks for misuse of Sprintf to construct a host with port in a URL 71 | - perfsprint # checks that fmt.Sprintf can be replaced with a faster alternative 72 | - predeclared # finds code that shadows one of Go's predeclared identifiers 73 | - promlinter # checks Prometheus metrics naming via promlint 74 | - reassign # checks that package variables are not reassigned 75 | - recvcheck # checks for receiver type consistency 76 | - revive # fast, configurable, extensible, flexible, and beautiful linter for Go, drop-in replacement of golint 77 | - rowserrcheck # checks whether Err of rows is checked successfully 78 | - sloglint # ensure consistent code style when using log/slog 79 | - sqlclosecheck # checks that sql.Rows and sql.Stmt are closed 80 | - staticcheck # is a go vet on steroids, applying a ton of static analysis checks 81 | - testableexamples # checks if examples are testable (have an expected output) 82 | - testifylint # checks usage of github.com/stretchr/testify 83 | - tparallel # detects inappropriate usage of t.Parallel() method in your Go test codes 84 | - unconvert # removes unnecessary type conversions 85 | - unparam # reports unused function parameters 86 | - unused # checks for unused constants, variables, functions and types 87 | - usestdlibvars # detects the possibility to use variables/constants from the Go standard library 88 | - usetesting # reports uses of functions with replacement inside the testing package 89 | - wastedassign # finds wasted assignment statements 90 | - whitespace # detects leading and trailing whitespace 91 | 92 | settings: 93 | gocognit: 94 | min-complexity: 40 95 | errcheck: 96 | check-type-assertions: true 97 | funlen: 98 | lines: 150 99 | statements: 75 100 | ignore-comments: true 101 | gocritic: 102 | disabled-checks: 103 | - whyNoLint 104 | - hugeParam 105 | - ifElseChain 106 | - singleCaseSwitch 107 | enabled-tags: 108 | - diagnostic 109 | - opinionated 110 | - performance 111 | - style 112 | settings: 113 | captLocal: 114 | paramsOnly: false 115 | underef: 116 | skipRecvDeref: false 117 | rangeValCopy: 118 | sizeThreshold: 512 119 | depguard: 120 | rules: 121 | "deprecated": 122 | files: ["$all"] 123 | deny: 124 | - pkg: github.com/golang/protobuf 125 | desc: Use google.golang.org/protobuf instead, see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules 126 | - pkg: github.com/satori/go.uuid 127 | desc: Use github.com/google/uuid instead, satori's package is not maintained 128 | - pkg: github.com/gofrs/uuid$ 129 | desc: Use github.com/gofrs/uuid/v5 or later, it was not a go module before v5 130 | - pkg: github.com/lrstanley/clix$ 131 | desc: Use github.com/lrstanley/clix/v2 instead 132 | - pkg: github.com/lrstanley/chix$ 133 | desc: Use github.com/lrstanley/chix/v2 instead 134 | - pkg: log$ 135 | desc: Use log/slog instead, see https://go.dev/blog/slog 136 | "non-test files": 137 | files: ["!$test"] 138 | deny: 139 | - pkg: math/rand$ 140 | desc: Use math/rand/v2 instead, see https://go.dev/blog/randv2 141 | "incorrect import": 142 | files: ["$test"] 143 | deny: 144 | - pkg: github.com/tj/assert$ 145 | desc: Use github.com/stretchr/testify/assert instead, see 146 | gochecksumtype: 147 | default-signifies-exhaustive: false 148 | exhaustive: 149 | check: 150 | - switch 151 | - map 152 | govet: 153 | disable: 154 | - fieldalignment 155 | enable-all: true 156 | settings: 157 | shadow: 158 | strict: true 159 | perfsprint: 160 | strconcat: false 161 | nakedret: 162 | max-func-lines: 0 163 | nestif: 164 | min-complexity: 10 165 | rowserrcheck: 166 | packages: 167 | - github.com/jmoiron/sqlx 168 | sloglint: 169 | no-global: default 170 | context: scope 171 | msg-style: lowercased 172 | static-msg: true 173 | forbidden-keys: 174 | - time 175 | - level 176 | - source 177 | staticcheck: 178 | checks: 179 | - all 180 | # Incorrect or missing package comment: https://staticcheck.dev/docs/checks/#ST1000 181 | - -ST1000 182 | # Use consistent method receiver names: https://staticcheck.dev/docs/checks/#ST1016 183 | - -ST1016 184 | # Omit embedded fields from selector expression: https://staticcheck.dev/docs/checks/#QF1008 185 | - -QF1008 186 | # duplicate struct tags -- used commonly for things like go-flags. 187 | - -SA5008 188 | usetesting: 189 | os-temp-dir: true 190 | exclusions: 191 | warn-unused: true 192 | generated: lax 193 | presets: 194 | - common-false-positives 195 | - std-error-handling 196 | paths: 197 | - ".*\\.gen\\.go$" 198 | - ".*\\.gen_test\\.go$" 199 | rules: 200 | - source: "TODO" 201 | linters: [godot] 202 | - text: "should have a package comment" 203 | linters: [revive] 204 | - text: 'exported \S+ \S+ should have comment( \(or a comment on this block\))? or be unexported' 205 | linters: [revive] 206 | - text: 'package comment should be of the form ".+"' 207 | source: "// ?(nolint|TODO)" 208 | linters: [revive] 209 | - text: 'comment on exported \S+ \S+ should be of the form ".+"' 210 | source: "// ?(nolint|TODO)" 211 | linters: [revive, staticcheck] 212 | - text: 'unexported-return: exported func \S+ returns unexported type \S+ .*' 213 | linters: [revive] 214 | - text: "var-declaration: should drop .* from declaration of .*; it is the zero value" 215 | linters: [revive] 216 | - text: ".*use ALL_CAPS in Go names.*" 217 | linters: [revive, staticcheck] 218 | - text: '.* always receives \S+' 219 | linters: [unparam] 220 | - path: _test\.go 221 | linters: 222 | - bodyclose 223 | - dupl 224 | - funlen 225 | - gocognit 226 | - goconst 227 | - gosec 228 | - noctx 229 | - wrapcheck 230 | -------------------------------------------------------------------------------- /realip_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | ) 12 | 13 | var testsRealIP = []struct { 14 | name string 15 | args []string 16 | headers map[string]string 17 | remoteAddr string 18 | wantRealIP string 19 | }{ 20 | { 21 | name: "ipv4:none:bogon:untrusted", 22 | args: []string{"x-forwarded-for", "local"}, 23 | headers: map[string]string{}, 24 | remoteAddr: "1.2.3.4:12345", 25 | wantRealIP: "1.2.3.4:12345", 26 | }, 27 | { 28 | name: "ipv4:x-forwarded-for:bogon:invalid-remote-addr", 29 | args: []string{"x-forwarded-for", "local"}, 30 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 31 | remoteAddr: "invalid", 32 | wantRealIP: "invalid", 33 | }, 34 | { 35 | name: "ipv4:cf-connecting-ip:cloudflare:trusted", 36 | args: []string{"cloudflare"}, 37 | headers: map[string]string{"CF-Connecting-IP": "1.1.1.1"}, 38 | remoteAddr: "173.245.48.0:12345", 39 | wantRealIP: "1.1.1.1", 40 | }, 41 | { 42 | name: "ipv4:x-forwarded-for:cloudflare:untrusted", 43 | args: []string{"cloudflare"}, 44 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 45 | remoteAddr: "173.245.48.0:12345", 46 | wantRealIP: "173.245.48.0:12345", 47 | }, 48 | { 49 | name: "ipv4:x-forwarded-for:cloudflare:untrusted", 50 | args: []string{"cloudflare"}, 51 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 52 | remoteAddr: "1.2.3.4:12345", 53 | wantRealIP: "1.2.3.4:12345", 54 | }, 55 | { 56 | name: "ipv4:cf-connecting-ip:cloudflare:untrusted", 57 | args: []string{"cloudflare"}, 58 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 59 | remoteAddr: "1.2.3.4:12345", 60 | wantRealIP: "1.2.3.4:12345", 61 | }, 62 | { 63 | name: "ipv4:x-forwarded-for-invalid:bogon:untrusted", 64 | args: []string{"x-forwarded-for", "local"}, 65 | headers: map[string]string{"X-Forwarded-For": "1.1.1.999"}, 66 | remoteAddr: "10.1.2.3:12345", 67 | wantRealIP: "10.1.2.3:12345", 68 | }, 69 | { 70 | name: "ipv6:x-forwarded-for:bogon:trusted-different-protocol", 71 | args: []string{"x-forwarded-for", "local"}, 72 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 73 | remoteAddr: "[::1]:12345", 74 | wantRealIP: "1.1.1.1", 75 | }, 76 | { 77 | name: "ipv6:x-forwarded-for:bogon:trusted-same-protocol", 78 | args: []string{"x-forwarded-for", "local"}, 79 | headers: map[string]string{"X-Forwarded-For": "2607:f8b0:4002:c00::8b"}, 80 | remoteAddr: "[::1]:12345", 81 | wantRealIP: "2607:f8b0:4002:c00::8b", 82 | }, 83 | { 84 | name: "ipv6:x-forwarded-for:bogon:untrusted-1", 85 | args: []string{"x-forwarded-for", "local"}, 86 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 87 | remoteAddr: "[2607:f8b0:4002:c00::8a]:12345", 88 | wantRealIP: "[2607:f8b0:4002:c00::8a]:12345", 89 | }, 90 | { 91 | name: "ipv6:x-forwarded-for:bogon:untrusted-2", 92 | args: []string{"x-forwarded-for", "local"}, 93 | headers: map[string]string{"X-Forwarded-For": "2607:f8b0:4002:c00::8b"}, 94 | remoteAddr: "[2607:f8b0:4002:c00::8a]:12345", 95 | wantRealIP: "[2607:f8b0:4002:c00::8a]:12345", 96 | }, 97 | { 98 | name: "ipv4:x-forwarded-for:bogon:untrusted", 99 | args: []string{"x-forwarded-for", "local"}, 100 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 101 | remoteAddr: "1.2.3.4:12345", 102 | wantRealIP: "1.2.3.4:12345", 103 | }, 104 | { 105 | name: "ipv4:x-forwarded-for:custom-cidr:untrusted", 106 | args: []string{"x-forwarded-for", "8.8.8.8/32"}, 107 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 108 | remoteAddr: "1.2.3.4:12345", 109 | wantRealIP: "1.2.3.4:12345", 110 | }, 111 | { 112 | name: "ipv4:x-forwarded-for:custom-cidr-multiple-1:trusted", 113 | args: []string{"x-forwarded-for", "8.0.0.0/8", "9.0.0.0/8"}, 114 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 115 | remoteAddr: "8.1.2.3:12345", 116 | wantRealIP: "1.1.1.1", 117 | }, 118 | { 119 | name: "ipv4:x-forwarded-for:custom-cidr-multiple-2:trusted", 120 | args: []string{"x-forwarded-for", "8.0.0.0/8", "9.0.0.0/8"}, 121 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 122 | remoteAddr: "9.1.2.3:12345", 123 | wantRealIP: "1.1.1.1", 124 | }, 125 | { 126 | name: "ipv4:x-forwarded-for:custom-cidr-multiple-3:untrusted", 127 | args: []string{"x-forwarded-for", "8.0.0.0/8", "9.0.0.0/8"}, 128 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 129 | remoteAddr: "10.1.2.3:12345", 130 | wantRealIP: "10.1.2.3:12345", 131 | }, 132 | { 133 | name: "ipv4:x-forwarded-for:custom-cidr:trusted", 134 | args: []string{"x-forwarded-for", "8.8.8.8/32"}, 135 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 136 | remoteAddr: "8.8.8.8:12345", 137 | wantRealIP: "1.1.1.1", 138 | }, 139 | { 140 | name: "ipv4:x-forwarded-for:one-ip:trusted", 141 | args: []string{"x-forwarded-for", "8.8.8.8"}, 142 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 143 | remoteAddr: "8.8.8.8:12345", 144 | wantRealIP: "1.1.1.1", 145 | }, 146 | { 147 | name: "ipv6:x-forwarded-for:one-ip:trusted", 148 | args: []string{"x-forwarded-for", "::1"}, 149 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 150 | remoteAddr: "[::1]:12345", 151 | wantRealIP: "1.1.1.1", 152 | }, 153 | { 154 | name: "ipv4:x-forwarded-for:all:trusted", 155 | args: []string{"x-forwarded-for", "all"}, 156 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 157 | remoteAddr: "8.8.8.8:12345", 158 | wantRealIP: "1.1.1.1", 159 | }, 160 | { 161 | name: "ipv4:x-forwarded-for:bogon:trusted", 162 | args: []string{"x-forwarded-for", "local"}, 163 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}, 164 | remoteAddr: "10.1.1.1:12345", 165 | wantRealIP: "1.1.1.1", 166 | }, 167 | { 168 | name: "ipv4:x-forwarded-for-multiple:bogon:trusted", 169 | args: []string{"x-forwarded-for", "local"}, 170 | headers: map[string]string{"X-Forwarded-For": "1.1.1.1,2.2.2.2"}, 171 | remoteAddr: "10.1.1.1:12345", 172 | wantRealIP: "2.2.2.2", 173 | }, 174 | { 175 | name: "ipv4:x-forwarded-for:bogon:x-real-ip", 176 | args: []string{"x-forwarded-for", "local"}, 177 | headers: map[string]string{"X-Real-IP": "1.1.1.1"}, 178 | remoteAddr: "1.2.3.4:12345", 179 | wantRealIP: "1.2.3.4:12345", 180 | }, 181 | { 182 | name: "ipv4:x-real-ip:bogon:untrusted", 183 | args: []string{"x-real-ip", "local"}, 184 | headers: map[string]string{"X-Real-IP": "1.1.1.1"}, 185 | remoteAddr: "1.2.3.4:12345", 186 | wantRealIP: "1.2.3.4:12345", 187 | }, 188 | { 189 | name: "ipv4:x-real-ip:bogon:trusted", 190 | args: []string{"x-real-ip", "local"}, 191 | headers: map[string]string{"X-Real-IP": "1.1.1.1"}, 192 | remoteAddr: "10.1.1.1:12345", 193 | wantRealIP: "1.1.1.1", 194 | }, 195 | { 196 | name: "ipv4:true-client-ip:bogon:untrusted", 197 | args: []string{"true-client-ip", "local"}, 198 | headers: map[string]string{"True-Client-IP": "1.1.1.1"}, 199 | remoteAddr: "1.2.3.4:12345", 200 | wantRealIP: "1.2.3.4:12345", 201 | }, 202 | { 203 | name: "ipv4:true-client-ip:bogon:trusted", 204 | args: []string{"true-client-ip", "local"}, 205 | headers: map[string]string{"True-Client-IP": "1.1.1.1"}, 206 | remoteAddr: "10.1.1.1:12345", 207 | wantRealIP: "1.1.1.1", 208 | }, 209 | } 210 | 211 | func FuzzUseRealIPCLIOpts(f *testing.F) { 212 | for _, tt := range testsRealIP { 213 | for _, v := range tt.headers { 214 | f.Add(v) 215 | } 216 | f.Add(tt.wantRealIP) 217 | f.Add(tt.remoteAddr) 218 | } 219 | 220 | f.Fuzz(func(t *testing.T, data string) { 221 | req := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody) 222 | req.RemoteAddr = "1.2.3.4:12345" 223 | req.Header.Set("X-Forwarded-For", data) 224 | 225 | handler := UseRealIPCLIOpts( 226 | []string{"x-forwarded-for", "all"}, 227 | )(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 228 | _ = parseIP(sanitizeIP(r.RemoteAddr)) 229 | })) 230 | 231 | handler.ServeHTTP(httptest.NewRecorder(), req) 232 | }) 233 | } 234 | 235 | func TestUseRealIPCLIOpts(t *testing.T) { 236 | for _, tt := range testsRealIP { 237 | t.Run(tt.name, func(t *testing.T) { 238 | req := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody) 239 | req.RemoteAddr = tt.remoteAddr 240 | 241 | for k, v := range tt.headers { 242 | req.Header.Set(k, v) 243 | } 244 | 245 | handler := UseRealIPCLIOpts(tt.args)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 246 | if r.RemoteAddr != tt.wantRealIP { 247 | t.Errorf("UseRealIPCLIOpts() = %v, want %v", r.RemoteAddr, tt.wantRealIP) 248 | } 249 | })) 250 | 251 | handler.ServeHTTP(httptest.NewRecorder(), req) 252 | }) 253 | } 254 | } 255 | 256 | func TestUsePrivateIP(t *testing.T) { 257 | tests := []struct { 258 | name string 259 | allowed bool 260 | remoteAddr string 261 | statusCode int 262 | }{ 263 | { 264 | name: "ipv4:private", 265 | allowed: true, 266 | remoteAddr: "10.1.2.3:12345", 267 | statusCode: http.StatusOK, 268 | }, 269 | { 270 | name: "ipv4:not-private", 271 | allowed: false, 272 | remoteAddr: "1.1.1.1:12345", 273 | statusCode: http.StatusForbidden, 274 | }, 275 | { 276 | name: "ipv6:private", 277 | allowed: true, 278 | remoteAddr: "[::1]:12345", 279 | statusCode: http.StatusOK, 280 | }, 281 | { 282 | name: "ipv6:not-private", 283 | allowed: false, 284 | remoteAddr: "[2001:4860:4860::8888]:12345", 285 | statusCode: http.StatusForbidden, 286 | }, 287 | } 288 | 289 | for _, tt := range tests { 290 | t.Run(tt.name, func(t *testing.T) { 291 | req := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody) 292 | req.RemoteAddr = tt.remoteAddr 293 | 294 | // Also test UseContextIP/GetContextIP. 295 | handler := UseContextIP(UsePrivateIP(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 296 | if !tt.allowed { 297 | t.Errorf("UsePrivateIP() = %v but allowed (true), want %v", r.RemoteAddr, tt.allowed) 298 | } 299 | 300 | if !GetContextIP(r.Context()).Equal(parseIP(sanitizeIP(r.RemoteAddr))) { 301 | t.Errorf("GetContextIP() = %v, want %v", r.RemoteAddr, GetContextIP(r.Context())) 302 | } 303 | 304 | w.WriteHeader(http.StatusOK) 305 | }))) 306 | 307 | rec := httptest.NewRecorder() 308 | handler.ServeHTTP(rec, req) 309 | 310 | if rec.Result().StatusCode != tt.statusCode { 311 | t.Errorf("UsePrivateIP() returned status %v, want %v", rec.Result().StatusCode, tt.statusCode) 312 | } 313 | }) 314 | } 315 | } 316 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/apex/log v1.9.0 h1:FHtw/xuaM8AgmvDDTI9fiwoAL25Sq2cxojnZICUU8l0= 2 | github.com/apex/log v1.9.0/go.mod h1:m82fZlWIuiWzWP04XCTXmnX0xRkYYbCdYn8jbJeLBEA= 3 | github.com/apex/logs v1.0.0/go.mod h1:XzxuLZ5myVHDy9SAmYpamKKRNApGj54PfYLcFrXqDwo= 4 | github.com/aphistic/golf v0.0.0-20180712155816-02c07f170c5a/go.mod h1:3NqKYiepwy8kCu4PNA+aP7WUV72eXWJeP9/r3/K9aLE= 5 | github.com/aphistic/sweet v0.2.0/go.mod h1:fWDlIh/isSE9n6EPsRmC0det+whmX6dJid3stzu0Xys= 6 | github.com/aws/aws-sdk-go v1.20.6/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= 7 | github.com/aybabtme/rgbterm v0.0.0-20170906152045-cc83f3b3ce59/go.mod h1:q/89r3U2H7sSsE2t6Kca0lfwTK8JdoNGS/yzM/4iH5I= 8 | github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= 9 | github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= 10 | github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= 11 | github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 12 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 13 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 14 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 15 | github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= 16 | github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= 17 | github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY= 18 | github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok= 19 | github.com/go-chi/chi/v5 v5.2.2 h1:CMwsvRVTbXVytCk1Wd72Zy1LAsAh9GxMmSNWLHCG618= 20 | github.com/go-chi/chi/v5 v5.2.2/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= 21 | github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= 22 | github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= 23 | github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= 24 | github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= 25 | github.com/go-playground/form/v4 v4.2.1 h1:HjdRDKO0fftVMU5epjPW2SOREcZ6/wLUzEobqUGJuPw= 26 | github.com/go-playground/form/v4 v4.2.1/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= 27 | github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= 28 | github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= 29 | github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= 30 | github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= 31 | github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k= 32 | github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= 33 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 34 | github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 35 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= 36 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= 37 | github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= 38 | github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= 39 | github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 40 | github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= 41 | github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= 42 | github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= 43 | github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= 44 | github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= 45 | github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik= 46 | github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= 47 | github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= 48 | github.com/jpillora/backoff v0.0.0-20180909062703-3050d21c67d7/go.mod h1:2iMrUgbbvHEiQClaW2NsSzMyGHqN+rDFqY705q49KG0= 49 | github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= 50 | github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= 51 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 52 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 53 | github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= 54 | github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= 55 | github.com/lrstanley/go-bogon v1.0.0 h1:EhFN3Bu+59u9g8n+xWCzRlbSfNKqAbhsjzW8lcMrG0g= 56 | github.com/lrstanley/go-bogon v1.0.0/go.mod h1:1H1sGTRZ05IO1sQHKLAQQ34v19KrQeYg2Ix9HgJuFXQ= 57 | github.com/markbates/goth v1.81.0 h1:XVcCkeGWokynPV7MXvgb8pd2s3r7DS40P7931w6kdnE= 58 | github.com/markbates/goth v1.81.0/go.mod h1:+6z31QyUms84EHmuBY7iuqYSxyoN3njIgg9iCF/lR1k= 59 | github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= 60 | github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= 61 | github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= 62 | github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= 63 | github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= 64 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= 65 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= 66 | github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= 67 | github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= 68 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 69 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 70 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 71 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 72 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 73 | github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= 74 | github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= 75 | github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= 76 | github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= 77 | github.com/prometheus/common v0.65.0 h1:QDwzd+G1twt//Kwj/Ww6E9FQq1iVMmODnILtW1t2VzE= 78 | github.com/prometheus/common v0.65.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= 79 | github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= 80 | github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= 81 | github.com/rogpeppe/fastuuid v1.1.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= 82 | github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= 83 | github.com/smartystreets/assertions v1.0.0/go.mod h1:kHHU4qYBaI3q23Pp3VPrmWhuIUrLW/7eUrw0BU5VaoM= 84 | github.com/smartystreets/go-aws-auth v0.0.0-20180515143844-0c1422d1fdb9/go.mod h1:SnhjPscd9TpLiy1LpzGSKh3bXCfxxXuqd9xmQJy3slM= 85 | github.com/smartystreets/gunit v1.0.0/go.mod h1:qwPWnhz6pn0NnRBP++URONOVyNkPyr4SauJk4cUOwJs= 86 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 87 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 88 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 89 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 90 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 91 | github.com/tj/assert v0.0.0-20171129193455-018094318fb0/go.mod h1:mZ9/Rh9oLWpLLDRpvE+3b7gP/C2YyLFYxNmcLnPTMe0= 92 | github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk= 93 | github.com/tj/assert v0.0.3/go.mod h1:Ne6X72Q+TB1AteidzQncjw9PabbMp4PBMZ1k+vd1Pvk= 94 | github.com/tj/go-buffer v1.1.0/go.mod h1:iyiJpfFcR2B9sXu7KvjbT9fpM4mOelRSDTbntVj52Uc= 95 | github.com/tj/go-elastic v0.0.0-20171221160941-36157cbbebc2/go.mod h1:WjeM0Oo1eNAjXGDx2yma7uG2XoyRZTq1uv3M/o7imD0= 96 | github.com/tj/go-kinesis v0.0.0-20171128231115-08b17f58cb1b/go.mod h1:/yhzCV0xPfx6jb1bBgRFjl5lytqVqZXEaeqWP8lTEao= 97 | github.com/tj/go-spin v1.1.0/go.mod h1:Mg1mzmePZm4dva8Qz60H2lHwmJ2loum4VIrLgVnKwh4= 98 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 99 | golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 100 | golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= 101 | golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= 102 | golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 103 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 104 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 105 | golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= 106 | golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= 107 | golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= 108 | golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= 109 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 110 | golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= 111 | golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 112 | golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 113 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 114 | golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 115 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 116 | golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= 117 | golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 118 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 119 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= 120 | golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= 121 | golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= 122 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 123 | google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= 124 | google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= 125 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 126 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 127 | gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= 128 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= 129 | gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 130 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 131 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 132 | gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 133 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 134 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 135 | -------------------------------------------------------------------------------- /auth.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) Liam Stanley . All rights reserved. Use of 2 | // this source code is governed by the MIT license that can be found in 3 | // the LICENSE file. 4 | 5 | package chix 6 | 7 | import ( 8 | "context" 9 | "encoding/hex" 10 | "fmt" 11 | "net/http" 12 | "strconv" 13 | "strings" 14 | "sync" 15 | 16 | "github.com/go-chi/chi/v5" 17 | "github.com/gorilla/sessions" 18 | "github.com/markbates/goth" 19 | "github.com/markbates/goth/gothic" 20 | ) 21 | 22 | var ( 23 | // DefaultCookieMaxAge is the max age for the session cookie. 24 | DefaulltCookieMaxAge = 30 * 86400 25 | 26 | gothInit sync.Once 27 | 28 | CookieStoreHook = func(_ *sessions.CookieStore) {} 29 | ) 30 | 31 | func initGothStore(authKey, encryptKey string) { 32 | authKeyBytes, err := hex.DecodeString(authKey) 33 | if err != nil { 34 | panic(err) 35 | } 36 | encryptKeyBytes, err := hex.DecodeString(encryptKey) 37 | if err != nil { 38 | panic(err) 39 | } 40 | 41 | gothInit.Do(func() { 42 | authStore := sessions.NewCookieStore(authKeyBytes, encryptKeyBytes) 43 | authStore.MaxAge(DefaulltCookieMaxAge) 44 | authStore.Options.Path = "/" 45 | authStore.Options.HttpOnly = true 46 | authStore.Options.SameSite = http.SameSiteLaxMode 47 | authStore.Options.Partitioned = true 48 | if CookieStoreHook != nil { 49 | CookieStoreHook(authStore) 50 | } 51 | gothic.Store = authStore 52 | }) 53 | } 54 | 55 | type AuthServiceReader[Ident any, ID comparable] interface { 56 | Get(context.Context, ID) (*Ident, error) 57 | Roles(context.Context, ID) ([]string, error) 58 | } 59 | 60 | // AuthService is the interface for the authentication service. This will 61 | // need to be implemented to utilize AuthHandler. 62 | type AuthService[Ident any, ID comparable] interface { 63 | Get(context.Context, ID) (*Ident, error) 64 | Set(context.Context, *goth.User) (ID, error) 65 | Roles(context.Context, ID) ([]string, error) 66 | } 67 | 68 | // NewAuthHandler creates a new AuthHandler. authKey is used to validate the 69 | // session cookie. encryptKey is used to encrypt the session cookie. 70 | // 71 | // It is recommended to use an authentication key with 32 or 64 bytes. The 72 | // encryption key, if set, must be either 16, 24, or 32 bytes to select 73 | // AES-128, AES-192, or AES-256 modes. Provide the keys in hexadecimal string 74 | // format. The following link can be used to generate a random key: 75 | // - https://go.dev/play/p/xwcJmQNU8ku 76 | // 77 | // The following endpoints are implemented: 78 | // - GET: /self - returns the current user authentication info. 79 | // - GET: /providers - returns a list of all available providers. 80 | // - GET: /providers/{provider} - initiates the provider authentication. 81 | // - GET: /providers/{provider}/callback - redirect target from the provider. 82 | // - GET: /logout - logs the user out. 83 | func NewAuthHandler[Ident any, ID comparable]( 84 | auth AuthService[Ident, ID], 85 | authKey, encryptKey string, 86 | ) *AuthHandler[Ident, ID] { 87 | initGothStore(authKey, encryptKey) 88 | 89 | h := &AuthHandler[Ident, ID]{ 90 | Auth: auth, 91 | Ident: new(Ident), 92 | ID: new(ID), 93 | errorHandler: Error, 94 | } 95 | 96 | router := chi.NewRouter() 97 | router.With(h.AddToContext, h.AuthRequired).Get("/self", h.self) 98 | router.Get("/providers", h.providers) 99 | router.Get("/providers/{provider}", h.provider) 100 | router.Get("/providers/{provider}/callback", h.callback) 101 | router.Get("/logout", h.logout) 102 | h.router = router 103 | 104 | AddLogHandler(func(r *http.Request) M { 105 | id := getAuthIDFromSession[ID](r) 106 | if id == nil { 107 | return M{"user_id": nil} 108 | } 109 | return M{"user_id": *id} 110 | }) 111 | 112 | return h 113 | } 114 | 115 | // AuthHandler wraps all authentication logic for oauth calls. 116 | type AuthHandler[Ident any, ID comparable] struct { 117 | Auth AuthService[Ident, ID] 118 | Ident *Ident 119 | ID *ID 120 | router http.Handler 121 | errorHandler ErrorHandler 122 | } 123 | 124 | // SetErrorHandler sets the error handler for AuthHandler. This error handler will 125 | // only be used for errors that occur within the callback process, NOT for middleware, 126 | // in which chix.Error() will still be used. 127 | func (h *AuthHandler[Ident, ID]) SetErrorHandler(handler ErrorHandler) { 128 | h.errorHandler = handler 129 | } 130 | 131 | // ServeHTTP implements http.Handler. 132 | func (h *AuthHandler[Ident, ID]) ServeHTTP(w http.ResponseWriter, r *http.Request) { 133 | h.router.ServeHTTP(w, r) 134 | } 135 | 136 | func (h *AuthHandler[Ident, ID]) providers(w http.ResponseWriter, r *http.Request) { 137 | providers := goth.GetProviders() 138 | var data []string 139 | for _, p := range providers { 140 | data = append(data, p.Name()) 141 | } 142 | 143 | JSON(w, r, http.StatusOK, M{"providers": data}) 144 | } 145 | 146 | func (h *AuthHandler[Ident, ID]) provider(w http.ResponseWriter, r *http.Request) { 147 | gothic.BeginAuthHandler(w, gothic.GetContextWithProvider(r, chi.URLParam(r, "provider"))) 148 | } 149 | 150 | func (h *AuthHandler[Ident, ID]) callback(w http.ResponseWriter, r *http.Request) { 151 | guser, err := gothic.CompleteUserAuth(w, gothic.GetContextWithProvider(r, chi.URLParam(r, "provider"))) 152 | if err != nil { 153 | h.errorHandler(w, r, err) 154 | return 155 | } 156 | 157 | id, err := h.Auth.Set(r.Context(), &guser) 158 | if err != nil { 159 | h.errorHandler(w, r, err) 160 | return 161 | } 162 | 163 | if err = gothic.StoreInSession(authSessionKey, fmt.Sprintf("%v", id), r, w); err != nil { 164 | h.errorHandler(w, r, err) 165 | return 166 | } 167 | SecureRedirect(w, r, http.StatusTemporaryRedirect, "/") 168 | } 169 | 170 | func (h *AuthHandler[Ident, ID]) logout(w http.ResponseWriter, r *http.Request) { 171 | _ = gothic.Logout(w, r) 172 | SecureRedirect(w, r, http.StatusFound, "/") 173 | } 174 | 175 | func (h *AuthHandler[Ident, ID]) self(w http.ResponseWriter, r *http.Request) { 176 | JSON(w, r, http.StatusOK, M{"auth": IdentFromContext[Ident](r.Context())}) 177 | } 178 | 179 | // Deprecated: use [IdentFromContext] instead. 180 | func (h *AuthHandler[Ident, ID]) FromContext(ctx context.Context) (auth *Ident) { 181 | return IdentFromContext[Ident](ctx) 182 | } 183 | 184 | // Deprecated: use [RolesFromContext] instead. 185 | func (h *AuthHandler[Ident, ID]) RolesFromContext(ctx context.Context) (roles AuthRoles) { 186 | return RolesFromContext(ctx) 187 | } 188 | 189 | // Deprecated: use [UseAuthContext] instead. 190 | func (h *AuthHandler[Ident, ID]) AddToContext(next http.Handler) http.Handler { 191 | return UseAuthContext(h.Auth)(next) 192 | } 193 | 194 | // Deprecated: use [UseAuthRequired] instead. 195 | func (h *AuthHandler[Ident, ID]) AuthRequired(next http.Handler) http.Handler { 196 | return UseAuthRequired[Ident](next) 197 | } 198 | 199 | // Deprecated: use [UseRoleRequired] instead. 200 | func (h *AuthHandler[Ident, ID]) RoleRequired(role string) func(http.Handler) http.Handler { 201 | return UseRoleRequired[ID](role) 202 | } 203 | 204 | type BasicAuthService[Ident any] interface { 205 | BasicAuth(context.Context, string, string) (*Ident, error) 206 | Get(context.Context, string) (*Ident, error) 207 | Roles(context.Context, string) ([]string, error) 208 | } 209 | 210 | // NewAuthHandler creates a new AuthHandler. authKey is used to validate the 211 | // session cookie. encryptKey is used to encrypt the session cookie. 212 | // 213 | // It is recommended to use an authentication key with 32 or 64 bytes. The 214 | // encryption key, if set, must be either 16, 24, or 32 bytes to select 215 | // AES-128, AES-192, or AES-256 modes. Provide the keys in hexadecimal string 216 | // format. The following link can be used to generate a random key: 217 | // - https://go.dev/play/p/xwcJmQNU8ku 218 | // 219 | // The following endpoints are implemented: 220 | // - GET: /self - returns the current user authentication info. 221 | // - GET: /login - initiates the provider authentication, using basic auth. 222 | // - GET: /logout - logs the user out. 223 | func NewBasicAuthHandler[Ident any]( 224 | auth BasicAuthService[Ident], 225 | authKey, encryptKey string, 226 | ) *BasicAuthHandler[Ident] { 227 | initGothStore(authKey, encryptKey) 228 | 229 | h := &BasicAuthHandler[Ident]{ 230 | Auth: auth, 231 | Ident: new(Ident), 232 | errorHandler: Error, 233 | } 234 | 235 | router := chi.NewRouter() 236 | router.With(UseAuthContext(auth), UseAuthRequired[Ident]).Get("/self", h.self) 237 | router.Get("/login", h.login) 238 | router.Get("/logout", h.logout) 239 | h.router = router 240 | 241 | AddLogHandler(func(r *http.Request) M { 242 | id := getAuthIDFromSession[string](r) 243 | if id == nil { 244 | return M{"user_id": nil} 245 | } 246 | return M{"user_id": *id} 247 | }) 248 | 249 | return h 250 | } 251 | 252 | // BasicAuthHandler wraps all authentication logic for basic auth calls. 253 | type BasicAuthHandler[Ident any] struct { 254 | Auth BasicAuthService[Ident] 255 | Ident *Ident 256 | router http.Handler 257 | errorHandler ErrorHandler 258 | } 259 | 260 | // SetErrorHandler sets the error handler for BasicAuthHandler. This error handler will 261 | // only be used for errors that occur within the callback process, NOT for middleware, 262 | // in which chix.Error() will still be used. 263 | func (h *BasicAuthHandler[Ident]) SetErrorHandler(handler ErrorHandler) { 264 | h.errorHandler = handler 265 | } 266 | 267 | // ServeHTTP implements http.Handler. 268 | func (h *BasicAuthHandler[Ident]) ServeHTTP(w http.ResponseWriter, r *http.Request) { 269 | h.router.ServeHTTP(w, r) 270 | } 271 | 272 | func (h *BasicAuthHandler[Ident]) login(w http.ResponseWriter, r *http.Request) { 273 | // Check if they've already logged in. 274 | _, ok := r.Context().Value(contextAuth).(*Ident) 275 | if ok { 276 | SecureRedirect(w, r, http.StatusTemporaryRedirect, "/") 277 | return 278 | } 279 | 280 | user, pass, ok := r.BasicAuth() 281 | if !ok { 282 | w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`) 283 | _ = Error(w, r, WrapError(ErrAuthNotFound, http.StatusUnauthorized)) 284 | return 285 | } 286 | 287 | _, err := h.Auth.BasicAuth(r.Context(), user, pass) 288 | if err != nil { 289 | _ = Error(w, r, WrapError(err, http.StatusUnauthorized)) 290 | return 291 | } 292 | 293 | if err = gothic.StoreInSession(authSessionKey, user, r, w); err != nil { 294 | h.errorHandler(w, r, err) 295 | return 296 | } 297 | SecureRedirect(w, r, http.StatusTemporaryRedirect, "/") 298 | } 299 | 300 | func (h *BasicAuthHandler[Ident]) logout(w http.ResponseWriter, r *http.Request) { 301 | _ = gothic.Logout(w, r) 302 | SecureRedirect(w, r, http.StatusFound, "/") 303 | } 304 | 305 | func (h *BasicAuthHandler[Ident]) self(w http.ResponseWriter, r *http.Request) { 306 | JSON(w, r, http.StatusOK, M{"auth": IdentFromContext[Ident](r.Context())}) 307 | } 308 | 309 | // OverrideContextAuth overrides the authentication information in the request, 310 | // and returns a new context with the updated information. This is useful for 311 | // when you want to temporarily override the authentication information in the 312 | // request, such as when you want to impersonate another user, or for mocking in 313 | // tests. 314 | func OverrideContextAuth[Ident any, ID comparable](parent context.Context, id ID, ident *Ident, roles []string) context.Context { 315 | ctx := context.WithValue(parent, contextAuth, ident) 316 | ctx = context.WithValue(ctx, contextAuthID, id) 317 | ctx = context.WithValue(ctx, contextAuthRoles, roles) 318 | return ctx 319 | } 320 | 321 | // UseAuthContext adds the user authentication info to the request context, using 322 | // the cookie session information. If used more than once in the same request 323 | // middleware chain, it will be a no-op. 324 | func UseAuthContext[Ident any, ID comparable, Service AuthServiceReader[Ident, ID]](auth Service) func(next http.Handler) http.Handler { 325 | return func(next http.Handler) http.Handler { 326 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 327 | _, ok := r.Context().Value(contextAuth).(*Ident) 328 | if ok { // Already in the context. 329 | next.ServeHTTP(w, r) 330 | return 331 | } 332 | 333 | id := getAuthIDFromSession[ID](r) 334 | if id == nil { 335 | next.ServeHTTP(w, r) 336 | return 337 | } 338 | 339 | ident, err := auth.Get(r.Context(), *id) 340 | if err != nil { 341 | Log(r).WithError(err).WithField("user_id", *id).Warn("failed to get ident from session (but id set)") 342 | next.ServeHTTP(w, r) 343 | return 344 | } 345 | 346 | ctx := context.WithValue(r.Context(), contextAuth, ident) 347 | ctx = context.WithValue(ctx, contextAuthID, *id) 348 | 349 | roles, err := auth.Roles(r.Context(), *id) 350 | if err != nil { 351 | Log(r).WithError(err).WithField("user_id", *id).Warn("failed to get roles from session (but id set)") 352 | } else { 353 | ctx = context.WithValue(ctx, contextAuthRoles, roles) 354 | } 355 | 356 | next.ServeHTTP(w, r.WithContext(ctx)) 357 | }) 358 | } 359 | } 360 | 361 | // UseRoleRequired is a middleware that requires the user to have the given roles, 362 | // provided via AuthService or BasicAuthService. Note that this requires the 363 | // [UseAuthContext] middleware to be loaded prior to this middleware. 364 | func UseRoleRequired[ID comparable](role string) func(http.Handler) http.Handler { 365 | return func(next http.Handler) http.Handler { 366 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 367 | id := getAuthIDFromSession[ID](r) 368 | if id == nil { 369 | if role == "anonymous" { 370 | next.ServeHTTP(w, r) 371 | return 372 | } 373 | 374 | _ = Error(w, r, WrapError(ErrAuthMissingRole, http.StatusUnauthorized)) 375 | return 376 | } 377 | 378 | for _, roleName := range RolesFromContext(r.Context()) { 379 | if roleName == role { 380 | next.ServeHTTP(w, r) 381 | return 382 | } 383 | } 384 | 385 | _ = Error(w, r, WrapError(ErrAuthMissingRole, http.StatusUnauthorized)) 386 | }) 387 | } 388 | } 389 | 390 | // getAuthIDFromSession returns the ID from the session cookie. Behind the scenes, 391 | // this converts the string stored in session cookies, to the ID type provided 392 | // by the caller. Only basic types are currently supported. 393 | func getAuthIDFromSession[ID comparable](r *http.Request) *ID { 394 | key, _ := gothic.GetFromSession(authSessionKey, r) 395 | if key == "" { 396 | return nil 397 | } 398 | 399 | var id ID 400 | var v any 401 | var err error 402 | 403 | switch any(&id).(type) { 404 | case *string: 405 | v = key 406 | case *int: 407 | v, err = strconv.Atoi(key) 408 | case *int64: 409 | v, err = strconv.ParseInt(key, 10, 64) 410 | case *float64: 411 | v, err = strconv.ParseFloat(key, 64) 412 | case *uint: 413 | v, err = strconv.ParseUint(key, 10, 64) 414 | case *uint16: 415 | v, err = strconv.ParseUint(key, 10, 16) 416 | case *uint32: 417 | v, err = strconv.ParseUint(key, 10, 32) 418 | case *uint64: 419 | v, err = strconv.ParseUint(key, 10, 64) 420 | default: 421 | panic("unsupported ID type") 422 | } 423 | if err != nil { 424 | return nil 425 | } 426 | 427 | id, _ = v.(ID) 428 | return &id 429 | } 430 | 431 | // RolesFromContext returns the user roles from the request context, if any. 432 | // Note that this will only work if the [UseAuthContext] middleware has been 433 | // loaded, and the user is authenticated. 434 | func RolesFromContext(ctx context.Context) (roles AuthRoles) { 435 | roles, _ = ctx.Value(contextAuthRoles).([]string) 436 | return roles 437 | } 438 | 439 | // IDFromContext returns the user ID from the request context, if any. Note that 440 | // this will only work if the [UseAuthContext] middleware has been loaded, and the 441 | // user is authenticated. 442 | // 443 | // Returns 0 if the user is not authenticated or the ID was not found in the 444 | // context. 445 | func IDFromContext[ID comparable](ctx context.Context) (id ID) { 446 | id, _ = ctx.Value(contextAuthID).(ID) 447 | return id 448 | } 449 | 450 | // IdentFromContext returns the ident from the request context, if any. Note that 451 | // this will only work if the [UseAuthContext] middleware has been loaded, and the 452 | // user is authenticated. Provided Ident type MUST match what is used in AuthHandler. 453 | // 454 | // Returns nil if the user is not authenticated or the ident was not found in the 455 | // context. 456 | func IdentFromContext[Ident any](ctx context.Context) (auth *Ident) { 457 | auth, _ = ctx.Value(contextAuth).(*Ident) 458 | return auth 459 | } 460 | 461 | // AuthRoles provides helper methods for working with roles. 462 | type AuthRoles []string 463 | 464 | // Has returns true if the given role is present for the authenticated identity 465 | // in the context. 466 | func (r AuthRoles) Has(role string) bool { 467 | if len(r) == 0 { 468 | return false 469 | } 470 | 471 | for _, r := range r { 472 | if strings.EqualFold(r, role) { 473 | return true 474 | } 475 | } 476 | 477 | return false 478 | } 479 | 480 | // UseAuthRequired is a middleware that requires the user to be authenticated. 481 | // Note that this requires the [UseAuthContext] middleware to be loaded prior to 482 | // this middleware. 483 | func UseAuthRequired[Ident any](next http.Handler) http.Handler { 484 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 485 | _, ok := r.Context().Value(contextAuth).(*Ident) 486 | if ok { // Already in the context. 487 | next.ServeHTTP(w, r) 488 | return 489 | } 490 | 491 | _ = Error(w, r, WrapCode(http.StatusUnauthorized)) 492 | }) 493 | } 494 | --------------------------------------------------------------------------------