├── .editorconfig ├── .gitattributes ├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE.md ├── stale.yml └── workflows │ ├── checks.yml │ └── echo.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── _fixture ├── _fixture │ └── README.md ├── certs │ ├── README.md │ ├── cert.pem │ └── key.pem ├── favicon.ico ├── folder │ └── index.html ├── images │ └── walle.png └── index.html ├── bind.go ├── bind_test.go ├── binder.go ├── binder_external_test.go ├── binder_test.go ├── codecov.yml ├── context.go ├── context_fs.go ├── context_fs_test.go ├── context_test.go ├── echo.go ├── echo_fs.go ├── echo_fs_test.go ├── echo_test.go ├── go.mod ├── go.sum ├── group.go ├── group_fs.go ├── group_fs_test.go ├── group_test.go ├── ip.go ├── ip_test.go ├── json.go ├── json_test.go ├── log.go ├── middleware ├── basic_auth.go ├── basic_auth_test.go ├── body_dump.go ├── body_dump_test.go ├── body_limit.go ├── body_limit_test.go ├── compress.go ├── compress_test.go ├── context_timeout.go ├── context_timeout_test.go ├── cors.go ├── cors_test.go ├── csrf.go ├── csrf_test.go ├── decompress.go ├── decompress_test.go ├── extractor.go ├── extractor_test.go ├── key_auth.go ├── key_auth_test.go ├── logger.go ├── logger_test.go ├── method_override.go ├── method_override_test.go ├── middleware.go ├── middleware_test.go ├── proxy.go ├── proxy_test.go ├── rate_limiter.go ├── rate_limiter_test.go ├── recover.go ├── recover_test.go ├── redirect.go ├── redirect_test.go ├── request_id.go ├── request_id_test.go ├── request_logger.go ├── request_logger_test.go ├── rewrite.go ├── rewrite_test.go ├── secure.go ├── secure_test.go ├── slash.go ├── slash_test.go ├── static.go ├── static_other.go ├── static_test.go ├── static_windows.go ├── timeout.go ├── timeout_test.go ├── util.go └── util_test.go ├── renderer.go ├── renderer_test.go ├── response.go ├── response_test.go ├── router.go └── router_test.go /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig coding styles definitions. For more information about the 2 | # properties used in this file, please see the EditorConfig documentation: 3 | # http://editorconfig.org/ 4 | 5 | # indicate this is the root of the project 6 | root = true 7 | 8 | [*] 9 | charset = utf-8 10 | 11 | end_of_line = LF 12 | insert_final_newline = true 13 | trim_trailing_whitespace = true 14 | 15 | indent_style = space 16 | indent_size = 2 17 | 18 | [Makefile] 19 | indent_style = tab 20 | 21 | [*.md] 22 | trim_trailing_whitespace = false 23 | 24 | [*.go] 25 | indent_style = tab 26 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Automatically normalize line endings for all text-based files 2 | # http://git-scm.com/docs/gitattributes#_end_of_line_conversion 3 | * text=auto 4 | 5 | # For the following file types, normalize line endings to LF on checking and 6 | # prevent conversion to CRLF when they are checked out (this is required in 7 | # order to prevent newline related issues) 8 | .* text eol=lf 9 | *.go text eol=lf 10 | *.yml text eol=lf 11 | *.html text eol=lf 12 | *.css text eol=lf 13 | *.js text eol=lf 14 | *.json text eol=lf 15 | LICENSE text eol=lf 16 | 17 | # Exclude `website` and `cookbook` from GitHub's language statistics 18 | # https://github.com/github/linguist#using-gitattributes 19 | cookbook/* linguist-documentation 20 | website/* linguist-documentation 21 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [labstack] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ### Issue Description 2 | 3 | ### Working code to debug 4 | 5 | ```go 6 | package main 7 | 8 | import ( 9 | "github.com/labstack/echo/v4" 10 | "net/http" 11 | "net/http/httptest" 12 | "testing" 13 | ) 14 | 15 | func TestExample(t *testing.T) { 16 | e := echo.New() 17 | 18 | e.GET("/", func(c echo.Context) error { 19 | return c.String(http.StatusOK, "Hello, World!") 20 | }) 21 | 22 | req := httptest.NewRequest(http.MethodGet, "/", nil) 23 | rec := httptest.NewRecorder() 24 | 25 | e.ServeHTTP(rec, req) 26 | 27 | if rec.Code != http.StatusOK { 28 | t.Errorf("got %d, want %d", rec.Code, http.StatusOK) 29 | } 30 | } 31 | ``` 32 | 33 | ### Version/commit 34 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 60 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 30 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | - bug 10 | - enhancement 11 | # Label to use when marking an issue as stale 12 | staleLabel: stale 13 | # Comment to post when marking an issue as stale. Set to `false` to disable 14 | markComment: > 15 | This issue has been automatically marked as stale because it has not had 16 | recent activity. It will be closed within a month if no further activity occurs. 17 | Thank you for your contributions. 18 | # Comment to post when closing a stale issue. Set to `false` to disable 19 | closeComment: false 20 | -------------------------------------------------------------------------------- /.github/workflows/checks.yml: -------------------------------------------------------------------------------- 1 | name: Run checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | workflow_dispatch: 11 | 12 | permissions: 13 | contents: read # to fetch code (actions/checkout) 14 | 15 | env: 16 | # run static analysis only with the latest Go version 17 | LATEST_GO_VERSION: "1.24" 18 | 19 | jobs: 20 | check: 21 | runs-on: ubuntu-latest 22 | steps: 23 | - name: Checkout Code 24 | uses: actions/checkout@v4 25 | 26 | - name: Set up Go ${{ matrix.go }} 27 | uses: actions/setup-go@v5 28 | with: 29 | go-version: ${{ env.LATEST_GO_VERSION }} 30 | check-latest: true 31 | 32 | - name: Run golint 33 | run: | 34 | go install golang.org/x/lint/golint@latest 35 | golint -set_exit_status ./... 36 | 37 | - name: Run staticcheck 38 | run: | 39 | go install honnef.co/go/tools/cmd/staticcheck@latest 40 | staticcheck ./... 41 | 42 | - name: Run govulncheck 43 | run: | 44 | go version 45 | go install golang.org/x/vuln/cmd/govulncheck@latest 46 | govulncheck ./... 47 | 48 | 49 | -------------------------------------------------------------------------------- /.github/workflows/echo.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | workflow_dispatch: 11 | 12 | permissions: 13 | contents: read # to fetch code (actions/checkout) 14 | 15 | env: 16 | # run coverage and benchmarks only with the latest Go version 17 | LATEST_GO_VERSION: "1.24" 18 | 19 | jobs: 20 | test: 21 | strategy: 22 | matrix: 23 | os: [ubuntu-latest, macos-latest, windows-latest] 24 | # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy 25 | # Echo tests with last four major releases (unless there are pressing vulnerabilities) 26 | # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when 27 | # we derive from last four major releases promise. 28 | go: ["1.21", "1.22", "1.23", "1.24"] 29 | name: ${{ matrix.os }} @ Go ${{ matrix.go }} 30 | runs-on: ${{ matrix.os }} 31 | steps: 32 | - name: Checkout Code 33 | uses: actions/checkout@v4 34 | 35 | - name: Set up Go ${{ matrix.go }} 36 | uses: actions/setup-go@v5 37 | with: 38 | go-version: ${{ matrix.go }} 39 | 40 | - name: Run Tests 41 | run: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... 42 | 43 | - name: Upload coverage to Codecov 44 | if: success() && matrix.go == env.LATEST_GO_VERSION && matrix.os == 'ubuntu-latest' 45 | uses: codecov/codecov-action@v3 46 | with: 47 | token: 48 | fail_ci_if_error: false 49 | 50 | benchmark: 51 | needs: test 52 | name: Benchmark comparison 53 | runs-on: ubuntu-latest 54 | steps: 55 | - name: Checkout Code (Previous) 56 | uses: actions/checkout@v4 57 | with: 58 | ref: ${{ github.base_ref }} 59 | path: previous 60 | 61 | - name: Checkout Code (New) 62 | uses: actions/checkout@v4 63 | with: 64 | path: new 65 | 66 | - name: Set up Go ${{ matrix.go }} 67 | uses: actions/setup-go@v5 68 | with: 69 | go-version: ${{ env.LATEST_GO_VERSION }} 70 | 71 | - name: Install Dependencies 72 | run: go install golang.org/x/perf/cmd/benchstat@latest 73 | 74 | - name: Run Benchmark (Previous) 75 | run: | 76 | cd previous 77 | go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt 78 | 79 | - name: Run Benchmark (New) 80 | run: | 81 | cd new 82 | go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt 83 | 84 | - name: Run Benchstat 85 | run: | 86 | benchstat previous/benchmark.txt new/benchmark.txt 87 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | coverage.txt 3 | _test 4 | vendor 5 | .idea 6 | *.iml 7 | *.out 8 | .vscode 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2021 LabStack 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PKG := "github.com/labstack/echo" 2 | PKG_LIST := $(shell go list ${PKG}/...) 3 | 4 | tag: 5 | @git tag `grep -P '^\tversion = ' echo.go|cut -f2 -d'"'` 6 | @git tag|grep -v ^v 7 | 8 | .DEFAULT_GOAL := check 9 | check: lint vet race ## Check project 10 | 11 | init: 12 | @go install golang.org/x/lint/golint@latest 13 | @go install honnef.co/go/tools/cmd/staticcheck@latest 14 | 15 | lint: ## Lint the files 16 | @staticcheck ${PKG_LIST} 17 | @golint -set_exit_status ${PKG_LIST} 18 | 19 | vet: ## Vet the files 20 | @go vet ${PKG_LIST} 21 | 22 | test: ## Run tests 23 | @go test -short ${PKG_LIST} 24 | 25 | race: ## Run tests with data race detector 26 | @go test -race ${PKG_LIST} 27 | 28 | benchmark: ## Run benchmarks 29 | @go test -run="-" -bench=".*" ${PKG_LIST} 30 | 31 | help: ## Display this help screen 32 | @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 33 | 34 | goversion ?= "1.21" 35 | test_version: ## Run tests inside Docker with given version (defaults to 1.21 oldest supported). Example: make test_version goversion=1.21 36 | @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" 37 | -------------------------------------------------------------------------------- /_fixture/_fixture/README.md: -------------------------------------------------------------------------------- 1 | This directory is used for the static middleware test -------------------------------------------------------------------------------- /_fixture/certs/README.md: -------------------------------------------------------------------------------- 1 | To generate a valid certificate and private key use the following command: 2 | 3 | ```bash 4 | # In OpenSSL ≥ 1.1.1 5 | openssl req -x509 -newkey rsa:4096 -sha256 -days 9999 -nodes \ 6 | -keyout key.pem -out cert.pem -subj "/CN=localhost" \ 7 | -addext "subjectAltName=DNS:localhost,IP:127.0.0.1,IP:::1" 8 | ``` 9 | 10 | To check a certificate use the following command: 11 | ```bash 12 | openssl x509 -in cert.pem -text 13 | ``` 14 | -------------------------------------------------------------------------------- /_fixture/certs/cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIFODCCAyCgAwIBAgIUaTvDluaMf+VJgYHQ0HFTS3yuCHYwDQYJKoZIhvcNAQEL 3 | BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIxMDIyNzIxMzQ0MVoXDTQ4MDcx 4 | NDIxMzQ0MVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEF 5 | AAOCAg8AMIICCgKCAgEAnqyyAAnWFH2TH7Epj5yfZxYrBvizydZe1Wo/1WpGR2IK 6 | QT+qIul5sEKX/ERqEOXsawSrL3fw9cuSM8Z2vD/57ZZdoSR7XIdVaMDEQenJ968a 7 | HObu4D27uBQwIwrM5ELgnd+fC4gis64nIu+2GSfHumZXi7lLW7DbNm8oWkMqI6tY 8 | 2s2wx2hwGYNVJrwSn4WGnkzhQ5U5mkcsLELMx7GR0Qnv6P7sNGZVeqMU7awkcSpR 9 | crKR1OUP7XCJkEq83WLHSx50+QZv7LiyDmGnujHevRbdSHlcFfHZtaufYat+qICe 10 | S3XADwRQe/0VSsmja6u3DAHy7VmL8PNisAdkopQZrhiI9OvGrpGZffs9zn+s/jeX 11 | N1bqVDihCMiEjqXMlHx2oj3AXrZTFxb7y7Ap9C07nf70lpxQWW9SjMYRF98JBiHF 12 | eJbQkNVkmz6T8ielQbX0l46F2SGK98oyFCGNIAZBUdj5CcS1E6w/lk4t58/em0k7 13 | 3wFC5qg0g0wfIbNSmxljBNxnaBYUqyaaAJJhpaEoOebm4RYV58hQ0FbMfpnLnSh4 14 | dYStsk6i1PumWoa7D45DTtxF3kH7TB3YOB5aWaNGAPQC1m4Qcd23YB5Rd/ABirSp 15 | ux6/cFGosjSfJ/G+G0RhNUpmcbDJvFSOhD2WCuieVhCTAzp+VPIA9bSqD+InlT0C 16 | AwEAAaOBgTB/MB0GA1UdDgQWBBQZyM//SvzYKokQZI/0MVGb6PkH+zAfBgNVHSME 17 | GDAWgBQZyM//SvzYKokQZI/0MVGb6PkH+zAPBgNVHRMBAf8EBTADAQH/MCwGA1Ud 18 | EQQlMCOCCWxvY2FsaG9zdIcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG 19 | 9w0BAQsFAAOCAgEAKGAJQmQ/KLw8iMb5QsyxxAonVjJ1eDAhNM3GWdHpM0/GFamO 20 | vVtATLQQldwDiZJvrsCQPEc8ctZ2Utvg/StLQ3+rZpsvt0+gcUlLJK61qguwYqb2 21 | +T7VK5s7V/OyI/tsuboOW50Pka9vQHV+Z0aM06Yu+HNDAq/UTpEOb/3MQvZd6Ooy 22 | PTpZtFb/+5jIQa1dIsfFWmpBxF0+wUd9GEkX3j7nekwoZfJ8Ze4GWYERZbOFpDAQ 23 | rIHdthH5VJztnpQJmaKqzgIOF+Rurwlp5ecSC33xNNjDaYtuf/fiWnoKGhHVSBhT 24 | 61+0yxn3rTgh/Dsm95xY00rSX6lmcvI+kRNTUc8GGPz0ajBH6xyY7bNhfMjmnSW/ 25 | C/XTEDbTAhT7ndWC5vvzp7ZU0TvN+WY6A0f2kxSnnrEk6QRUvRtKkjAkmAFz8exi 26 | ttBBW0I3E5HNIC5CYRimq/9z+3clM/P1KbNblwuC65bL+PZ+nzFnn5hFaK9eLPol 27 | OwZQXv7IvAw8GfgLTrEUT7eBCQwe1IqesA7NTxF1BVwmNUb2XamvQZ7ly67QybRw 28 | 0uJq80XjpVjBWYTTQy1dsnC2OTKdqGsV9TVIDR+UGfIG9cxL70pEbiSH2AX+IDCy 29 | i3kNIvpXgBliAyOjW6Hj1fv6dNfAat/hqEfnquWkfvcs3HNrG/InwpwNAUs= 30 | -----END CERTIFICATE----- 31 | -------------------------------------------------------------------------------- /_fixture/certs/key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQCerLIACdYUfZMf 3 | sSmPnJ9nFisG+LPJ1l7Vaj/VakZHYgpBP6oi6XmwQpf8RGoQ5exrBKsvd/D1y5Iz 4 | xna8P/ntll2hJHtch1VowMRB6cn3rxoc5u7gPbu4FDAjCszkQuCd358LiCKzrici 5 | 77YZJ8e6ZleLuUtbsNs2byhaQyojq1jazbDHaHAZg1UmvBKfhYaeTOFDlTmaRyws 6 | QszHsZHRCe/o/uw0ZlV6oxTtrCRxKlFyspHU5Q/tcImQSrzdYsdLHnT5Bm/suLIO 7 | Yae6Md69Ft1IeVwV8dm1q59hq36ogJ5LdcAPBFB7/RVKyaNrq7cMAfLtWYvw82Kw 8 | B2SilBmuGIj068aukZl9+z3Of6z+N5c3VupUOKEIyISOpcyUfHaiPcBetlMXFvvL 9 | sCn0LTud/vSWnFBZb1KMxhEX3wkGIcV4ltCQ1WSbPpPyJ6VBtfSXjoXZIYr3yjIU 10 | IY0gBkFR2PkJxLUTrD+WTi3nz96bSTvfAULmqDSDTB8hs1KbGWME3GdoFhSrJpoA 11 | kmGloSg55ubhFhXnyFDQVsx+mcudKHh1hK2yTqLU+6ZahrsPjkNO3EXeQftMHdg4 12 | HlpZo0YA9ALWbhBx3bdgHlF38AGKtKm7Hr9wUaiyNJ8n8b4bRGE1SmZxsMm8VI6E 13 | PZYK6J5WEJMDOn5U8gD1tKoP4ieVPQIDAQABAoICAEHF2CsH6MOpofi7GT08cR7s 14 | I33KTcxWngzc9ATk/qjMTO/rEf1Sxmx3zkR1n3nNtQhPcR5GG43nin0HwWQbKOCB 15 | OeJ4GuKp/o9jiHbCEEQpQyvD1jUBofSV+bYs3e2ogy8t6OGA1tGgWPy0XMlkoff0 16 | QEnczw3864FO5m0z9h2/Ax//r02ZTw5kUEG0KAwT709jEuVO0AfRhM/8CKKmSola 17 | EyaDtSmrWbdyLlSuzJRUNFrVBno3UTjdM0iqkks6jN3ojBhFwNNhY/1uIXafAXNk 18 | LOnD1JYMIHCb6X809VWnqvYgozIWWb5rlA3iM2mITmId1LLqMYX5fWj2R5LUzSek 19 | H+XG+F9FIouTaL1ACoXr0zyeY5N5YJdyXYa1tThdW+axX9ZrnPgeiQrmxzKPIyb7 20 | LLlVtNBQUg/t5tX80KyYjkNUu4j3oq/uBYPi0m//ovwMyi9bSbbyPT+cDXuXX5Bc 21 | oY7wyn3evXX0c1R7vdJLZLkLu+ctVex/9hvMjeW/mMasDjLnqY7pF3Skct1SX5N2 22 | U8YVU9bGvFpLEwM9lmi/T7bcv+zbmGPlfTsZiFrCsixPLn7sX7y5M4L8au8O0jh0 23 | nHm/8rWVg1Qw0Hobg3tA8FjeMa8Sr2fYmkNLVKFzhuJLxknTJLaUbX5CymNqWP4H 24 | OctvfSY0nSZ1eQpBkQaJAoIBAQDTb/NhYCfaJBLXHVMy/VYd7kWGZ+I87artcE/l 25 | 8u0pJ8XOP4kp0otFIumpHUFodysAeP6HrI79MuJB40fy91HzWZC+NrPufFFFuZ0z 26 | Ld1o3Y5nAeoZmMlf1F12Oe3OQZy7nm9eNNkfeoVtKqDv4FhAqk+aoMor86HscKsR 27 | C6HlZFdGc7kX0ylrQAXPq9KLhcvUU9oAUpbqTbhYK83IebRJgFDG45HkVo9SUHpF 28 | dmCFSb91eZpRGpdfNLCuLiSu52TebayaUCnceeAt8SyeiChJ/TwWmRRDJS0QUv6h 29 | s3Wdp+cx9ANoujA4XzAs8Fld5IZ4bcG5jjwD62/tJyWrCC5DAoIBAQDAHfHjrYCK 30 | GHBrMj+MA7cK7fCJUn/iJLSLGgo2ANYF5oq9gaCwHCtKIyB9DN/KiY0JpJ6PWg+Q 31 | 9Difq23YXiJjNEBS5EFTu9UwWAr1RhSAegrfHxm0sDbcAx31NtDYvBsADCWQYmzc 32 | KPfBshf5K4g/VCIj2VzC2CE6kNtdhqLU6AV2Pi1Tl1S82xWoAjHy91tDmlFQNWCj 33 | B2ZnZ7tY9zuwDfeBBOVCPHICgl5Q4PrY1KEWEXiNxgbtkNmOPAsY9WSqgOsP9pWK 34 | J924gdCCvovINzZtgRisxKth6Fkhra+VCsheg9SWvgR09Deo6CCoSwYxOSb0cjh2 35 | oyX5Rb1kJ7Z/AoIBAQCX2iNVoBV/GcFeNXV3fXLH9ESCj0FwuNC1zp/TanDhyerK 36 | gd8k5k2Xzcc66gP73vpHUJ6dGlVni4/r+ivGV9HHkF/f/LGlaiuEhBZel2YY1mZb 37 | nIhg8dZOuNqW+mvMYlsKdHNPmW0GqpwBF0iWfu1jI+4gA7Kvdj6o7RIvH8eaVEJK 38 | GvqoHcP1fvmteJ2yDtmhGMfMy4QPqtnmmS8l+CJ/V2SsMuyorXIpkBsAoFAZ6ilT 39 | WY53CT4F5nWt4v39j7pl9SatfT1TV0SmOjvtb6Rf3zu0jyR6RMzkmHa/839ZRylI 40 | OxPntzDCi7qxy7yjLmlVPJ6RgZGgzwqHrEHlX+65AoIBAQCEzu6d3x5B2N02LZli 41 | eFr8MjqbI64GLiulEY5HgNJzZ8k3cjocJI0Ehj36VIEMaYRXSzbVkIO8SCgwsPiR 42 | n5mUDNX+t441jV62Odbxcc3Qdw226rABieOSupDmKEu92GOt57e8FV5939BOVYhf 43 | FunsJYQoViXbCEAIVYVgJSfBmNfVwuvgonfQyn8xErtm4/pyRGa71PqGGSKAj2Qi 44 | /16CuVUFGtZFsLV76JW8wZqHdI4bTF6TW3cEmaLbwcRGL7W0bMSS13rO8/pBh3QW 45 | PhUxhoGYt6rQHHEBkPa04nXDyZ10QRwgTSGVnBIyMK4KyTpxorm8OI2x7dzdcomX 46 | iCCPAoIBAETwfr2JKPb/AzrKhhbZgU+sLVn3WH/nb68VheNEmGOzsqXaSHCR2NOq 47 | /ow7bawjc8yUIhBRzokR4F/7jGolOmfdq0MYFb6/YokssKfv1ugxBhmvOxpZ6F6E 48 | cERJ8Ex/ffQU053gLR/0ammddVuS1GR5I/jEdP0lJVh0xapoZNUlT5dWYCgo20hY 49 | ZAmKpU+veyUn+5Li0pmm959vnLK5LJzEA5mpz3w1QPPtVwQs05dwmEV3CRAcCeeh 50 | 8sXp49WNCSW4I3BxuTZzRV845SGIFhZwgVV42PTp2LPKl2p6E7Bk8xpUCCvBpALp 51 | QmA5yIMx+u2Jpr7fUsXEXEPTEhvjff0= 52 | -----END PRIVATE KEY----- 53 | -------------------------------------------------------------------------------- /_fixture/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/labstack/echo/98ca08e7dd64075b858e758d6693bf9799340756/_fixture/favicon.ico -------------------------------------------------------------------------------- /_fixture/folder/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Echo 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /_fixture/images/walle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/labstack/echo/98ca08e7dd64075b858e758d6693bf9799340756/_fixture/images/walle.png -------------------------------------------------------------------------------- /_fixture/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Echo 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /binder_external_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | // run tests as external package to get real feel for API 5 | package echo_test 6 | 7 | import ( 8 | "encoding/base64" 9 | "fmt" 10 | "github.com/labstack/echo/v4" 11 | "log" 12 | "net/http" 13 | "net/http/httptest" 14 | ) 15 | 16 | func ExampleValueBinder_BindErrors() { 17 | // example route function that binds query params to different destinations and returns all bind errors in one go 18 | routeFunc := func(c echo.Context) error { 19 | var opts struct { 20 | Active bool 21 | IDs []int64 22 | } 23 | length := int64(50) // default length is 50 24 | 25 | b := echo.QueryParamsBinder(c) 26 | 27 | errs := b.Int64("length", &length). 28 | Int64s("ids", &opts.IDs). 29 | Bool("active", &opts.Active). 30 | BindErrors() // returns all errors 31 | if errs != nil { 32 | for _, err := range errs { 33 | bErr := err.(*echo.BindingError) 34 | log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values) 35 | } 36 | return fmt.Errorf("%v fields failed to bind", len(errs)) 37 | } 38 | fmt.Printf("active = %v, length = %v, ids = %v", opts.Active, length, opts.IDs) 39 | 40 | return c.JSON(http.StatusOK, opts) 41 | } 42 | 43 | e := echo.New() 44 | c := e.NewContext( 45 | httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil), 46 | httptest.NewRecorder(), 47 | ) 48 | 49 | _ = routeFunc(c) 50 | 51 | // Output: active = true, length = 25, ids = [1 2 3] 52 | } 53 | 54 | func ExampleValueBinder_BindError() { 55 | // example route function that binds query params to different destinations and stops binding on first bind error 56 | failFastRouteFunc := func(c echo.Context) error { 57 | var opts struct { 58 | Active bool 59 | IDs []int64 60 | } 61 | length := int64(50) // default length is 50 62 | 63 | // create binder that stops binding at first error 64 | b := echo.QueryParamsBinder(c) 65 | 66 | err := b.Int64("length", &length). 67 | Int64s("ids", &opts.IDs). 68 | Bool("active", &opts.Active). 69 | BindError() // returns first binding error 70 | if err != nil { 71 | bErr := err.(*echo.BindingError) 72 | return fmt.Errorf("my own custom error for field: %s values: %v", bErr.Field, bErr.Values) 73 | } 74 | fmt.Printf("active = %v, length = %v, ids = %v\n", opts.Active, length, opts.IDs) 75 | 76 | return c.JSON(http.StatusOK, opts) 77 | } 78 | 79 | e := echo.New() 80 | c := e.NewContext( 81 | httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil), 82 | httptest.NewRecorder(), 83 | ) 84 | 85 | _ = failFastRouteFunc(c) 86 | 87 | // Output: active = true, length = 25, ids = [1 2 3] 88 | } 89 | 90 | func ExampleValueBinder_CustomFunc() { 91 | // example route function that binds query params using custom function closure 92 | routeFunc := func(c echo.Context) error { 93 | length := int64(50) // default length is 50 94 | var binary []byte 95 | 96 | b := echo.QueryParamsBinder(c) 97 | errs := b.Int64("length", &length). 98 | CustomFunc("base64", func(values []string) []error { 99 | if len(values) == 0 { 100 | return nil 101 | } 102 | decoded, err := base64.URLEncoding.DecodeString(values[0]) 103 | if err != nil { 104 | // in this example we use only first param value but url could contain multiple params in reality and 105 | // therefore in theory produce multiple binding errors 106 | return []error{echo.NewBindingError("base64", values[0:1], "failed to decode base64", err)} 107 | } 108 | binary = decoded 109 | return nil 110 | }). 111 | BindErrors() // returns all errors 112 | 113 | if errs != nil { 114 | for _, err := range errs { 115 | bErr := err.(*echo.BindingError) 116 | log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values) 117 | } 118 | return fmt.Errorf("%v fields failed to bind", len(errs)) 119 | } 120 | fmt.Printf("length = %v, base64 = %s", length, binary) 121 | 122 | return c.JSON(http.StatusOK, "ok") 123 | } 124 | 125 | e := echo.New() 126 | c := e.NewContext( 127 | httptest.NewRequest(http.MethodGet, "/api/endpoint?length=25&base64=SGVsbG8gV29ybGQ%3D", nil), 128 | httptest.NewRecorder(), 129 | ) 130 | _ = routeFunc(c) 131 | 132 | // Output: length = 25, base64 = Hello World 133 | } 134 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | threshold: 1% 6 | patch: 7 | default: 8 | threshold: 1% 9 | 10 | comment: 11 | require_changes: true -------------------------------------------------------------------------------- /context_fs.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package echo 5 | 6 | import ( 7 | "errors" 8 | "io" 9 | "io/fs" 10 | "net/http" 11 | "path/filepath" 12 | ) 13 | 14 | func (c *context) File(file string) error { 15 | return fsFile(c, file, c.echo.Filesystem) 16 | } 17 | 18 | // FileFS serves file from given file system. 19 | // 20 | // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary 21 | // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths 22 | // including `assets/images` as their prefix. 23 | func (c *context) FileFS(file string, filesystem fs.FS) error { 24 | return fsFile(c, file, filesystem) 25 | } 26 | 27 | func fsFile(c Context, file string, filesystem fs.FS) error { 28 | f, err := filesystem.Open(file) 29 | if err != nil { 30 | return ErrNotFound 31 | } 32 | defer f.Close() 33 | 34 | fi, _ := f.Stat() 35 | if fi.IsDir() { 36 | file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. 37 | f, err = filesystem.Open(file) 38 | if err != nil { 39 | return ErrNotFound 40 | } 41 | defer f.Close() 42 | if fi, err = f.Stat(); err != nil { 43 | return err 44 | } 45 | } 46 | ff, ok := f.(io.ReadSeeker) 47 | if !ok { 48 | return errors.New("file does not implement io.ReadSeeker") 49 | } 50 | http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) 51 | return nil 52 | } 53 | -------------------------------------------------------------------------------- /context_fs_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package echo 5 | 6 | import ( 7 | "github.com/stretchr/testify/assert" 8 | "io/fs" 9 | "net/http" 10 | "net/http/httptest" 11 | "os" 12 | "testing" 13 | ) 14 | 15 | func TestContext_File(t *testing.T) { 16 | var testCases = []struct { 17 | name string 18 | whenFile string 19 | whenFS fs.FS 20 | expectStatus int 21 | expectStartsWith []byte 22 | expectError string 23 | }{ 24 | { 25 | name: "ok, from default file system", 26 | whenFile: "_fixture/images/walle.png", 27 | whenFS: nil, 28 | expectStatus: http.StatusOK, 29 | expectStartsWith: []byte{0x89, 0x50, 0x4e}, 30 | }, 31 | { 32 | name: "ok, from custom file system", 33 | whenFile: "walle.png", 34 | whenFS: os.DirFS("_fixture/images"), 35 | expectStatus: http.StatusOK, 36 | expectStartsWith: []byte{0x89, 0x50, 0x4e}, 37 | }, 38 | { 39 | name: "nok, not existent file", 40 | whenFile: "not.png", 41 | whenFS: os.DirFS("_fixture/images"), 42 | expectStatus: http.StatusOK, 43 | expectStartsWith: nil, 44 | expectError: "code=404, message=Not Found", 45 | }, 46 | } 47 | 48 | for _, tc := range testCases { 49 | t.Run(tc.name, func(t *testing.T) { 50 | e := New() 51 | if tc.whenFS != nil { 52 | e.Filesystem = tc.whenFS 53 | } 54 | 55 | handler := func(ec Context) error { 56 | return ec.(*context).File(tc.whenFile) 57 | } 58 | 59 | req := httptest.NewRequest(http.MethodGet, "/match.png", nil) 60 | rec := httptest.NewRecorder() 61 | c := e.NewContext(req, rec) 62 | 63 | err := handler(c) 64 | 65 | assert.Equal(t, tc.expectStatus, rec.Code) 66 | if tc.expectError != "" { 67 | assert.EqualError(t, err, tc.expectError) 68 | } else { 69 | assert.NoError(t, err) 70 | } 71 | 72 | body := rec.Body.Bytes() 73 | if len(body) > len(tc.expectStartsWith) { 74 | body = body[:len(tc.expectStartsWith)] 75 | } 76 | assert.Equal(t, tc.expectStartsWith, body) 77 | }) 78 | } 79 | } 80 | 81 | func TestContext_FileFS(t *testing.T) { 82 | var testCases = []struct { 83 | name string 84 | whenFile string 85 | whenFS fs.FS 86 | expectStatus int 87 | expectStartsWith []byte 88 | expectError string 89 | }{ 90 | { 91 | name: "ok", 92 | whenFile: "walle.png", 93 | whenFS: os.DirFS("_fixture/images"), 94 | expectStatus: http.StatusOK, 95 | expectStartsWith: []byte{0x89, 0x50, 0x4e}, 96 | }, 97 | { 98 | name: "nok, not existent file", 99 | whenFile: "not.png", 100 | whenFS: os.DirFS("_fixture/images"), 101 | expectStatus: http.StatusOK, 102 | expectStartsWith: nil, 103 | expectError: "code=404, message=Not Found", 104 | }, 105 | } 106 | 107 | for _, tc := range testCases { 108 | t.Run(tc.name, func(t *testing.T) { 109 | e := New() 110 | 111 | handler := func(ec Context) error { 112 | return ec.(*context).FileFS(tc.whenFile, tc.whenFS) 113 | } 114 | 115 | req := httptest.NewRequest(http.MethodGet, "/match.png", nil) 116 | rec := httptest.NewRecorder() 117 | c := e.NewContext(req, rec) 118 | 119 | err := handler(c) 120 | 121 | assert.Equal(t, tc.expectStatus, rec.Code) 122 | if tc.expectError != "" { 123 | assert.EqualError(t, err, tc.expectError) 124 | } else { 125 | assert.NoError(t, err) 126 | } 127 | 128 | body := rec.Body.Bytes() 129 | if len(body) > len(tc.expectStartsWith) { 130 | body = body[:len(tc.expectStartsWith)] 131 | } 132 | assert.Equal(t, tc.expectStartsWith, body) 133 | }) 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /echo_fs.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package echo 5 | 6 | import ( 7 | "fmt" 8 | "io/fs" 9 | "net/http" 10 | "net/url" 11 | "os" 12 | "path/filepath" 13 | "strings" 14 | ) 15 | 16 | type filesystem struct { 17 | // Filesystem is file system used by Static and File handlers to access files. 18 | // Defaults to os.DirFS(".") 19 | // 20 | // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary 21 | // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths 22 | // including `assets/images` as their prefix. 23 | Filesystem fs.FS 24 | } 25 | 26 | func createFilesystem() filesystem { 27 | return filesystem{ 28 | Filesystem: newDefaultFS(), 29 | } 30 | } 31 | 32 | // Static registers a new route with path prefix to serve static files from the provided root directory. 33 | func (e *Echo) Static(pathPrefix, fsRoot string) *Route { 34 | subFs := MustSubFS(e.Filesystem, fsRoot) 35 | return e.Add( 36 | http.MethodGet, 37 | pathPrefix+"*", 38 | StaticDirectoryHandler(subFs, false), 39 | ) 40 | } 41 | 42 | // StaticFS registers a new route with path prefix to serve static files from the provided file system. 43 | // 44 | // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary 45 | // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths 46 | // including `assets/images` as their prefix. 47 | func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route { 48 | return e.Add( 49 | http.MethodGet, 50 | pathPrefix+"*", 51 | StaticDirectoryHandler(filesystem, false), 52 | ) 53 | } 54 | 55 | // StaticDirectoryHandler creates handler function to serve files from provided file system 56 | // When disablePathUnescaping is set then file name from path is not unescaped and is served as is. 57 | func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { 58 | return func(c Context) error { 59 | p := c.Param("*") 60 | if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice 61 | tmpPath, err := url.PathUnescape(p) 62 | if err != nil { 63 | return fmt.Errorf("failed to unescape path variable: %w", err) 64 | } 65 | p = tmpPath 66 | } 67 | 68 | // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid 69 | name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) 70 | fi, err := fs.Stat(fileSystem, name) 71 | if err != nil { 72 | return ErrNotFound 73 | } 74 | 75 | // If the request is for a directory and does not end with "/" 76 | p = c.Request().URL.Path // path must not be empty. 77 | if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { 78 | // Redirect to ends with "/" 79 | return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/")) 80 | } 81 | return fsFile(c, name, fileSystem) 82 | } 83 | } 84 | 85 | // FileFS registers a new route with path to serve file from the provided file system. 86 | func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { 87 | return e.GET(path, StaticFileHandler(file, filesystem), m...) 88 | } 89 | 90 | // StaticFileHandler creates handler function to serve file from provided file system 91 | func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { 92 | return func(c Context) error { 93 | return fsFile(c, file, filesystem) 94 | } 95 | } 96 | 97 | // defaultFS exists to preserve pre v4.7.0 behaviour where files were open by `os.Open`. 98 | // v4.7 introduced `echo.Filesystem` field which is Go1.16+ `fs.Fs` interface. 99 | // Difference between `os.Open` and `fs.Open` is that FS does not allow opening path that start with `.`, `..` or `/` 100 | // etc. For example previously you could have `../images` in your application but `fs := os.DirFS("./")` would not 101 | // allow you to use `fs.Open("../images")` and this would break all old applications that rely on being able to 102 | // traverse up from current executable run path. 103 | // NB: private because you really should use fs.FS implementation instances 104 | type defaultFS struct { 105 | fs fs.FS 106 | prefix string 107 | } 108 | 109 | func newDefaultFS() *defaultFS { 110 | dir, _ := os.Getwd() 111 | return &defaultFS{ 112 | prefix: dir, 113 | fs: nil, 114 | } 115 | } 116 | 117 | func (fs defaultFS) Open(name string) (fs.File, error) { 118 | if fs.fs == nil { 119 | return os.Open(name) 120 | } 121 | return fs.fs.Open(name) 122 | } 123 | 124 | func subFS(currentFs fs.FS, root string) (fs.FS, error) { 125 | root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows 126 | if dFS, ok := currentFs.(*defaultFS); ok { 127 | // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. 128 | // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we 129 | // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs 130 | if !filepath.IsAbs(root) { 131 | root = filepath.Join(dFS.prefix, root) 132 | } 133 | return &defaultFS{ 134 | prefix: root, 135 | fs: os.DirFS(root), 136 | }, nil 137 | } 138 | return fs.Sub(currentFs, root) 139 | } 140 | 141 | // MustSubFS creates sub FS from current filesystem or panic on failure. 142 | // Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. 143 | // 144 | // MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with 145 | // paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to 146 | // create sub fs which uses necessary prefix for directory path. 147 | func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { 148 | subFs, err := subFS(currentFs, fsRoot) 149 | if err != nil { 150 | panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) 151 | } 152 | return subFs 153 | } 154 | 155 | func sanitizeURI(uri string) string { 156 | // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri 157 | // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash 158 | if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { 159 | uri = "/" + strings.TrimLeft(uri, `/\`) 160 | } 161 | return uri 162 | } 163 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/labstack/echo/v4 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/labstack/gommon v0.4.2 7 | github.com/stretchr/testify v1.10.0 8 | github.com/valyala/fasttemplate v1.2.2 9 | golang.org/x/crypto v0.38.0 10 | golang.org/x/net v0.40.0 11 | golang.org/x/time v0.11.0 12 | ) 13 | 14 | require ( 15 | github.com/davecgh/go-spew v1.1.1 // indirect 16 | github.com/mattn/go-colorable v0.1.14 // indirect 17 | github.com/mattn/go-isatty v0.0.20 // indirect 18 | github.com/pmezard/go-difflib v1.0.0 // indirect 19 | github.com/valyala/bytebufferpool v1.0.0 // indirect 20 | golang.org/x/sys v0.33.0 // indirect 21 | golang.org/x/text v0.25.0 // indirect 22 | gopkg.in/yaml.v3 v3.0.1 // indirect 23 | ) 24 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= 4 | github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= 5 | github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= 6 | github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= 7 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= 8 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 9 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 10 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 11 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 12 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 13 | github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= 14 | github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= 15 | github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= 16 | github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= 17 | golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= 18 | golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= 19 | golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= 20 | golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= 21 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 22 | golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= 23 | golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 24 | golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= 25 | golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= 26 | golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= 27 | golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= 28 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 29 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 30 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 31 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 32 | -------------------------------------------------------------------------------- /group.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package echo 5 | 6 | import ( 7 | "net/http" 8 | ) 9 | 10 | // Group is a set of sub-routes for a specified route. It can be used for inner 11 | // routes that share a common middleware or functionality that should be separate 12 | // from the parent echo instance while still inheriting from it. 13 | type Group struct { 14 | common 15 | host string 16 | prefix string 17 | echo *Echo 18 | middleware []MiddlewareFunc 19 | } 20 | 21 | // Use implements `Echo#Use()` for sub-routes within the Group. 22 | func (g *Group) Use(middleware ...MiddlewareFunc) { 23 | g.middleware = append(g.middleware, middleware...) 24 | if len(g.middleware) == 0 { 25 | return 26 | } 27 | // group level middlewares are different from Echo `Pre` and `Use` middlewares (those are global). Group level middlewares 28 | // are only executed if they are added to the Router with route. 29 | // So we register catch all route (404 is a safe way to emulate route match) for this group and now during routing the 30 | // Router would find route to match our request path and therefore guarantee the middleware(s) will get executed. 31 | g.RouteNotFound("", NotFoundHandler) 32 | g.RouteNotFound("/*", NotFoundHandler) 33 | } 34 | 35 | // CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. 36 | func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { 37 | return g.Add(http.MethodConnect, path, h, m...) 38 | } 39 | 40 | // DELETE implements `Echo#DELETE()` for sub-routes within the Group. 41 | func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { 42 | return g.Add(http.MethodDelete, path, h, m...) 43 | } 44 | 45 | // GET implements `Echo#GET()` for sub-routes within the Group. 46 | func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { 47 | return g.Add(http.MethodGet, path, h, m...) 48 | } 49 | 50 | // HEAD implements `Echo#HEAD()` for sub-routes within the Group. 51 | func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { 52 | return g.Add(http.MethodHead, path, h, m...) 53 | } 54 | 55 | // OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. 56 | func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { 57 | return g.Add(http.MethodOptions, path, h, m...) 58 | } 59 | 60 | // PATCH implements `Echo#PATCH()` for sub-routes within the Group. 61 | func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { 62 | return g.Add(http.MethodPatch, path, h, m...) 63 | } 64 | 65 | // POST implements `Echo#POST()` for sub-routes within the Group. 66 | func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { 67 | return g.Add(http.MethodPost, path, h, m...) 68 | } 69 | 70 | // PUT implements `Echo#PUT()` for sub-routes within the Group. 71 | func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { 72 | return g.Add(http.MethodPut, path, h, m...) 73 | } 74 | 75 | // TRACE implements `Echo#TRACE()` for sub-routes within the Group. 76 | func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { 77 | return g.Add(http.MethodTrace, path, h, m...) 78 | } 79 | 80 | // Any implements `Echo#Any()` for sub-routes within the Group. 81 | func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { 82 | routes := make([]*Route, len(methods)) 83 | for i, m := range methods { 84 | routes[i] = g.Add(m, path, handler, middleware...) 85 | } 86 | return routes 87 | } 88 | 89 | // Match implements `Echo#Match()` for sub-routes within the Group. 90 | func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { 91 | routes := make([]*Route, len(methods)) 92 | for i, m := range methods { 93 | routes[i] = g.Add(m, path, handler, middleware...) 94 | } 95 | return routes 96 | } 97 | 98 | // Group creates a new sub-group with prefix and optional sub-group-level middleware. 99 | func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { 100 | m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) 101 | m = append(m, g.middleware...) 102 | m = append(m, middleware...) 103 | sg = g.echo.Group(g.prefix+prefix, m...) 104 | sg.host = g.host 105 | return 106 | } 107 | 108 | // File implements `Echo#File()` for sub-routes within the Group. 109 | func (g *Group) File(path, file string) { 110 | g.file(path, file, g.GET) 111 | } 112 | 113 | // RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group. 114 | // 115 | // Example: `g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` 116 | func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { 117 | return g.Add(RouteNotFound, path, h, m...) 118 | } 119 | 120 | // Add implements `Echo#Add()` for sub-routes within the Group. 121 | func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { 122 | // Combine into a new slice to avoid accidentally passing the same slice for 123 | // multiple routes, which would lead to later add() calls overwriting the 124 | // middleware from earlier calls. 125 | m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) 126 | m = append(m, g.middleware...) 127 | m = append(m, middleware...) 128 | return g.echo.add(g.host, method, g.prefix+path, handler, m...) 129 | } 130 | -------------------------------------------------------------------------------- /group_fs.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package echo 5 | 6 | import ( 7 | "io/fs" 8 | "net/http" 9 | ) 10 | 11 | // Static implements `Echo#Static()` for sub-routes within the Group. 12 | func (g *Group) Static(pathPrefix, fsRoot string) { 13 | subFs := MustSubFS(g.echo.Filesystem, fsRoot) 14 | g.StaticFS(pathPrefix, subFs) 15 | } 16 | 17 | // StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. 18 | // 19 | // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary 20 | // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths 21 | // including `assets/images` as their prefix. 22 | func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) { 23 | g.Add( 24 | http.MethodGet, 25 | pathPrefix+"*", 26 | StaticDirectoryHandler(filesystem, false), 27 | ) 28 | } 29 | 30 | // FileFS implements `Echo#FileFS()` for sub-routes within the Group. 31 | func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { 32 | return g.GET(path, StaticFileHandler(file, filesystem), m...) 33 | } 34 | -------------------------------------------------------------------------------- /group_fs_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package echo 5 | 6 | import ( 7 | "github.com/stretchr/testify/assert" 8 | "io/fs" 9 | "net/http" 10 | "net/http/httptest" 11 | "os" 12 | "testing" 13 | ) 14 | 15 | func TestGroup_FileFS(t *testing.T) { 16 | var testCases = []struct { 17 | name string 18 | whenPath string 19 | whenFile string 20 | whenFS fs.FS 21 | givenURL string 22 | expectCode int 23 | expectStartsWith []byte 24 | }{ 25 | { 26 | name: "ok", 27 | whenPath: "/walle", 28 | whenFS: os.DirFS("_fixture/images"), 29 | whenFile: "walle.png", 30 | givenURL: "/assets/walle", 31 | expectCode: http.StatusOK, 32 | expectStartsWith: []byte{0x89, 0x50, 0x4e}, 33 | }, 34 | { 35 | name: "nok, requesting invalid path", 36 | whenPath: "/walle", 37 | whenFS: os.DirFS("_fixture/images"), 38 | whenFile: "walle.png", 39 | givenURL: "/assets/walle.png", 40 | expectCode: http.StatusNotFound, 41 | expectStartsWith: []byte(`{"message":"Not Found"}`), 42 | }, 43 | { 44 | name: "nok, serving not existent file from filesystem", 45 | whenPath: "/walle", 46 | whenFS: os.DirFS("_fixture/images"), 47 | whenFile: "not-existent.png", 48 | givenURL: "/assets/walle", 49 | expectCode: http.StatusNotFound, 50 | expectStartsWith: []byte(`{"message":"Not Found"}`), 51 | }, 52 | } 53 | 54 | for _, tc := range testCases { 55 | t.Run(tc.name, func(t *testing.T) { 56 | e := New() 57 | g := e.Group("/assets") 58 | g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) 59 | 60 | req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) 61 | rec := httptest.NewRecorder() 62 | 63 | e.ServeHTTP(rec, req) 64 | 65 | assert.Equal(t, tc.expectCode, rec.Code) 66 | 67 | body := rec.Body.Bytes() 68 | if len(body) > len(tc.expectStartsWith) { 69 | body = body[:len(tc.expectStartsWith)] 70 | } 71 | assert.Equal(t, tc.expectStartsWith, body) 72 | }) 73 | } 74 | } 75 | 76 | func TestGroup_StaticPanic(t *testing.T) { 77 | var testCases = []struct { 78 | name string 79 | givenRoot string 80 | }{ 81 | { 82 | name: "panics for ../", 83 | givenRoot: "../images", 84 | }, 85 | { 86 | name: "panics for /", 87 | givenRoot: "/images", 88 | }, 89 | } 90 | 91 | for _, tc := range testCases { 92 | t.Run(tc.name, func(t *testing.T) { 93 | e := New() 94 | e.Filesystem = os.DirFS("./") 95 | 96 | g := e.Group("/assets") 97 | 98 | assert.Panics(t, func() { 99 | g.Static("/images", tc.givenRoot) 100 | }) 101 | }) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /group_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package echo 5 | 6 | import ( 7 | "net/http" 8 | "net/http/httptest" 9 | "os" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | // TODO: Fix me 16 | func TestGroup(t *testing.T) { 17 | g := New().Group("/group") 18 | h := func(Context) error { return nil } 19 | g.CONNECT("/", h) 20 | g.DELETE("/", h) 21 | g.GET("/", h) 22 | g.HEAD("/", h) 23 | g.OPTIONS("/", h) 24 | g.PATCH("/", h) 25 | g.POST("/", h) 26 | g.PUT("/", h) 27 | g.TRACE("/", h) 28 | g.Any("/", h) 29 | g.Match([]string{http.MethodGet, http.MethodPost}, "/", h) 30 | g.Static("/static", "/tmp") 31 | g.File("/walle", "_fixture/images//walle.png") 32 | } 33 | 34 | func TestGroupFile(t *testing.T) { 35 | e := New() 36 | g := e.Group("/group") 37 | g.File("/walle", "_fixture/images/walle.png") 38 | expectedData, err := os.ReadFile("_fixture/images/walle.png") 39 | assert.Nil(t, err) 40 | req := httptest.NewRequest(http.MethodGet, "/group/walle", nil) 41 | rec := httptest.NewRecorder() 42 | e.ServeHTTP(rec, req) 43 | assert.Equal(t, http.StatusOK, rec.Code) 44 | assert.Equal(t, expectedData, rec.Body.Bytes()) 45 | } 46 | 47 | func TestGroupRouteMiddleware(t *testing.T) { 48 | // Ensure middleware slices are not re-used 49 | e := New() 50 | g := e.Group("/group") 51 | h := func(Context) error { return nil } 52 | m1 := func(next HandlerFunc) HandlerFunc { 53 | return func(c Context) error { 54 | return next(c) 55 | } 56 | } 57 | m2 := func(next HandlerFunc) HandlerFunc { 58 | return func(c Context) error { 59 | return next(c) 60 | } 61 | } 62 | m3 := func(next HandlerFunc) HandlerFunc { 63 | return func(c Context) error { 64 | return next(c) 65 | } 66 | } 67 | m4 := func(next HandlerFunc) HandlerFunc { 68 | return func(c Context) error { 69 | return c.NoContent(404) 70 | } 71 | } 72 | m5 := func(next HandlerFunc) HandlerFunc { 73 | return func(c Context) error { 74 | return c.NoContent(405) 75 | } 76 | } 77 | g.Use(m1, m2, m3) 78 | g.GET("/404", h, m4) 79 | g.GET("/405", h, m5) 80 | 81 | c, _ := request(http.MethodGet, "/group/404", e) 82 | assert.Equal(t, 404, c) 83 | c, _ = request(http.MethodGet, "/group/405", e) 84 | assert.Equal(t, 405, c) 85 | } 86 | 87 | func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { 88 | // Ensure middleware and match any routes do not conflict 89 | e := New() 90 | g := e.Group("/group") 91 | m1 := func(next HandlerFunc) HandlerFunc { 92 | return func(c Context) error { 93 | return next(c) 94 | } 95 | } 96 | m2 := func(next HandlerFunc) HandlerFunc { 97 | return func(c Context) error { 98 | return c.String(http.StatusOK, c.Path()) 99 | } 100 | } 101 | h := func(c Context) error { 102 | return c.String(http.StatusOK, c.Path()) 103 | } 104 | g.Use(m1) 105 | g.GET("/help", h, m2) 106 | g.GET("/*", h, m2) 107 | g.GET("", h, m2) 108 | e.GET("unrelated", h, m2) 109 | e.GET("*", h, m2) 110 | 111 | _, m := request(http.MethodGet, "/group/help", e) 112 | assert.Equal(t, "/group/help", m) 113 | _, m = request(http.MethodGet, "/group/help/other", e) 114 | assert.Equal(t, "/group/*", m) 115 | _, m = request(http.MethodGet, "/group/404", e) 116 | assert.Equal(t, "/group/*", m) 117 | _, m = request(http.MethodGet, "/group", e) 118 | assert.Equal(t, "/group", m) 119 | _, m = request(http.MethodGet, "/other", e) 120 | assert.Equal(t, "/*", m) 121 | _, m = request(http.MethodGet, "/", e) 122 | assert.Equal(t, "/*", m) 123 | 124 | } 125 | 126 | func TestGroup_RouteNotFound(t *testing.T) { 127 | var testCases = []struct { 128 | name string 129 | whenURL string 130 | expectRoute interface{} 131 | expectCode int 132 | }{ 133 | { 134 | name: "404, route to static not found handler /group/a/c/xx", 135 | whenURL: "/group/a/c/xx", 136 | expectRoute: "GET /group/a/c/xx", 137 | expectCode: http.StatusNotFound, 138 | }, 139 | { 140 | name: "404, route to path param not found handler /group/a/:file", 141 | whenURL: "/group/a/echo.exe", 142 | expectRoute: "GET /group/a/:file", 143 | expectCode: http.StatusNotFound, 144 | }, 145 | { 146 | name: "404, route to any not found handler /group/*", 147 | whenURL: "/group/b/echo.exe", 148 | expectRoute: "GET /group/*", 149 | expectCode: http.StatusNotFound, 150 | }, 151 | { 152 | name: "200, route /group/a/c/df to /group/a/c/df", 153 | whenURL: "/group/a/c/df", 154 | expectRoute: "GET /group/a/c/df", 155 | expectCode: http.StatusOK, 156 | }, 157 | } 158 | 159 | for _, tc := range testCases { 160 | t.Run(tc.name, func(t *testing.T) { 161 | e := New() 162 | g := e.Group("/group") 163 | 164 | okHandler := func(c Context) error { 165 | return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) 166 | } 167 | notFoundHandler := func(c Context) error { 168 | return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) 169 | } 170 | 171 | g.GET("/", okHandler) 172 | g.GET("/a/c/df", okHandler) 173 | g.GET("/a/b*", okHandler) 174 | g.PUT("/*", okHandler) 175 | 176 | g.RouteNotFound("/a/c/xx", notFoundHandler) // static 177 | g.RouteNotFound("/a/:file", notFoundHandler) // param 178 | g.RouteNotFound("/*", notFoundHandler) // any 179 | 180 | req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) 181 | rec := httptest.NewRecorder() 182 | 183 | e.ServeHTTP(rec, req) 184 | 185 | assert.Equal(t, tc.expectCode, rec.Code) 186 | assert.Equal(t, tc.expectRoute, rec.Body.String()) 187 | }) 188 | } 189 | } 190 | 191 | func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) { 192 | var testCases = []struct { 193 | name string 194 | givenCustom404 bool 195 | whenURL string 196 | expectBody interface{} 197 | expectCode int 198 | }{ 199 | { 200 | name: "ok, custom 404 handler is called with middleware", 201 | givenCustom404: true, 202 | whenURL: "/group/test3", 203 | expectBody: "GET /group/*", 204 | expectCode: http.StatusNotFound, 205 | }, 206 | { 207 | name: "ok, default group 404 handler is called with middleware", 208 | givenCustom404: false, 209 | whenURL: "/group/test3", 210 | expectBody: "{\"message\":\"Not Found\"}\n", 211 | expectCode: http.StatusNotFound, 212 | }, 213 | { 214 | name: "ok, (no slash) default group 404 handler is called with middleware", 215 | givenCustom404: false, 216 | whenURL: "/group", 217 | expectBody: "{\"message\":\"Not Found\"}\n", 218 | expectCode: http.StatusNotFound, 219 | }, 220 | } 221 | for _, tc := range testCases { 222 | t.Run(tc.name, func(t *testing.T) { 223 | 224 | okHandler := func(c Context) error { 225 | return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) 226 | } 227 | notFoundHandler := func(c Context) error { 228 | return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) 229 | } 230 | 231 | e := New() 232 | e.GET("/test1", okHandler) 233 | e.RouteNotFound("/*", notFoundHandler) 234 | 235 | g := e.Group("/group") 236 | g.GET("/test1", okHandler) 237 | 238 | middlewareCalled := false 239 | g.Use(func(next HandlerFunc) HandlerFunc { 240 | return func(c Context) error { 241 | middlewareCalled = true 242 | return next(c) 243 | } 244 | }) 245 | if tc.givenCustom404 { 246 | g.RouteNotFound("/*", notFoundHandler) 247 | } 248 | 249 | req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) 250 | rec := httptest.NewRecorder() 251 | 252 | e.ServeHTTP(rec, req) 253 | 254 | assert.True(t, middlewareCalled) 255 | assert.Equal(t, tc.expectCode, rec.Code) 256 | assert.Equal(t, tc.expectBody, rec.Body.String()) 257 | }) 258 | } 259 | } 260 | -------------------------------------------------------------------------------- /json.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package echo 5 | 6 | import ( 7 | "encoding/json" 8 | "fmt" 9 | "net/http" 10 | ) 11 | 12 | // DefaultJSONSerializer implements JSON encoding using encoding/json. 13 | type DefaultJSONSerializer struct{} 14 | 15 | // Serialize converts an interface into a json and writes it to the response. 16 | // You can optionally use the indent parameter to produce pretty JSONs. 17 | func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string) error { 18 | enc := json.NewEncoder(c.Response()) 19 | if indent != "" { 20 | enc.SetIndent("", indent) 21 | } 22 | return enc.Encode(i) 23 | } 24 | 25 | // Deserialize reads a JSON from a request body and converts it into an interface. 26 | func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error { 27 | err := json.NewDecoder(c.Request().Body).Decode(i) 28 | if ute, ok := err.(*json.UnmarshalTypeError); ok { 29 | return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err) 30 | } else if se, ok := err.(*json.SyntaxError); ok { 31 | return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err) 32 | } 33 | return err 34 | } 35 | -------------------------------------------------------------------------------- /json_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package echo 5 | 6 | import ( 7 | "github.com/stretchr/testify/assert" 8 | "net/http" 9 | "net/http/httptest" 10 | "strings" 11 | "testing" 12 | ) 13 | 14 | // Note this test is deliberately simple as there's not a lot to test. 15 | // Just need to ensure it writes JSONs. The heavy work is done by the context methods. 16 | func TestDefaultJSONCodec_Encode(t *testing.T) { 17 | e := New() 18 | req := httptest.NewRequest(http.MethodPost, "/", nil) 19 | rec := httptest.NewRecorder() 20 | c := e.NewContext(req, rec).(*context) 21 | 22 | // Echo 23 | assert.Equal(t, e, c.Echo()) 24 | 25 | // Request 26 | assert.NotNil(t, c.Request()) 27 | 28 | // Response 29 | assert.NotNil(t, c.Response()) 30 | 31 | //-------- 32 | // Default JSON encoder 33 | //-------- 34 | 35 | enc := new(DefaultJSONSerializer) 36 | 37 | err := enc.Serialize(c, user{1, "Jon Snow"}, "") 38 | if assert.NoError(t, err) { 39 | assert.Equal(t, userJSON+"\n", rec.Body.String()) 40 | } 41 | 42 | req = httptest.NewRequest(http.MethodPost, "/", nil) 43 | rec = httptest.NewRecorder() 44 | c = e.NewContext(req, rec).(*context) 45 | err = enc.Serialize(c, user{1, "Jon Snow"}, " ") 46 | if assert.NoError(t, err) { 47 | assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) 48 | } 49 | } 50 | 51 | // Note this test is deliberately simple as there's not a lot to test. 52 | // Just need to ensure it writes JSONs. The heavy work is done by the context methods. 53 | func TestDefaultJSONCodec_Decode(t *testing.T) { 54 | e := New() 55 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) 56 | rec := httptest.NewRecorder() 57 | c := e.NewContext(req, rec).(*context) 58 | 59 | // Echo 60 | assert.Equal(t, e, c.Echo()) 61 | 62 | // Request 63 | assert.NotNil(t, c.Request()) 64 | 65 | // Response 66 | assert.NotNil(t, c.Response()) 67 | 68 | //-------- 69 | // Default JSON encoder 70 | //-------- 71 | 72 | enc := new(DefaultJSONSerializer) 73 | 74 | var u = user{} 75 | err := enc.Deserialize(c, &u) 76 | if assert.NoError(t, err) { 77 | assert.Equal(t, u, user{ID: 1, Name: "Jon Snow"}) 78 | } 79 | 80 | var userUnmarshalSyntaxError = user{} 81 | req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent)) 82 | rec = httptest.NewRecorder() 83 | c = e.NewContext(req, rec).(*context) 84 | err = enc.Deserialize(c, &userUnmarshalSyntaxError) 85 | assert.IsType(t, &HTTPError{}, err) 86 | assert.EqualError(t, err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") 87 | 88 | var userUnmarshalTypeError = struct { 89 | ID string `json:"id"` 90 | Name string `json:"name"` 91 | }{} 92 | 93 | req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) 94 | rec = httptest.NewRecorder() 95 | c = e.NewContext(req, rec).(*context) 96 | err = enc.Deserialize(c, &userUnmarshalTypeError) 97 | assert.IsType(t, &HTTPError{}, err) 98 | assert.EqualError(t, err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") 99 | 100 | } 101 | -------------------------------------------------------------------------------- /log.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package echo 5 | 6 | import ( 7 | "github.com/labstack/gommon/log" 8 | "io" 9 | ) 10 | 11 | // Logger defines the logging interface. 12 | type Logger interface { 13 | Output() io.Writer 14 | SetOutput(w io.Writer) 15 | Prefix() string 16 | SetPrefix(p string) 17 | Level() log.Lvl 18 | SetLevel(v log.Lvl) 19 | SetHeader(h string) 20 | Print(i ...interface{}) 21 | Printf(format string, args ...interface{}) 22 | Printj(j log.JSON) 23 | Debug(i ...interface{}) 24 | Debugf(format string, args ...interface{}) 25 | Debugj(j log.JSON) 26 | Info(i ...interface{}) 27 | Infof(format string, args ...interface{}) 28 | Infoj(j log.JSON) 29 | Warn(i ...interface{}) 30 | Warnf(format string, args ...interface{}) 31 | Warnj(j log.JSON) 32 | Error(i ...interface{}) 33 | Errorf(format string, args ...interface{}) 34 | Errorj(j log.JSON) 35 | Fatal(i ...interface{}) 36 | Fatalj(j log.JSON) 37 | Fatalf(format string, args ...interface{}) 38 | Panic(i ...interface{}) 39 | Panicj(j log.JSON) 40 | Panicf(format string, args ...interface{}) 41 | } 42 | -------------------------------------------------------------------------------- /middleware/basic_auth.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "encoding/base64" 8 | "net/http" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/labstack/echo/v4" 13 | ) 14 | 15 | // BasicAuthConfig defines the config for BasicAuth middleware. 16 | type BasicAuthConfig struct { 17 | // Skipper defines a function to skip middleware. 18 | Skipper Skipper 19 | 20 | // Validator is a function to validate BasicAuth credentials. 21 | // Required. 22 | Validator BasicAuthValidator 23 | 24 | // Realm is a string to define realm attribute of BasicAuth. 25 | // Default value "Restricted". 26 | Realm string 27 | } 28 | 29 | // BasicAuthValidator defines a function to validate BasicAuth credentials. 30 | // The function should return a boolean indicating whether the credentials are valid, 31 | // and an error if any error occurs during the validation process. 32 | type BasicAuthValidator func(string, string, echo.Context) (bool, error) 33 | 34 | const ( 35 | basic = "basic" 36 | defaultRealm = "Restricted" 37 | ) 38 | 39 | // DefaultBasicAuthConfig is the default BasicAuth middleware config. 40 | var DefaultBasicAuthConfig = BasicAuthConfig{ 41 | Skipper: DefaultSkipper, 42 | Realm: defaultRealm, 43 | } 44 | 45 | // BasicAuth returns an BasicAuth middleware. 46 | // 47 | // For valid credentials it calls the next handler. 48 | // For missing or invalid credentials, it sends "401 - Unauthorized" response. 49 | func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { 50 | c := DefaultBasicAuthConfig 51 | c.Validator = fn 52 | return BasicAuthWithConfig(c) 53 | } 54 | 55 | // BasicAuthWithConfig returns an BasicAuth middleware with config. 56 | // See `BasicAuth()`. 57 | func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { 58 | // Defaults 59 | if config.Validator == nil { 60 | panic("echo: basic-auth middleware requires a validator function") 61 | } 62 | if config.Skipper == nil { 63 | config.Skipper = DefaultBasicAuthConfig.Skipper 64 | } 65 | if config.Realm == "" { 66 | config.Realm = defaultRealm 67 | } 68 | 69 | return func(next echo.HandlerFunc) echo.HandlerFunc { 70 | return func(c echo.Context) error { 71 | if config.Skipper(c) { 72 | return next(c) 73 | } 74 | 75 | auth := c.Request().Header.Get(echo.HeaderAuthorization) 76 | l := len(basic) 77 | 78 | if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { 79 | // Invalid base64 shouldn't be treated as error 80 | // instead should be treated as invalid client input 81 | b, err := base64.StdEncoding.DecodeString(auth[l+1:]) 82 | if err != nil { 83 | return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err) 84 | } 85 | 86 | cred := string(b) 87 | for i := 0; i < len(cred); i++ { 88 | if cred[i] == ':' { 89 | // Verify credentials 90 | valid, err := config.Validator(cred[:i], cred[i+1:], c) 91 | if err != nil { 92 | return err 93 | } else if valid { 94 | return next(c) 95 | } 96 | break 97 | } 98 | } 99 | } 100 | 101 | realm := defaultRealm 102 | if config.Realm != defaultRealm { 103 | realm = strconv.Quote(config.Realm) 104 | } 105 | 106 | // Need to return `401` for browsers to pop-up login box. 107 | c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm) 108 | return echo.ErrUnauthorized 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /middleware/basic_auth_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "encoding/base64" 8 | "errors" 9 | "net/http" 10 | "net/http/httptest" 11 | "strings" 12 | "testing" 13 | 14 | "github.com/labstack/echo/v4" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | func TestBasicAuth(t *testing.T) { 19 | e := echo.New() 20 | 21 | mockValidator := func(u, p string, c echo.Context) (bool, error) { 22 | if u == "joe" && p == "secret" { 23 | return true, nil 24 | } 25 | return false, nil 26 | } 27 | 28 | tests := []struct { 29 | name string 30 | authHeader string 31 | expectedCode int 32 | expectedAuth string 33 | skipperResult bool 34 | expectedErr bool 35 | expectedErrMsg string 36 | }{ 37 | { 38 | name: "Valid credentials", 39 | authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), 40 | expectedCode: http.StatusOK, 41 | }, 42 | { 43 | name: "Case-insensitive header scheme", 44 | authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), 45 | expectedCode: http.StatusOK, 46 | }, 47 | { 48 | name: "Invalid credentials", 49 | authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")), 50 | expectedCode: http.StatusUnauthorized, 51 | expectedAuth: basic + ` realm="someRealm"`, 52 | expectedErr: true, 53 | expectedErrMsg: "Unauthorized", 54 | }, 55 | { 56 | name: "Invalid base64 string", 57 | authHeader: basic + " invalidString", 58 | expectedCode: http.StatusBadRequest, 59 | expectedErr: true, 60 | expectedErrMsg: "Bad Request", 61 | }, 62 | { 63 | name: "Missing Authorization header", 64 | expectedCode: http.StatusUnauthorized, 65 | expectedErr: true, 66 | expectedErrMsg: "Unauthorized", 67 | }, 68 | { 69 | name: "Invalid Authorization header", 70 | authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")), 71 | expectedCode: http.StatusUnauthorized, 72 | expectedErr: true, 73 | expectedErrMsg: "Unauthorized", 74 | }, 75 | { 76 | name: "Skipped Request", 77 | authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")), 78 | expectedCode: http.StatusOK, 79 | skipperResult: true, 80 | }, 81 | } 82 | 83 | for _, tt := range tests { 84 | t.Run(tt.name, func(t *testing.T) { 85 | 86 | req := httptest.NewRequest(http.MethodGet, "/", nil) 87 | res := httptest.NewRecorder() 88 | c := e.NewContext(req, res) 89 | 90 | if tt.authHeader != "" { 91 | req.Header.Set(echo.HeaderAuthorization, tt.authHeader) 92 | } 93 | 94 | h := BasicAuthWithConfig(BasicAuthConfig{ 95 | Validator: mockValidator, 96 | Realm: "someRealm", 97 | Skipper: func(c echo.Context) bool { 98 | return tt.skipperResult 99 | }, 100 | })(func(c echo.Context) error { 101 | return c.String(http.StatusOK, "test") 102 | }) 103 | 104 | err := h(c) 105 | 106 | if tt.expectedErr { 107 | var he *echo.HTTPError 108 | errors.As(err, &he) 109 | assert.Equal(t, tt.expectedCode, he.Code) 110 | if tt.expectedAuth != "" { 111 | assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate)) 112 | } 113 | } else { 114 | assert.NoError(t, err) 115 | assert.Equal(t, tt.expectedCode, res.Code) 116 | } 117 | }) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /middleware/body_dump.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "bufio" 8 | "bytes" 9 | "errors" 10 | "io" 11 | "net" 12 | "net/http" 13 | 14 | "github.com/labstack/echo/v4" 15 | ) 16 | 17 | // BodyDumpConfig defines the config for BodyDump middleware. 18 | type BodyDumpConfig struct { 19 | // Skipper defines a function to skip middleware. 20 | Skipper Skipper 21 | 22 | // Handler receives request and response payload. 23 | // Required. 24 | Handler BodyDumpHandler 25 | } 26 | 27 | // BodyDumpHandler receives the request and response payload. 28 | type BodyDumpHandler func(echo.Context, []byte, []byte) 29 | 30 | type bodyDumpResponseWriter struct { 31 | io.Writer 32 | http.ResponseWriter 33 | } 34 | 35 | // DefaultBodyDumpConfig is the default BodyDump middleware config. 36 | var DefaultBodyDumpConfig = BodyDumpConfig{ 37 | Skipper: DefaultSkipper, 38 | } 39 | 40 | // BodyDump returns a BodyDump middleware. 41 | // 42 | // BodyDump middleware captures the request and response payload and calls the 43 | // registered handler. 44 | func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc { 45 | c := DefaultBodyDumpConfig 46 | c.Handler = handler 47 | return BodyDumpWithConfig(c) 48 | } 49 | 50 | // BodyDumpWithConfig returns a BodyDump middleware with config. 51 | // See: `BodyDump()`. 52 | func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { 53 | // Defaults 54 | if config.Handler == nil { 55 | panic("echo: body-dump middleware requires a handler function") 56 | } 57 | if config.Skipper == nil { 58 | config.Skipper = DefaultBodyDumpConfig.Skipper 59 | } 60 | 61 | return func(next echo.HandlerFunc) echo.HandlerFunc { 62 | return func(c echo.Context) (err error) { 63 | if config.Skipper(c) { 64 | return next(c) 65 | } 66 | 67 | // Request 68 | reqBody := []byte{} 69 | if c.Request().Body != nil { // Read 70 | reqBody, _ = io.ReadAll(c.Request().Body) 71 | } 72 | c.Request().Body = io.NopCloser(bytes.NewBuffer(reqBody)) // Reset 73 | 74 | // Response 75 | resBody := new(bytes.Buffer) 76 | mw := io.MultiWriter(c.Response().Writer, resBody) 77 | writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer} 78 | c.Response().Writer = writer 79 | 80 | if err = next(c); err != nil { 81 | c.Error(err) 82 | } 83 | 84 | // Callback 85 | config.Handler(c, reqBody, resBody.Bytes()) 86 | 87 | return 88 | } 89 | } 90 | } 91 | 92 | func (w *bodyDumpResponseWriter) WriteHeader(code int) { 93 | w.ResponseWriter.WriteHeader(code) 94 | } 95 | 96 | func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) { 97 | return w.Writer.Write(b) 98 | } 99 | 100 | func (w *bodyDumpResponseWriter) Flush() { 101 | err := http.NewResponseController(w.ResponseWriter).Flush() 102 | if err != nil && errors.Is(err, http.ErrNotSupported) { 103 | panic(errors.New("response writer flushing is not supported")) 104 | } 105 | } 106 | 107 | func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 108 | return http.NewResponseController(w.ResponseWriter).Hijack() 109 | } 110 | 111 | func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter { 112 | return w.ResponseWriter 113 | } 114 | -------------------------------------------------------------------------------- /middleware/body_dump_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "errors" 8 | "io" 9 | "net/http" 10 | "net/http/httptest" 11 | "strings" 12 | "testing" 13 | 14 | "github.com/labstack/echo/v4" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | func TestBodyDump(t *testing.T) { 19 | e := echo.New() 20 | hw := "Hello, World!" 21 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) 22 | rec := httptest.NewRecorder() 23 | c := e.NewContext(req, rec) 24 | h := func(c echo.Context) error { 25 | body, err := io.ReadAll(c.Request().Body) 26 | if err != nil { 27 | return err 28 | } 29 | return c.String(http.StatusOK, string(body)) 30 | } 31 | 32 | requestBody := "" 33 | responseBody := "" 34 | mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) { 35 | requestBody = string(reqBody) 36 | responseBody = string(resBody) 37 | }) 38 | 39 | if assert.NoError(t, mw(h)(c)) { 40 | assert.Equal(t, requestBody, hw) 41 | assert.Equal(t, responseBody, hw) 42 | assert.Equal(t, http.StatusOK, rec.Code) 43 | assert.Equal(t, hw, rec.Body.String()) 44 | } 45 | 46 | // Must set default skipper 47 | BodyDumpWithConfig(BodyDumpConfig{ 48 | Skipper: nil, 49 | Handler: func(c echo.Context, reqBody, resBody []byte) { 50 | requestBody = string(reqBody) 51 | responseBody = string(resBody) 52 | }, 53 | }) 54 | } 55 | 56 | func TestBodyDumpFails(t *testing.T) { 57 | e := echo.New() 58 | hw := "Hello, World!" 59 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) 60 | rec := httptest.NewRecorder() 61 | c := e.NewContext(req, rec) 62 | h := func(c echo.Context) error { 63 | return errors.New("some error") 64 | } 65 | 66 | mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) 67 | 68 | if !assert.Error(t, mw(h)(c)) { 69 | t.FailNow() 70 | } 71 | 72 | assert.Panics(t, func() { 73 | mw = BodyDumpWithConfig(BodyDumpConfig{ 74 | Skipper: nil, 75 | Handler: nil, 76 | }) 77 | }) 78 | 79 | assert.NotPanics(t, func() { 80 | mw = BodyDumpWithConfig(BodyDumpConfig{ 81 | Skipper: func(c echo.Context) bool { 82 | return true 83 | }, 84 | Handler: func(c echo.Context, reqBody, resBody []byte) { 85 | }, 86 | }) 87 | 88 | if !assert.Error(t, mw(h)(c)) { 89 | t.FailNow() 90 | } 91 | }) 92 | } 93 | 94 | func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) { 95 | bdrw := bodyDumpResponseWriter{ 96 | ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush 97 | } 98 | 99 | assert.PanicsWithError(t, "response writer flushing is not supported", func() { 100 | bdrw.Flush() 101 | }) 102 | } 103 | 104 | func TestBodyDumpResponseWriter_CanFlush(t *testing.T) { 105 | trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} 106 | bdrw := bodyDumpResponseWriter{ 107 | ResponseWriter: &trwu, 108 | } 109 | 110 | bdrw.Flush() 111 | assert.Equal(t, 1, trwu.unwrapCalled) 112 | } 113 | 114 | func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) { 115 | trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()} 116 | bdrw := bodyDumpResponseWriter{ 117 | ResponseWriter: trwu, 118 | } 119 | 120 | result := bdrw.Unwrap() 121 | assert.Equal(t, trwu, result) 122 | } 123 | 124 | func TestBodyDumpResponseWriter_CanHijack(t *testing.T) { 125 | trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} 126 | bdrw := bodyDumpResponseWriter{ 127 | ResponseWriter: &trwu, // this RW supports hijacking through unwrapping 128 | } 129 | 130 | _, _, err := bdrw.Hijack() 131 | assert.EqualError(t, err, "can hijack") 132 | } 133 | 134 | func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) { 135 | trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()} 136 | bdrw := bodyDumpResponseWriter{ 137 | ResponseWriter: &trwu, // this RW supports hijacking through unwrapping 138 | } 139 | 140 | _, _, err := bdrw.Hijack() 141 | assert.EqualError(t, err, "feature not supported") 142 | } 143 | -------------------------------------------------------------------------------- /middleware/body_limit.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "fmt" 8 | "io" 9 | "sync" 10 | 11 | "github.com/labstack/echo/v4" 12 | "github.com/labstack/gommon/bytes" 13 | ) 14 | 15 | // BodyLimitConfig defines the config for BodyLimit middleware. 16 | type BodyLimitConfig struct { 17 | // Skipper defines a function to skip middleware. 18 | Skipper Skipper 19 | 20 | // Maximum allowed size for a request body, it can be specified 21 | // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. 22 | Limit string `yaml:"limit"` 23 | limit int64 24 | } 25 | 26 | type limitedReader struct { 27 | BodyLimitConfig 28 | reader io.ReadCloser 29 | read int64 30 | } 31 | 32 | // DefaultBodyLimitConfig is the default BodyLimit middleware config. 33 | var DefaultBodyLimitConfig = BodyLimitConfig{ 34 | Skipper: DefaultSkipper, 35 | } 36 | 37 | // BodyLimit returns a BodyLimit middleware. 38 | // 39 | // BodyLimit middleware sets the maximum allowed size for a request body, if the 40 | // size exceeds the configured limit, it sends "413 - Request Entity Too Large" 41 | // response. The BodyLimit is determined based on both `Content-Length` request 42 | // header and actual content read, which makes it super secure. 43 | // Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M, 44 | // G, T or P. 45 | func BodyLimit(limit string) echo.MiddlewareFunc { 46 | c := DefaultBodyLimitConfig 47 | c.Limit = limit 48 | return BodyLimitWithConfig(c) 49 | } 50 | 51 | // BodyLimitWithConfig returns a BodyLimit middleware with config. 52 | // See: `BodyLimit()`. 53 | func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { 54 | // Defaults 55 | if config.Skipper == nil { 56 | config.Skipper = DefaultBodyLimitConfig.Skipper 57 | } 58 | 59 | limit, err := bytes.Parse(config.Limit) 60 | if err != nil { 61 | panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit)) 62 | } 63 | config.limit = limit 64 | pool := limitedReaderPool(config) 65 | 66 | return func(next echo.HandlerFunc) echo.HandlerFunc { 67 | return func(c echo.Context) error { 68 | if config.Skipper(c) { 69 | return next(c) 70 | } 71 | 72 | req := c.Request() 73 | 74 | // Based on content length 75 | if req.ContentLength > config.limit { 76 | return echo.ErrStatusRequestEntityTooLarge 77 | } 78 | 79 | // Based on content read 80 | r := pool.Get().(*limitedReader) 81 | r.Reset(req.Body) 82 | defer pool.Put(r) 83 | req.Body = r 84 | 85 | return next(c) 86 | } 87 | } 88 | } 89 | 90 | func (r *limitedReader) Read(b []byte) (n int, err error) { 91 | n, err = r.reader.Read(b) 92 | r.read += int64(n) 93 | if r.read > r.limit { 94 | return n, echo.ErrStatusRequestEntityTooLarge 95 | } 96 | return 97 | } 98 | 99 | func (r *limitedReader) Close() error { 100 | return r.reader.Close() 101 | } 102 | 103 | func (r *limitedReader) Reset(reader io.ReadCloser) { 104 | r.reader = reader 105 | r.read = 0 106 | } 107 | 108 | func limitedReaderPool(c BodyLimitConfig) sync.Pool { 109 | return sync.Pool{ 110 | New: func() interface{} { 111 | return &limitedReader{BodyLimitConfig: c} 112 | }, 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /middleware/body_limit_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "bytes" 8 | "io" 9 | "net/http" 10 | "net/http/httptest" 11 | "testing" 12 | 13 | "github.com/labstack/echo/v4" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func TestBodyLimit(t *testing.T) { 18 | e := echo.New() 19 | hw := []byte("Hello, World!") 20 | req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) 21 | rec := httptest.NewRecorder() 22 | c := e.NewContext(req, rec) 23 | h := func(c echo.Context) error { 24 | body, err := io.ReadAll(c.Request().Body) 25 | if err != nil { 26 | return err 27 | } 28 | return c.String(http.StatusOK, string(body)) 29 | } 30 | 31 | // Based on content length (within limit) 32 | if assert.NoError(t, BodyLimit("2M")(h)(c)) { 33 | assert.Equal(t, http.StatusOK, rec.Code) 34 | assert.Equal(t, hw, rec.Body.Bytes()) 35 | } 36 | 37 | // Based on content length (overlimit) 38 | he := BodyLimit("2B")(h)(c).(*echo.HTTPError) 39 | assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) 40 | 41 | // Based on content read (within limit) 42 | req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) 43 | req.ContentLength = -1 44 | rec = httptest.NewRecorder() 45 | c = e.NewContext(req, rec) 46 | if assert.NoError(t, BodyLimit("2M")(h)(c)) { 47 | assert.Equal(t, http.StatusOK, rec.Code) 48 | assert.Equal(t, "Hello, World!", rec.Body.String()) 49 | } 50 | 51 | // Based on content read (overlimit) 52 | req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) 53 | req.ContentLength = -1 54 | rec = httptest.NewRecorder() 55 | c = e.NewContext(req, rec) 56 | he = BodyLimit("2B")(h)(c).(*echo.HTTPError) 57 | assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) 58 | } 59 | 60 | func TestBodyLimitReader(t *testing.T) { 61 | hw := []byte("Hello, World!") 62 | 63 | config := BodyLimitConfig{ 64 | Skipper: DefaultSkipper, 65 | Limit: "2B", 66 | limit: 2, 67 | } 68 | reader := &limitedReader{ 69 | BodyLimitConfig: config, 70 | reader: io.NopCloser(bytes.NewReader(hw)), 71 | } 72 | 73 | // read all should return ErrStatusRequestEntityTooLarge 74 | _, err := io.ReadAll(reader) 75 | he := err.(*echo.HTTPError) 76 | assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) 77 | 78 | // reset reader and read two bytes must succeed 79 | bt := make([]byte, 2) 80 | reader.Reset(io.NopCloser(bytes.NewReader(hw))) 81 | n, err := reader.Read(bt) 82 | assert.Equal(t, 2, n) 83 | assert.Equal(t, nil, err) 84 | } 85 | 86 | func TestBodyLimitWithConfig_Skipper(t *testing.T) { 87 | e := echo.New() 88 | h := func(c echo.Context) error { 89 | body, err := io.ReadAll(c.Request().Body) 90 | if err != nil { 91 | return err 92 | } 93 | return c.String(http.StatusOK, string(body)) 94 | } 95 | mw := BodyLimitWithConfig(BodyLimitConfig{ 96 | Skipper: func(c echo.Context) bool { 97 | return true 98 | }, 99 | Limit: "2B", // if not skipped this limit would make request to fail limit check 100 | }) 101 | 102 | hw := []byte("Hello, World!") 103 | req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) 104 | rec := httptest.NewRecorder() 105 | c := e.NewContext(req, rec) 106 | 107 | err := mw(h)(c) 108 | assert.NoError(t, err) 109 | assert.Equal(t, http.StatusOK, rec.Code) 110 | assert.Equal(t, hw, rec.Body.Bytes()) 111 | } 112 | 113 | func TestBodyLimitWithConfig(t *testing.T) { 114 | var testCases = []struct { 115 | name string 116 | givenLimit string 117 | whenBody []byte 118 | expectBody []byte 119 | expectError string 120 | }{ 121 | { 122 | name: "ok, body is less than limit", 123 | givenLimit: "10B", 124 | whenBody: []byte("123456789"), 125 | expectBody: []byte("123456789"), 126 | expectError: "", 127 | }, 128 | { 129 | name: "nok, body is more than limit", 130 | givenLimit: "9B", 131 | whenBody: []byte("1234567890"), 132 | expectBody: []byte(nil), 133 | expectError: "code=413, message=Request Entity Too Large", 134 | }, 135 | } 136 | 137 | for _, tc := range testCases { 138 | t.Run(tc.name, func(t *testing.T) { 139 | e := echo.New() 140 | h := func(c echo.Context) error { 141 | body, err := io.ReadAll(c.Request().Body) 142 | if err != nil { 143 | return err 144 | } 145 | return c.String(http.StatusOK, string(body)) 146 | } 147 | mw := BodyLimitWithConfig(BodyLimitConfig{ 148 | Limit: tc.givenLimit, 149 | }) 150 | 151 | req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(tc.whenBody)) 152 | rec := httptest.NewRecorder() 153 | c := e.NewContext(req, rec) 154 | 155 | err := mw(h)(c) 156 | if tc.expectError != "" { 157 | assert.EqualError(t, err, tc.expectError) 158 | } else { 159 | assert.NoError(t, err) 160 | } 161 | // not testing status as middlewares return error instead of committing it and OK cases are anyway 200 162 | assert.Equal(t, tc.expectBody, rec.Body.Bytes()) 163 | }) 164 | } 165 | } 166 | 167 | func TestBodyLimit_panicOnInvalidLimit(t *testing.T) { 168 | assert.PanicsWithError( 169 | t, 170 | "echo: invalid body-limit=", 171 | func() { BodyLimit("") }, 172 | ) 173 | } 174 | -------------------------------------------------------------------------------- /middleware/compress.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "bufio" 8 | "bytes" 9 | "compress/gzip" 10 | "io" 11 | "net" 12 | "net/http" 13 | "strings" 14 | "sync" 15 | 16 | "github.com/labstack/echo/v4" 17 | ) 18 | 19 | // GzipConfig defines the config for Gzip middleware. 20 | type GzipConfig struct { 21 | // Skipper defines a function to skip middleware. 22 | Skipper Skipper 23 | 24 | // Gzip compression level. 25 | // Optional. Default value -1. 26 | Level int `yaml:"level"` 27 | 28 | // Length threshold before gzip compression is applied. 29 | // Optional. Default value 0. 30 | // 31 | // Most of the time you will not need to change the default. Compressing 32 | // a short response might increase the transmitted data because of the 33 | // gzip format overhead. Compressing the response will also consume CPU 34 | // and time on the server and the client (for decompressing). Depending on 35 | // your use case such a threshold might be useful. 36 | // 37 | // See also: 38 | // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits 39 | MinLength int 40 | } 41 | 42 | type gzipResponseWriter struct { 43 | io.Writer 44 | http.ResponseWriter 45 | wroteHeader bool 46 | wroteBody bool 47 | minLength int 48 | minLengthExceeded bool 49 | buffer *bytes.Buffer 50 | code int 51 | } 52 | 53 | const ( 54 | gzipScheme = "gzip" 55 | ) 56 | 57 | // DefaultGzipConfig is the default Gzip middleware config. 58 | var DefaultGzipConfig = GzipConfig{ 59 | Skipper: DefaultSkipper, 60 | Level: -1, 61 | MinLength: 0, 62 | } 63 | 64 | // Gzip returns a middleware which compresses HTTP response using gzip compression 65 | // scheme. 66 | func Gzip() echo.MiddlewareFunc { 67 | return GzipWithConfig(DefaultGzipConfig) 68 | } 69 | 70 | // GzipWithConfig return Gzip middleware with config. 71 | // See: `Gzip()`. 72 | func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { 73 | // Defaults 74 | if config.Skipper == nil { 75 | config.Skipper = DefaultGzipConfig.Skipper 76 | } 77 | if config.Level == 0 { 78 | config.Level = DefaultGzipConfig.Level 79 | } 80 | if config.MinLength < 0 { 81 | config.MinLength = DefaultGzipConfig.MinLength 82 | } 83 | 84 | pool := gzipCompressPool(config) 85 | bpool := bufferPool() 86 | 87 | return func(next echo.HandlerFunc) echo.HandlerFunc { 88 | return func(c echo.Context) error { 89 | if config.Skipper(c) { 90 | return next(c) 91 | } 92 | 93 | res := c.Response() 94 | res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) 95 | if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { 96 | i := pool.Get() 97 | w, ok := i.(*gzip.Writer) 98 | if !ok { 99 | return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) 100 | } 101 | rw := res.Writer 102 | w.Reset(rw) 103 | 104 | buf := bpool.Get().(*bytes.Buffer) 105 | buf.Reset() 106 | 107 | grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf} 108 | defer func() { 109 | // There are different reasons for cases when we have not yet written response to the client and now need to do so. 110 | // a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now. 111 | // b) body is shorter than our minimum length threshold and being buffered currently and needs to be written 112 | if !grw.wroteBody { 113 | if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { 114 | res.Header().Del(echo.HeaderContentEncoding) 115 | } 116 | if grw.wroteHeader { 117 | rw.WriteHeader(grw.code) 118 | } 119 | // We have to reset response to it's pristine state when 120 | // nothing is written to body or error is returned. 121 | // See issue #424, #407. 122 | res.Writer = rw 123 | w.Reset(io.Discard) 124 | } else if !grw.minLengthExceeded { 125 | // Write uncompressed response 126 | res.Writer = rw 127 | if grw.wroteHeader { 128 | grw.ResponseWriter.WriteHeader(grw.code) 129 | } 130 | grw.buffer.WriteTo(rw) 131 | w.Reset(io.Discard) 132 | } 133 | w.Close() 134 | bpool.Put(buf) 135 | pool.Put(w) 136 | }() 137 | res.Writer = grw 138 | } 139 | return next(c) 140 | } 141 | } 142 | } 143 | 144 | func (w *gzipResponseWriter) WriteHeader(code int) { 145 | w.Header().Del(echo.HeaderContentLength) // Issue #444 146 | 147 | w.wroteHeader = true 148 | 149 | // Delay writing of the header until we know if we'll actually compress the response 150 | w.code = code 151 | } 152 | 153 | func (w *gzipResponseWriter) Write(b []byte) (int, error) { 154 | if w.Header().Get(echo.HeaderContentType) == "" { 155 | w.Header().Set(echo.HeaderContentType, http.DetectContentType(b)) 156 | } 157 | w.wroteBody = true 158 | 159 | if !w.minLengthExceeded { 160 | n, err := w.buffer.Write(b) 161 | 162 | if w.buffer.Len() >= w.minLength { 163 | w.minLengthExceeded = true 164 | 165 | // The minimum length is exceeded, add Content-Encoding header and write the header 166 | w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 167 | if w.wroteHeader { 168 | w.ResponseWriter.WriteHeader(w.code) 169 | } 170 | 171 | return w.Writer.Write(w.buffer.Bytes()) 172 | } 173 | 174 | return n, err 175 | } 176 | 177 | return w.Writer.Write(b) 178 | } 179 | 180 | func (w *gzipResponseWriter) Flush() { 181 | if !w.minLengthExceeded { 182 | // Enforce compression because we will not know how much more data will come 183 | w.minLengthExceeded = true 184 | w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 185 | if w.wroteHeader { 186 | w.ResponseWriter.WriteHeader(w.code) 187 | } 188 | 189 | w.Writer.Write(w.buffer.Bytes()) 190 | } 191 | 192 | w.Writer.(*gzip.Writer).Flush() 193 | _ = http.NewResponseController(w.ResponseWriter).Flush() 194 | } 195 | 196 | func (w *gzipResponseWriter) Unwrap() http.ResponseWriter { 197 | return w.ResponseWriter 198 | } 199 | 200 | func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 201 | return http.NewResponseController(w.ResponseWriter).Hijack() 202 | } 203 | 204 | func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { 205 | if p, ok := w.ResponseWriter.(http.Pusher); ok { 206 | return p.Push(target, opts) 207 | } 208 | return http.ErrNotSupported 209 | } 210 | 211 | func gzipCompressPool(config GzipConfig) sync.Pool { 212 | return sync.Pool{ 213 | New: func() interface{} { 214 | w, err := gzip.NewWriterLevel(io.Discard, config.Level) 215 | if err != nil { 216 | return err 217 | } 218 | return w 219 | }, 220 | } 221 | } 222 | 223 | func bufferPool() sync.Pool { 224 | return sync.Pool{ 225 | New: func() interface{} { 226 | b := &bytes.Buffer{} 227 | return b 228 | }, 229 | } 230 | } 231 | -------------------------------------------------------------------------------- /middleware/context_timeout.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "time" 10 | 11 | "github.com/labstack/echo/v4" 12 | ) 13 | 14 | // ContextTimeoutConfig defines the config for ContextTimeout middleware. 15 | type ContextTimeoutConfig struct { 16 | // Skipper defines a function to skip middleware. 17 | Skipper Skipper 18 | 19 | // ErrorHandler is a function when error aries in middleware execution. 20 | ErrorHandler func(err error, c echo.Context) error 21 | 22 | // Timeout configures a timeout for the middleware, defaults to 0 for no timeout 23 | Timeout time.Duration 24 | } 25 | 26 | // ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client 27 | // when underlying method returns context.DeadlineExceeded error. 28 | func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc { 29 | return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout}) 30 | } 31 | 32 | // ContextTimeoutWithConfig returns a Timeout middleware with config. 33 | func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc { 34 | mw, err := config.ToMiddleware() 35 | if err != nil { 36 | panic(err) 37 | } 38 | return mw 39 | } 40 | 41 | // ToMiddleware converts Config to middleware. 42 | func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { 43 | if config.Timeout == 0 { 44 | return nil, errors.New("timeout must be set") 45 | } 46 | if config.Skipper == nil { 47 | config.Skipper = DefaultSkipper 48 | } 49 | if config.ErrorHandler == nil { 50 | config.ErrorHandler = func(err error, c echo.Context) error { 51 | if err != nil && errors.Is(err, context.DeadlineExceeded) { 52 | return echo.ErrServiceUnavailable.WithInternal(err) 53 | } 54 | return err 55 | } 56 | } 57 | 58 | return func(next echo.HandlerFunc) echo.HandlerFunc { 59 | return func(c echo.Context) error { 60 | if config.Skipper(c) { 61 | return next(c) 62 | } 63 | 64 | timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout) 65 | defer cancel() 66 | 67 | c.SetRequest(c.Request().WithContext(timeoutContext)) 68 | 69 | if err := next(c); err != nil { 70 | return config.ErrorHandler(err, c) 71 | } 72 | return nil 73 | } 74 | }, nil 75 | } 76 | -------------------------------------------------------------------------------- /middleware/context_timeout_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "net/http" 10 | "net/http/httptest" 11 | "net/url" 12 | "strings" 13 | "testing" 14 | "time" 15 | 16 | "github.com/labstack/echo/v4" 17 | "github.com/stretchr/testify/assert" 18 | ) 19 | 20 | func TestContextTimeoutSkipper(t *testing.T) { 21 | t.Parallel() 22 | m := ContextTimeoutWithConfig(ContextTimeoutConfig{ 23 | Skipper: func(context echo.Context) bool { 24 | return true 25 | }, 26 | Timeout: 10 * time.Millisecond, 27 | }) 28 | 29 | req := httptest.NewRequest(http.MethodGet, "/", nil) 30 | rec := httptest.NewRecorder() 31 | 32 | e := echo.New() 33 | c := e.NewContext(req, rec) 34 | 35 | err := m(func(c echo.Context) error { 36 | if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { 37 | return err 38 | } 39 | 40 | return errors.New("response from handler") 41 | })(c) 42 | 43 | // if not skipped we would have not returned error due context timeout logic 44 | assert.EqualError(t, err, "response from handler") 45 | } 46 | 47 | func TestContextTimeoutWithTimeout0(t *testing.T) { 48 | t.Parallel() 49 | assert.Panics(t, func() { 50 | ContextTimeout(time.Duration(0)) 51 | }) 52 | } 53 | 54 | func TestContextTimeoutErrorOutInHandler(t *testing.T) { 55 | t.Parallel() 56 | m := ContextTimeoutWithConfig(ContextTimeoutConfig{ 57 | // Timeout has to be defined or the whole flow for timeout middleware will be skipped 58 | Timeout: 10 * time.Millisecond, 59 | }) 60 | 61 | req := httptest.NewRequest(http.MethodGet, "/", nil) 62 | rec := httptest.NewRecorder() 63 | 64 | e := echo.New() 65 | c := e.NewContext(req, rec) 66 | 67 | rec.Code = 1 // we want to be sure that even 200 will not be sent 68 | err := m(func(c echo.Context) error { 69 | // this error must not be written to the client response. Middlewares upstream of timeout middleware must be able 70 | // to handle returned error and this can be done only then handler has not yet committed (written status code) 71 | // the response. 72 | return echo.NewHTTPError(http.StatusTeapot, "err") 73 | })(c) 74 | 75 | assert.Error(t, err) 76 | assert.EqualError(t, err, "code=418, message=err") 77 | assert.Equal(t, 1, rec.Code) 78 | assert.Equal(t, "", rec.Body.String()) 79 | } 80 | 81 | func TestContextTimeoutSuccessfulRequest(t *testing.T) { 82 | t.Parallel() 83 | m := ContextTimeoutWithConfig(ContextTimeoutConfig{ 84 | // Timeout has to be defined or the whole flow for timeout middleware will be skipped 85 | Timeout: 10 * time.Millisecond, 86 | }) 87 | 88 | req := httptest.NewRequest(http.MethodGet, "/", nil) 89 | rec := httptest.NewRecorder() 90 | 91 | e := echo.New() 92 | c := e.NewContext(req, rec) 93 | 94 | err := m(func(c echo.Context) error { 95 | return c.JSON(http.StatusCreated, map[string]string{"data": "ok"}) 96 | })(c) 97 | 98 | assert.NoError(t, err) 99 | assert.Equal(t, http.StatusCreated, rec.Code) 100 | assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String()) 101 | } 102 | 103 | func TestContextTimeoutTestRequestClone(t *testing.T) { 104 | t.Parallel() 105 | req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode())) 106 | req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"}) 107 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 108 | rec := httptest.NewRecorder() 109 | 110 | m := ContextTimeoutWithConfig(ContextTimeoutConfig{ 111 | // Timeout has to be defined or the whole flow for timeout middleware will be skipped 112 | Timeout: 1 * time.Second, 113 | }) 114 | 115 | e := echo.New() 116 | c := e.NewContext(req, rec) 117 | 118 | err := m(func(c echo.Context) error { 119 | // Cookie test 120 | cookie, err := c.Request().Cookie("cookie") 121 | if assert.NoError(t, err) { 122 | assert.EqualValues(t, "cookie", cookie.Name) 123 | assert.EqualValues(t, "value", cookie.Value) 124 | } 125 | 126 | // Form values 127 | if assert.NoError(t, c.Request().ParseForm()) { 128 | assert.EqualValues(t, "value", c.Request().FormValue("form")) 129 | } 130 | 131 | // Query string 132 | assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0]) 133 | return nil 134 | })(c) 135 | 136 | assert.NoError(t, err) 137 | } 138 | 139 | func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) { 140 | t.Parallel() 141 | 142 | timeout := 10 * time.Millisecond 143 | m := ContextTimeoutWithConfig(ContextTimeoutConfig{ 144 | Timeout: timeout, 145 | }) 146 | 147 | req := httptest.NewRequest(http.MethodGet, "/", nil) 148 | rec := httptest.NewRecorder() 149 | 150 | e := echo.New() 151 | c := e.NewContext(req, rec) 152 | 153 | err := m(func(c echo.Context) error { 154 | if err := sleepWithContext(c.Request().Context(), time.Duration(80*time.Millisecond)); err != nil { 155 | return err 156 | } 157 | return c.String(http.StatusOK, "Hello, World!") 158 | })(c) 159 | 160 | assert.IsType(t, &echo.HTTPError{}, err) 161 | assert.Error(t, err) 162 | assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) 163 | assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message) 164 | } 165 | 166 | func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { 167 | t.Parallel() 168 | 169 | timeoutErrorHandler := func(err error, c echo.Context) error { 170 | if err != nil { 171 | if errors.Is(err, context.DeadlineExceeded) { 172 | return &echo.HTTPError{ 173 | Code: http.StatusServiceUnavailable, 174 | Message: "Timeout! change me", 175 | } 176 | } 177 | return err 178 | } 179 | return nil 180 | } 181 | 182 | timeout := 50 * time.Millisecond 183 | m := ContextTimeoutWithConfig(ContextTimeoutConfig{ 184 | Timeout: timeout, 185 | ErrorHandler: timeoutErrorHandler, 186 | }) 187 | 188 | req := httptest.NewRequest(http.MethodGet, "/", nil) 189 | rec := httptest.NewRecorder() 190 | 191 | e := echo.New() 192 | c := e.NewContext(req, rec) 193 | 194 | err := m(func(c echo.Context) error { 195 | // NOTE: Very short periods are not reliable for tests due to Go routine scheduling and the unpredictable order 196 | // for 1) request and 2) time goroutine. For most OS this works as expected, but MacOS seems most flaky. 197 | 198 | if err := sleepWithContext(c.Request().Context(), 100*time.Millisecond); err != nil { 199 | return err 200 | } 201 | 202 | // The Request Context should have a Deadline set by http.ContextTimeoutHandler 203 | if _, ok := c.Request().Context().Deadline(); !ok { 204 | assert.Fail(t, "No timeout set on Request Context") 205 | } 206 | return c.String(http.StatusOK, "Hello, World!") 207 | })(c) 208 | 209 | assert.IsType(t, &echo.HTTPError{}, err) 210 | assert.Error(t, err) 211 | assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) 212 | assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message) 213 | } 214 | 215 | func sleepWithContext(ctx context.Context, d time.Duration) error { 216 | timer := time.NewTimer(d) 217 | 218 | defer func() { 219 | _ = timer.Stop() 220 | }() 221 | 222 | select { 223 | case <-ctx.Done(): 224 | return context.DeadlineExceeded 225 | case <-timer.C: 226 | return nil 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /middleware/csrf.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "crypto/subtle" 8 | "net/http" 9 | "time" 10 | 11 | "github.com/labstack/echo/v4" 12 | ) 13 | 14 | // CSRFConfig defines the config for CSRF middleware. 15 | type CSRFConfig struct { 16 | // Skipper defines a function to skip middleware. 17 | Skipper Skipper 18 | 19 | // TokenLength is the length of the generated token. 20 | TokenLength uint8 `yaml:"token_length"` 21 | // Optional. Default value 32. 22 | 23 | // TokenLookup is a string in the form of ":" or ":,:" that is used 24 | // to extract token from the request. 25 | // Optional. Default value "header:X-CSRF-Token". 26 | // Possible values: 27 | // - "header:" or "header::" 28 | // - "query:" 29 | // - "form:" 30 | // Multiple sources example: 31 | // - "header:X-CSRF-Token,query:csrf" 32 | TokenLookup string `yaml:"token_lookup"` 33 | 34 | // Context key to store generated CSRF token into context. 35 | // Optional. Default value "csrf". 36 | ContextKey string `yaml:"context_key"` 37 | 38 | // Name of the CSRF cookie. This cookie will store CSRF token. 39 | // Optional. Default value "csrf". 40 | CookieName string `yaml:"cookie_name"` 41 | 42 | // Domain of the CSRF cookie. 43 | // Optional. Default value none. 44 | CookieDomain string `yaml:"cookie_domain"` 45 | 46 | // Path of the CSRF cookie. 47 | // Optional. Default value none. 48 | CookiePath string `yaml:"cookie_path"` 49 | 50 | // Max age (in seconds) of the CSRF cookie. 51 | // Optional. Default value 86400 (24hr). 52 | CookieMaxAge int `yaml:"cookie_max_age"` 53 | 54 | // Indicates if CSRF cookie is secure. 55 | // Optional. Default value false. 56 | CookieSecure bool `yaml:"cookie_secure"` 57 | 58 | // Indicates if CSRF cookie is HTTP only. 59 | // Optional. Default value false. 60 | CookieHTTPOnly bool `yaml:"cookie_http_only"` 61 | 62 | // Indicates SameSite mode of the CSRF cookie. 63 | // Optional. Default value SameSiteDefaultMode. 64 | CookieSameSite http.SameSite `yaml:"cookie_same_site"` 65 | 66 | // ErrorHandler defines a function which is executed for returning custom errors. 67 | ErrorHandler CSRFErrorHandler 68 | } 69 | 70 | // CSRFErrorHandler is a function which is executed for creating custom errors. 71 | type CSRFErrorHandler func(err error, c echo.Context) error 72 | 73 | // ErrCSRFInvalid is returned when CSRF check fails 74 | var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") 75 | 76 | // DefaultCSRFConfig is the default CSRF middleware config. 77 | var DefaultCSRFConfig = CSRFConfig{ 78 | Skipper: DefaultSkipper, 79 | TokenLength: 32, 80 | TokenLookup: "header:" + echo.HeaderXCSRFToken, 81 | ContextKey: "csrf", 82 | CookieName: "_csrf", 83 | CookieMaxAge: 86400, 84 | CookieSameSite: http.SameSiteDefaultMode, 85 | } 86 | 87 | // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. 88 | // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery 89 | func CSRF() echo.MiddlewareFunc { 90 | c := DefaultCSRFConfig 91 | return CSRFWithConfig(c) 92 | } 93 | 94 | // CSRFWithConfig returns a CSRF middleware with config. 95 | // See `CSRF()`. 96 | func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { 97 | // Defaults 98 | if config.Skipper == nil { 99 | config.Skipper = DefaultCSRFConfig.Skipper 100 | } 101 | if config.TokenLength == 0 { 102 | config.TokenLength = DefaultCSRFConfig.TokenLength 103 | } 104 | 105 | if config.TokenLookup == "" { 106 | config.TokenLookup = DefaultCSRFConfig.TokenLookup 107 | } 108 | if config.ContextKey == "" { 109 | config.ContextKey = DefaultCSRFConfig.ContextKey 110 | } 111 | if config.CookieName == "" { 112 | config.CookieName = DefaultCSRFConfig.CookieName 113 | } 114 | if config.CookieMaxAge == 0 { 115 | config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge 116 | } 117 | if config.CookieSameSite == http.SameSiteNoneMode { 118 | config.CookieSecure = true 119 | } 120 | 121 | extractors, cErr := CreateExtractors(config.TokenLookup) 122 | if cErr != nil { 123 | panic(cErr) 124 | } 125 | 126 | return func(next echo.HandlerFunc) echo.HandlerFunc { 127 | return func(c echo.Context) error { 128 | if config.Skipper(c) { 129 | return next(c) 130 | } 131 | 132 | token := "" 133 | if k, err := c.Cookie(config.CookieName); err != nil { 134 | token = randomString(config.TokenLength) 135 | } else { 136 | token = k.Value // Reuse token 137 | } 138 | 139 | switch c.Request().Method { 140 | case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: 141 | default: 142 | // Validate token only for requests which are not defined as 'safe' by RFC7231 143 | var lastExtractorErr error 144 | var lastTokenErr error 145 | outer: 146 | for _, extractor := range extractors { 147 | clientTokens, err := extractor(c) 148 | if err != nil { 149 | lastExtractorErr = err 150 | continue 151 | } 152 | 153 | for _, clientToken := range clientTokens { 154 | if validateCSRFToken(token, clientToken) { 155 | lastTokenErr = nil 156 | lastExtractorErr = nil 157 | break outer 158 | } 159 | lastTokenErr = ErrCSRFInvalid 160 | } 161 | } 162 | var finalErr error 163 | if lastTokenErr != nil { 164 | finalErr = lastTokenErr 165 | } else if lastExtractorErr != nil { 166 | // ugly part to preserve backwards compatible errors. someone could rely on them 167 | if lastExtractorErr == errQueryExtractorValueMissing { 168 | lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string") 169 | } else if lastExtractorErr == errFormExtractorValueMissing { 170 | lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter") 171 | } else if lastExtractorErr == errHeaderExtractorValueMissing { 172 | lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header") 173 | } else { 174 | lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error()) 175 | } 176 | finalErr = lastExtractorErr 177 | } 178 | 179 | if finalErr != nil { 180 | if config.ErrorHandler != nil { 181 | return config.ErrorHandler(finalErr, c) 182 | } 183 | return finalErr 184 | } 185 | } 186 | 187 | // Set CSRF cookie 188 | cookie := new(http.Cookie) 189 | cookie.Name = config.CookieName 190 | cookie.Value = token 191 | if config.CookiePath != "" { 192 | cookie.Path = config.CookiePath 193 | } 194 | if config.CookieDomain != "" { 195 | cookie.Domain = config.CookieDomain 196 | } 197 | if config.CookieSameSite != http.SameSiteDefaultMode { 198 | cookie.SameSite = config.CookieSameSite 199 | } 200 | cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second) 201 | cookie.Secure = config.CookieSecure 202 | cookie.HttpOnly = config.CookieHTTPOnly 203 | c.SetCookie(cookie) 204 | 205 | // Store token in the context 206 | c.Set(config.ContextKey, token) 207 | 208 | // Protect clients from caching the response 209 | c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie) 210 | 211 | return next(c) 212 | } 213 | } 214 | } 215 | 216 | func validateCSRFToken(token, clientToken string) bool { 217 | return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 218 | } 219 | -------------------------------------------------------------------------------- /middleware/decompress.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "compress/gzip" 8 | "io" 9 | "net/http" 10 | "sync" 11 | 12 | "github.com/labstack/echo/v4" 13 | ) 14 | 15 | // DecompressConfig defines the config for Decompress middleware. 16 | type DecompressConfig struct { 17 | // Skipper defines a function to skip middleware. 18 | Skipper Skipper 19 | 20 | // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers 21 | GzipDecompressPool Decompressor 22 | } 23 | 24 | // GZIPEncoding content-encoding header if set to "gzip", decompress body contents. 25 | const GZIPEncoding string = "gzip" 26 | 27 | // Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers 28 | type Decompressor interface { 29 | gzipDecompressPool() sync.Pool 30 | } 31 | 32 | // DefaultDecompressConfig defines the config for decompress middleware 33 | var DefaultDecompressConfig = DecompressConfig{ 34 | Skipper: DefaultSkipper, 35 | GzipDecompressPool: &DefaultGzipDecompressPool{}, 36 | } 37 | 38 | // DefaultGzipDecompressPool is the default implementation of Decompressor interface 39 | type DefaultGzipDecompressPool struct { 40 | } 41 | 42 | func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { 43 | return sync.Pool{New: func() interface{} { return new(gzip.Reader) }} 44 | } 45 | 46 | // Decompress decompresses request body based if content encoding type is set to "gzip" with default config 47 | func Decompress() echo.MiddlewareFunc { 48 | return DecompressWithConfig(DefaultDecompressConfig) 49 | } 50 | 51 | // DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config 52 | func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { 53 | // Defaults 54 | if config.Skipper == nil { 55 | config.Skipper = DefaultGzipConfig.Skipper 56 | } 57 | if config.GzipDecompressPool == nil { 58 | config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool 59 | } 60 | 61 | return func(next echo.HandlerFunc) echo.HandlerFunc { 62 | pool := config.GzipDecompressPool.gzipDecompressPool() 63 | 64 | return func(c echo.Context) error { 65 | if config.Skipper(c) { 66 | return next(c) 67 | } 68 | 69 | if c.Request().Header.Get(echo.HeaderContentEncoding) != GZIPEncoding { 70 | return next(c) 71 | } 72 | 73 | i := pool.Get() 74 | gr, ok := i.(*gzip.Reader) 75 | if !ok || gr == nil { 76 | return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) 77 | } 78 | defer pool.Put(gr) 79 | 80 | b := c.Request().Body 81 | defer b.Close() 82 | 83 | if err := gr.Reset(b); err != nil { 84 | if err == io.EOF { //ignore if body is empty 85 | return next(c) 86 | } 87 | return err 88 | } 89 | 90 | // only Close gzip reader if it was set to a proper gzip source otherwise it will panic on close. 91 | defer gr.Close() 92 | 93 | c.Request().Body = gr 94 | 95 | return next(c) 96 | } 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /middleware/decompress_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "bytes" 8 | "compress/gzip" 9 | "errors" 10 | "io" 11 | "net/http" 12 | "net/http/httptest" 13 | "strings" 14 | "sync" 15 | "testing" 16 | 17 | "github.com/labstack/echo/v4" 18 | "github.com/stretchr/testify/assert" 19 | ) 20 | 21 | func TestDecompress(t *testing.T) { 22 | e := echo.New() 23 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) 24 | rec := httptest.NewRecorder() 25 | c := e.NewContext(req, rec) 26 | 27 | // Skip if no Content-Encoding header 28 | h := Decompress()(func(c echo.Context) error { 29 | c.Response().Write([]byte("test")) // For Content-Type sniffing 30 | return nil 31 | }) 32 | h(c) 33 | 34 | assert.Equal(t, "test", rec.Body.String()) 35 | 36 | // Decompress 37 | body := `{"name": "echo"}` 38 | gz, _ := gzipString(body) 39 | req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) 40 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) 41 | rec = httptest.NewRecorder() 42 | c = e.NewContext(req, rec) 43 | h(c) 44 | assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) 45 | b, err := io.ReadAll(req.Body) 46 | assert.NoError(t, err) 47 | assert.Equal(t, body, string(b)) 48 | } 49 | 50 | func TestDecompressDefaultConfig(t *testing.T) { 51 | e := echo.New() 52 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) 53 | rec := httptest.NewRecorder() 54 | c := e.NewContext(req, rec) 55 | 56 | h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { 57 | c.Response().Write([]byte("test")) // For Content-Type sniffing 58 | return nil 59 | }) 60 | h(c) 61 | 62 | assert.Equal(t, "test", rec.Body.String()) 63 | 64 | // Decompress 65 | body := `{"name": "echo"}` 66 | gz, _ := gzipString(body) 67 | req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) 68 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) 69 | rec = httptest.NewRecorder() 70 | c = e.NewContext(req, rec) 71 | h(c) 72 | assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) 73 | b, err := io.ReadAll(req.Body) 74 | assert.NoError(t, err) 75 | assert.Equal(t, body, string(b)) 76 | } 77 | 78 | func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { 79 | e := echo.New() 80 | body := `{"name":"echo"}` 81 | gz, _ := gzipString(body) 82 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) 83 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) 84 | rec := httptest.NewRecorder() 85 | e.NewContext(req, rec) 86 | e.ServeHTTP(rec, req) 87 | assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) 88 | b, err := io.ReadAll(req.Body) 89 | assert.NoError(t, err) 90 | assert.NotEqual(t, b, body) 91 | assert.Equal(t, b, gz) 92 | } 93 | 94 | func TestDecompressNoContent(t *testing.T) { 95 | e := echo.New() 96 | req := httptest.NewRequest(http.MethodGet, "/", nil) 97 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) 98 | rec := httptest.NewRecorder() 99 | c := e.NewContext(req, rec) 100 | h := Decompress()(func(c echo.Context) error { 101 | return c.NoContent(http.StatusNoContent) 102 | }) 103 | if assert.NoError(t, h(c)) { 104 | assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) 105 | assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) 106 | assert.Equal(t, 0, len(rec.Body.Bytes())) 107 | } 108 | } 109 | 110 | func TestDecompressErrorReturned(t *testing.T) { 111 | e := echo.New() 112 | e.Use(Decompress()) 113 | e.GET("/", func(c echo.Context) error { 114 | return echo.ErrNotFound 115 | }) 116 | req := httptest.NewRequest(http.MethodGet, "/", nil) 117 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) 118 | rec := httptest.NewRecorder() 119 | e.ServeHTTP(rec, req) 120 | assert.Equal(t, http.StatusNotFound, rec.Code) 121 | assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) 122 | } 123 | 124 | func TestDecompressSkipper(t *testing.T) { 125 | e := echo.New() 126 | e.Use(DecompressWithConfig(DecompressConfig{ 127 | Skipper: func(c echo.Context) bool { 128 | return c.Request().URL.Path == "/skip" 129 | }, 130 | })) 131 | body := `{"name": "echo"}` 132 | req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body)) 133 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) 134 | rec := httptest.NewRecorder() 135 | c := e.NewContext(req, rec) 136 | e.ServeHTTP(rec, req) 137 | assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSON) 138 | reqBody, err := io.ReadAll(c.Request().Body) 139 | assert.NoError(t, err) 140 | assert.Equal(t, body, string(reqBody)) 141 | } 142 | 143 | type TestDecompressPoolWithError struct { 144 | } 145 | 146 | func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool { 147 | return sync.Pool{ 148 | New: func() interface{} { 149 | return errors.New("pool error") 150 | }, 151 | } 152 | } 153 | 154 | func TestDecompressPoolError(t *testing.T) { 155 | e := echo.New() 156 | e.Use(DecompressWithConfig(DecompressConfig{ 157 | Skipper: DefaultSkipper, 158 | GzipDecompressPool: &TestDecompressPoolWithError{}, 159 | })) 160 | body := `{"name": "echo"}` 161 | req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body)) 162 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) 163 | rec := httptest.NewRecorder() 164 | c := e.NewContext(req, rec) 165 | e.ServeHTTP(rec, req) 166 | assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) 167 | reqBody, err := io.ReadAll(c.Request().Body) 168 | assert.NoError(t, err) 169 | assert.Equal(t, body, string(reqBody)) 170 | assert.Equal(t, rec.Code, http.StatusInternalServerError) 171 | } 172 | 173 | func BenchmarkDecompress(b *testing.B) { 174 | e := echo.New() 175 | body := `{"name": "echo"}` 176 | gz, _ := gzipString(body) 177 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) 178 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) 179 | 180 | h := Decompress()(func(c echo.Context) error { 181 | c.Response().Write([]byte(body)) // For Content-Type sniffing 182 | return nil 183 | }) 184 | 185 | b.ReportAllocs() 186 | b.ResetTimer() 187 | 188 | for i := 0; i < b.N; i++ { 189 | // Decompress 190 | rec := httptest.NewRecorder() 191 | c := e.NewContext(req, rec) 192 | h(c) 193 | } 194 | } 195 | 196 | func gzipString(body string) ([]byte, error) { 197 | var buf bytes.Buffer 198 | gz := gzip.NewWriter(&buf) 199 | 200 | _, err := gz.Write([]byte(body)) 201 | if err != nil { 202 | return nil, err 203 | } 204 | 205 | if err := gz.Close(); err != nil { 206 | return nil, err 207 | } 208 | 209 | return buf.Bytes(), nil 210 | } 211 | -------------------------------------------------------------------------------- /middleware/extractor.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "errors" 8 | "fmt" 9 | "github.com/labstack/echo/v4" 10 | "net/textproto" 11 | "strings" 12 | ) 13 | 14 | const ( 15 | // extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion 16 | // attack vector 17 | extractorLimit = 20 18 | ) 19 | 20 | var errHeaderExtractorValueMissing = errors.New("missing value in request header") 21 | var errHeaderExtractorValueInvalid = errors.New("invalid value in request header") 22 | var errQueryExtractorValueMissing = errors.New("missing value in the query string") 23 | var errParamExtractorValueMissing = errors.New("missing value in path params") 24 | var errCookieExtractorValueMissing = errors.New("missing value in cookies") 25 | var errFormExtractorValueMissing = errors.New("missing value in the form") 26 | 27 | // ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. 28 | type ValuesExtractor func(c echo.Context) ([]string, error) 29 | 30 | // CreateExtractors creates ValuesExtractors from given lookups. 31 | // Lookups is a string in the form of ":" or ":,:" that is used 32 | // to extract key from the request. 33 | // Possible values: 34 | // - "header:" or "header::" 35 | // `` is argument value to cut/trim prefix of the extracted value. This is useful if header 36 | // value has static prefix like `Authorization: ` where part that we 37 | // want to cut is ` ` note the space at the end. 38 | // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. 39 | // - "query:" 40 | // - "param:" 41 | // - "form:" 42 | // - "cookie:" 43 | // 44 | // Multiple sources example: 45 | // - "header:Authorization,header:X-Api-Key" 46 | func CreateExtractors(lookups string) ([]ValuesExtractor, error) { 47 | return createExtractors(lookups, "") 48 | } 49 | 50 | func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) { 51 | if lookups == "" { 52 | return nil, nil 53 | } 54 | sources := strings.Split(lookups, ",") 55 | var extractors = make([]ValuesExtractor, 0) 56 | for _, source := range sources { 57 | parts := strings.Split(source, ":") 58 | if len(parts) < 2 { 59 | return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source) 60 | } 61 | 62 | switch parts[0] { 63 | case "query": 64 | extractors = append(extractors, valuesFromQuery(parts[1])) 65 | case "param": 66 | extractors = append(extractors, valuesFromParam(parts[1])) 67 | case "cookie": 68 | extractors = append(extractors, valuesFromCookie(parts[1])) 69 | case "form": 70 | extractors = append(extractors, valuesFromForm(parts[1])) 71 | case "header": 72 | prefix := "" 73 | if len(parts) > 2 { 74 | prefix = parts[2] 75 | } else if authScheme != "" && parts[1] == echo.HeaderAuthorization { 76 | // backwards compatibility for JWT and KeyAuth: 77 | // * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer " etc 78 | // * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that 79 | // behaviour for default values and Authorization header. 80 | prefix = authScheme 81 | if !strings.HasSuffix(prefix, " ") { 82 | prefix += " " 83 | } 84 | } 85 | extractors = append(extractors, valuesFromHeader(parts[1], prefix)) 86 | } 87 | } 88 | return extractors, nil 89 | } 90 | 91 | // valuesFromHeader returns a functions that extracts values from the request header. 92 | // valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static 93 | // prefix like `Authorization: ` where part that we want to remove is ` ` 94 | // note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove 95 | // is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `. 96 | // If prefix is left empty the whole value is returned. 97 | func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { 98 | prefixLen := len(valuePrefix) 99 | // standard library parses http.Request header keys in canonical form but we may provide something else so fix this 100 | header = textproto.CanonicalMIMEHeaderKey(header) 101 | return func(c echo.Context) ([]string, error) { 102 | values := c.Request().Header.Values(header) 103 | if len(values) == 0 { 104 | return nil, errHeaderExtractorValueMissing 105 | } 106 | 107 | result := make([]string, 0) 108 | for i, value := range values { 109 | if prefixLen == 0 { 110 | result = append(result, value) 111 | if i >= extractorLimit-1 { 112 | break 113 | } 114 | continue 115 | } 116 | if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { 117 | result = append(result, value[prefixLen:]) 118 | if i >= extractorLimit-1 { 119 | break 120 | } 121 | } 122 | } 123 | 124 | if len(result) == 0 { 125 | if prefixLen > 0 { 126 | return nil, errHeaderExtractorValueInvalid 127 | } 128 | return nil, errHeaderExtractorValueMissing 129 | } 130 | return result, nil 131 | } 132 | } 133 | 134 | // valuesFromQuery returns a function that extracts values from the query string. 135 | func valuesFromQuery(param string) ValuesExtractor { 136 | return func(c echo.Context) ([]string, error) { 137 | result := c.QueryParams()[param] 138 | if len(result) == 0 { 139 | return nil, errQueryExtractorValueMissing 140 | } else if len(result) > extractorLimit-1 { 141 | result = result[:extractorLimit] 142 | } 143 | return result, nil 144 | } 145 | } 146 | 147 | // valuesFromParam returns a function that extracts values from the url param string. 148 | func valuesFromParam(param string) ValuesExtractor { 149 | return func(c echo.Context) ([]string, error) { 150 | result := make([]string, 0) 151 | paramVales := c.ParamValues() 152 | for i, p := range c.ParamNames() { 153 | if param == p { 154 | result = append(result, paramVales[i]) 155 | if i >= extractorLimit-1 { 156 | break 157 | } 158 | } 159 | } 160 | if len(result) == 0 { 161 | return nil, errParamExtractorValueMissing 162 | } 163 | return result, nil 164 | } 165 | } 166 | 167 | // valuesFromCookie returns a function that extracts values from the named cookie. 168 | func valuesFromCookie(name string) ValuesExtractor { 169 | return func(c echo.Context) ([]string, error) { 170 | cookies := c.Cookies() 171 | if len(cookies) == 0 { 172 | return nil, errCookieExtractorValueMissing 173 | } 174 | 175 | result := make([]string, 0) 176 | for i, cookie := range cookies { 177 | if name == cookie.Name { 178 | result = append(result, cookie.Value) 179 | if i >= extractorLimit-1 { 180 | break 181 | } 182 | } 183 | } 184 | if len(result) == 0 { 185 | return nil, errCookieExtractorValueMissing 186 | } 187 | return result, nil 188 | } 189 | } 190 | 191 | // valuesFromForm returns a function that extracts values from the form field. 192 | func valuesFromForm(name string) ValuesExtractor { 193 | return func(c echo.Context) ([]string, error) { 194 | if c.Request().Form == nil { 195 | _ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does 196 | } 197 | values := c.Request().Form[name] 198 | if len(values) == 0 { 199 | return nil, errFormExtractorValueMissing 200 | } 201 | if len(values) > extractorLimit-1 { 202 | values = values[:extractorLimit] 203 | } 204 | result := append([]string{}, values...) 205 | return result, nil 206 | } 207 | } 208 | -------------------------------------------------------------------------------- /middleware/key_auth.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "errors" 8 | "github.com/labstack/echo/v4" 9 | "net/http" 10 | ) 11 | 12 | // KeyAuthConfig defines the config for KeyAuth middleware. 13 | type KeyAuthConfig struct { 14 | // Skipper defines a function to skip middleware. 15 | Skipper Skipper 16 | 17 | // KeyLookup is a string in the form of ":" or ":,:" that is used 18 | // to extract key from the request. 19 | // Optional. Default value "header:Authorization". 20 | // Possible values: 21 | // - "header:" or "header::" 22 | // `` is argument value to cut/trim prefix of the extracted value. This is useful if header 23 | // value has static prefix like `Authorization: ` where part that we 24 | // want to cut is ` ` note the space at the end. 25 | // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. 26 | // - "query:" 27 | // - "form:" 28 | // - "cookie:" 29 | // Multiple sources example: 30 | // - "header:Authorization,header:X-Api-Key" 31 | KeyLookup string 32 | 33 | // AuthScheme to be used in the Authorization header. 34 | // Optional. Default value "Bearer". 35 | AuthScheme string 36 | 37 | // Validator is a function to validate key. 38 | // Required. 39 | Validator KeyAuthValidator 40 | 41 | // ErrorHandler defines a function which is executed for an invalid key. 42 | // It may be used to define a custom error. 43 | ErrorHandler KeyAuthErrorHandler 44 | 45 | // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to 46 | // ignore the error (by returning `nil`). 47 | // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. 48 | // In that case you can use ErrorHandler to set a default public key auth value in the request context 49 | // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. 50 | ContinueOnIgnoredError bool 51 | } 52 | 53 | // KeyAuthValidator defines a function to validate KeyAuth credentials. 54 | type KeyAuthValidator func(auth string, c echo.Context) (bool, error) 55 | 56 | // KeyAuthErrorHandler defines a function which is executed for an invalid key. 57 | type KeyAuthErrorHandler func(err error, c echo.Context) error 58 | 59 | // ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups 60 | type ErrKeyAuthMissing struct { 61 | Err error 62 | } 63 | 64 | // DefaultKeyAuthConfig is the default KeyAuth middleware config. 65 | var DefaultKeyAuthConfig = KeyAuthConfig{ 66 | Skipper: DefaultSkipper, 67 | KeyLookup: "header:" + echo.HeaderAuthorization, 68 | AuthScheme: "Bearer", 69 | } 70 | 71 | // Error returns errors text 72 | func (e *ErrKeyAuthMissing) Error() string { 73 | return e.Err.Error() 74 | } 75 | 76 | // Unwrap unwraps error 77 | func (e *ErrKeyAuthMissing) Unwrap() error { 78 | return e.Err 79 | } 80 | 81 | // KeyAuth returns an KeyAuth middleware. 82 | // 83 | // For valid key it calls the next handler. 84 | // For invalid key, it sends "401 - Unauthorized" response. 85 | // For missing key, it sends "400 - Bad Request" response. 86 | func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc { 87 | c := DefaultKeyAuthConfig 88 | c.Validator = fn 89 | return KeyAuthWithConfig(c) 90 | } 91 | 92 | // KeyAuthWithConfig returns an KeyAuth middleware with config. 93 | // See `KeyAuth()`. 94 | func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { 95 | // Defaults 96 | if config.Skipper == nil { 97 | config.Skipper = DefaultKeyAuthConfig.Skipper 98 | } 99 | // Defaults 100 | if config.AuthScheme == "" { 101 | config.AuthScheme = DefaultKeyAuthConfig.AuthScheme 102 | } 103 | if config.KeyLookup == "" { 104 | config.KeyLookup = DefaultKeyAuthConfig.KeyLookup 105 | } 106 | if config.Validator == nil { 107 | panic("echo: key-auth middleware requires a validator function") 108 | } 109 | 110 | extractors, cErr := createExtractors(config.KeyLookup, config.AuthScheme) 111 | if cErr != nil { 112 | panic(cErr) 113 | } 114 | 115 | return func(next echo.HandlerFunc) echo.HandlerFunc { 116 | return func(c echo.Context) error { 117 | if config.Skipper(c) { 118 | return next(c) 119 | } 120 | 121 | var lastExtractorErr error 122 | var lastValidatorErr error 123 | for _, extractor := range extractors { 124 | keys, err := extractor(c) 125 | if err != nil { 126 | lastExtractorErr = err 127 | continue 128 | } 129 | for _, key := range keys { 130 | valid, err := config.Validator(key, c) 131 | if err != nil { 132 | lastValidatorErr = err 133 | continue 134 | } 135 | if valid { 136 | return next(c) 137 | } 138 | lastValidatorErr = errors.New("invalid key") 139 | } 140 | } 141 | 142 | // we are here only when we did not successfully extract and validate any of keys 143 | err := lastValidatorErr 144 | if err == nil { // prioritize validator errors over extracting errors 145 | // ugly part to preserve backwards compatible errors. someone could rely on them 146 | if lastExtractorErr == errQueryExtractorValueMissing { 147 | err = errors.New("missing key in the query string") 148 | } else if lastExtractorErr == errCookieExtractorValueMissing { 149 | err = errors.New("missing key in cookies") 150 | } else if lastExtractorErr == errFormExtractorValueMissing { 151 | err = errors.New("missing key in the form") 152 | } else if lastExtractorErr == errHeaderExtractorValueMissing { 153 | err = errors.New("missing key in request header") 154 | } else if lastExtractorErr == errHeaderExtractorValueInvalid { 155 | err = errors.New("invalid key in the request header") 156 | } else { 157 | err = lastExtractorErr 158 | } 159 | err = &ErrKeyAuthMissing{Err: err} 160 | } 161 | 162 | if config.ErrorHandler != nil { 163 | tmpErr := config.ErrorHandler(err, c) 164 | if config.ContinueOnIgnoredError && tmpErr == nil { 165 | return next(c) 166 | } 167 | return tmpErr 168 | } 169 | if lastValidatorErr != nil { // prioritize validator errors over extracting errors 170 | return &echo.HTTPError{ 171 | Code: http.StatusUnauthorized, 172 | Message: "Unauthorized", 173 | Internal: lastValidatorErr, 174 | } 175 | } 176 | return echo.NewHTTPError(http.StatusBadRequest, err.Error()) 177 | } 178 | } 179 | } 180 | -------------------------------------------------------------------------------- /middleware/logger.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "bytes" 8 | "encoding/json" 9 | "io" 10 | "strconv" 11 | "strings" 12 | "sync" 13 | "time" 14 | 15 | "github.com/labstack/echo/v4" 16 | "github.com/labstack/gommon/color" 17 | "github.com/valyala/fasttemplate" 18 | ) 19 | 20 | // LoggerConfig defines the config for Logger middleware. 21 | type LoggerConfig struct { 22 | // Skipper defines a function to skip middleware. 23 | Skipper Skipper 24 | 25 | // Tags to construct the logger format. 26 | // 27 | // - time_unix 28 | // - time_unix_milli 29 | // - time_unix_micro 30 | // - time_unix_nano 31 | // - time_rfc3339 32 | // - time_rfc3339_nano 33 | // - time_custom 34 | // - id (Request ID) 35 | // - remote_ip 36 | // - uri 37 | // - host 38 | // - method 39 | // - path 40 | // - route 41 | // - protocol 42 | // - referer 43 | // - user_agent 44 | // - status 45 | // - error 46 | // - latency (In nanoseconds) 47 | // - latency_human (Human readable) 48 | // - bytes_in (Bytes received) 49 | // - bytes_out (Bytes sent) 50 | // - header: 51 | // - query: 52 | // - form: 53 | // - custom (see CustomTagFunc field) 54 | // 55 | // Example "${remote_ip} ${status}" 56 | // 57 | // Optional. Default value DefaultLoggerConfig.Format. 58 | Format string `yaml:"format"` 59 | 60 | // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. 61 | CustomTimeFormat string `yaml:"custom_time_format"` 62 | 63 | // CustomTagFunc is function called for `${custom}` tag to output user implemented text by writing it to buf. 64 | // Make sure that outputted text creates valid JSON string with other logged tags. 65 | // Optional. 66 | CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error) 67 | 68 | // Output is a writer where logs in JSON format are written. 69 | // Optional. Default value os.Stdout. 70 | Output io.Writer 71 | 72 | template *fasttemplate.Template 73 | colorer *color.Color 74 | pool *sync.Pool 75 | } 76 | 77 | // DefaultLoggerConfig is the default Logger middleware config. 78 | var DefaultLoggerConfig = LoggerConfig{ 79 | Skipper: DefaultSkipper, 80 | Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + 81 | `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + 82 | `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + 83 | `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", 84 | CustomTimeFormat: "2006-01-02 15:04:05.00000", 85 | colorer: color.New(), 86 | } 87 | 88 | // Logger returns a middleware that logs HTTP requests. 89 | func Logger() echo.MiddlewareFunc { 90 | return LoggerWithConfig(DefaultLoggerConfig) 91 | } 92 | 93 | // LoggerWithConfig returns a Logger middleware with config. 94 | // See: `Logger()`. 95 | func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { 96 | // Defaults 97 | if config.Skipper == nil { 98 | config.Skipper = DefaultLoggerConfig.Skipper 99 | } 100 | if config.Format == "" { 101 | config.Format = DefaultLoggerConfig.Format 102 | } 103 | if config.Output == nil { 104 | config.Output = DefaultLoggerConfig.Output 105 | } 106 | 107 | config.template = fasttemplate.New(config.Format, "${", "}") 108 | config.colorer = color.New() 109 | config.colorer.SetOutput(config.Output) 110 | config.pool = &sync.Pool{ 111 | New: func() interface{} { 112 | return bytes.NewBuffer(make([]byte, 256)) 113 | }, 114 | } 115 | 116 | return func(next echo.HandlerFunc) echo.HandlerFunc { 117 | return func(c echo.Context) (err error) { 118 | if config.Skipper(c) { 119 | return next(c) 120 | } 121 | 122 | req := c.Request() 123 | res := c.Response() 124 | start := time.Now() 125 | if err = next(c); err != nil { 126 | c.Error(err) 127 | } 128 | stop := time.Now() 129 | buf := config.pool.Get().(*bytes.Buffer) 130 | buf.Reset() 131 | defer config.pool.Put(buf) 132 | 133 | if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { 134 | switch tag { 135 | case "custom": 136 | if config.CustomTagFunc == nil { 137 | return 0, nil 138 | } 139 | return config.CustomTagFunc(c, buf) 140 | case "time_unix": 141 | return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) 142 | case "time_unix_milli": 143 | // go 1.17 or later, it supports time#UnixMilli() 144 | return buf.WriteString(strconv.FormatInt(time.Now().UnixNano()/1000000, 10)) 145 | case "time_unix_micro": 146 | // go 1.17 or later, it supports time#UnixMicro() 147 | return buf.WriteString(strconv.FormatInt(time.Now().UnixNano()/1000, 10)) 148 | case "time_unix_nano": 149 | return buf.WriteString(strconv.FormatInt(time.Now().UnixNano(), 10)) 150 | case "time_rfc3339": 151 | return buf.WriteString(time.Now().Format(time.RFC3339)) 152 | case "time_rfc3339_nano": 153 | return buf.WriteString(time.Now().Format(time.RFC3339Nano)) 154 | case "time_custom": 155 | return buf.WriteString(time.Now().Format(config.CustomTimeFormat)) 156 | case "id": 157 | id := req.Header.Get(echo.HeaderXRequestID) 158 | if id == "" { 159 | id = res.Header().Get(echo.HeaderXRequestID) 160 | } 161 | return buf.WriteString(id) 162 | case "remote_ip": 163 | return buf.WriteString(c.RealIP()) 164 | case "host": 165 | return buf.WriteString(req.Host) 166 | case "uri": 167 | return buf.WriteString(req.RequestURI) 168 | case "method": 169 | return buf.WriteString(req.Method) 170 | case "path": 171 | p := req.URL.Path 172 | if p == "" { 173 | p = "/" 174 | } 175 | return buf.WriteString(p) 176 | case "route": 177 | return buf.WriteString(c.Path()) 178 | case "protocol": 179 | return buf.WriteString(req.Proto) 180 | case "referer": 181 | return buf.WriteString(req.Referer()) 182 | case "user_agent": 183 | return buf.WriteString(req.UserAgent()) 184 | case "status": 185 | n := res.Status 186 | s := config.colorer.Green(n) 187 | switch { 188 | case n >= 500: 189 | s = config.colorer.Red(n) 190 | case n >= 400: 191 | s = config.colorer.Yellow(n) 192 | case n >= 300: 193 | s = config.colorer.Cyan(n) 194 | } 195 | return buf.WriteString(s) 196 | case "error": 197 | if err != nil { 198 | // Error may contain invalid JSON e.g. `"` 199 | b, _ := json.Marshal(err.Error()) 200 | b = b[1 : len(b)-1] 201 | return buf.Write(b) 202 | } 203 | case "latency": 204 | l := stop.Sub(start) 205 | return buf.WriteString(strconv.FormatInt(int64(l), 10)) 206 | case "latency_human": 207 | return buf.WriteString(stop.Sub(start).String()) 208 | case "bytes_in": 209 | cl := req.Header.Get(echo.HeaderContentLength) 210 | if cl == "" { 211 | cl = "0" 212 | } 213 | return buf.WriteString(cl) 214 | case "bytes_out": 215 | return buf.WriteString(strconv.FormatInt(res.Size, 10)) 216 | default: 217 | switch { 218 | case strings.HasPrefix(tag, "header:"): 219 | return buf.Write([]byte(c.Request().Header.Get(tag[7:]))) 220 | case strings.HasPrefix(tag, "query:"): 221 | return buf.Write([]byte(c.QueryParam(tag[6:]))) 222 | case strings.HasPrefix(tag, "form:"): 223 | return buf.Write([]byte(c.FormValue(tag[5:]))) 224 | case strings.HasPrefix(tag, "cookie:"): 225 | cookie, err := c.Cookie(tag[7:]) 226 | if err == nil { 227 | return buf.Write([]byte(cookie.Value)) 228 | } 229 | } 230 | } 231 | return 0, nil 232 | }); err != nil { 233 | return 234 | } 235 | 236 | if config.Output == nil { 237 | _, err = c.Logger().Output().Write(buf.Bytes()) 238 | return 239 | } 240 | _, err = config.Output.Write(buf.Bytes()) 241 | return 242 | } 243 | } 244 | } 245 | -------------------------------------------------------------------------------- /middleware/method_override.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "net/http" 8 | 9 | "github.com/labstack/echo/v4" 10 | ) 11 | 12 | // MethodOverrideConfig defines the config for MethodOverride middleware. 13 | type MethodOverrideConfig struct { 14 | // Skipper defines a function to skip middleware. 15 | Skipper Skipper 16 | 17 | // Getter is a function that gets overridden method from the request. 18 | // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). 19 | Getter MethodOverrideGetter 20 | } 21 | 22 | // MethodOverrideGetter is a function that gets overridden method from the request 23 | type MethodOverrideGetter func(echo.Context) string 24 | 25 | // DefaultMethodOverrideConfig is the default MethodOverride middleware config. 26 | var DefaultMethodOverrideConfig = MethodOverrideConfig{ 27 | Skipper: DefaultSkipper, 28 | Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), 29 | } 30 | 31 | // MethodOverride returns a MethodOverride middleware. 32 | // MethodOverride middleware checks for the overridden method from the request and 33 | // uses it instead of the original method. 34 | // 35 | // For security reasons, only `POST` method can be overridden. 36 | func MethodOverride() echo.MiddlewareFunc { 37 | return MethodOverrideWithConfig(DefaultMethodOverrideConfig) 38 | } 39 | 40 | // MethodOverrideWithConfig returns a MethodOverride middleware with config. 41 | // See: `MethodOverride()`. 42 | func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { 43 | // Defaults 44 | if config.Skipper == nil { 45 | config.Skipper = DefaultMethodOverrideConfig.Skipper 46 | } 47 | if config.Getter == nil { 48 | config.Getter = DefaultMethodOverrideConfig.Getter 49 | } 50 | 51 | return func(next echo.HandlerFunc) echo.HandlerFunc { 52 | return func(c echo.Context) error { 53 | if config.Skipper(c) { 54 | return next(c) 55 | } 56 | 57 | req := c.Request() 58 | if req.Method == http.MethodPost { 59 | m := config.Getter(c) 60 | if m != "" { 61 | req.Method = m 62 | } 63 | } 64 | return next(c) 65 | } 66 | } 67 | } 68 | 69 | // MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from 70 | // the request header. 71 | func MethodFromHeader(header string) MethodOverrideGetter { 72 | return func(c echo.Context) string { 73 | return c.Request().Header.Get(header) 74 | } 75 | } 76 | 77 | // MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the 78 | // form parameter. 79 | func MethodFromForm(param string) MethodOverrideGetter { 80 | return func(c echo.Context) string { 81 | return c.FormValue(param) 82 | } 83 | } 84 | 85 | // MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from 86 | // the query parameter. 87 | func MethodFromQuery(param string) MethodOverrideGetter { 88 | return func(c echo.Context) string { 89 | return c.QueryParam(param) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /middleware/method_override_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "bytes" 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | 12 | "github.com/labstack/echo/v4" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestMethodOverride(t *testing.T) { 17 | e := echo.New() 18 | m := MethodOverride() 19 | h := func(c echo.Context) error { 20 | return c.String(http.StatusOK, "test") 21 | } 22 | 23 | // Override with http header 24 | req := httptest.NewRequest(http.MethodPost, "/", nil) 25 | rec := httptest.NewRecorder() 26 | req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) 27 | c := e.NewContext(req, rec) 28 | m(h)(c) 29 | assert.Equal(t, http.MethodDelete, req.Method) 30 | 31 | // Override with form parameter 32 | m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")}) 33 | req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) 34 | rec = httptest.NewRecorder() 35 | req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) 36 | c = e.NewContext(req, rec) 37 | m(h)(c) 38 | assert.Equal(t, http.MethodDelete, req.Method) 39 | 40 | // Override with query parameter 41 | m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")}) 42 | req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) 43 | rec = httptest.NewRecorder() 44 | c = e.NewContext(req, rec) 45 | m(h)(c) 46 | assert.Equal(t, http.MethodDelete, req.Method) 47 | 48 | // Ignore `GET` 49 | req = httptest.NewRequest(http.MethodGet, "/", nil) 50 | req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) 51 | assert.Equal(t, http.MethodGet, req.Method) 52 | } 53 | -------------------------------------------------------------------------------- /middleware/middleware.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "net/http" 8 | "regexp" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/labstack/echo/v4" 13 | ) 14 | 15 | // Skipper defines a function to skip middleware. Returning true skips processing 16 | // the middleware. 17 | type Skipper func(c echo.Context) bool 18 | 19 | // BeforeFunc defines a function which is executed just before the middleware. 20 | type BeforeFunc func(c echo.Context) 21 | 22 | func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { 23 | groups := pattern.FindAllStringSubmatch(input, -1) 24 | if groups == nil { 25 | return nil 26 | } 27 | values := groups[0][1:] 28 | replace := make([]string, 2*len(values)) 29 | for i, v := range values { 30 | j := 2 * i 31 | replace[j] = "$" + strconv.Itoa(i+1) 32 | replace[j+1] = v 33 | } 34 | return strings.NewReplacer(replace...) 35 | } 36 | 37 | func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { 38 | // Initialize 39 | rulesRegex := map[*regexp.Regexp]string{} 40 | for k, v := range rewrite { 41 | k = regexp.QuoteMeta(k) 42 | k = strings.ReplaceAll(k, `\*`, "(.*?)") 43 | if strings.HasPrefix(k, `\^`) { 44 | k = strings.ReplaceAll(k, `\^`, "^") 45 | } 46 | k = k + "$" 47 | rulesRegex[regexp.MustCompile(k)] = v 48 | } 49 | return rulesRegex 50 | } 51 | 52 | func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error { 53 | if len(rewriteRegex) == 0 { 54 | return nil 55 | } 56 | 57 | // Depending on how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. 58 | // We only want to use path part for rewriting and therefore trim prefix if it exists 59 | rawURI := req.RequestURI 60 | if rawURI != "" && rawURI[0] != '/' { 61 | prefix := "" 62 | if req.URL.Scheme != "" { 63 | prefix = req.URL.Scheme + "://" 64 | } 65 | if req.URL.Host != "" { 66 | prefix += req.URL.Host // host or host:port 67 | } 68 | if prefix != "" { 69 | rawURI = strings.TrimPrefix(rawURI, prefix) 70 | } 71 | } 72 | 73 | for k, v := range rewriteRegex { 74 | if replacer := captureTokens(k, rawURI); replacer != nil { 75 | url, err := req.URL.Parse(replacer.Replace(v)) 76 | if err != nil { 77 | return err 78 | } 79 | req.URL = url 80 | 81 | return nil // rewrite only once 82 | } 83 | } 84 | return nil 85 | } 86 | 87 | // DefaultSkipper returns false which processes the middleware. 88 | func DefaultSkipper(echo.Context) bool { 89 | return false 90 | } 91 | -------------------------------------------------------------------------------- /middleware/middleware_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "bufio" 8 | "errors" 9 | "github.com/stretchr/testify/assert" 10 | "net" 11 | "net/http" 12 | "net/http/httptest" 13 | "regexp" 14 | "testing" 15 | ) 16 | 17 | func TestRewriteURL(t *testing.T) { 18 | var testCases = []struct { 19 | whenURL string 20 | expectPath string 21 | expectRawPath string 22 | expectQuery string 23 | expectErr string 24 | }{ 25 | { 26 | whenURL: "http://localhost:8080/old", 27 | expectPath: "/new", 28 | expectRawPath: "", 29 | }, 30 | { // encoded `ol%64` (decoded `old`) should not be rewritten to `/new` 31 | whenURL: "/ol%64", // `%64` is decoded `d` 32 | expectPath: "/old", 33 | expectRawPath: "/ol%64", 34 | }, 35 | { 36 | whenURL: "http://localhost:8080/users/+_+/orders/___++++?test=1", 37 | expectPath: "/user/+_+/order/___++++", 38 | expectRawPath: "", 39 | expectQuery: "test=1", 40 | }, 41 | { 42 | whenURL: "http://localhost:8080/users/%20a/orders/%20aa", 43 | expectPath: "/user/ a/order/ aa", 44 | expectRawPath: "", 45 | }, 46 | { 47 | whenURL: "http://localhost:8080/%47%6f%2f?test=1", 48 | expectPath: "/Go/", 49 | expectRawPath: "/%47%6f%2f", 50 | expectQuery: "test=1", 51 | }, 52 | { 53 | whenURL: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F", 54 | expectPath: "/user/jill/order/T/cO4lW/t/Vp/", 55 | expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", 56 | }, 57 | { // do nothing, replace nothing 58 | whenURL: "http://localhost:8080/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", 59 | expectPath: "/user/jill/order/T/cO4lW/t/Vp/", 60 | expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", 61 | }, 62 | { 63 | whenURL: "http://localhost:8080/static", 64 | expectPath: "/static/path", 65 | expectRawPath: "", 66 | expectQuery: "role=AUTHOR&limit=1000", 67 | }, 68 | { 69 | whenURL: "/static", 70 | expectPath: "/static/path", 71 | expectRawPath: "", 72 | expectQuery: "role=AUTHOR&limit=1000", 73 | }, 74 | } 75 | 76 | rules := map[*regexp.Regexp]string{ 77 | regexp.MustCompile("^/old$"): "/new", 78 | regexp.MustCompile("^/users/(.*?)/orders/(.*?)$"): "/user/$1/order/$2", 79 | regexp.MustCompile("^/static$"): "/static/path?role=AUTHOR&limit=1000", 80 | } 81 | 82 | for _, tc := range testCases { 83 | t.Run(tc.whenURL, func(t *testing.T) { 84 | req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) 85 | 86 | err := rewriteURL(rules, req) 87 | 88 | if tc.expectErr != "" { 89 | assert.EqualError(t, err, tc.expectErr) 90 | } else { 91 | assert.NoError(t, err) 92 | } 93 | assert.Equal(t, tc.expectPath, req.URL.Path) // Path field is stored in decoded form: /%47%6f%2f becomes /Go/. 94 | assert.Equal(t, tc.expectRawPath, req.URL.RawPath) // RawPath, an optional field which only gets set if the default encoding is different from Path. 95 | assert.Equal(t, tc.expectQuery, req.URL.RawQuery) 96 | }) 97 | } 98 | } 99 | 100 | type testResponseWriterNoFlushHijack struct { 101 | } 102 | 103 | func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) { 104 | } 105 | 106 | func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) { 107 | return 0, nil 108 | } 109 | 110 | func (w *testResponseWriterNoFlushHijack) Header() http.Header { 111 | return nil 112 | } 113 | 114 | type testResponseWriterUnwrapper struct { 115 | unwrapCalled int 116 | rw http.ResponseWriter 117 | } 118 | 119 | func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) { 120 | } 121 | 122 | func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) { 123 | return 0, nil 124 | } 125 | 126 | func (w *testResponseWriterUnwrapper) Header() http.Header { 127 | return nil 128 | } 129 | 130 | func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter { 131 | w.unwrapCalled++ 132 | return w.rw 133 | } 134 | 135 | type testResponseWriterUnwrapperHijack struct { 136 | testResponseWriterUnwrapper 137 | } 138 | 139 | func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) { 140 | return nil, nil, errors.New("can hijack") 141 | } 142 | -------------------------------------------------------------------------------- /middleware/recover.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "fmt" 8 | "net/http" 9 | "runtime" 10 | 11 | "github.com/labstack/echo/v4" 12 | "github.com/labstack/gommon/log" 13 | ) 14 | 15 | // LogErrorFunc defines a function for custom logging in the middleware. 16 | type LogErrorFunc func(c echo.Context, err error, stack []byte) error 17 | 18 | // RecoverConfig defines the config for Recover middleware. 19 | type RecoverConfig struct { 20 | // Skipper defines a function to skip middleware. 21 | Skipper Skipper 22 | 23 | // Size of the stack to be printed. 24 | // Optional. Default value 4KB. 25 | StackSize int `yaml:"stack_size"` 26 | 27 | // DisableStackAll disables formatting stack traces of all other goroutines 28 | // into buffer after the trace for the current goroutine. 29 | // Optional. Default value false. 30 | DisableStackAll bool `yaml:"disable_stack_all"` 31 | 32 | // DisablePrintStack disables printing stack trace. 33 | // Optional. Default value as false. 34 | DisablePrintStack bool `yaml:"disable_print_stack"` 35 | 36 | // LogLevel is log level to printing stack trace. 37 | // Optional. Default value 0 (Print). 38 | LogLevel log.Lvl 39 | 40 | // LogErrorFunc defines a function for custom logging in the middleware. 41 | // If it's set you don't need to provide LogLevel for config. 42 | // If this function returns nil, the centralized HTTPErrorHandler will not be called. 43 | LogErrorFunc LogErrorFunc 44 | 45 | // DisableErrorHandler disables the call to centralized HTTPErrorHandler. 46 | // The recovered error is then passed back to upstream middleware, instead of swallowing the error. 47 | // Optional. Default value false. 48 | DisableErrorHandler bool `yaml:"disable_error_handler"` 49 | } 50 | 51 | // DefaultRecoverConfig is the default Recover middleware config. 52 | var DefaultRecoverConfig = RecoverConfig{ 53 | Skipper: DefaultSkipper, 54 | StackSize: 4 << 10, // 4 KB 55 | DisableStackAll: false, 56 | DisablePrintStack: false, 57 | LogLevel: 0, 58 | LogErrorFunc: nil, 59 | DisableErrorHandler: false, 60 | } 61 | 62 | // Recover returns a middleware which recovers from panics anywhere in the chain 63 | // and handles the control to the centralized HTTPErrorHandler. 64 | func Recover() echo.MiddlewareFunc { 65 | return RecoverWithConfig(DefaultRecoverConfig) 66 | } 67 | 68 | // RecoverWithConfig returns a Recover middleware with config. 69 | // See: `Recover()`. 70 | func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { 71 | // Defaults 72 | if config.Skipper == nil { 73 | config.Skipper = DefaultRecoverConfig.Skipper 74 | } 75 | if config.StackSize == 0 { 76 | config.StackSize = DefaultRecoverConfig.StackSize 77 | } 78 | 79 | return func(next echo.HandlerFunc) echo.HandlerFunc { 80 | return func(c echo.Context) (returnErr error) { 81 | if config.Skipper(c) { 82 | return next(c) 83 | } 84 | 85 | defer func() { 86 | if r := recover(); r != nil { 87 | if r == http.ErrAbortHandler { 88 | panic(r) 89 | } 90 | err, ok := r.(error) 91 | if !ok { 92 | err = fmt.Errorf("%v", r) 93 | } 94 | var stack []byte 95 | var length int 96 | 97 | if !config.DisablePrintStack { 98 | stack = make([]byte, config.StackSize) 99 | length = runtime.Stack(stack, !config.DisableStackAll) 100 | stack = stack[:length] 101 | } 102 | 103 | if config.LogErrorFunc != nil { 104 | err = config.LogErrorFunc(c, err, stack) 105 | } else if !config.DisablePrintStack { 106 | msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) 107 | switch config.LogLevel { 108 | case log.DEBUG: 109 | c.Logger().Debug(msg) 110 | case log.INFO: 111 | c.Logger().Info(msg) 112 | case log.WARN: 113 | c.Logger().Warn(msg) 114 | case log.ERROR: 115 | c.Logger().Error(msg) 116 | case log.OFF: 117 | // None. 118 | default: 119 | c.Logger().Print(msg) 120 | } 121 | } 122 | 123 | if err != nil && !config.DisableErrorHandler { 124 | c.Error(err) 125 | } else { 126 | returnErr = err 127 | } 128 | } 129 | }() 130 | return next(c) 131 | } 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /middleware/recover_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "bytes" 8 | "errors" 9 | "fmt" 10 | "net/http" 11 | "net/http/httptest" 12 | "testing" 13 | 14 | "github.com/labstack/echo/v4" 15 | "github.com/labstack/gommon/log" 16 | "github.com/stretchr/testify/assert" 17 | ) 18 | 19 | func TestRecover(t *testing.T) { 20 | e := echo.New() 21 | buf := new(bytes.Buffer) 22 | e.Logger.SetOutput(buf) 23 | req := httptest.NewRequest(http.MethodGet, "/", nil) 24 | rec := httptest.NewRecorder() 25 | c := e.NewContext(req, rec) 26 | h := Recover()(echo.HandlerFunc(func(c echo.Context) error { 27 | panic("test") 28 | })) 29 | err := h(c) 30 | assert.NoError(t, err) 31 | assert.Equal(t, http.StatusInternalServerError, rec.Code) 32 | assert.Contains(t, buf.String(), "PANIC RECOVER") 33 | } 34 | 35 | func TestRecoverErrAbortHandler(t *testing.T) { 36 | e := echo.New() 37 | buf := new(bytes.Buffer) 38 | e.Logger.SetOutput(buf) 39 | req := httptest.NewRequest(http.MethodGet, "/", nil) 40 | rec := httptest.NewRecorder() 41 | c := e.NewContext(req, rec) 42 | h := Recover()(echo.HandlerFunc(func(c echo.Context) error { 43 | panic(http.ErrAbortHandler) 44 | })) 45 | defer func() { 46 | r := recover() 47 | if r == nil { 48 | assert.Fail(t, "expecting `http.ErrAbortHandler`, got `nil`") 49 | } else { 50 | if err, ok := r.(error); ok { 51 | assert.ErrorIs(t, err, http.ErrAbortHandler) 52 | } else { 53 | assert.Fail(t, "not of error type") 54 | } 55 | } 56 | }() 57 | 58 | h(c) 59 | 60 | assert.Equal(t, http.StatusInternalServerError, rec.Code) 61 | assert.NotContains(t, buf.String(), "PANIC RECOVER") 62 | } 63 | 64 | func TestRecoverWithConfig_LogLevel(t *testing.T) { 65 | tests := []struct { 66 | logLevel log.Lvl 67 | levelName string 68 | }{{ 69 | logLevel: log.DEBUG, 70 | levelName: "DEBUG", 71 | }, { 72 | logLevel: log.INFO, 73 | levelName: "INFO", 74 | }, { 75 | logLevel: log.WARN, 76 | levelName: "WARN", 77 | }, { 78 | logLevel: log.ERROR, 79 | levelName: "ERROR", 80 | }, { 81 | logLevel: log.OFF, 82 | levelName: "OFF", 83 | }} 84 | 85 | for _, tt := range tests { 86 | tt := tt 87 | t.Run(tt.levelName, func(t *testing.T) { 88 | e := echo.New() 89 | e.Logger.SetLevel(log.DEBUG) 90 | 91 | buf := new(bytes.Buffer) 92 | e.Logger.SetOutput(buf) 93 | 94 | req := httptest.NewRequest(http.MethodGet, "/", nil) 95 | rec := httptest.NewRecorder() 96 | c := e.NewContext(req, rec) 97 | 98 | config := DefaultRecoverConfig 99 | config.LogLevel = tt.logLevel 100 | h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { 101 | panic("test") 102 | })) 103 | 104 | h(c) 105 | 106 | assert.Equal(t, http.StatusInternalServerError, rec.Code) 107 | 108 | output := buf.String() 109 | if tt.logLevel == log.OFF { 110 | assert.Empty(t, output) 111 | } else { 112 | assert.Contains(t, output, "PANIC RECOVER") 113 | assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName)) 114 | } 115 | }) 116 | } 117 | } 118 | 119 | func TestRecoverWithConfig_LogErrorFunc(t *testing.T) { 120 | e := echo.New() 121 | e.Logger.SetLevel(log.DEBUG) 122 | 123 | buf := new(bytes.Buffer) 124 | e.Logger.SetOutput(buf) 125 | 126 | req := httptest.NewRequest(http.MethodGet, "/", nil) 127 | rec := httptest.NewRecorder() 128 | c := e.NewContext(req, rec) 129 | 130 | testError := errors.New("test") 131 | config := DefaultRecoverConfig 132 | config.LogErrorFunc = func(c echo.Context, err error, stack []byte) error { 133 | msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack) 134 | if errors.Is(err, testError) { 135 | c.Logger().Debug(msg) 136 | } else { 137 | c.Logger().Error(msg) 138 | } 139 | return err 140 | } 141 | 142 | t.Run("first branch case for LogErrorFunc", func(t *testing.T) { 143 | buf.Reset() 144 | h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { 145 | panic(testError) 146 | })) 147 | 148 | h(c) 149 | assert.Equal(t, http.StatusInternalServerError, rec.Code) 150 | 151 | output := buf.String() 152 | assert.Contains(t, output, "PANIC RECOVER") 153 | assert.Contains(t, output, `"level":"DEBUG"`) 154 | }) 155 | 156 | t.Run("else branch case for LogErrorFunc", func(t *testing.T) { 157 | buf.Reset() 158 | h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { 159 | panic("other") 160 | })) 161 | 162 | h(c) 163 | assert.Equal(t, http.StatusInternalServerError, rec.Code) 164 | 165 | output := buf.String() 166 | assert.Contains(t, output, "PANIC RECOVER") 167 | assert.Contains(t, output, `"level":"ERROR"`) 168 | }) 169 | } 170 | 171 | func TestRecoverWithDisabled_ErrorHandler(t *testing.T) { 172 | e := echo.New() 173 | buf := new(bytes.Buffer) 174 | e.Logger.SetOutput(buf) 175 | req := httptest.NewRequest(http.MethodGet, "/", nil) 176 | rec := httptest.NewRecorder() 177 | c := e.NewContext(req, rec) 178 | 179 | config := DefaultRecoverConfig 180 | config.DisableErrorHandler = true 181 | h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { 182 | panic("test") 183 | })) 184 | err := h(c) 185 | 186 | assert.Equal(t, http.StatusOK, rec.Code) 187 | assert.Contains(t, buf.String(), "PANIC RECOVER") 188 | assert.EqualError(t, err, "test") 189 | } 190 | -------------------------------------------------------------------------------- /middleware/redirect.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "net/http" 8 | "strings" 9 | 10 | "github.com/labstack/echo/v4" 11 | ) 12 | 13 | // RedirectConfig defines the config for Redirect middleware. 14 | type RedirectConfig struct { 15 | // Skipper defines a function to skip middleware. 16 | Skipper 17 | 18 | // Status code to be used when redirecting the request. 19 | // Optional. Default value http.StatusMovedPermanently. 20 | Code int `yaml:"code"` 21 | } 22 | 23 | // redirectLogic represents a function that given a scheme, host and uri 24 | // can both: 1) determine if redirect is needed (will set ok accordingly) and 25 | // 2) return the appropriate redirect url. 26 | type redirectLogic func(scheme, host, uri string) (ok bool, url string) 27 | 28 | const www = "www." 29 | 30 | // DefaultRedirectConfig is the default Redirect middleware config. 31 | var DefaultRedirectConfig = RedirectConfig{ 32 | Skipper: DefaultSkipper, 33 | Code: http.StatusMovedPermanently, 34 | } 35 | 36 | // HTTPSRedirect redirects http requests to https. 37 | // For example, http://labstack.com will be redirect to https://labstack.com. 38 | // 39 | // Usage `Echo#Pre(HTTPSRedirect())` 40 | func HTTPSRedirect() echo.MiddlewareFunc { 41 | return HTTPSRedirectWithConfig(DefaultRedirectConfig) 42 | } 43 | 44 | // HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. 45 | // See `HTTPSRedirect()`. 46 | func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { 47 | return redirect(config, func(scheme, host, uri string) (bool, string) { 48 | if scheme != "https" { 49 | return true, "https://" + host + uri 50 | } 51 | return false, "" 52 | }) 53 | } 54 | 55 | // HTTPSWWWRedirect redirects http requests to https www. 56 | // For example, http://labstack.com will be redirect to https://www.labstack.com. 57 | // 58 | // Usage `Echo#Pre(HTTPSWWWRedirect())` 59 | func HTTPSWWWRedirect() echo.MiddlewareFunc { 60 | return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig) 61 | } 62 | 63 | // HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. 64 | // See `HTTPSWWWRedirect()`. 65 | func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { 66 | return redirect(config, func(scheme, host, uri string) (bool, string) { 67 | if scheme != "https" && !strings.HasPrefix(host, www) { 68 | return true, "https://www." + host + uri 69 | } 70 | return false, "" 71 | }) 72 | } 73 | 74 | // HTTPSNonWWWRedirect redirects http requests to https non www. 75 | // For example, http://www.labstack.com will be redirect to https://labstack.com. 76 | // 77 | // Usage `Echo#Pre(HTTPSNonWWWRedirect())` 78 | func HTTPSNonWWWRedirect() echo.MiddlewareFunc { 79 | return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig) 80 | } 81 | 82 | // HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. 83 | // See `HTTPSNonWWWRedirect()`. 84 | func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { 85 | return redirect(config, func(scheme, host, uri string) (ok bool, url string) { 86 | if scheme != "https" { 87 | host = strings.TrimPrefix(host, www) 88 | return true, "https://" + host + uri 89 | } 90 | return false, "" 91 | }) 92 | } 93 | 94 | // WWWRedirect redirects non www requests to www. 95 | // For example, http://labstack.com will be redirect to http://www.labstack.com. 96 | // 97 | // Usage `Echo#Pre(WWWRedirect())` 98 | func WWWRedirect() echo.MiddlewareFunc { 99 | return WWWRedirectWithConfig(DefaultRedirectConfig) 100 | } 101 | 102 | // WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. 103 | // See `WWWRedirect()`. 104 | func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { 105 | return redirect(config, func(scheme, host, uri string) (bool, string) { 106 | if !strings.HasPrefix(host, www) { 107 | return true, scheme + "://www." + host + uri 108 | } 109 | return false, "" 110 | }) 111 | } 112 | 113 | // NonWWWRedirect redirects www requests to non www. 114 | // For example, http://www.labstack.com will be redirect to http://labstack.com. 115 | // 116 | // Usage `Echo#Pre(NonWWWRedirect())` 117 | func NonWWWRedirect() echo.MiddlewareFunc { 118 | return NonWWWRedirectWithConfig(DefaultRedirectConfig) 119 | } 120 | 121 | // NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. 122 | // See `NonWWWRedirect()`. 123 | func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { 124 | return redirect(config, func(scheme, host, uri string) (bool, string) { 125 | if strings.HasPrefix(host, www) { 126 | return true, scheme + "://" + host[4:] + uri 127 | } 128 | return false, "" 129 | }) 130 | } 131 | 132 | func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { 133 | if config.Skipper == nil { 134 | config.Skipper = DefaultRedirectConfig.Skipper 135 | } 136 | if config.Code == 0 { 137 | config.Code = DefaultRedirectConfig.Code 138 | } 139 | 140 | return func(next echo.HandlerFunc) echo.HandlerFunc { 141 | return func(c echo.Context) error { 142 | if config.Skipper(c) { 143 | return next(c) 144 | } 145 | 146 | req, scheme := c.Request(), c.Scheme() 147 | host := req.Host 148 | if ok, url := cb(scheme, host, req.RequestURI); ok { 149 | return c.Redirect(config.Code, url) 150 | } 151 | 152 | return next(c) 153 | } 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /middleware/request_id.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "github.com/labstack/echo/v4" 8 | ) 9 | 10 | // RequestIDConfig defines the config for RequestID middleware. 11 | type RequestIDConfig struct { 12 | // Skipper defines a function to skip middleware. 13 | Skipper Skipper 14 | 15 | // Generator defines a function to generate an ID. 16 | // Optional. Defaults to generator for random string of length 32. 17 | Generator func() string 18 | 19 | // RequestIDHandler defines a function which is executed for a request id. 20 | RequestIDHandler func(echo.Context, string) 21 | 22 | // TargetHeader defines what header to look for to populate the id 23 | TargetHeader string 24 | } 25 | 26 | // DefaultRequestIDConfig is the default RequestID middleware config. 27 | var DefaultRequestIDConfig = RequestIDConfig{ 28 | Skipper: DefaultSkipper, 29 | Generator: generator, 30 | TargetHeader: echo.HeaderXRequestID, 31 | } 32 | 33 | // RequestID returns a X-Request-ID middleware. 34 | func RequestID() echo.MiddlewareFunc { 35 | return RequestIDWithConfig(DefaultRequestIDConfig) 36 | } 37 | 38 | // RequestIDWithConfig returns a X-Request-ID middleware with config. 39 | func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { 40 | // Defaults 41 | if config.Skipper == nil { 42 | config.Skipper = DefaultRequestIDConfig.Skipper 43 | } 44 | if config.Generator == nil { 45 | config.Generator = generator 46 | } 47 | if config.TargetHeader == "" { 48 | config.TargetHeader = echo.HeaderXRequestID 49 | } 50 | 51 | return func(next echo.HandlerFunc) echo.HandlerFunc { 52 | return func(c echo.Context) error { 53 | if config.Skipper(c) { 54 | return next(c) 55 | } 56 | 57 | req := c.Request() 58 | res := c.Response() 59 | rid := req.Header.Get(config.TargetHeader) 60 | if rid == "" { 61 | rid = config.Generator() 62 | } 63 | res.Header().Set(config.TargetHeader, rid) 64 | if config.RequestIDHandler != nil { 65 | config.RequestIDHandler(c, rid) 66 | } 67 | 68 | return next(c) 69 | } 70 | } 71 | } 72 | 73 | func generator() string { 74 | return randomString(32) 75 | } 76 | -------------------------------------------------------------------------------- /middleware/request_id_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | 11 | "github.com/labstack/echo/v4" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestRequestID(t *testing.T) { 16 | e := echo.New() 17 | req := httptest.NewRequest(http.MethodGet, "/", nil) 18 | rec := httptest.NewRecorder() 19 | c := e.NewContext(req, rec) 20 | handler := func(c echo.Context) error { 21 | return c.String(http.StatusOK, "test") 22 | } 23 | 24 | rid := RequestIDWithConfig(RequestIDConfig{}) 25 | h := rid(handler) 26 | h(c) 27 | assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) 28 | 29 | // Custom generator and handler 30 | customID := "customGenerator" 31 | calledHandler := false 32 | rid = RequestIDWithConfig(RequestIDConfig{ 33 | Generator: func() string { return customID }, 34 | RequestIDHandler: func(_ echo.Context, id string) { 35 | calledHandler = true 36 | assert.Equal(t, customID, id) 37 | }, 38 | }) 39 | h = rid(handler) 40 | h(c) 41 | assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") 42 | assert.True(t, calledHandler) 43 | } 44 | 45 | func TestRequestID_IDNotAltered(t *testing.T) { 46 | e := echo.New() 47 | req := httptest.NewRequest(http.MethodGet, "/", nil) 48 | req.Header.Add(echo.HeaderXRequestID, "") 49 | 50 | rec := httptest.NewRecorder() 51 | c := e.NewContext(req, rec) 52 | handler := func(c echo.Context) error { 53 | return c.String(http.StatusOK, "test") 54 | } 55 | 56 | rid := RequestIDWithConfig(RequestIDConfig{}) 57 | h := rid(handler) 58 | _ = h(c) 59 | assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "") 60 | } 61 | 62 | func TestRequestIDConfigDifferentHeader(t *testing.T) { 63 | e := echo.New() 64 | req := httptest.NewRequest(http.MethodGet, "/", nil) 65 | rec := httptest.NewRecorder() 66 | c := e.NewContext(req, rec) 67 | handler := func(c echo.Context) error { 68 | return c.String(http.StatusOK, "test") 69 | } 70 | 71 | rid := RequestIDWithConfig(RequestIDConfig{TargetHeader: echo.HeaderXCorrelationID}) 72 | h := rid(handler) 73 | h(c) 74 | assert.Len(t, rec.Header().Get(echo.HeaderXCorrelationID), 32) 75 | 76 | // Custom generator and handler 77 | customID := "customGenerator" 78 | calledHandler := false 79 | rid = RequestIDWithConfig(RequestIDConfig{ 80 | Generator: func() string { return customID }, 81 | TargetHeader: echo.HeaderXCorrelationID, 82 | RequestIDHandler: func(_ echo.Context, id string) { 83 | calledHandler = true 84 | assert.Equal(t, customID, id) 85 | }, 86 | }) 87 | h = rid(handler) 88 | h(c) 89 | assert.Equal(t, rec.Header().Get(echo.HeaderXCorrelationID), "customGenerator") 90 | assert.True(t, calledHandler) 91 | } 92 | -------------------------------------------------------------------------------- /middleware/rewrite.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "regexp" 8 | 9 | "github.com/labstack/echo/v4" 10 | ) 11 | 12 | // RewriteConfig defines the config for Rewrite middleware. 13 | type RewriteConfig struct { 14 | // Skipper defines a function to skip middleware. 15 | Skipper Skipper 16 | 17 | // Rules defines the URL path rewrite rules. The values captured in asterisk can be 18 | // retrieved by index e.g. $1, $2 and so on. 19 | // Example: 20 | // "/old": "/new", 21 | // "/api/*": "/$1", 22 | // "/js/*": "/public/javascripts/$1", 23 | // "/users/*/orders/*": "/user/$1/order/$2", 24 | // Required. 25 | Rules map[string]string `yaml:"rules"` 26 | 27 | // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures 28 | // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. 29 | // Example: 30 | // "^/old/[0.9]+/": "/new", 31 | // "^/api/.+?/(.*)": "/v2/$1", 32 | RegexRules map[*regexp.Regexp]string `yaml:"-"` 33 | } 34 | 35 | // DefaultRewriteConfig is the default Rewrite middleware config. 36 | var DefaultRewriteConfig = RewriteConfig{ 37 | Skipper: DefaultSkipper, 38 | } 39 | 40 | // Rewrite returns a Rewrite middleware. 41 | // 42 | // Rewrite middleware rewrites the URL path based on the provided rules. 43 | func Rewrite(rules map[string]string) echo.MiddlewareFunc { 44 | c := DefaultRewriteConfig 45 | c.Rules = rules 46 | return RewriteWithConfig(c) 47 | } 48 | 49 | // RewriteWithConfig returns a Rewrite middleware with config. 50 | // See: `Rewrite()`. 51 | func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { 52 | // Defaults 53 | if config.Rules == nil && config.RegexRules == nil { 54 | panic("echo: rewrite middleware requires url path rewrite rules or regex rules") 55 | } 56 | 57 | if config.Skipper == nil { 58 | config.Skipper = DefaultBodyDumpConfig.Skipper 59 | } 60 | 61 | if config.RegexRules == nil { 62 | config.RegexRules = make(map[*regexp.Regexp]string) 63 | } 64 | for k, v := range rewriteRulesRegex(config.Rules) { 65 | config.RegexRules[k] = v 66 | } 67 | 68 | return func(next echo.HandlerFunc) echo.HandlerFunc { 69 | return func(c echo.Context) (err error) { 70 | if config.Skipper(c) { 71 | return next(c) 72 | } 73 | 74 | if err := rewriteURL(config.RegexRules, c.Request()); err != nil { 75 | return err 76 | } 77 | return next(c) 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /middleware/secure.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors 3 | 4 | package middleware 5 | 6 | import ( 7 | "fmt" 8 | 9 | "github.com/labstack/echo/v4" 10 | ) 11 | 12 | // SecureConfig defines the config for Secure middleware. 13 | type SecureConfig struct { 14 | // Skipper defines a function to skip middleware. 15 | Skipper Skipper 16 | 17 | // XSSProtection provides protection against cross-site scripting attack (XSS) 18 | // by setting the `X-XSS-Protection` header. 19 | // Optional. Default value "1; mode=block". 20 | XSSProtection string `yaml:"xss_protection"` 21 | 22 | // ContentTypeNosniff provides protection against overriding Content-Type 23 | // header by setting the `X-Content-Type-Options` header. 24 | // Optional. Default value "nosniff". 25 | ContentTypeNosniff string `yaml:"content_type_nosniff"` 26 | 27 | // XFrameOptions can be used to indicate whether or not a browser should 28 | // be allowed to render a page in a ,