├── .editorconfig
├── .gitattributes
├── .github
├── FUNDING.yml
├── ISSUE_TEMPLATE.md
├── stale.yml
└── workflows
│ ├── checks.yml
│ └── echo.yml
├── .gitignore
├── CHANGELOG.md
├── LICENSE
├── Makefile
├── README.md
├── _fixture
├── _fixture
│ └── README.md
├── certs
│ ├── README.md
│ ├── cert.pem
│ └── key.pem
├── favicon.ico
├── folder
│ └── index.html
├── images
│ └── walle.png
└── index.html
├── bind.go
├── bind_test.go
├── binder.go
├── binder_external_test.go
├── binder_test.go
├── codecov.yml
├── context.go
├── context_fs.go
├── context_fs_test.go
├── context_test.go
├── echo.go
├── echo_fs.go
├── echo_fs_test.go
├── echo_test.go
├── go.mod
├── go.sum
├── group.go
├── group_fs.go
├── group_fs_test.go
├── group_test.go
├── ip.go
├── ip_test.go
├── json.go
├── json_test.go
├── log.go
├── middleware
├── basic_auth.go
├── basic_auth_test.go
├── body_dump.go
├── body_dump_test.go
├── body_limit.go
├── body_limit_test.go
├── compress.go
├── compress_test.go
├── context_timeout.go
├── context_timeout_test.go
├── cors.go
├── cors_test.go
├── csrf.go
├── csrf_test.go
├── decompress.go
├── decompress_test.go
├── extractor.go
├── extractor_test.go
├── key_auth.go
├── key_auth_test.go
├── logger.go
├── logger_test.go
├── method_override.go
├── method_override_test.go
├── middleware.go
├── middleware_test.go
├── proxy.go
├── proxy_test.go
├── rate_limiter.go
├── rate_limiter_test.go
├── recover.go
├── recover_test.go
├── redirect.go
├── redirect_test.go
├── request_id.go
├── request_id_test.go
├── request_logger.go
├── request_logger_test.go
├── rewrite.go
├── rewrite_test.go
├── secure.go
├── secure_test.go
├── slash.go
├── slash_test.go
├── static.go
├── static_other.go
├── static_test.go
├── static_windows.go
├── timeout.go
├── timeout_test.go
├── util.go
└── util_test.go
├── renderer.go
├── renderer_test.go
├── response.go
├── response_test.go
├── router.go
└── router_test.go
/.editorconfig:
--------------------------------------------------------------------------------
1 | # EditorConfig coding styles definitions. For more information about the
2 | # properties used in this file, please see the EditorConfig documentation:
3 | # http://editorconfig.org/
4 |
5 | # indicate this is the root of the project
6 | root = true
7 |
8 | [*]
9 | charset = utf-8
10 |
11 | end_of_line = LF
12 | insert_final_newline = true
13 | trim_trailing_whitespace = true
14 |
15 | indent_style = space
16 | indent_size = 2
17 |
18 | [Makefile]
19 | indent_style = tab
20 |
21 | [*.md]
22 | trim_trailing_whitespace = false
23 |
24 | [*.go]
25 | indent_style = tab
26 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Automatically normalize line endings for all text-based files
2 | # http://git-scm.com/docs/gitattributes#_end_of_line_conversion
3 | * text=auto
4 |
5 | # For the following file types, normalize line endings to LF on checking and
6 | # prevent conversion to CRLF when they are checked out (this is required in
7 | # order to prevent newline related issues)
8 | .* text eol=lf
9 | *.go text eol=lf
10 | *.yml text eol=lf
11 | *.html text eol=lf
12 | *.css text eol=lf
13 | *.js text eol=lf
14 | *.json text eol=lf
15 | LICENSE text eol=lf
16 |
17 | # Exclude `website` and `cookbook` from GitHub's language statistics
18 | # https://github.com/github/linguist#using-gitattributes
19 | cookbook/* linguist-documentation
20 | website/* linguist-documentation
21 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: [labstack]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
13 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ### Issue Description
2 |
3 | ### Working code to debug
4 |
5 | ```go
6 | package main
7 |
8 | import (
9 | "github.com/labstack/echo/v4"
10 | "net/http"
11 | "net/http/httptest"
12 | "testing"
13 | )
14 |
15 | func TestExample(t *testing.T) {
16 | e := echo.New()
17 |
18 | e.GET("/", func(c echo.Context) error {
19 | return c.String(http.StatusOK, "Hello, World!")
20 | })
21 |
22 | req := httptest.NewRequest(http.MethodGet, "/", nil)
23 | rec := httptest.NewRecorder()
24 |
25 | e.ServeHTTP(rec, req)
26 |
27 | if rec.Code != http.StatusOK {
28 | t.Errorf("got %d, want %d", rec.Code, http.StatusOK)
29 | }
30 | }
31 | ```
32 |
33 | ### Version/commit
34 |
--------------------------------------------------------------------------------
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # Number of days of inactivity before an issue becomes stale
2 | daysUntilStale: 60
3 | # Number of days of inactivity before a stale issue is closed
4 | daysUntilClose: 30
5 | # Issues with these labels will never be considered stale
6 | exemptLabels:
7 | - pinned
8 | - security
9 | - bug
10 | - enhancement
11 | # Label to use when marking an issue as stale
12 | staleLabel: stale
13 | # Comment to post when marking an issue as stale. Set to `false` to disable
14 | markComment: >
15 | This issue has been automatically marked as stale because it has not had
16 | recent activity. It will be closed within a month if no further activity occurs.
17 | Thank you for your contributions.
18 | # Comment to post when closing a stale issue. Set to `false` to disable
19 | closeComment: false
20 |
--------------------------------------------------------------------------------
/.github/workflows/checks.yml:
--------------------------------------------------------------------------------
1 | name: Run checks
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 | branches:
9 | - master
10 | workflow_dispatch:
11 |
12 | permissions:
13 | contents: read # to fetch code (actions/checkout)
14 |
15 | env:
16 | # run static analysis only with the latest Go version
17 | LATEST_GO_VERSION: "1.24"
18 |
19 | jobs:
20 | check:
21 | runs-on: ubuntu-latest
22 | steps:
23 | - name: Checkout Code
24 | uses: actions/checkout@v4
25 |
26 | - name: Set up Go ${{ matrix.go }}
27 | uses: actions/setup-go@v5
28 | with:
29 | go-version: ${{ env.LATEST_GO_VERSION }}
30 | check-latest: true
31 |
32 | - name: Run golint
33 | run: |
34 | go install golang.org/x/lint/golint@latest
35 | golint -set_exit_status ./...
36 |
37 | - name: Run staticcheck
38 | run: |
39 | go install honnef.co/go/tools/cmd/staticcheck@latest
40 | staticcheck ./...
41 |
42 | - name: Run govulncheck
43 | run: |
44 | go version
45 | go install golang.org/x/vuln/cmd/govulncheck@latest
46 | govulncheck ./...
47 |
48 |
49 |
--------------------------------------------------------------------------------
/.github/workflows/echo.yml:
--------------------------------------------------------------------------------
1 | name: Run Tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 | branches:
9 | - master
10 | workflow_dispatch:
11 |
12 | permissions:
13 | contents: read # to fetch code (actions/checkout)
14 |
15 | env:
16 | # run coverage and benchmarks only with the latest Go version
17 | LATEST_GO_VERSION: "1.24"
18 |
19 | jobs:
20 | test:
21 | strategy:
22 | matrix:
23 | os: [ubuntu-latest, macos-latest, windows-latest]
24 | # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
25 | # Echo tests with last four major releases (unless there are pressing vulnerabilities)
26 | # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when
27 | # we derive from last four major releases promise.
28 | go: ["1.21", "1.22", "1.23", "1.24"]
29 | name: ${{ matrix.os }} @ Go ${{ matrix.go }}
30 | runs-on: ${{ matrix.os }}
31 | steps:
32 | - name: Checkout Code
33 | uses: actions/checkout@v4
34 |
35 | - name: Set up Go ${{ matrix.go }}
36 | uses: actions/setup-go@v5
37 | with:
38 | go-version: ${{ matrix.go }}
39 |
40 | - name: Run Tests
41 | run: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./...
42 |
43 | - name: Upload coverage to Codecov
44 | if: success() && matrix.go == env.LATEST_GO_VERSION && matrix.os == 'ubuntu-latest'
45 | uses: codecov/codecov-action@v3
46 | with:
47 | token:
48 | fail_ci_if_error: false
49 |
50 | benchmark:
51 | needs: test
52 | name: Benchmark comparison
53 | runs-on: ubuntu-latest
54 | steps:
55 | - name: Checkout Code (Previous)
56 | uses: actions/checkout@v4
57 | with:
58 | ref: ${{ github.base_ref }}
59 | path: previous
60 |
61 | - name: Checkout Code (New)
62 | uses: actions/checkout@v4
63 | with:
64 | path: new
65 |
66 | - name: Set up Go ${{ matrix.go }}
67 | uses: actions/setup-go@v5
68 | with:
69 | go-version: ${{ env.LATEST_GO_VERSION }}
70 |
71 | - name: Install Dependencies
72 | run: go install golang.org/x/perf/cmd/benchstat@latest
73 |
74 | - name: Run Benchmark (Previous)
75 | run: |
76 | cd previous
77 | go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
78 |
79 | - name: Run Benchmark (New)
80 | run: |
81 | cd new
82 | go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
83 |
84 | - name: Run Benchstat
85 | run: |
86 | benchstat previous/benchmark.txt new/benchmark.txt
87 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | coverage.txt
3 | _test
4 | vendor
5 | .idea
6 | *.iml
7 | *.out
8 | .vscode
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2021 LabStack
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | PKG := "github.com/labstack/echo"
2 | PKG_LIST := $(shell go list ${PKG}/...)
3 |
4 | tag:
5 | @git tag `grep -P '^\tversion = ' echo.go|cut -f2 -d'"'`
6 | @git tag|grep -v ^v
7 |
8 | .DEFAULT_GOAL := check
9 | check: lint vet race ## Check project
10 |
11 | init:
12 | @go install golang.org/x/lint/golint@latest
13 | @go install honnef.co/go/tools/cmd/staticcheck@latest
14 |
15 | lint: ## Lint the files
16 | @staticcheck ${PKG_LIST}
17 | @golint -set_exit_status ${PKG_LIST}
18 |
19 | vet: ## Vet the files
20 | @go vet ${PKG_LIST}
21 |
22 | test: ## Run tests
23 | @go test -short ${PKG_LIST}
24 |
25 | race: ## Run tests with data race detector
26 | @go test -race ${PKG_LIST}
27 |
28 | benchmark: ## Run benchmarks
29 | @go test -run="-" -bench=".*" ${PKG_LIST}
30 |
31 | help: ## Display this help screen
32 | @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
33 |
34 | goversion ?= "1.21"
35 | test_version: ## Run tests inside Docker with given version (defaults to 1.21 oldest supported). Example: make test_version goversion=1.21
36 | @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"
37 |
--------------------------------------------------------------------------------
/_fixture/_fixture/README.md:
--------------------------------------------------------------------------------
1 | This directory is used for the static middleware test
--------------------------------------------------------------------------------
/_fixture/certs/README.md:
--------------------------------------------------------------------------------
1 | To generate a valid certificate and private key use the following command:
2 |
3 | ```bash
4 | # In OpenSSL ≥ 1.1.1
5 | openssl req -x509 -newkey rsa:4096 -sha256 -days 9999 -nodes \
6 | -keyout key.pem -out cert.pem -subj "/CN=localhost" \
7 | -addext "subjectAltName=DNS:localhost,IP:127.0.0.1,IP:::1"
8 | ```
9 |
10 | To check a certificate use the following command:
11 | ```bash
12 | openssl x509 -in cert.pem -text
13 | ```
14 |
--------------------------------------------------------------------------------
/_fixture/certs/cert.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIIFODCCAyCgAwIBAgIUaTvDluaMf+VJgYHQ0HFTS3yuCHYwDQYJKoZIhvcNAQEL
3 | BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIxMDIyNzIxMzQ0MVoXDTQ4MDcx
4 | NDIxMzQ0MVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEF
5 | AAOCAg8AMIICCgKCAgEAnqyyAAnWFH2TH7Epj5yfZxYrBvizydZe1Wo/1WpGR2IK
6 | QT+qIul5sEKX/ERqEOXsawSrL3fw9cuSM8Z2vD/57ZZdoSR7XIdVaMDEQenJ968a
7 | HObu4D27uBQwIwrM5ELgnd+fC4gis64nIu+2GSfHumZXi7lLW7DbNm8oWkMqI6tY
8 | 2s2wx2hwGYNVJrwSn4WGnkzhQ5U5mkcsLELMx7GR0Qnv6P7sNGZVeqMU7awkcSpR
9 | crKR1OUP7XCJkEq83WLHSx50+QZv7LiyDmGnujHevRbdSHlcFfHZtaufYat+qICe
10 | S3XADwRQe/0VSsmja6u3DAHy7VmL8PNisAdkopQZrhiI9OvGrpGZffs9zn+s/jeX
11 | N1bqVDihCMiEjqXMlHx2oj3AXrZTFxb7y7Ap9C07nf70lpxQWW9SjMYRF98JBiHF
12 | eJbQkNVkmz6T8ielQbX0l46F2SGK98oyFCGNIAZBUdj5CcS1E6w/lk4t58/em0k7
13 | 3wFC5qg0g0wfIbNSmxljBNxnaBYUqyaaAJJhpaEoOebm4RYV58hQ0FbMfpnLnSh4
14 | dYStsk6i1PumWoa7D45DTtxF3kH7TB3YOB5aWaNGAPQC1m4Qcd23YB5Rd/ABirSp
15 | ux6/cFGosjSfJ/G+G0RhNUpmcbDJvFSOhD2WCuieVhCTAzp+VPIA9bSqD+InlT0C
16 | AwEAAaOBgTB/MB0GA1UdDgQWBBQZyM//SvzYKokQZI/0MVGb6PkH+zAfBgNVHSME
17 | GDAWgBQZyM//SvzYKokQZI/0MVGb6PkH+zAPBgNVHRMBAf8EBTADAQH/MCwGA1Ud
18 | EQQlMCOCCWxvY2FsaG9zdIcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG
19 | 9w0BAQsFAAOCAgEAKGAJQmQ/KLw8iMb5QsyxxAonVjJ1eDAhNM3GWdHpM0/GFamO
20 | vVtATLQQldwDiZJvrsCQPEc8ctZ2Utvg/StLQ3+rZpsvt0+gcUlLJK61qguwYqb2
21 | +T7VK5s7V/OyI/tsuboOW50Pka9vQHV+Z0aM06Yu+HNDAq/UTpEOb/3MQvZd6Ooy
22 | PTpZtFb/+5jIQa1dIsfFWmpBxF0+wUd9GEkX3j7nekwoZfJ8Ze4GWYERZbOFpDAQ
23 | rIHdthH5VJztnpQJmaKqzgIOF+Rurwlp5ecSC33xNNjDaYtuf/fiWnoKGhHVSBhT
24 | 61+0yxn3rTgh/Dsm95xY00rSX6lmcvI+kRNTUc8GGPz0ajBH6xyY7bNhfMjmnSW/
25 | C/XTEDbTAhT7ndWC5vvzp7ZU0TvN+WY6A0f2kxSnnrEk6QRUvRtKkjAkmAFz8exi
26 | ttBBW0I3E5HNIC5CYRimq/9z+3clM/P1KbNblwuC65bL+PZ+nzFnn5hFaK9eLPol
27 | OwZQXv7IvAw8GfgLTrEUT7eBCQwe1IqesA7NTxF1BVwmNUb2XamvQZ7ly67QybRw
28 | 0uJq80XjpVjBWYTTQy1dsnC2OTKdqGsV9TVIDR+UGfIG9cxL70pEbiSH2AX+IDCy
29 | i3kNIvpXgBliAyOjW6Hj1fv6dNfAat/hqEfnquWkfvcs3HNrG/InwpwNAUs=
30 | -----END CERTIFICATE-----
31 |
--------------------------------------------------------------------------------
/_fixture/certs/key.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN PRIVATE KEY-----
2 | MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQCerLIACdYUfZMf
3 | sSmPnJ9nFisG+LPJ1l7Vaj/VakZHYgpBP6oi6XmwQpf8RGoQ5exrBKsvd/D1y5Iz
4 | xna8P/ntll2hJHtch1VowMRB6cn3rxoc5u7gPbu4FDAjCszkQuCd358LiCKzrici
5 | 77YZJ8e6ZleLuUtbsNs2byhaQyojq1jazbDHaHAZg1UmvBKfhYaeTOFDlTmaRyws
6 | QszHsZHRCe/o/uw0ZlV6oxTtrCRxKlFyspHU5Q/tcImQSrzdYsdLHnT5Bm/suLIO
7 | Yae6Md69Ft1IeVwV8dm1q59hq36ogJ5LdcAPBFB7/RVKyaNrq7cMAfLtWYvw82Kw
8 | B2SilBmuGIj068aukZl9+z3Of6z+N5c3VupUOKEIyISOpcyUfHaiPcBetlMXFvvL
9 | sCn0LTud/vSWnFBZb1KMxhEX3wkGIcV4ltCQ1WSbPpPyJ6VBtfSXjoXZIYr3yjIU
10 | IY0gBkFR2PkJxLUTrD+WTi3nz96bSTvfAULmqDSDTB8hs1KbGWME3GdoFhSrJpoA
11 | kmGloSg55ubhFhXnyFDQVsx+mcudKHh1hK2yTqLU+6ZahrsPjkNO3EXeQftMHdg4
12 | HlpZo0YA9ALWbhBx3bdgHlF38AGKtKm7Hr9wUaiyNJ8n8b4bRGE1SmZxsMm8VI6E
13 | PZYK6J5WEJMDOn5U8gD1tKoP4ieVPQIDAQABAoICAEHF2CsH6MOpofi7GT08cR7s
14 | I33KTcxWngzc9ATk/qjMTO/rEf1Sxmx3zkR1n3nNtQhPcR5GG43nin0HwWQbKOCB
15 | OeJ4GuKp/o9jiHbCEEQpQyvD1jUBofSV+bYs3e2ogy8t6OGA1tGgWPy0XMlkoff0
16 | QEnczw3864FO5m0z9h2/Ax//r02ZTw5kUEG0KAwT709jEuVO0AfRhM/8CKKmSola
17 | EyaDtSmrWbdyLlSuzJRUNFrVBno3UTjdM0iqkks6jN3ojBhFwNNhY/1uIXafAXNk
18 | LOnD1JYMIHCb6X809VWnqvYgozIWWb5rlA3iM2mITmId1LLqMYX5fWj2R5LUzSek
19 | H+XG+F9FIouTaL1ACoXr0zyeY5N5YJdyXYa1tThdW+axX9ZrnPgeiQrmxzKPIyb7
20 | LLlVtNBQUg/t5tX80KyYjkNUu4j3oq/uBYPi0m//ovwMyi9bSbbyPT+cDXuXX5Bc
21 | oY7wyn3evXX0c1R7vdJLZLkLu+ctVex/9hvMjeW/mMasDjLnqY7pF3Skct1SX5N2
22 | U8YVU9bGvFpLEwM9lmi/T7bcv+zbmGPlfTsZiFrCsixPLn7sX7y5M4L8au8O0jh0
23 | nHm/8rWVg1Qw0Hobg3tA8FjeMa8Sr2fYmkNLVKFzhuJLxknTJLaUbX5CymNqWP4H
24 | OctvfSY0nSZ1eQpBkQaJAoIBAQDTb/NhYCfaJBLXHVMy/VYd7kWGZ+I87artcE/l
25 | 8u0pJ8XOP4kp0otFIumpHUFodysAeP6HrI79MuJB40fy91HzWZC+NrPufFFFuZ0z
26 | Ld1o3Y5nAeoZmMlf1F12Oe3OQZy7nm9eNNkfeoVtKqDv4FhAqk+aoMor86HscKsR
27 | C6HlZFdGc7kX0ylrQAXPq9KLhcvUU9oAUpbqTbhYK83IebRJgFDG45HkVo9SUHpF
28 | dmCFSb91eZpRGpdfNLCuLiSu52TebayaUCnceeAt8SyeiChJ/TwWmRRDJS0QUv6h
29 | s3Wdp+cx9ANoujA4XzAs8Fld5IZ4bcG5jjwD62/tJyWrCC5DAoIBAQDAHfHjrYCK
30 | GHBrMj+MA7cK7fCJUn/iJLSLGgo2ANYF5oq9gaCwHCtKIyB9DN/KiY0JpJ6PWg+Q
31 | 9Difq23YXiJjNEBS5EFTu9UwWAr1RhSAegrfHxm0sDbcAx31NtDYvBsADCWQYmzc
32 | KPfBshf5K4g/VCIj2VzC2CE6kNtdhqLU6AV2Pi1Tl1S82xWoAjHy91tDmlFQNWCj
33 | B2ZnZ7tY9zuwDfeBBOVCPHICgl5Q4PrY1KEWEXiNxgbtkNmOPAsY9WSqgOsP9pWK
34 | J924gdCCvovINzZtgRisxKth6Fkhra+VCsheg9SWvgR09Deo6CCoSwYxOSb0cjh2
35 | oyX5Rb1kJ7Z/AoIBAQCX2iNVoBV/GcFeNXV3fXLH9ESCj0FwuNC1zp/TanDhyerK
36 | gd8k5k2Xzcc66gP73vpHUJ6dGlVni4/r+ivGV9HHkF/f/LGlaiuEhBZel2YY1mZb
37 | nIhg8dZOuNqW+mvMYlsKdHNPmW0GqpwBF0iWfu1jI+4gA7Kvdj6o7RIvH8eaVEJK
38 | GvqoHcP1fvmteJ2yDtmhGMfMy4QPqtnmmS8l+CJ/V2SsMuyorXIpkBsAoFAZ6ilT
39 | WY53CT4F5nWt4v39j7pl9SatfT1TV0SmOjvtb6Rf3zu0jyR6RMzkmHa/839ZRylI
40 | OxPntzDCi7qxy7yjLmlVPJ6RgZGgzwqHrEHlX+65AoIBAQCEzu6d3x5B2N02LZli
41 | eFr8MjqbI64GLiulEY5HgNJzZ8k3cjocJI0Ehj36VIEMaYRXSzbVkIO8SCgwsPiR
42 | n5mUDNX+t441jV62Odbxcc3Qdw226rABieOSupDmKEu92GOt57e8FV5939BOVYhf
43 | FunsJYQoViXbCEAIVYVgJSfBmNfVwuvgonfQyn8xErtm4/pyRGa71PqGGSKAj2Qi
44 | /16CuVUFGtZFsLV76JW8wZqHdI4bTF6TW3cEmaLbwcRGL7W0bMSS13rO8/pBh3QW
45 | PhUxhoGYt6rQHHEBkPa04nXDyZ10QRwgTSGVnBIyMK4KyTpxorm8OI2x7dzdcomX
46 | iCCPAoIBAETwfr2JKPb/AzrKhhbZgU+sLVn3WH/nb68VheNEmGOzsqXaSHCR2NOq
47 | /ow7bawjc8yUIhBRzokR4F/7jGolOmfdq0MYFb6/YokssKfv1ugxBhmvOxpZ6F6E
48 | cERJ8Ex/ffQU053gLR/0ammddVuS1GR5I/jEdP0lJVh0xapoZNUlT5dWYCgo20hY
49 | ZAmKpU+veyUn+5Li0pmm959vnLK5LJzEA5mpz3w1QPPtVwQs05dwmEV3CRAcCeeh
50 | 8sXp49WNCSW4I3BxuTZzRV845SGIFhZwgVV42PTp2LPKl2p6E7Bk8xpUCCvBpALp
51 | QmA5yIMx+u2Jpr7fUsXEXEPTEhvjff0=
52 | -----END PRIVATE KEY-----
53 |
--------------------------------------------------------------------------------
/_fixture/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/labstack/echo/98ca08e7dd64075b858e758d6693bf9799340756/_fixture/favicon.ico
--------------------------------------------------------------------------------
/_fixture/folder/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Echo
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/_fixture/images/walle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/labstack/echo/98ca08e7dd64075b858e758d6693bf9799340756/_fixture/images/walle.png
--------------------------------------------------------------------------------
/_fixture/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Echo
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/binder_external_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | // run tests as external package to get real feel for API
5 | package echo_test
6 |
7 | import (
8 | "encoding/base64"
9 | "fmt"
10 | "github.com/labstack/echo/v4"
11 | "log"
12 | "net/http"
13 | "net/http/httptest"
14 | )
15 |
16 | func ExampleValueBinder_BindErrors() {
17 | // example route function that binds query params to different destinations and returns all bind errors in one go
18 | routeFunc := func(c echo.Context) error {
19 | var opts struct {
20 | Active bool
21 | IDs []int64
22 | }
23 | length := int64(50) // default length is 50
24 |
25 | b := echo.QueryParamsBinder(c)
26 |
27 | errs := b.Int64("length", &length).
28 | Int64s("ids", &opts.IDs).
29 | Bool("active", &opts.Active).
30 | BindErrors() // returns all errors
31 | if errs != nil {
32 | for _, err := range errs {
33 | bErr := err.(*echo.BindingError)
34 | log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values)
35 | }
36 | return fmt.Errorf("%v fields failed to bind", len(errs))
37 | }
38 | fmt.Printf("active = %v, length = %v, ids = %v", opts.Active, length, opts.IDs)
39 |
40 | return c.JSON(http.StatusOK, opts)
41 | }
42 |
43 | e := echo.New()
44 | c := e.NewContext(
45 | httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil),
46 | httptest.NewRecorder(),
47 | )
48 |
49 | _ = routeFunc(c)
50 |
51 | // Output: active = true, length = 25, ids = [1 2 3]
52 | }
53 |
54 | func ExampleValueBinder_BindError() {
55 | // example route function that binds query params to different destinations and stops binding on first bind error
56 | failFastRouteFunc := func(c echo.Context) error {
57 | var opts struct {
58 | Active bool
59 | IDs []int64
60 | }
61 | length := int64(50) // default length is 50
62 |
63 | // create binder that stops binding at first error
64 | b := echo.QueryParamsBinder(c)
65 |
66 | err := b.Int64("length", &length).
67 | Int64s("ids", &opts.IDs).
68 | Bool("active", &opts.Active).
69 | BindError() // returns first binding error
70 | if err != nil {
71 | bErr := err.(*echo.BindingError)
72 | return fmt.Errorf("my own custom error for field: %s values: %v", bErr.Field, bErr.Values)
73 | }
74 | fmt.Printf("active = %v, length = %v, ids = %v\n", opts.Active, length, opts.IDs)
75 |
76 | return c.JSON(http.StatusOK, opts)
77 | }
78 |
79 | e := echo.New()
80 | c := e.NewContext(
81 | httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil),
82 | httptest.NewRecorder(),
83 | )
84 |
85 | _ = failFastRouteFunc(c)
86 |
87 | // Output: active = true, length = 25, ids = [1 2 3]
88 | }
89 |
90 | func ExampleValueBinder_CustomFunc() {
91 | // example route function that binds query params using custom function closure
92 | routeFunc := func(c echo.Context) error {
93 | length := int64(50) // default length is 50
94 | var binary []byte
95 |
96 | b := echo.QueryParamsBinder(c)
97 | errs := b.Int64("length", &length).
98 | CustomFunc("base64", func(values []string) []error {
99 | if len(values) == 0 {
100 | return nil
101 | }
102 | decoded, err := base64.URLEncoding.DecodeString(values[0])
103 | if err != nil {
104 | // in this example we use only first param value but url could contain multiple params in reality and
105 | // therefore in theory produce multiple binding errors
106 | return []error{echo.NewBindingError("base64", values[0:1], "failed to decode base64", err)}
107 | }
108 | binary = decoded
109 | return nil
110 | }).
111 | BindErrors() // returns all errors
112 |
113 | if errs != nil {
114 | for _, err := range errs {
115 | bErr := err.(*echo.BindingError)
116 | log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values)
117 | }
118 | return fmt.Errorf("%v fields failed to bind", len(errs))
119 | }
120 | fmt.Printf("length = %v, base64 = %s", length, binary)
121 |
122 | return c.JSON(http.StatusOK, "ok")
123 | }
124 |
125 | e := echo.New()
126 | c := e.NewContext(
127 | httptest.NewRequest(http.MethodGet, "/api/endpoint?length=25&base64=SGVsbG8gV29ybGQ%3D", nil),
128 | httptest.NewRecorder(),
129 | )
130 | _ = routeFunc(c)
131 |
132 | // Output: length = 25, base64 = Hello World
133 | }
134 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | status:
3 | project:
4 | default:
5 | threshold: 1%
6 | patch:
7 | default:
8 | threshold: 1%
9 |
10 | comment:
11 | require_changes: true
--------------------------------------------------------------------------------
/context_fs.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package echo
5 |
6 | import (
7 | "errors"
8 | "io"
9 | "io/fs"
10 | "net/http"
11 | "path/filepath"
12 | )
13 |
14 | func (c *context) File(file string) error {
15 | return fsFile(c, file, c.echo.Filesystem)
16 | }
17 |
18 | // FileFS serves file from given file system.
19 | //
20 | // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
21 | // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
22 | // including `assets/images` as their prefix.
23 | func (c *context) FileFS(file string, filesystem fs.FS) error {
24 | return fsFile(c, file, filesystem)
25 | }
26 |
27 | func fsFile(c Context, file string, filesystem fs.FS) error {
28 | f, err := filesystem.Open(file)
29 | if err != nil {
30 | return ErrNotFound
31 | }
32 | defer f.Close()
33 |
34 | fi, _ := f.Stat()
35 | if fi.IsDir() {
36 | file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect.
37 | f, err = filesystem.Open(file)
38 | if err != nil {
39 | return ErrNotFound
40 | }
41 | defer f.Close()
42 | if fi, err = f.Stat(); err != nil {
43 | return err
44 | }
45 | }
46 | ff, ok := f.(io.ReadSeeker)
47 | if !ok {
48 | return errors.New("file does not implement io.ReadSeeker")
49 | }
50 | http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
51 | return nil
52 | }
53 |
--------------------------------------------------------------------------------
/context_fs_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package echo
5 |
6 | import (
7 | "github.com/stretchr/testify/assert"
8 | "io/fs"
9 | "net/http"
10 | "net/http/httptest"
11 | "os"
12 | "testing"
13 | )
14 |
15 | func TestContext_File(t *testing.T) {
16 | var testCases = []struct {
17 | name string
18 | whenFile string
19 | whenFS fs.FS
20 | expectStatus int
21 | expectStartsWith []byte
22 | expectError string
23 | }{
24 | {
25 | name: "ok, from default file system",
26 | whenFile: "_fixture/images/walle.png",
27 | whenFS: nil,
28 | expectStatus: http.StatusOK,
29 | expectStartsWith: []byte{0x89, 0x50, 0x4e},
30 | },
31 | {
32 | name: "ok, from custom file system",
33 | whenFile: "walle.png",
34 | whenFS: os.DirFS("_fixture/images"),
35 | expectStatus: http.StatusOK,
36 | expectStartsWith: []byte{0x89, 0x50, 0x4e},
37 | },
38 | {
39 | name: "nok, not existent file",
40 | whenFile: "not.png",
41 | whenFS: os.DirFS("_fixture/images"),
42 | expectStatus: http.StatusOK,
43 | expectStartsWith: nil,
44 | expectError: "code=404, message=Not Found",
45 | },
46 | }
47 |
48 | for _, tc := range testCases {
49 | t.Run(tc.name, func(t *testing.T) {
50 | e := New()
51 | if tc.whenFS != nil {
52 | e.Filesystem = tc.whenFS
53 | }
54 |
55 | handler := func(ec Context) error {
56 | return ec.(*context).File(tc.whenFile)
57 | }
58 |
59 | req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
60 | rec := httptest.NewRecorder()
61 | c := e.NewContext(req, rec)
62 |
63 | err := handler(c)
64 |
65 | assert.Equal(t, tc.expectStatus, rec.Code)
66 | if tc.expectError != "" {
67 | assert.EqualError(t, err, tc.expectError)
68 | } else {
69 | assert.NoError(t, err)
70 | }
71 |
72 | body := rec.Body.Bytes()
73 | if len(body) > len(tc.expectStartsWith) {
74 | body = body[:len(tc.expectStartsWith)]
75 | }
76 | assert.Equal(t, tc.expectStartsWith, body)
77 | })
78 | }
79 | }
80 |
81 | func TestContext_FileFS(t *testing.T) {
82 | var testCases = []struct {
83 | name string
84 | whenFile string
85 | whenFS fs.FS
86 | expectStatus int
87 | expectStartsWith []byte
88 | expectError string
89 | }{
90 | {
91 | name: "ok",
92 | whenFile: "walle.png",
93 | whenFS: os.DirFS("_fixture/images"),
94 | expectStatus: http.StatusOK,
95 | expectStartsWith: []byte{0x89, 0x50, 0x4e},
96 | },
97 | {
98 | name: "nok, not existent file",
99 | whenFile: "not.png",
100 | whenFS: os.DirFS("_fixture/images"),
101 | expectStatus: http.StatusOK,
102 | expectStartsWith: nil,
103 | expectError: "code=404, message=Not Found",
104 | },
105 | }
106 |
107 | for _, tc := range testCases {
108 | t.Run(tc.name, func(t *testing.T) {
109 | e := New()
110 |
111 | handler := func(ec Context) error {
112 | return ec.(*context).FileFS(tc.whenFile, tc.whenFS)
113 | }
114 |
115 | req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
116 | rec := httptest.NewRecorder()
117 | c := e.NewContext(req, rec)
118 |
119 | err := handler(c)
120 |
121 | assert.Equal(t, tc.expectStatus, rec.Code)
122 | if tc.expectError != "" {
123 | assert.EqualError(t, err, tc.expectError)
124 | } else {
125 | assert.NoError(t, err)
126 | }
127 |
128 | body := rec.Body.Bytes()
129 | if len(body) > len(tc.expectStartsWith) {
130 | body = body[:len(tc.expectStartsWith)]
131 | }
132 | assert.Equal(t, tc.expectStartsWith, body)
133 | })
134 | }
135 | }
136 |
--------------------------------------------------------------------------------
/echo_fs.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package echo
5 |
6 | import (
7 | "fmt"
8 | "io/fs"
9 | "net/http"
10 | "net/url"
11 | "os"
12 | "path/filepath"
13 | "strings"
14 | )
15 |
16 | type filesystem struct {
17 | // Filesystem is file system used by Static and File handlers to access files.
18 | // Defaults to os.DirFS(".")
19 | //
20 | // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
21 | // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
22 | // including `assets/images` as their prefix.
23 | Filesystem fs.FS
24 | }
25 |
26 | func createFilesystem() filesystem {
27 | return filesystem{
28 | Filesystem: newDefaultFS(),
29 | }
30 | }
31 |
32 | // Static registers a new route with path prefix to serve static files from the provided root directory.
33 | func (e *Echo) Static(pathPrefix, fsRoot string) *Route {
34 | subFs := MustSubFS(e.Filesystem, fsRoot)
35 | return e.Add(
36 | http.MethodGet,
37 | pathPrefix+"*",
38 | StaticDirectoryHandler(subFs, false),
39 | )
40 | }
41 |
42 | // StaticFS registers a new route with path prefix to serve static files from the provided file system.
43 | //
44 | // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
45 | // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
46 | // including `assets/images` as their prefix.
47 | func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route {
48 | return e.Add(
49 | http.MethodGet,
50 | pathPrefix+"*",
51 | StaticDirectoryHandler(filesystem, false),
52 | )
53 | }
54 |
55 | // StaticDirectoryHandler creates handler function to serve files from provided file system
56 | // When disablePathUnescaping is set then file name from path is not unescaped and is served as is.
57 | func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc {
58 | return func(c Context) error {
59 | p := c.Param("*")
60 | if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice
61 | tmpPath, err := url.PathUnescape(p)
62 | if err != nil {
63 | return fmt.Errorf("failed to unescape path variable: %w", err)
64 | }
65 | p = tmpPath
66 | }
67 |
68 | // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid
69 | name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/")))
70 | fi, err := fs.Stat(fileSystem, name)
71 | if err != nil {
72 | return ErrNotFound
73 | }
74 |
75 | // If the request is for a directory and does not end with "/"
76 | p = c.Request().URL.Path // path must not be empty.
77 | if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' {
78 | // Redirect to ends with "/"
79 | return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/"))
80 | }
81 | return fsFile(c, name, fileSystem)
82 | }
83 | }
84 |
85 | // FileFS registers a new route with path to serve file from the provided file system.
86 | func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route {
87 | return e.GET(path, StaticFileHandler(file, filesystem), m...)
88 | }
89 |
90 | // StaticFileHandler creates handler function to serve file from provided file system
91 | func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc {
92 | return func(c Context) error {
93 | return fsFile(c, file, filesystem)
94 | }
95 | }
96 |
97 | // defaultFS exists to preserve pre v4.7.0 behaviour where files were open by `os.Open`.
98 | // v4.7 introduced `echo.Filesystem` field which is Go1.16+ `fs.Fs` interface.
99 | // Difference between `os.Open` and `fs.Open` is that FS does not allow opening path that start with `.`, `..` or `/`
100 | // etc. For example previously you could have `../images` in your application but `fs := os.DirFS("./")` would not
101 | // allow you to use `fs.Open("../images")` and this would break all old applications that rely on being able to
102 | // traverse up from current executable run path.
103 | // NB: private because you really should use fs.FS implementation instances
104 | type defaultFS struct {
105 | fs fs.FS
106 | prefix string
107 | }
108 |
109 | func newDefaultFS() *defaultFS {
110 | dir, _ := os.Getwd()
111 | return &defaultFS{
112 | prefix: dir,
113 | fs: nil,
114 | }
115 | }
116 |
117 | func (fs defaultFS) Open(name string) (fs.File, error) {
118 | if fs.fs == nil {
119 | return os.Open(name)
120 | }
121 | return fs.fs.Open(name)
122 | }
123 |
124 | func subFS(currentFs fs.FS, root string) (fs.FS, error) {
125 | root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows
126 | if dFS, ok := currentFs.(*defaultFS); ok {
127 | // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS.
128 | // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we
129 | // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs
130 | if !filepath.IsAbs(root) {
131 | root = filepath.Join(dFS.prefix, root)
132 | }
133 | return &defaultFS{
134 | prefix: root,
135 | fs: os.DirFS(root),
136 | }, nil
137 | }
138 | return fs.Sub(currentFs, root)
139 | }
140 |
141 | // MustSubFS creates sub FS from current filesystem or panic on failure.
142 | // Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules.
143 | //
144 | // MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with
145 | // paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to
146 | // create sub fs which uses necessary prefix for directory path.
147 | func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS {
148 | subFs, err := subFS(currentFs, fsRoot)
149 | if err != nil {
150 | panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err))
151 | }
152 | return subFs
153 | }
154 |
155 | func sanitizeURI(uri string) string {
156 | // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri
157 | // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash
158 | if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') {
159 | uri = "/" + strings.TrimLeft(uri, `/\`)
160 | }
161 | return uri
162 | }
163 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/labstack/echo/v4
2 |
3 | go 1.23.0
4 |
5 | require (
6 | github.com/labstack/gommon v0.4.2
7 | github.com/stretchr/testify v1.10.0
8 | github.com/valyala/fasttemplate v1.2.2
9 | golang.org/x/crypto v0.38.0
10 | golang.org/x/net v0.40.0
11 | golang.org/x/time v0.11.0
12 | )
13 |
14 | require (
15 | github.com/davecgh/go-spew v1.1.1 // indirect
16 | github.com/mattn/go-colorable v0.1.14 // indirect
17 | github.com/mattn/go-isatty v0.0.20 // indirect
18 | github.com/pmezard/go-difflib v1.0.0 // indirect
19 | github.com/valyala/bytebufferpool v1.0.0 // indirect
20 | golang.org/x/sys v0.33.0 // indirect
21 | golang.org/x/text v0.25.0 // indirect
22 | gopkg.in/yaml.v3 v3.0.1 // indirect
23 | )
24 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3 | github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
4 | github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU=
5 | github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
6 | github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
7 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
8 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
9 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
10 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
11 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
12 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
13 | github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
14 | github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
15 | github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
16 | github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
17 | golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
18 | golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
19 | golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
20 | golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
21 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
22 | golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
23 | golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
24 | golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
25 | golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
26 | golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
27 | golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
28 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
29 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
30 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
31 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
32 |
--------------------------------------------------------------------------------
/group.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package echo
5 |
6 | import (
7 | "net/http"
8 | )
9 |
10 | // Group is a set of sub-routes for a specified route. It can be used for inner
11 | // routes that share a common middleware or functionality that should be separate
12 | // from the parent echo instance while still inheriting from it.
13 | type Group struct {
14 | common
15 | host string
16 | prefix string
17 | echo *Echo
18 | middleware []MiddlewareFunc
19 | }
20 |
21 | // Use implements `Echo#Use()` for sub-routes within the Group.
22 | func (g *Group) Use(middleware ...MiddlewareFunc) {
23 | g.middleware = append(g.middleware, middleware...)
24 | if len(g.middleware) == 0 {
25 | return
26 | }
27 | // group level middlewares are different from Echo `Pre` and `Use` middlewares (those are global). Group level middlewares
28 | // are only executed if they are added to the Router with route.
29 | // So we register catch all route (404 is a safe way to emulate route match) for this group and now during routing the
30 | // Router would find route to match our request path and therefore guarantee the middleware(s) will get executed.
31 | g.RouteNotFound("", NotFoundHandler)
32 | g.RouteNotFound("/*", NotFoundHandler)
33 | }
34 |
35 | // CONNECT implements `Echo#CONNECT()` for sub-routes within the Group.
36 | func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
37 | return g.Add(http.MethodConnect, path, h, m...)
38 | }
39 |
40 | // DELETE implements `Echo#DELETE()` for sub-routes within the Group.
41 | func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
42 | return g.Add(http.MethodDelete, path, h, m...)
43 | }
44 |
45 | // GET implements `Echo#GET()` for sub-routes within the Group.
46 | func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
47 | return g.Add(http.MethodGet, path, h, m...)
48 | }
49 |
50 | // HEAD implements `Echo#HEAD()` for sub-routes within the Group.
51 | func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
52 | return g.Add(http.MethodHead, path, h, m...)
53 | }
54 |
55 | // OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group.
56 | func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
57 | return g.Add(http.MethodOptions, path, h, m...)
58 | }
59 |
60 | // PATCH implements `Echo#PATCH()` for sub-routes within the Group.
61 | func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
62 | return g.Add(http.MethodPatch, path, h, m...)
63 | }
64 |
65 | // POST implements `Echo#POST()` for sub-routes within the Group.
66 | func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
67 | return g.Add(http.MethodPost, path, h, m...)
68 | }
69 |
70 | // PUT implements `Echo#PUT()` for sub-routes within the Group.
71 | func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
72 | return g.Add(http.MethodPut, path, h, m...)
73 | }
74 |
75 | // TRACE implements `Echo#TRACE()` for sub-routes within the Group.
76 | func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
77 | return g.Add(http.MethodTrace, path, h, m...)
78 | }
79 |
80 | // Any implements `Echo#Any()` for sub-routes within the Group.
81 | func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
82 | routes := make([]*Route, len(methods))
83 | for i, m := range methods {
84 | routes[i] = g.Add(m, path, handler, middleware...)
85 | }
86 | return routes
87 | }
88 |
89 | // Match implements `Echo#Match()` for sub-routes within the Group.
90 | func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
91 | routes := make([]*Route, len(methods))
92 | for i, m := range methods {
93 | routes[i] = g.Add(m, path, handler, middleware...)
94 | }
95 | return routes
96 | }
97 |
98 | // Group creates a new sub-group with prefix and optional sub-group-level middleware.
99 | func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
100 | m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
101 | m = append(m, g.middleware...)
102 | m = append(m, middleware...)
103 | sg = g.echo.Group(g.prefix+prefix, m...)
104 | sg.host = g.host
105 | return
106 | }
107 |
108 | // File implements `Echo#File()` for sub-routes within the Group.
109 | func (g *Group) File(path, file string) {
110 | g.file(path, file, g.GET)
111 | }
112 |
113 | // RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group.
114 | //
115 | // Example: `g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })`
116 | func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
117 | return g.Add(RouteNotFound, path, h, m...)
118 | }
119 |
120 | // Add implements `Echo#Add()` for sub-routes within the Group.
121 | func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
122 | // Combine into a new slice to avoid accidentally passing the same slice for
123 | // multiple routes, which would lead to later add() calls overwriting the
124 | // middleware from earlier calls.
125 | m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
126 | m = append(m, g.middleware...)
127 | m = append(m, middleware...)
128 | return g.echo.add(g.host, method, g.prefix+path, handler, m...)
129 | }
130 |
--------------------------------------------------------------------------------
/group_fs.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package echo
5 |
6 | import (
7 | "io/fs"
8 | "net/http"
9 | )
10 |
11 | // Static implements `Echo#Static()` for sub-routes within the Group.
12 | func (g *Group) Static(pathPrefix, fsRoot string) {
13 | subFs := MustSubFS(g.echo.Filesystem, fsRoot)
14 | g.StaticFS(pathPrefix, subFs)
15 | }
16 |
17 | // StaticFS implements `Echo#StaticFS()` for sub-routes within the Group.
18 | //
19 | // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
20 | // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
21 | // including `assets/images` as their prefix.
22 | func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) {
23 | g.Add(
24 | http.MethodGet,
25 | pathPrefix+"*",
26 | StaticDirectoryHandler(filesystem, false),
27 | )
28 | }
29 |
30 | // FileFS implements `Echo#FileFS()` for sub-routes within the Group.
31 | func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route {
32 | return g.GET(path, StaticFileHandler(file, filesystem), m...)
33 | }
34 |
--------------------------------------------------------------------------------
/group_fs_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package echo
5 |
6 | import (
7 | "github.com/stretchr/testify/assert"
8 | "io/fs"
9 | "net/http"
10 | "net/http/httptest"
11 | "os"
12 | "testing"
13 | )
14 |
15 | func TestGroup_FileFS(t *testing.T) {
16 | var testCases = []struct {
17 | name string
18 | whenPath string
19 | whenFile string
20 | whenFS fs.FS
21 | givenURL string
22 | expectCode int
23 | expectStartsWith []byte
24 | }{
25 | {
26 | name: "ok",
27 | whenPath: "/walle",
28 | whenFS: os.DirFS("_fixture/images"),
29 | whenFile: "walle.png",
30 | givenURL: "/assets/walle",
31 | expectCode: http.StatusOK,
32 | expectStartsWith: []byte{0x89, 0x50, 0x4e},
33 | },
34 | {
35 | name: "nok, requesting invalid path",
36 | whenPath: "/walle",
37 | whenFS: os.DirFS("_fixture/images"),
38 | whenFile: "walle.png",
39 | givenURL: "/assets/walle.png",
40 | expectCode: http.StatusNotFound,
41 | expectStartsWith: []byte(`{"message":"Not Found"}`),
42 | },
43 | {
44 | name: "nok, serving not existent file from filesystem",
45 | whenPath: "/walle",
46 | whenFS: os.DirFS("_fixture/images"),
47 | whenFile: "not-existent.png",
48 | givenURL: "/assets/walle",
49 | expectCode: http.StatusNotFound,
50 | expectStartsWith: []byte(`{"message":"Not Found"}`),
51 | },
52 | }
53 |
54 | for _, tc := range testCases {
55 | t.Run(tc.name, func(t *testing.T) {
56 | e := New()
57 | g := e.Group("/assets")
58 | g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
59 |
60 | req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
61 | rec := httptest.NewRecorder()
62 |
63 | e.ServeHTTP(rec, req)
64 |
65 | assert.Equal(t, tc.expectCode, rec.Code)
66 |
67 | body := rec.Body.Bytes()
68 | if len(body) > len(tc.expectStartsWith) {
69 | body = body[:len(tc.expectStartsWith)]
70 | }
71 | assert.Equal(t, tc.expectStartsWith, body)
72 | })
73 | }
74 | }
75 |
76 | func TestGroup_StaticPanic(t *testing.T) {
77 | var testCases = []struct {
78 | name string
79 | givenRoot string
80 | }{
81 | {
82 | name: "panics for ../",
83 | givenRoot: "../images",
84 | },
85 | {
86 | name: "panics for /",
87 | givenRoot: "/images",
88 | },
89 | }
90 |
91 | for _, tc := range testCases {
92 | t.Run(tc.name, func(t *testing.T) {
93 | e := New()
94 | e.Filesystem = os.DirFS("./")
95 |
96 | g := e.Group("/assets")
97 |
98 | assert.Panics(t, func() {
99 | g.Static("/images", tc.givenRoot)
100 | })
101 | })
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/group_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package echo
5 |
6 | import (
7 | "net/http"
8 | "net/http/httptest"
9 | "os"
10 | "testing"
11 |
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | // TODO: Fix me
16 | func TestGroup(t *testing.T) {
17 | g := New().Group("/group")
18 | h := func(Context) error { return nil }
19 | g.CONNECT("/", h)
20 | g.DELETE("/", h)
21 | g.GET("/", h)
22 | g.HEAD("/", h)
23 | g.OPTIONS("/", h)
24 | g.PATCH("/", h)
25 | g.POST("/", h)
26 | g.PUT("/", h)
27 | g.TRACE("/", h)
28 | g.Any("/", h)
29 | g.Match([]string{http.MethodGet, http.MethodPost}, "/", h)
30 | g.Static("/static", "/tmp")
31 | g.File("/walle", "_fixture/images//walle.png")
32 | }
33 |
34 | func TestGroupFile(t *testing.T) {
35 | e := New()
36 | g := e.Group("/group")
37 | g.File("/walle", "_fixture/images/walle.png")
38 | expectedData, err := os.ReadFile("_fixture/images/walle.png")
39 | assert.Nil(t, err)
40 | req := httptest.NewRequest(http.MethodGet, "/group/walle", nil)
41 | rec := httptest.NewRecorder()
42 | e.ServeHTTP(rec, req)
43 | assert.Equal(t, http.StatusOK, rec.Code)
44 | assert.Equal(t, expectedData, rec.Body.Bytes())
45 | }
46 |
47 | func TestGroupRouteMiddleware(t *testing.T) {
48 | // Ensure middleware slices are not re-used
49 | e := New()
50 | g := e.Group("/group")
51 | h := func(Context) error { return nil }
52 | m1 := func(next HandlerFunc) HandlerFunc {
53 | return func(c Context) error {
54 | return next(c)
55 | }
56 | }
57 | m2 := func(next HandlerFunc) HandlerFunc {
58 | return func(c Context) error {
59 | return next(c)
60 | }
61 | }
62 | m3 := func(next HandlerFunc) HandlerFunc {
63 | return func(c Context) error {
64 | return next(c)
65 | }
66 | }
67 | m4 := func(next HandlerFunc) HandlerFunc {
68 | return func(c Context) error {
69 | return c.NoContent(404)
70 | }
71 | }
72 | m5 := func(next HandlerFunc) HandlerFunc {
73 | return func(c Context) error {
74 | return c.NoContent(405)
75 | }
76 | }
77 | g.Use(m1, m2, m3)
78 | g.GET("/404", h, m4)
79 | g.GET("/405", h, m5)
80 |
81 | c, _ := request(http.MethodGet, "/group/404", e)
82 | assert.Equal(t, 404, c)
83 | c, _ = request(http.MethodGet, "/group/405", e)
84 | assert.Equal(t, 405, c)
85 | }
86 |
87 | func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
88 | // Ensure middleware and match any routes do not conflict
89 | e := New()
90 | g := e.Group("/group")
91 | m1 := func(next HandlerFunc) HandlerFunc {
92 | return func(c Context) error {
93 | return next(c)
94 | }
95 | }
96 | m2 := func(next HandlerFunc) HandlerFunc {
97 | return func(c Context) error {
98 | return c.String(http.StatusOK, c.Path())
99 | }
100 | }
101 | h := func(c Context) error {
102 | return c.String(http.StatusOK, c.Path())
103 | }
104 | g.Use(m1)
105 | g.GET("/help", h, m2)
106 | g.GET("/*", h, m2)
107 | g.GET("", h, m2)
108 | e.GET("unrelated", h, m2)
109 | e.GET("*", h, m2)
110 |
111 | _, m := request(http.MethodGet, "/group/help", e)
112 | assert.Equal(t, "/group/help", m)
113 | _, m = request(http.MethodGet, "/group/help/other", e)
114 | assert.Equal(t, "/group/*", m)
115 | _, m = request(http.MethodGet, "/group/404", e)
116 | assert.Equal(t, "/group/*", m)
117 | _, m = request(http.MethodGet, "/group", e)
118 | assert.Equal(t, "/group", m)
119 | _, m = request(http.MethodGet, "/other", e)
120 | assert.Equal(t, "/*", m)
121 | _, m = request(http.MethodGet, "/", e)
122 | assert.Equal(t, "/*", m)
123 |
124 | }
125 |
126 | func TestGroup_RouteNotFound(t *testing.T) {
127 | var testCases = []struct {
128 | name string
129 | whenURL string
130 | expectRoute interface{}
131 | expectCode int
132 | }{
133 | {
134 | name: "404, route to static not found handler /group/a/c/xx",
135 | whenURL: "/group/a/c/xx",
136 | expectRoute: "GET /group/a/c/xx",
137 | expectCode: http.StatusNotFound,
138 | },
139 | {
140 | name: "404, route to path param not found handler /group/a/:file",
141 | whenURL: "/group/a/echo.exe",
142 | expectRoute: "GET /group/a/:file",
143 | expectCode: http.StatusNotFound,
144 | },
145 | {
146 | name: "404, route to any not found handler /group/*",
147 | whenURL: "/group/b/echo.exe",
148 | expectRoute: "GET /group/*",
149 | expectCode: http.StatusNotFound,
150 | },
151 | {
152 | name: "200, route /group/a/c/df to /group/a/c/df",
153 | whenURL: "/group/a/c/df",
154 | expectRoute: "GET /group/a/c/df",
155 | expectCode: http.StatusOK,
156 | },
157 | }
158 |
159 | for _, tc := range testCases {
160 | t.Run(tc.name, func(t *testing.T) {
161 | e := New()
162 | g := e.Group("/group")
163 |
164 | okHandler := func(c Context) error {
165 | return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
166 | }
167 | notFoundHandler := func(c Context) error {
168 | return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
169 | }
170 |
171 | g.GET("/", okHandler)
172 | g.GET("/a/c/df", okHandler)
173 | g.GET("/a/b*", okHandler)
174 | g.PUT("/*", okHandler)
175 |
176 | g.RouteNotFound("/a/c/xx", notFoundHandler) // static
177 | g.RouteNotFound("/a/:file", notFoundHandler) // param
178 | g.RouteNotFound("/*", notFoundHandler) // any
179 |
180 | req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
181 | rec := httptest.NewRecorder()
182 |
183 | e.ServeHTTP(rec, req)
184 |
185 | assert.Equal(t, tc.expectCode, rec.Code)
186 | assert.Equal(t, tc.expectRoute, rec.Body.String())
187 | })
188 | }
189 | }
190 |
191 | func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) {
192 | var testCases = []struct {
193 | name string
194 | givenCustom404 bool
195 | whenURL string
196 | expectBody interface{}
197 | expectCode int
198 | }{
199 | {
200 | name: "ok, custom 404 handler is called with middleware",
201 | givenCustom404: true,
202 | whenURL: "/group/test3",
203 | expectBody: "GET /group/*",
204 | expectCode: http.StatusNotFound,
205 | },
206 | {
207 | name: "ok, default group 404 handler is called with middleware",
208 | givenCustom404: false,
209 | whenURL: "/group/test3",
210 | expectBody: "{\"message\":\"Not Found\"}\n",
211 | expectCode: http.StatusNotFound,
212 | },
213 | {
214 | name: "ok, (no slash) default group 404 handler is called with middleware",
215 | givenCustom404: false,
216 | whenURL: "/group",
217 | expectBody: "{\"message\":\"Not Found\"}\n",
218 | expectCode: http.StatusNotFound,
219 | },
220 | }
221 | for _, tc := range testCases {
222 | t.Run(tc.name, func(t *testing.T) {
223 |
224 | okHandler := func(c Context) error {
225 | return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
226 | }
227 | notFoundHandler := func(c Context) error {
228 | return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
229 | }
230 |
231 | e := New()
232 | e.GET("/test1", okHandler)
233 | e.RouteNotFound("/*", notFoundHandler)
234 |
235 | g := e.Group("/group")
236 | g.GET("/test1", okHandler)
237 |
238 | middlewareCalled := false
239 | g.Use(func(next HandlerFunc) HandlerFunc {
240 | return func(c Context) error {
241 | middlewareCalled = true
242 | return next(c)
243 | }
244 | })
245 | if tc.givenCustom404 {
246 | g.RouteNotFound("/*", notFoundHandler)
247 | }
248 |
249 | req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
250 | rec := httptest.NewRecorder()
251 |
252 | e.ServeHTTP(rec, req)
253 |
254 | assert.True(t, middlewareCalled)
255 | assert.Equal(t, tc.expectCode, rec.Code)
256 | assert.Equal(t, tc.expectBody, rec.Body.String())
257 | })
258 | }
259 | }
260 |
--------------------------------------------------------------------------------
/json.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package echo
5 |
6 | import (
7 | "encoding/json"
8 | "fmt"
9 | "net/http"
10 | )
11 |
12 | // DefaultJSONSerializer implements JSON encoding using encoding/json.
13 | type DefaultJSONSerializer struct{}
14 |
15 | // Serialize converts an interface into a json and writes it to the response.
16 | // You can optionally use the indent parameter to produce pretty JSONs.
17 | func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string) error {
18 | enc := json.NewEncoder(c.Response())
19 | if indent != "" {
20 | enc.SetIndent("", indent)
21 | }
22 | return enc.Encode(i)
23 | }
24 |
25 | // Deserialize reads a JSON from a request body and converts it into an interface.
26 | func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error {
27 | err := json.NewDecoder(c.Request().Body).Decode(i)
28 | if ute, ok := err.(*json.UnmarshalTypeError); ok {
29 | return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err)
30 | } else if se, ok := err.(*json.SyntaxError); ok {
31 | return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err)
32 | }
33 | return err
34 | }
35 |
--------------------------------------------------------------------------------
/json_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package echo
5 |
6 | import (
7 | "github.com/stretchr/testify/assert"
8 | "net/http"
9 | "net/http/httptest"
10 | "strings"
11 | "testing"
12 | )
13 |
14 | // Note this test is deliberately simple as there's not a lot to test.
15 | // Just need to ensure it writes JSONs. The heavy work is done by the context methods.
16 | func TestDefaultJSONCodec_Encode(t *testing.T) {
17 | e := New()
18 | req := httptest.NewRequest(http.MethodPost, "/", nil)
19 | rec := httptest.NewRecorder()
20 | c := e.NewContext(req, rec).(*context)
21 |
22 | // Echo
23 | assert.Equal(t, e, c.Echo())
24 |
25 | // Request
26 | assert.NotNil(t, c.Request())
27 |
28 | // Response
29 | assert.NotNil(t, c.Response())
30 |
31 | //--------
32 | // Default JSON encoder
33 | //--------
34 |
35 | enc := new(DefaultJSONSerializer)
36 |
37 | err := enc.Serialize(c, user{1, "Jon Snow"}, "")
38 | if assert.NoError(t, err) {
39 | assert.Equal(t, userJSON+"\n", rec.Body.String())
40 | }
41 |
42 | req = httptest.NewRequest(http.MethodPost, "/", nil)
43 | rec = httptest.NewRecorder()
44 | c = e.NewContext(req, rec).(*context)
45 | err = enc.Serialize(c, user{1, "Jon Snow"}, " ")
46 | if assert.NoError(t, err) {
47 | assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
48 | }
49 | }
50 |
51 | // Note this test is deliberately simple as there's not a lot to test.
52 | // Just need to ensure it writes JSONs. The heavy work is done by the context methods.
53 | func TestDefaultJSONCodec_Decode(t *testing.T) {
54 | e := New()
55 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
56 | rec := httptest.NewRecorder()
57 | c := e.NewContext(req, rec).(*context)
58 |
59 | // Echo
60 | assert.Equal(t, e, c.Echo())
61 |
62 | // Request
63 | assert.NotNil(t, c.Request())
64 |
65 | // Response
66 | assert.NotNil(t, c.Response())
67 |
68 | //--------
69 | // Default JSON encoder
70 | //--------
71 |
72 | enc := new(DefaultJSONSerializer)
73 |
74 | var u = user{}
75 | err := enc.Deserialize(c, &u)
76 | if assert.NoError(t, err) {
77 | assert.Equal(t, u, user{ID: 1, Name: "Jon Snow"})
78 | }
79 |
80 | var userUnmarshalSyntaxError = user{}
81 | req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent))
82 | rec = httptest.NewRecorder()
83 | c = e.NewContext(req, rec).(*context)
84 | err = enc.Deserialize(c, &userUnmarshalSyntaxError)
85 | assert.IsType(t, &HTTPError{}, err)
86 | assert.EqualError(t, err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value")
87 |
88 | var userUnmarshalTypeError = struct {
89 | ID string `json:"id"`
90 | Name string `json:"name"`
91 | }{}
92 |
93 | req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
94 | rec = httptest.NewRecorder()
95 | c = e.NewContext(req, rec).(*context)
96 | err = enc.Deserialize(c, &userUnmarshalTypeError)
97 | assert.IsType(t, &HTTPError{}, err)
98 | assert.EqualError(t, err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string")
99 |
100 | }
101 |
--------------------------------------------------------------------------------
/log.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package echo
5 |
6 | import (
7 | "github.com/labstack/gommon/log"
8 | "io"
9 | )
10 |
11 | // Logger defines the logging interface.
12 | type Logger interface {
13 | Output() io.Writer
14 | SetOutput(w io.Writer)
15 | Prefix() string
16 | SetPrefix(p string)
17 | Level() log.Lvl
18 | SetLevel(v log.Lvl)
19 | SetHeader(h string)
20 | Print(i ...interface{})
21 | Printf(format string, args ...interface{})
22 | Printj(j log.JSON)
23 | Debug(i ...interface{})
24 | Debugf(format string, args ...interface{})
25 | Debugj(j log.JSON)
26 | Info(i ...interface{})
27 | Infof(format string, args ...interface{})
28 | Infoj(j log.JSON)
29 | Warn(i ...interface{})
30 | Warnf(format string, args ...interface{})
31 | Warnj(j log.JSON)
32 | Error(i ...interface{})
33 | Errorf(format string, args ...interface{})
34 | Errorj(j log.JSON)
35 | Fatal(i ...interface{})
36 | Fatalj(j log.JSON)
37 | Fatalf(format string, args ...interface{})
38 | Panic(i ...interface{})
39 | Panicj(j log.JSON)
40 | Panicf(format string, args ...interface{})
41 | }
42 |
--------------------------------------------------------------------------------
/middleware/basic_auth.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "encoding/base64"
8 | "net/http"
9 | "strconv"
10 | "strings"
11 |
12 | "github.com/labstack/echo/v4"
13 | )
14 |
15 | // BasicAuthConfig defines the config for BasicAuth middleware.
16 | type BasicAuthConfig struct {
17 | // Skipper defines a function to skip middleware.
18 | Skipper Skipper
19 |
20 | // Validator is a function to validate BasicAuth credentials.
21 | // Required.
22 | Validator BasicAuthValidator
23 |
24 | // Realm is a string to define realm attribute of BasicAuth.
25 | // Default value "Restricted".
26 | Realm string
27 | }
28 |
29 | // BasicAuthValidator defines a function to validate BasicAuth credentials.
30 | // The function should return a boolean indicating whether the credentials are valid,
31 | // and an error if any error occurs during the validation process.
32 | type BasicAuthValidator func(string, string, echo.Context) (bool, error)
33 |
34 | const (
35 | basic = "basic"
36 | defaultRealm = "Restricted"
37 | )
38 |
39 | // DefaultBasicAuthConfig is the default BasicAuth middleware config.
40 | var DefaultBasicAuthConfig = BasicAuthConfig{
41 | Skipper: DefaultSkipper,
42 | Realm: defaultRealm,
43 | }
44 |
45 | // BasicAuth returns an BasicAuth middleware.
46 | //
47 | // For valid credentials it calls the next handler.
48 | // For missing or invalid credentials, it sends "401 - Unauthorized" response.
49 | func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
50 | c := DefaultBasicAuthConfig
51 | c.Validator = fn
52 | return BasicAuthWithConfig(c)
53 | }
54 |
55 | // BasicAuthWithConfig returns an BasicAuth middleware with config.
56 | // See `BasicAuth()`.
57 | func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
58 | // Defaults
59 | if config.Validator == nil {
60 | panic("echo: basic-auth middleware requires a validator function")
61 | }
62 | if config.Skipper == nil {
63 | config.Skipper = DefaultBasicAuthConfig.Skipper
64 | }
65 | if config.Realm == "" {
66 | config.Realm = defaultRealm
67 | }
68 |
69 | return func(next echo.HandlerFunc) echo.HandlerFunc {
70 | return func(c echo.Context) error {
71 | if config.Skipper(c) {
72 | return next(c)
73 | }
74 |
75 | auth := c.Request().Header.Get(echo.HeaderAuthorization)
76 | l := len(basic)
77 |
78 | if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) {
79 | // Invalid base64 shouldn't be treated as error
80 | // instead should be treated as invalid client input
81 | b, err := base64.StdEncoding.DecodeString(auth[l+1:])
82 | if err != nil {
83 | return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err)
84 | }
85 |
86 | cred := string(b)
87 | for i := 0; i < len(cred); i++ {
88 | if cred[i] == ':' {
89 | // Verify credentials
90 | valid, err := config.Validator(cred[:i], cred[i+1:], c)
91 | if err != nil {
92 | return err
93 | } else if valid {
94 | return next(c)
95 | }
96 | break
97 | }
98 | }
99 | }
100 |
101 | realm := defaultRealm
102 | if config.Realm != defaultRealm {
103 | realm = strconv.Quote(config.Realm)
104 | }
105 |
106 | // Need to return `401` for browsers to pop-up login box.
107 | c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm)
108 | return echo.ErrUnauthorized
109 | }
110 | }
111 | }
112 |
--------------------------------------------------------------------------------
/middleware/basic_auth_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "encoding/base64"
8 | "errors"
9 | "net/http"
10 | "net/http/httptest"
11 | "strings"
12 | "testing"
13 |
14 | "github.com/labstack/echo/v4"
15 | "github.com/stretchr/testify/assert"
16 | )
17 |
18 | func TestBasicAuth(t *testing.T) {
19 | e := echo.New()
20 |
21 | mockValidator := func(u, p string, c echo.Context) (bool, error) {
22 | if u == "joe" && p == "secret" {
23 | return true, nil
24 | }
25 | return false, nil
26 | }
27 |
28 | tests := []struct {
29 | name string
30 | authHeader string
31 | expectedCode int
32 | expectedAuth string
33 | skipperResult bool
34 | expectedErr bool
35 | expectedErrMsg string
36 | }{
37 | {
38 | name: "Valid credentials",
39 | authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
40 | expectedCode: http.StatusOK,
41 | },
42 | {
43 | name: "Case-insensitive header scheme",
44 | authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
45 | expectedCode: http.StatusOK,
46 | },
47 | {
48 | name: "Invalid credentials",
49 | authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")),
50 | expectedCode: http.StatusUnauthorized,
51 | expectedAuth: basic + ` realm="someRealm"`,
52 | expectedErr: true,
53 | expectedErrMsg: "Unauthorized",
54 | },
55 | {
56 | name: "Invalid base64 string",
57 | authHeader: basic + " invalidString",
58 | expectedCode: http.StatusBadRequest,
59 | expectedErr: true,
60 | expectedErrMsg: "Bad Request",
61 | },
62 | {
63 | name: "Missing Authorization header",
64 | expectedCode: http.StatusUnauthorized,
65 | expectedErr: true,
66 | expectedErrMsg: "Unauthorized",
67 | },
68 | {
69 | name: "Invalid Authorization header",
70 | authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")),
71 | expectedCode: http.StatusUnauthorized,
72 | expectedErr: true,
73 | expectedErrMsg: "Unauthorized",
74 | },
75 | {
76 | name: "Skipped Request",
77 | authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")),
78 | expectedCode: http.StatusOK,
79 | skipperResult: true,
80 | },
81 | }
82 |
83 | for _, tt := range tests {
84 | t.Run(tt.name, func(t *testing.T) {
85 |
86 | req := httptest.NewRequest(http.MethodGet, "/", nil)
87 | res := httptest.NewRecorder()
88 | c := e.NewContext(req, res)
89 |
90 | if tt.authHeader != "" {
91 | req.Header.Set(echo.HeaderAuthorization, tt.authHeader)
92 | }
93 |
94 | h := BasicAuthWithConfig(BasicAuthConfig{
95 | Validator: mockValidator,
96 | Realm: "someRealm",
97 | Skipper: func(c echo.Context) bool {
98 | return tt.skipperResult
99 | },
100 | })(func(c echo.Context) error {
101 | return c.String(http.StatusOK, "test")
102 | })
103 |
104 | err := h(c)
105 |
106 | if tt.expectedErr {
107 | var he *echo.HTTPError
108 | errors.As(err, &he)
109 | assert.Equal(t, tt.expectedCode, he.Code)
110 | if tt.expectedAuth != "" {
111 | assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate))
112 | }
113 | } else {
114 | assert.NoError(t, err)
115 | assert.Equal(t, tt.expectedCode, res.Code)
116 | }
117 | })
118 | }
119 | }
120 |
--------------------------------------------------------------------------------
/middleware/body_dump.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "bufio"
8 | "bytes"
9 | "errors"
10 | "io"
11 | "net"
12 | "net/http"
13 |
14 | "github.com/labstack/echo/v4"
15 | )
16 |
17 | // BodyDumpConfig defines the config for BodyDump middleware.
18 | type BodyDumpConfig struct {
19 | // Skipper defines a function to skip middleware.
20 | Skipper Skipper
21 |
22 | // Handler receives request and response payload.
23 | // Required.
24 | Handler BodyDumpHandler
25 | }
26 |
27 | // BodyDumpHandler receives the request and response payload.
28 | type BodyDumpHandler func(echo.Context, []byte, []byte)
29 |
30 | type bodyDumpResponseWriter struct {
31 | io.Writer
32 | http.ResponseWriter
33 | }
34 |
35 | // DefaultBodyDumpConfig is the default BodyDump middleware config.
36 | var DefaultBodyDumpConfig = BodyDumpConfig{
37 | Skipper: DefaultSkipper,
38 | }
39 |
40 | // BodyDump returns a BodyDump middleware.
41 | //
42 | // BodyDump middleware captures the request and response payload and calls the
43 | // registered handler.
44 | func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc {
45 | c := DefaultBodyDumpConfig
46 | c.Handler = handler
47 | return BodyDumpWithConfig(c)
48 | }
49 |
50 | // BodyDumpWithConfig returns a BodyDump middleware with config.
51 | // See: `BodyDump()`.
52 | func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
53 | // Defaults
54 | if config.Handler == nil {
55 | panic("echo: body-dump middleware requires a handler function")
56 | }
57 | if config.Skipper == nil {
58 | config.Skipper = DefaultBodyDumpConfig.Skipper
59 | }
60 |
61 | return func(next echo.HandlerFunc) echo.HandlerFunc {
62 | return func(c echo.Context) (err error) {
63 | if config.Skipper(c) {
64 | return next(c)
65 | }
66 |
67 | // Request
68 | reqBody := []byte{}
69 | if c.Request().Body != nil { // Read
70 | reqBody, _ = io.ReadAll(c.Request().Body)
71 | }
72 | c.Request().Body = io.NopCloser(bytes.NewBuffer(reqBody)) // Reset
73 |
74 | // Response
75 | resBody := new(bytes.Buffer)
76 | mw := io.MultiWriter(c.Response().Writer, resBody)
77 | writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer}
78 | c.Response().Writer = writer
79 |
80 | if err = next(c); err != nil {
81 | c.Error(err)
82 | }
83 |
84 | // Callback
85 | config.Handler(c, reqBody, resBody.Bytes())
86 |
87 | return
88 | }
89 | }
90 | }
91 |
92 | func (w *bodyDumpResponseWriter) WriteHeader(code int) {
93 | w.ResponseWriter.WriteHeader(code)
94 | }
95 |
96 | func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) {
97 | return w.Writer.Write(b)
98 | }
99 |
100 | func (w *bodyDumpResponseWriter) Flush() {
101 | err := http.NewResponseController(w.ResponseWriter).Flush()
102 | if err != nil && errors.Is(err, http.ErrNotSupported) {
103 | panic(errors.New("response writer flushing is not supported"))
104 | }
105 | }
106 |
107 | func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
108 | return http.NewResponseController(w.ResponseWriter).Hijack()
109 | }
110 |
111 | func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter {
112 | return w.ResponseWriter
113 | }
114 |
--------------------------------------------------------------------------------
/middleware/body_dump_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "errors"
8 | "io"
9 | "net/http"
10 | "net/http/httptest"
11 | "strings"
12 | "testing"
13 |
14 | "github.com/labstack/echo/v4"
15 | "github.com/stretchr/testify/assert"
16 | )
17 |
18 | func TestBodyDump(t *testing.T) {
19 | e := echo.New()
20 | hw := "Hello, World!"
21 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
22 | rec := httptest.NewRecorder()
23 | c := e.NewContext(req, rec)
24 | h := func(c echo.Context) error {
25 | body, err := io.ReadAll(c.Request().Body)
26 | if err != nil {
27 | return err
28 | }
29 | return c.String(http.StatusOK, string(body))
30 | }
31 |
32 | requestBody := ""
33 | responseBody := ""
34 | mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {
35 | requestBody = string(reqBody)
36 | responseBody = string(resBody)
37 | })
38 |
39 | if assert.NoError(t, mw(h)(c)) {
40 | assert.Equal(t, requestBody, hw)
41 | assert.Equal(t, responseBody, hw)
42 | assert.Equal(t, http.StatusOK, rec.Code)
43 | assert.Equal(t, hw, rec.Body.String())
44 | }
45 |
46 | // Must set default skipper
47 | BodyDumpWithConfig(BodyDumpConfig{
48 | Skipper: nil,
49 | Handler: func(c echo.Context, reqBody, resBody []byte) {
50 | requestBody = string(reqBody)
51 | responseBody = string(resBody)
52 | },
53 | })
54 | }
55 |
56 | func TestBodyDumpFails(t *testing.T) {
57 | e := echo.New()
58 | hw := "Hello, World!"
59 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
60 | rec := httptest.NewRecorder()
61 | c := e.NewContext(req, rec)
62 | h := func(c echo.Context) error {
63 | return errors.New("some error")
64 | }
65 |
66 | mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {})
67 |
68 | if !assert.Error(t, mw(h)(c)) {
69 | t.FailNow()
70 | }
71 |
72 | assert.Panics(t, func() {
73 | mw = BodyDumpWithConfig(BodyDumpConfig{
74 | Skipper: nil,
75 | Handler: nil,
76 | })
77 | })
78 |
79 | assert.NotPanics(t, func() {
80 | mw = BodyDumpWithConfig(BodyDumpConfig{
81 | Skipper: func(c echo.Context) bool {
82 | return true
83 | },
84 | Handler: func(c echo.Context, reqBody, resBody []byte) {
85 | },
86 | })
87 |
88 | if !assert.Error(t, mw(h)(c)) {
89 | t.FailNow()
90 | }
91 | })
92 | }
93 |
94 | func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) {
95 | bdrw := bodyDumpResponseWriter{
96 | ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush
97 | }
98 |
99 | assert.PanicsWithError(t, "response writer flushing is not supported", func() {
100 | bdrw.Flush()
101 | })
102 | }
103 |
104 | func TestBodyDumpResponseWriter_CanFlush(t *testing.T) {
105 | trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
106 | bdrw := bodyDumpResponseWriter{
107 | ResponseWriter: &trwu,
108 | }
109 |
110 | bdrw.Flush()
111 | assert.Equal(t, 1, trwu.unwrapCalled)
112 | }
113 |
114 | func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) {
115 | trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
116 | bdrw := bodyDumpResponseWriter{
117 | ResponseWriter: trwu,
118 | }
119 |
120 | result := bdrw.Unwrap()
121 | assert.Equal(t, trwu, result)
122 | }
123 |
124 | func TestBodyDumpResponseWriter_CanHijack(t *testing.T) {
125 | trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
126 | bdrw := bodyDumpResponseWriter{
127 | ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
128 | }
129 |
130 | _, _, err := bdrw.Hijack()
131 | assert.EqualError(t, err, "can hijack")
132 | }
133 |
134 | func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) {
135 | trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
136 | bdrw := bodyDumpResponseWriter{
137 | ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
138 | }
139 |
140 | _, _, err := bdrw.Hijack()
141 | assert.EqualError(t, err, "feature not supported")
142 | }
143 |
--------------------------------------------------------------------------------
/middleware/body_limit.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "fmt"
8 | "io"
9 | "sync"
10 |
11 | "github.com/labstack/echo/v4"
12 | "github.com/labstack/gommon/bytes"
13 | )
14 |
15 | // BodyLimitConfig defines the config for BodyLimit middleware.
16 | type BodyLimitConfig struct {
17 | // Skipper defines a function to skip middleware.
18 | Skipper Skipper
19 |
20 | // Maximum allowed size for a request body, it can be specified
21 | // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P.
22 | Limit string `yaml:"limit"`
23 | limit int64
24 | }
25 |
26 | type limitedReader struct {
27 | BodyLimitConfig
28 | reader io.ReadCloser
29 | read int64
30 | }
31 |
32 | // DefaultBodyLimitConfig is the default BodyLimit middleware config.
33 | var DefaultBodyLimitConfig = BodyLimitConfig{
34 | Skipper: DefaultSkipper,
35 | }
36 |
37 | // BodyLimit returns a BodyLimit middleware.
38 | //
39 | // BodyLimit middleware sets the maximum allowed size for a request body, if the
40 | // size exceeds the configured limit, it sends "413 - Request Entity Too Large"
41 | // response. The BodyLimit is determined based on both `Content-Length` request
42 | // header and actual content read, which makes it super secure.
43 | // Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
44 | // G, T or P.
45 | func BodyLimit(limit string) echo.MiddlewareFunc {
46 | c := DefaultBodyLimitConfig
47 | c.Limit = limit
48 | return BodyLimitWithConfig(c)
49 | }
50 |
51 | // BodyLimitWithConfig returns a BodyLimit middleware with config.
52 | // See: `BodyLimit()`.
53 | func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
54 | // Defaults
55 | if config.Skipper == nil {
56 | config.Skipper = DefaultBodyLimitConfig.Skipper
57 | }
58 |
59 | limit, err := bytes.Parse(config.Limit)
60 | if err != nil {
61 | panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit))
62 | }
63 | config.limit = limit
64 | pool := limitedReaderPool(config)
65 |
66 | return func(next echo.HandlerFunc) echo.HandlerFunc {
67 | return func(c echo.Context) error {
68 | if config.Skipper(c) {
69 | return next(c)
70 | }
71 |
72 | req := c.Request()
73 |
74 | // Based on content length
75 | if req.ContentLength > config.limit {
76 | return echo.ErrStatusRequestEntityTooLarge
77 | }
78 |
79 | // Based on content read
80 | r := pool.Get().(*limitedReader)
81 | r.Reset(req.Body)
82 | defer pool.Put(r)
83 | req.Body = r
84 |
85 | return next(c)
86 | }
87 | }
88 | }
89 |
90 | func (r *limitedReader) Read(b []byte) (n int, err error) {
91 | n, err = r.reader.Read(b)
92 | r.read += int64(n)
93 | if r.read > r.limit {
94 | return n, echo.ErrStatusRequestEntityTooLarge
95 | }
96 | return
97 | }
98 |
99 | func (r *limitedReader) Close() error {
100 | return r.reader.Close()
101 | }
102 |
103 | func (r *limitedReader) Reset(reader io.ReadCloser) {
104 | r.reader = reader
105 | r.read = 0
106 | }
107 |
108 | func limitedReaderPool(c BodyLimitConfig) sync.Pool {
109 | return sync.Pool{
110 | New: func() interface{} {
111 | return &limitedReader{BodyLimitConfig: c}
112 | },
113 | }
114 | }
115 |
--------------------------------------------------------------------------------
/middleware/body_limit_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "bytes"
8 | "io"
9 | "net/http"
10 | "net/http/httptest"
11 | "testing"
12 |
13 | "github.com/labstack/echo/v4"
14 | "github.com/stretchr/testify/assert"
15 | )
16 |
17 | func TestBodyLimit(t *testing.T) {
18 | e := echo.New()
19 | hw := []byte("Hello, World!")
20 | req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
21 | rec := httptest.NewRecorder()
22 | c := e.NewContext(req, rec)
23 | h := func(c echo.Context) error {
24 | body, err := io.ReadAll(c.Request().Body)
25 | if err != nil {
26 | return err
27 | }
28 | return c.String(http.StatusOK, string(body))
29 | }
30 |
31 | // Based on content length (within limit)
32 | if assert.NoError(t, BodyLimit("2M")(h)(c)) {
33 | assert.Equal(t, http.StatusOK, rec.Code)
34 | assert.Equal(t, hw, rec.Body.Bytes())
35 | }
36 |
37 | // Based on content length (overlimit)
38 | he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
39 | assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
40 |
41 | // Based on content read (within limit)
42 | req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
43 | req.ContentLength = -1
44 | rec = httptest.NewRecorder()
45 | c = e.NewContext(req, rec)
46 | if assert.NoError(t, BodyLimit("2M")(h)(c)) {
47 | assert.Equal(t, http.StatusOK, rec.Code)
48 | assert.Equal(t, "Hello, World!", rec.Body.String())
49 | }
50 |
51 | // Based on content read (overlimit)
52 | req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
53 | req.ContentLength = -1
54 | rec = httptest.NewRecorder()
55 | c = e.NewContext(req, rec)
56 | he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
57 | assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
58 | }
59 |
60 | func TestBodyLimitReader(t *testing.T) {
61 | hw := []byte("Hello, World!")
62 |
63 | config := BodyLimitConfig{
64 | Skipper: DefaultSkipper,
65 | Limit: "2B",
66 | limit: 2,
67 | }
68 | reader := &limitedReader{
69 | BodyLimitConfig: config,
70 | reader: io.NopCloser(bytes.NewReader(hw)),
71 | }
72 |
73 | // read all should return ErrStatusRequestEntityTooLarge
74 | _, err := io.ReadAll(reader)
75 | he := err.(*echo.HTTPError)
76 | assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
77 |
78 | // reset reader and read two bytes must succeed
79 | bt := make([]byte, 2)
80 | reader.Reset(io.NopCloser(bytes.NewReader(hw)))
81 | n, err := reader.Read(bt)
82 | assert.Equal(t, 2, n)
83 | assert.Equal(t, nil, err)
84 | }
85 |
86 | func TestBodyLimitWithConfig_Skipper(t *testing.T) {
87 | e := echo.New()
88 | h := func(c echo.Context) error {
89 | body, err := io.ReadAll(c.Request().Body)
90 | if err != nil {
91 | return err
92 | }
93 | return c.String(http.StatusOK, string(body))
94 | }
95 | mw := BodyLimitWithConfig(BodyLimitConfig{
96 | Skipper: func(c echo.Context) bool {
97 | return true
98 | },
99 | Limit: "2B", // if not skipped this limit would make request to fail limit check
100 | })
101 |
102 | hw := []byte("Hello, World!")
103 | req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
104 | rec := httptest.NewRecorder()
105 | c := e.NewContext(req, rec)
106 |
107 | err := mw(h)(c)
108 | assert.NoError(t, err)
109 | assert.Equal(t, http.StatusOK, rec.Code)
110 | assert.Equal(t, hw, rec.Body.Bytes())
111 | }
112 |
113 | func TestBodyLimitWithConfig(t *testing.T) {
114 | var testCases = []struct {
115 | name string
116 | givenLimit string
117 | whenBody []byte
118 | expectBody []byte
119 | expectError string
120 | }{
121 | {
122 | name: "ok, body is less than limit",
123 | givenLimit: "10B",
124 | whenBody: []byte("123456789"),
125 | expectBody: []byte("123456789"),
126 | expectError: "",
127 | },
128 | {
129 | name: "nok, body is more than limit",
130 | givenLimit: "9B",
131 | whenBody: []byte("1234567890"),
132 | expectBody: []byte(nil),
133 | expectError: "code=413, message=Request Entity Too Large",
134 | },
135 | }
136 |
137 | for _, tc := range testCases {
138 | t.Run(tc.name, func(t *testing.T) {
139 | e := echo.New()
140 | h := func(c echo.Context) error {
141 | body, err := io.ReadAll(c.Request().Body)
142 | if err != nil {
143 | return err
144 | }
145 | return c.String(http.StatusOK, string(body))
146 | }
147 | mw := BodyLimitWithConfig(BodyLimitConfig{
148 | Limit: tc.givenLimit,
149 | })
150 |
151 | req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(tc.whenBody))
152 | rec := httptest.NewRecorder()
153 | c := e.NewContext(req, rec)
154 |
155 | err := mw(h)(c)
156 | if tc.expectError != "" {
157 | assert.EqualError(t, err, tc.expectError)
158 | } else {
159 | assert.NoError(t, err)
160 | }
161 | // not testing status as middlewares return error instead of committing it and OK cases are anyway 200
162 | assert.Equal(t, tc.expectBody, rec.Body.Bytes())
163 | })
164 | }
165 | }
166 |
167 | func TestBodyLimit_panicOnInvalidLimit(t *testing.T) {
168 | assert.PanicsWithError(
169 | t,
170 | "echo: invalid body-limit=",
171 | func() { BodyLimit("") },
172 | )
173 | }
174 |
--------------------------------------------------------------------------------
/middleware/compress.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "bufio"
8 | "bytes"
9 | "compress/gzip"
10 | "io"
11 | "net"
12 | "net/http"
13 | "strings"
14 | "sync"
15 |
16 | "github.com/labstack/echo/v4"
17 | )
18 |
19 | // GzipConfig defines the config for Gzip middleware.
20 | type GzipConfig struct {
21 | // Skipper defines a function to skip middleware.
22 | Skipper Skipper
23 |
24 | // Gzip compression level.
25 | // Optional. Default value -1.
26 | Level int `yaml:"level"`
27 |
28 | // Length threshold before gzip compression is applied.
29 | // Optional. Default value 0.
30 | //
31 | // Most of the time you will not need to change the default. Compressing
32 | // a short response might increase the transmitted data because of the
33 | // gzip format overhead. Compressing the response will also consume CPU
34 | // and time on the server and the client (for decompressing). Depending on
35 | // your use case such a threshold might be useful.
36 | //
37 | // See also:
38 | // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
39 | MinLength int
40 | }
41 |
42 | type gzipResponseWriter struct {
43 | io.Writer
44 | http.ResponseWriter
45 | wroteHeader bool
46 | wroteBody bool
47 | minLength int
48 | minLengthExceeded bool
49 | buffer *bytes.Buffer
50 | code int
51 | }
52 |
53 | const (
54 | gzipScheme = "gzip"
55 | )
56 |
57 | // DefaultGzipConfig is the default Gzip middleware config.
58 | var DefaultGzipConfig = GzipConfig{
59 | Skipper: DefaultSkipper,
60 | Level: -1,
61 | MinLength: 0,
62 | }
63 |
64 | // Gzip returns a middleware which compresses HTTP response using gzip compression
65 | // scheme.
66 | func Gzip() echo.MiddlewareFunc {
67 | return GzipWithConfig(DefaultGzipConfig)
68 | }
69 |
70 | // GzipWithConfig return Gzip middleware with config.
71 | // See: `Gzip()`.
72 | func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
73 | // Defaults
74 | if config.Skipper == nil {
75 | config.Skipper = DefaultGzipConfig.Skipper
76 | }
77 | if config.Level == 0 {
78 | config.Level = DefaultGzipConfig.Level
79 | }
80 | if config.MinLength < 0 {
81 | config.MinLength = DefaultGzipConfig.MinLength
82 | }
83 |
84 | pool := gzipCompressPool(config)
85 | bpool := bufferPool()
86 |
87 | return func(next echo.HandlerFunc) echo.HandlerFunc {
88 | return func(c echo.Context) error {
89 | if config.Skipper(c) {
90 | return next(c)
91 | }
92 |
93 | res := c.Response()
94 | res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
95 | if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
96 | i := pool.Get()
97 | w, ok := i.(*gzip.Writer)
98 | if !ok {
99 | return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
100 | }
101 | rw := res.Writer
102 | w.Reset(rw)
103 |
104 | buf := bpool.Get().(*bytes.Buffer)
105 | buf.Reset()
106 |
107 | grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf}
108 | defer func() {
109 | // There are different reasons for cases when we have not yet written response to the client and now need to do so.
110 | // a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
111 | // b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
112 | if !grw.wroteBody {
113 | if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
114 | res.Header().Del(echo.HeaderContentEncoding)
115 | }
116 | if grw.wroteHeader {
117 | rw.WriteHeader(grw.code)
118 | }
119 | // We have to reset response to it's pristine state when
120 | // nothing is written to body or error is returned.
121 | // See issue #424, #407.
122 | res.Writer = rw
123 | w.Reset(io.Discard)
124 | } else if !grw.minLengthExceeded {
125 | // Write uncompressed response
126 | res.Writer = rw
127 | if grw.wroteHeader {
128 | grw.ResponseWriter.WriteHeader(grw.code)
129 | }
130 | grw.buffer.WriteTo(rw)
131 | w.Reset(io.Discard)
132 | }
133 | w.Close()
134 | bpool.Put(buf)
135 | pool.Put(w)
136 | }()
137 | res.Writer = grw
138 | }
139 | return next(c)
140 | }
141 | }
142 | }
143 |
144 | func (w *gzipResponseWriter) WriteHeader(code int) {
145 | w.Header().Del(echo.HeaderContentLength) // Issue #444
146 |
147 | w.wroteHeader = true
148 |
149 | // Delay writing of the header until we know if we'll actually compress the response
150 | w.code = code
151 | }
152 |
153 | func (w *gzipResponseWriter) Write(b []byte) (int, error) {
154 | if w.Header().Get(echo.HeaderContentType) == "" {
155 | w.Header().Set(echo.HeaderContentType, http.DetectContentType(b))
156 | }
157 | w.wroteBody = true
158 |
159 | if !w.minLengthExceeded {
160 | n, err := w.buffer.Write(b)
161 |
162 | if w.buffer.Len() >= w.minLength {
163 | w.minLengthExceeded = true
164 |
165 | // The minimum length is exceeded, add Content-Encoding header and write the header
166 | w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
167 | if w.wroteHeader {
168 | w.ResponseWriter.WriteHeader(w.code)
169 | }
170 |
171 | return w.Writer.Write(w.buffer.Bytes())
172 | }
173 |
174 | return n, err
175 | }
176 |
177 | return w.Writer.Write(b)
178 | }
179 |
180 | func (w *gzipResponseWriter) Flush() {
181 | if !w.minLengthExceeded {
182 | // Enforce compression because we will not know how much more data will come
183 | w.minLengthExceeded = true
184 | w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
185 | if w.wroteHeader {
186 | w.ResponseWriter.WriteHeader(w.code)
187 | }
188 |
189 | w.Writer.Write(w.buffer.Bytes())
190 | }
191 |
192 | w.Writer.(*gzip.Writer).Flush()
193 | _ = http.NewResponseController(w.ResponseWriter).Flush()
194 | }
195 |
196 | func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
197 | return w.ResponseWriter
198 | }
199 |
200 | func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
201 | return http.NewResponseController(w.ResponseWriter).Hijack()
202 | }
203 |
204 | func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
205 | if p, ok := w.ResponseWriter.(http.Pusher); ok {
206 | return p.Push(target, opts)
207 | }
208 | return http.ErrNotSupported
209 | }
210 |
211 | func gzipCompressPool(config GzipConfig) sync.Pool {
212 | return sync.Pool{
213 | New: func() interface{} {
214 | w, err := gzip.NewWriterLevel(io.Discard, config.Level)
215 | if err != nil {
216 | return err
217 | }
218 | return w
219 | },
220 | }
221 | }
222 |
223 | func bufferPool() sync.Pool {
224 | return sync.Pool{
225 | New: func() interface{} {
226 | b := &bytes.Buffer{}
227 | return b
228 | },
229 | }
230 | }
231 |
--------------------------------------------------------------------------------
/middleware/context_timeout.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "context"
8 | "errors"
9 | "time"
10 |
11 | "github.com/labstack/echo/v4"
12 | )
13 |
14 | // ContextTimeoutConfig defines the config for ContextTimeout middleware.
15 | type ContextTimeoutConfig struct {
16 | // Skipper defines a function to skip middleware.
17 | Skipper Skipper
18 |
19 | // ErrorHandler is a function when error aries in middleware execution.
20 | ErrorHandler func(err error, c echo.Context) error
21 |
22 | // Timeout configures a timeout for the middleware, defaults to 0 for no timeout
23 | Timeout time.Duration
24 | }
25 |
26 | // ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client
27 | // when underlying method returns context.DeadlineExceeded error.
28 | func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc {
29 | return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout})
30 | }
31 |
32 | // ContextTimeoutWithConfig returns a Timeout middleware with config.
33 | func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc {
34 | mw, err := config.ToMiddleware()
35 | if err != nil {
36 | panic(err)
37 | }
38 | return mw
39 | }
40 |
41 | // ToMiddleware converts Config to middleware.
42 | func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
43 | if config.Timeout == 0 {
44 | return nil, errors.New("timeout must be set")
45 | }
46 | if config.Skipper == nil {
47 | config.Skipper = DefaultSkipper
48 | }
49 | if config.ErrorHandler == nil {
50 | config.ErrorHandler = func(err error, c echo.Context) error {
51 | if err != nil && errors.Is(err, context.DeadlineExceeded) {
52 | return echo.ErrServiceUnavailable.WithInternal(err)
53 | }
54 | return err
55 | }
56 | }
57 |
58 | return func(next echo.HandlerFunc) echo.HandlerFunc {
59 | return func(c echo.Context) error {
60 | if config.Skipper(c) {
61 | return next(c)
62 | }
63 |
64 | timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
65 | defer cancel()
66 |
67 | c.SetRequest(c.Request().WithContext(timeoutContext))
68 |
69 | if err := next(c); err != nil {
70 | return config.ErrorHandler(err, c)
71 | }
72 | return nil
73 | }
74 | }, nil
75 | }
76 |
--------------------------------------------------------------------------------
/middleware/context_timeout_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "context"
8 | "errors"
9 | "net/http"
10 | "net/http/httptest"
11 | "net/url"
12 | "strings"
13 | "testing"
14 | "time"
15 |
16 | "github.com/labstack/echo/v4"
17 | "github.com/stretchr/testify/assert"
18 | )
19 |
20 | func TestContextTimeoutSkipper(t *testing.T) {
21 | t.Parallel()
22 | m := ContextTimeoutWithConfig(ContextTimeoutConfig{
23 | Skipper: func(context echo.Context) bool {
24 | return true
25 | },
26 | Timeout: 10 * time.Millisecond,
27 | })
28 |
29 | req := httptest.NewRequest(http.MethodGet, "/", nil)
30 | rec := httptest.NewRecorder()
31 |
32 | e := echo.New()
33 | c := e.NewContext(req, rec)
34 |
35 | err := m(func(c echo.Context) error {
36 | if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
37 | return err
38 | }
39 |
40 | return errors.New("response from handler")
41 | })(c)
42 |
43 | // if not skipped we would have not returned error due context timeout logic
44 | assert.EqualError(t, err, "response from handler")
45 | }
46 |
47 | func TestContextTimeoutWithTimeout0(t *testing.T) {
48 | t.Parallel()
49 | assert.Panics(t, func() {
50 | ContextTimeout(time.Duration(0))
51 | })
52 | }
53 |
54 | func TestContextTimeoutErrorOutInHandler(t *testing.T) {
55 | t.Parallel()
56 | m := ContextTimeoutWithConfig(ContextTimeoutConfig{
57 | // Timeout has to be defined or the whole flow for timeout middleware will be skipped
58 | Timeout: 10 * time.Millisecond,
59 | })
60 |
61 | req := httptest.NewRequest(http.MethodGet, "/", nil)
62 | rec := httptest.NewRecorder()
63 |
64 | e := echo.New()
65 | c := e.NewContext(req, rec)
66 |
67 | rec.Code = 1 // we want to be sure that even 200 will not be sent
68 | err := m(func(c echo.Context) error {
69 | // this error must not be written to the client response. Middlewares upstream of timeout middleware must be able
70 | // to handle returned error and this can be done only then handler has not yet committed (written status code)
71 | // the response.
72 | return echo.NewHTTPError(http.StatusTeapot, "err")
73 | })(c)
74 |
75 | assert.Error(t, err)
76 | assert.EqualError(t, err, "code=418, message=err")
77 | assert.Equal(t, 1, rec.Code)
78 | assert.Equal(t, "", rec.Body.String())
79 | }
80 |
81 | func TestContextTimeoutSuccessfulRequest(t *testing.T) {
82 | t.Parallel()
83 | m := ContextTimeoutWithConfig(ContextTimeoutConfig{
84 | // Timeout has to be defined or the whole flow for timeout middleware will be skipped
85 | Timeout: 10 * time.Millisecond,
86 | })
87 |
88 | req := httptest.NewRequest(http.MethodGet, "/", nil)
89 | rec := httptest.NewRecorder()
90 |
91 | e := echo.New()
92 | c := e.NewContext(req, rec)
93 |
94 | err := m(func(c echo.Context) error {
95 | return c.JSON(http.StatusCreated, map[string]string{"data": "ok"})
96 | })(c)
97 |
98 | assert.NoError(t, err)
99 | assert.Equal(t, http.StatusCreated, rec.Code)
100 | assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String())
101 | }
102 |
103 | func TestContextTimeoutTestRequestClone(t *testing.T) {
104 | t.Parallel()
105 | req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
106 | req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
107 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
108 | rec := httptest.NewRecorder()
109 |
110 | m := ContextTimeoutWithConfig(ContextTimeoutConfig{
111 | // Timeout has to be defined or the whole flow for timeout middleware will be skipped
112 | Timeout: 1 * time.Second,
113 | })
114 |
115 | e := echo.New()
116 | c := e.NewContext(req, rec)
117 |
118 | err := m(func(c echo.Context) error {
119 | // Cookie test
120 | cookie, err := c.Request().Cookie("cookie")
121 | if assert.NoError(t, err) {
122 | assert.EqualValues(t, "cookie", cookie.Name)
123 | assert.EqualValues(t, "value", cookie.Value)
124 | }
125 |
126 | // Form values
127 | if assert.NoError(t, c.Request().ParseForm()) {
128 | assert.EqualValues(t, "value", c.Request().FormValue("form"))
129 | }
130 |
131 | // Query string
132 | assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
133 | return nil
134 | })(c)
135 |
136 | assert.NoError(t, err)
137 | }
138 |
139 | func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) {
140 | t.Parallel()
141 |
142 | timeout := 10 * time.Millisecond
143 | m := ContextTimeoutWithConfig(ContextTimeoutConfig{
144 | Timeout: timeout,
145 | })
146 |
147 | req := httptest.NewRequest(http.MethodGet, "/", nil)
148 | rec := httptest.NewRecorder()
149 |
150 | e := echo.New()
151 | c := e.NewContext(req, rec)
152 |
153 | err := m(func(c echo.Context) error {
154 | if err := sleepWithContext(c.Request().Context(), time.Duration(80*time.Millisecond)); err != nil {
155 | return err
156 | }
157 | return c.String(http.StatusOK, "Hello, World!")
158 | })(c)
159 |
160 | assert.IsType(t, &echo.HTTPError{}, err)
161 | assert.Error(t, err)
162 | assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
163 | assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message)
164 | }
165 |
166 | func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) {
167 | t.Parallel()
168 |
169 | timeoutErrorHandler := func(err error, c echo.Context) error {
170 | if err != nil {
171 | if errors.Is(err, context.DeadlineExceeded) {
172 | return &echo.HTTPError{
173 | Code: http.StatusServiceUnavailable,
174 | Message: "Timeout! change me",
175 | }
176 | }
177 | return err
178 | }
179 | return nil
180 | }
181 |
182 | timeout := 50 * time.Millisecond
183 | m := ContextTimeoutWithConfig(ContextTimeoutConfig{
184 | Timeout: timeout,
185 | ErrorHandler: timeoutErrorHandler,
186 | })
187 |
188 | req := httptest.NewRequest(http.MethodGet, "/", nil)
189 | rec := httptest.NewRecorder()
190 |
191 | e := echo.New()
192 | c := e.NewContext(req, rec)
193 |
194 | err := m(func(c echo.Context) error {
195 | // NOTE: Very short periods are not reliable for tests due to Go routine scheduling and the unpredictable order
196 | // for 1) request and 2) time goroutine. For most OS this works as expected, but MacOS seems most flaky.
197 |
198 | if err := sleepWithContext(c.Request().Context(), 100*time.Millisecond); err != nil {
199 | return err
200 | }
201 |
202 | // The Request Context should have a Deadline set by http.ContextTimeoutHandler
203 | if _, ok := c.Request().Context().Deadline(); !ok {
204 | assert.Fail(t, "No timeout set on Request Context")
205 | }
206 | return c.String(http.StatusOK, "Hello, World!")
207 | })(c)
208 |
209 | assert.IsType(t, &echo.HTTPError{}, err)
210 | assert.Error(t, err)
211 | assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
212 | assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message)
213 | }
214 |
215 | func sleepWithContext(ctx context.Context, d time.Duration) error {
216 | timer := time.NewTimer(d)
217 |
218 | defer func() {
219 | _ = timer.Stop()
220 | }()
221 |
222 | select {
223 | case <-ctx.Done():
224 | return context.DeadlineExceeded
225 | case <-timer.C:
226 | return nil
227 | }
228 | }
229 |
--------------------------------------------------------------------------------
/middleware/csrf.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "crypto/subtle"
8 | "net/http"
9 | "time"
10 |
11 | "github.com/labstack/echo/v4"
12 | )
13 |
14 | // CSRFConfig defines the config for CSRF middleware.
15 | type CSRFConfig struct {
16 | // Skipper defines a function to skip middleware.
17 | Skipper Skipper
18 |
19 | // TokenLength is the length of the generated token.
20 | TokenLength uint8 `yaml:"token_length"`
21 | // Optional. Default value 32.
22 |
23 | // TokenLookup is a string in the form of ":" or ":,:" that is used
24 | // to extract token from the request.
25 | // Optional. Default value "header:X-CSRF-Token".
26 | // Possible values:
27 | // - "header:" or "header::"
28 | // - "query:"
29 | // - "form:"
30 | // Multiple sources example:
31 | // - "header:X-CSRF-Token,query:csrf"
32 | TokenLookup string `yaml:"token_lookup"`
33 |
34 | // Context key to store generated CSRF token into context.
35 | // Optional. Default value "csrf".
36 | ContextKey string `yaml:"context_key"`
37 |
38 | // Name of the CSRF cookie. This cookie will store CSRF token.
39 | // Optional. Default value "csrf".
40 | CookieName string `yaml:"cookie_name"`
41 |
42 | // Domain of the CSRF cookie.
43 | // Optional. Default value none.
44 | CookieDomain string `yaml:"cookie_domain"`
45 |
46 | // Path of the CSRF cookie.
47 | // Optional. Default value none.
48 | CookiePath string `yaml:"cookie_path"`
49 |
50 | // Max age (in seconds) of the CSRF cookie.
51 | // Optional. Default value 86400 (24hr).
52 | CookieMaxAge int `yaml:"cookie_max_age"`
53 |
54 | // Indicates if CSRF cookie is secure.
55 | // Optional. Default value false.
56 | CookieSecure bool `yaml:"cookie_secure"`
57 |
58 | // Indicates if CSRF cookie is HTTP only.
59 | // Optional. Default value false.
60 | CookieHTTPOnly bool `yaml:"cookie_http_only"`
61 |
62 | // Indicates SameSite mode of the CSRF cookie.
63 | // Optional. Default value SameSiteDefaultMode.
64 | CookieSameSite http.SameSite `yaml:"cookie_same_site"`
65 |
66 | // ErrorHandler defines a function which is executed for returning custom errors.
67 | ErrorHandler CSRFErrorHandler
68 | }
69 |
70 | // CSRFErrorHandler is a function which is executed for creating custom errors.
71 | type CSRFErrorHandler func(err error, c echo.Context) error
72 |
73 | // ErrCSRFInvalid is returned when CSRF check fails
74 | var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
75 |
76 | // DefaultCSRFConfig is the default CSRF middleware config.
77 | var DefaultCSRFConfig = CSRFConfig{
78 | Skipper: DefaultSkipper,
79 | TokenLength: 32,
80 | TokenLookup: "header:" + echo.HeaderXCSRFToken,
81 | ContextKey: "csrf",
82 | CookieName: "_csrf",
83 | CookieMaxAge: 86400,
84 | CookieSameSite: http.SameSiteDefaultMode,
85 | }
86 |
87 | // CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
88 | // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
89 | func CSRF() echo.MiddlewareFunc {
90 | c := DefaultCSRFConfig
91 | return CSRFWithConfig(c)
92 | }
93 |
94 | // CSRFWithConfig returns a CSRF middleware with config.
95 | // See `CSRF()`.
96 | func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
97 | // Defaults
98 | if config.Skipper == nil {
99 | config.Skipper = DefaultCSRFConfig.Skipper
100 | }
101 | if config.TokenLength == 0 {
102 | config.TokenLength = DefaultCSRFConfig.TokenLength
103 | }
104 |
105 | if config.TokenLookup == "" {
106 | config.TokenLookup = DefaultCSRFConfig.TokenLookup
107 | }
108 | if config.ContextKey == "" {
109 | config.ContextKey = DefaultCSRFConfig.ContextKey
110 | }
111 | if config.CookieName == "" {
112 | config.CookieName = DefaultCSRFConfig.CookieName
113 | }
114 | if config.CookieMaxAge == 0 {
115 | config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
116 | }
117 | if config.CookieSameSite == http.SameSiteNoneMode {
118 | config.CookieSecure = true
119 | }
120 |
121 | extractors, cErr := CreateExtractors(config.TokenLookup)
122 | if cErr != nil {
123 | panic(cErr)
124 | }
125 |
126 | return func(next echo.HandlerFunc) echo.HandlerFunc {
127 | return func(c echo.Context) error {
128 | if config.Skipper(c) {
129 | return next(c)
130 | }
131 |
132 | token := ""
133 | if k, err := c.Cookie(config.CookieName); err != nil {
134 | token = randomString(config.TokenLength)
135 | } else {
136 | token = k.Value // Reuse token
137 | }
138 |
139 | switch c.Request().Method {
140 | case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
141 | default:
142 | // Validate token only for requests which are not defined as 'safe' by RFC7231
143 | var lastExtractorErr error
144 | var lastTokenErr error
145 | outer:
146 | for _, extractor := range extractors {
147 | clientTokens, err := extractor(c)
148 | if err != nil {
149 | lastExtractorErr = err
150 | continue
151 | }
152 |
153 | for _, clientToken := range clientTokens {
154 | if validateCSRFToken(token, clientToken) {
155 | lastTokenErr = nil
156 | lastExtractorErr = nil
157 | break outer
158 | }
159 | lastTokenErr = ErrCSRFInvalid
160 | }
161 | }
162 | var finalErr error
163 | if lastTokenErr != nil {
164 | finalErr = lastTokenErr
165 | } else if lastExtractorErr != nil {
166 | // ugly part to preserve backwards compatible errors. someone could rely on them
167 | if lastExtractorErr == errQueryExtractorValueMissing {
168 | lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string")
169 | } else if lastExtractorErr == errFormExtractorValueMissing {
170 | lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter")
171 | } else if lastExtractorErr == errHeaderExtractorValueMissing {
172 | lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header")
173 | } else {
174 | lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
175 | }
176 | finalErr = lastExtractorErr
177 | }
178 |
179 | if finalErr != nil {
180 | if config.ErrorHandler != nil {
181 | return config.ErrorHandler(finalErr, c)
182 | }
183 | return finalErr
184 | }
185 | }
186 |
187 | // Set CSRF cookie
188 | cookie := new(http.Cookie)
189 | cookie.Name = config.CookieName
190 | cookie.Value = token
191 | if config.CookiePath != "" {
192 | cookie.Path = config.CookiePath
193 | }
194 | if config.CookieDomain != "" {
195 | cookie.Domain = config.CookieDomain
196 | }
197 | if config.CookieSameSite != http.SameSiteDefaultMode {
198 | cookie.SameSite = config.CookieSameSite
199 | }
200 | cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
201 | cookie.Secure = config.CookieSecure
202 | cookie.HttpOnly = config.CookieHTTPOnly
203 | c.SetCookie(cookie)
204 |
205 | // Store token in the context
206 | c.Set(config.ContextKey, token)
207 |
208 | // Protect clients from caching the response
209 | c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
210 |
211 | return next(c)
212 | }
213 | }
214 | }
215 |
216 | func validateCSRFToken(token, clientToken string) bool {
217 | return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
218 | }
219 |
--------------------------------------------------------------------------------
/middleware/decompress.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "compress/gzip"
8 | "io"
9 | "net/http"
10 | "sync"
11 |
12 | "github.com/labstack/echo/v4"
13 | )
14 |
15 | // DecompressConfig defines the config for Decompress middleware.
16 | type DecompressConfig struct {
17 | // Skipper defines a function to skip middleware.
18 | Skipper Skipper
19 |
20 | // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
21 | GzipDecompressPool Decompressor
22 | }
23 |
24 | // GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
25 | const GZIPEncoding string = "gzip"
26 |
27 | // Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
28 | type Decompressor interface {
29 | gzipDecompressPool() sync.Pool
30 | }
31 |
32 | // DefaultDecompressConfig defines the config for decompress middleware
33 | var DefaultDecompressConfig = DecompressConfig{
34 | Skipper: DefaultSkipper,
35 | GzipDecompressPool: &DefaultGzipDecompressPool{},
36 | }
37 |
38 | // DefaultGzipDecompressPool is the default implementation of Decompressor interface
39 | type DefaultGzipDecompressPool struct {
40 | }
41 |
42 | func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
43 | return sync.Pool{New: func() interface{} { return new(gzip.Reader) }}
44 | }
45 |
46 | // Decompress decompresses request body based if content encoding type is set to "gzip" with default config
47 | func Decompress() echo.MiddlewareFunc {
48 | return DecompressWithConfig(DefaultDecompressConfig)
49 | }
50 |
51 | // DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
52 | func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
53 | // Defaults
54 | if config.Skipper == nil {
55 | config.Skipper = DefaultGzipConfig.Skipper
56 | }
57 | if config.GzipDecompressPool == nil {
58 | config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
59 | }
60 |
61 | return func(next echo.HandlerFunc) echo.HandlerFunc {
62 | pool := config.GzipDecompressPool.gzipDecompressPool()
63 |
64 | return func(c echo.Context) error {
65 | if config.Skipper(c) {
66 | return next(c)
67 | }
68 |
69 | if c.Request().Header.Get(echo.HeaderContentEncoding) != GZIPEncoding {
70 | return next(c)
71 | }
72 |
73 | i := pool.Get()
74 | gr, ok := i.(*gzip.Reader)
75 | if !ok || gr == nil {
76 | return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
77 | }
78 | defer pool.Put(gr)
79 |
80 | b := c.Request().Body
81 | defer b.Close()
82 |
83 | if err := gr.Reset(b); err != nil {
84 | if err == io.EOF { //ignore if body is empty
85 | return next(c)
86 | }
87 | return err
88 | }
89 |
90 | // only Close gzip reader if it was set to a proper gzip source otherwise it will panic on close.
91 | defer gr.Close()
92 |
93 | c.Request().Body = gr
94 |
95 | return next(c)
96 | }
97 | }
98 | }
99 |
--------------------------------------------------------------------------------
/middleware/decompress_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "bytes"
8 | "compress/gzip"
9 | "errors"
10 | "io"
11 | "net/http"
12 | "net/http/httptest"
13 | "strings"
14 | "sync"
15 | "testing"
16 |
17 | "github.com/labstack/echo/v4"
18 | "github.com/stretchr/testify/assert"
19 | )
20 |
21 | func TestDecompress(t *testing.T) {
22 | e := echo.New()
23 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
24 | rec := httptest.NewRecorder()
25 | c := e.NewContext(req, rec)
26 |
27 | // Skip if no Content-Encoding header
28 | h := Decompress()(func(c echo.Context) error {
29 | c.Response().Write([]byte("test")) // For Content-Type sniffing
30 | return nil
31 | })
32 | h(c)
33 |
34 | assert.Equal(t, "test", rec.Body.String())
35 |
36 | // Decompress
37 | body := `{"name": "echo"}`
38 | gz, _ := gzipString(body)
39 | req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
40 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
41 | rec = httptest.NewRecorder()
42 | c = e.NewContext(req, rec)
43 | h(c)
44 | assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
45 | b, err := io.ReadAll(req.Body)
46 | assert.NoError(t, err)
47 | assert.Equal(t, body, string(b))
48 | }
49 |
50 | func TestDecompressDefaultConfig(t *testing.T) {
51 | e := echo.New()
52 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
53 | rec := httptest.NewRecorder()
54 | c := e.NewContext(req, rec)
55 |
56 | h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error {
57 | c.Response().Write([]byte("test")) // For Content-Type sniffing
58 | return nil
59 | })
60 | h(c)
61 |
62 | assert.Equal(t, "test", rec.Body.String())
63 |
64 | // Decompress
65 | body := `{"name": "echo"}`
66 | gz, _ := gzipString(body)
67 | req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
68 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
69 | rec = httptest.NewRecorder()
70 | c = e.NewContext(req, rec)
71 | h(c)
72 | assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
73 | b, err := io.ReadAll(req.Body)
74 | assert.NoError(t, err)
75 | assert.Equal(t, body, string(b))
76 | }
77 |
78 | func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
79 | e := echo.New()
80 | body := `{"name":"echo"}`
81 | gz, _ := gzipString(body)
82 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
83 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
84 | rec := httptest.NewRecorder()
85 | e.NewContext(req, rec)
86 | e.ServeHTTP(rec, req)
87 | assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
88 | b, err := io.ReadAll(req.Body)
89 | assert.NoError(t, err)
90 | assert.NotEqual(t, b, body)
91 | assert.Equal(t, b, gz)
92 | }
93 |
94 | func TestDecompressNoContent(t *testing.T) {
95 | e := echo.New()
96 | req := httptest.NewRequest(http.MethodGet, "/", nil)
97 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
98 | rec := httptest.NewRecorder()
99 | c := e.NewContext(req, rec)
100 | h := Decompress()(func(c echo.Context) error {
101 | return c.NoContent(http.StatusNoContent)
102 | })
103 | if assert.NoError(t, h(c)) {
104 | assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
105 | assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
106 | assert.Equal(t, 0, len(rec.Body.Bytes()))
107 | }
108 | }
109 |
110 | func TestDecompressErrorReturned(t *testing.T) {
111 | e := echo.New()
112 | e.Use(Decompress())
113 | e.GET("/", func(c echo.Context) error {
114 | return echo.ErrNotFound
115 | })
116 | req := httptest.NewRequest(http.MethodGet, "/", nil)
117 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
118 | rec := httptest.NewRecorder()
119 | e.ServeHTTP(rec, req)
120 | assert.Equal(t, http.StatusNotFound, rec.Code)
121 | assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
122 | }
123 |
124 | func TestDecompressSkipper(t *testing.T) {
125 | e := echo.New()
126 | e.Use(DecompressWithConfig(DecompressConfig{
127 | Skipper: func(c echo.Context) bool {
128 | return c.Request().URL.Path == "/skip"
129 | },
130 | }))
131 | body := `{"name": "echo"}`
132 | req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body))
133 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
134 | rec := httptest.NewRecorder()
135 | c := e.NewContext(req, rec)
136 | e.ServeHTTP(rec, req)
137 | assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSON)
138 | reqBody, err := io.ReadAll(c.Request().Body)
139 | assert.NoError(t, err)
140 | assert.Equal(t, body, string(reqBody))
141 | }
142 |
143 | type TestDecompressPoolWithError struct {
144 | }
145 |
146 | func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool {
147 | return sync.Pool{
148 | New: func() interface{} {
149 | return errors.New("pool error")
150 | },
151 | }
152 | }
153 |
154 | func TestDecompressPoolError(t *testing.T) {
155 | e := echo.New()
156 | e.Use(DecompressWithConfig(DecompressConfig{
157 | Skipper: DefaultSkipper,
158 | GzipDecompressPool: &TestDecompressPoolWithError{},
159 | }))
160 | body := `{"name": "echo"}`
161 | req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body))
162 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
163 | rec := httptest.NewRecorder()
164 | c := e.NewContext(req, rec)
165 | e.ServeHTTP(rec, req)
166 | assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
167 | reqBody, err := io.ReadAll(c.Request().Body)
168 | assert.NoError(t, err)
169 | assert.Equal(t, body, string(reqBody))
170 | assert.Equal(t, rec.Code, http.StatusInternalServerError)
171 | }
172 |
173 | func BenchmarkDecompress(b *testing.B) {
174 | e := echo.New()
175 | body := `{"name": "echo"}`
176 | gz, _ := gzipString(body)
177 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
178 | req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
179 |
180 | h := Decompress()(func(c echo.Context) error {
181 | c.Response().Write([]byte(body)) // For Content-Type sniffing
182 | return nil
183 | })
184 |
185 | b.ReportAllocs()
186 | b.ResetTimer()
187 |
188 | for i := 0; i < b.N; i++ {
189 | // Decompress
190 | rec := httptest.NewRecorder()
191 | c := e.NewContext(req, rec)
192 | h(c)
193 | }
194 | }
195 |
196 | func gzipString(body string) ([]byte, error) {
197 | var buf bytes.Buffer
198 | gz := gzip.NewWriter(&buf)
199 |
200 | _, err := gz.Write([]byte(body))
201 | if err != nil {
202 | return nil, err
203 | }
204 |
205 | if err := gz.Close(); err != nil {
206 | return nil, err
207 | }
208 |
209 | return buf.Bytes(), nil
210 | }
211 |
--------------------------------------------------------------------------------
/middleware/extractor.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "errors"
8 | "fmt"
9 | "github.com/labstack/echo/v4"
10 | "net/textproto"
11 | "strings"
12 | )
13 |
14 | const (
15 | // extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion
16 | // attack vector
17 | extractorLimit = 20
18 | )
19 |
20 | var errHeaderExtractorValueMissing = errors.New("missing value in request header")
21 | var errHeaderExtractorValueInvalid = errors.New("invalid value in request header")
22 | var errQueryExtractorValueMissing = errors.New("missing value in the query string")
23 | var errParamExtractorValueMissing = errors.New("missing value in path params")
24 | var errCookieExtractorValueMissing = errors.New("missing value in cookies")
25 | var errFormExtractorValueMissing = errors.New("missing value in the form")
26 |
27 | // ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
28 | type ValuesExtractor func(c echo.Context) ([]string, error)
29 |
30 | // CreateExtractors creates ValuesExtractors from given lookups.
31 | // Lookups is a string in the form of ":" or ":,:" that is used
32 | // to extract key from the request.
33 | // Possible values:
34 | // - "header:" or "header::"
35 | // `` is argument value to cut/trim prefix of the extracted value. This is useful if header
36 | // value has static prefix like `Authorization: ` where part that we
37 | // want to cut is ` ` note the space at the end.
38 | // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `.
39 | // - "query:"
40 | // - "param:"
41 | // - "form:"
42 | // - "cookie:"
43 | //
44 | // Multiple sources example:
45 | // - "header:Authorization,header:X-Api-Key"
46 | func CreateExtractors(lookups string) ([]ValuesExtractor, error) {
47 | return createExtractors(lookups, "")
48 | }
49 |
50 | func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) {
51 | if lookups == "" {
52 | return nil, nil
53 | }
54 | sources := strings.Split(lookups, ",")
55 | var extractors = make([]ValuesExtractor, 0)
56 | for _, source := range sources {
57 | parts := strings.Split(source, ":")
58 | if len(parts) < 2 {
59 | return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source)
60 | }
61 |
62 | switch parts[0] {
63 | case "query":
64 | extractors = append(extractors, valuesFromQuery(parts[1]))
65 | case "param":
66 | extractors = append(extractors, valuesFromParam(parts[1]))
67 | case "cookie":
68 | extractors = append(extractors, valuesFromCookie(parts[1]))
69 | case "form":
70 | extractors = append(extractors, valuesFromForm(parts[1]))
71 | case "header":
72 | prefix := ""
73 | if len(parts) > 2 {
74 | prefix = parts[2]
75 | } else if authScheme != "" && parts[1] == echo.HeaderAuthorization {
76 | // backwards compatibility for JWT and KeyAuth:
77 | // * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer " etc
78 | // * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that
79 | // behaviour for default values and Authorization header.
80 | prefix = authScheme
81 | if !strings.HasSuffix(prefix, " ") {
82 | prefix += " "
83 | }
84 | }
85 | extractors = append(extractors, valuesFromHeader(parts[1], prefix))
86 | }
87 | }
88 | return extractors, nil
89 | }
90 |
91 | // valuesFromHeader returns a functions that extracts values from the request header.
92 | // valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static
93 | // prefix like `Authorization: ` where part that we want to remove is ` `
94 | // note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove
95 | // is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `.
96 | // If prefix is left empty the whole value is returned.
97 | func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
98 | prefixLen := len(valuePrefix)
99 | // standard library parses http.Request header keys in canonical form but we may provide something else so fix this
100 | header = textproto.CanonicalMIMEHeaderKey(header)
101 | return func(c echo.Context) ([]string, error) {
102 | values := c.Request().Header.Values(header)
103 | if len(values) == 0 {
104 | return nil, errHeaderExtractorValueMissing
105 | }
106 |
107 | result := make([]string, 0)
108 | for i, value := range values {
109 | if prefixLen == 0 {
110 | result = append(result, value)
111 | if i >= extractorLimit-1 {
112 | break
113 | }
114 | continue
115 | }
116 | if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
117 | result = append(result, value[prefixLen:])
118 | if i >= extractorLimit-1 {
119 | break
120 | }
121 | }
122 | }
123 |
124 | if len(result) == 0 {
125 | if prefixLen > 0 {
126 | return nil, errHeaderExtractorValueInvalid
127 | }
128 | return nil, errHeaderExtractorValueMissing
129 | }
130 | return result, nil
131 | }
132 | }
133 |
134 | // valuesFromQuery returns a function that extracts values from the query string.
135 | func valuesFromQuery(param string) ValuesExtractor {
136 | return func(c echo.Context) ([]string, error) {
137 | result := c.QueryParams()[param]
138 | if len(result) == 0 {
139 | return nil, errQueryExtractorValueMissing
140 | } else if len(result) > extractorLimit-1 {
141 | result = result[:extractorLimit]
142 | }
143 | return result, nil
144 | }
145 | }
146 |
147 | // valuesFromParam returns a function that extracts values from the url param string.
148 | func valuesFromParam(param string) ValuesExtractor {
149 | return func(c echo.Context) ([]string, error) {
150 | result := make([]string, 0)
151 | paramVales := c.ParamValues()
152 | for i, p := range c.ParamNames() {
153 | if param == p {
154 | result = append(result, paramVales[i])
155 | if i >= extractorLimit-1 {
156 | break
157 | }
158 | }
159 | }
160 | if len(result) == 0 {
161 | return nil, errParamExtractorValueMissing
162 | }
163 | return result, nil
164 | }
165 | }
166 |
167 | // valuesFromCookie returns a function that extracts values from the named cookie.
168 | func valuesFromCookie(name string) ValuesExtractor {
169 | return func(c echo.Context) ([]string, error) {
170 | cookies := c.Cookies()
171 | if len(cookies) == 0 {
172 | return nil, errCookieExtractorValueMissing
173 | }
174 |
175 | result := make([]string, 0)
176 | for i, cookie := range cookies {
177 | if name == cookie.Name {
178 | result = append(result, cookie.Value)
179 | if i >= extractorLimit-1 {
180 | break
181 | }
182 | }
183 | }
184 | if len(result) == 0 {
185 | return nil, errCookieExtractorValueMissing
186 | }
187 | return result, nil
188 | }
189 | }
190 |
191 | // valuesFromForm returns a function that extracts values from the form field.
192 | func valuesFromForm(name string) ValuesExtractor {
193 | return func(c echo.Context) ([]string, error) {
194 | if c.Request().Form == nil {
195 | _ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does
196 | }
197 | values := c.Request().Form[name]
198 | if len(values) == 0 {
199 | return nil, errFormExtractorValueMissing
200 | }
201 | if len(values) > extractorLimit-1 {
202 | values = values[:extractorLimit]
203 | }
204 | result := append([]string{}, values...)
205 | return result, nil
206 | }
207 | }
208 |
--------------------------------------------------------------------------------
/middleware/key_auth.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "errors"
8 | "github.com/labstack/echo/v4"
9 | "net/http"
10 | )
11 |
12 | // KeyAuthConfig defines the config for KeyAuth middleware.
13 | type KeyAuthConfig struct {
14 | // Skipper defines a function to skip middleware.
15 | Skipper Skipper
16 |
17 | // KeyLookup is a string in the form of ":" or ":,:" that is used
18 | // to extract key from the request.
19 | // Optional. Default value "header:Authorization".
20 | // Possible values:
21 | // - "header:" or "header::"
22 | // `` is argument value to cut/trim prefix of the extracted value. This is useful if header
23 | // value has static prefix like `Authorization: ` where part that we
24 | // want to cut is ` ` note the space at the end.
25 | // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `.
26 | // - "query:"
27 | // - "form:"
28 | // - "cookie:"
29 | // Multiple sources example:
30 | // - "header:Authorization,header:X-Api-Key"
31 | KeyLookup string
32 |
33 | // AuthScheme to be used in the Authorization header.
34 | // Optional. Default value "Bearer".
35 | AuthScheme string
36 |
37 | // Validator is a function to validate key.
38 | // Required.
39 | Validator KeyAuthValidator
40 |
41 | // ErrorHandler defines a function which is executed for an invalid key.
42 | // It may be used to define a custom error.
43 | ErrorHandler KeyAuthErrorHandler
44 |
45 | // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to
46 | // ignore the error (by returning `nil`).
47 | // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
48 | // In that case you can use ErrorHandler to set a default public key auth value in the request context
49 | // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then.
50 | ContinueOnIgnoredError bool
51 | }
52 |
53 | // KeyAuthValidator defines a function to validate KeyAuth credentials.
54 | type KeyAuthValidator func(auth string, c echo.Context) (bool, error)
55 |
56 | // KeyAuthErrorHandler defines a function which is executed for an invalid key.
57 | type KeyAuthErrorHandler func(err error, c echo.Context) error
58 |
59 | // ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups
60 | type ErrKeyAuthMissing struct {
61 | Err error
62 | }
63 |
64 | // DefaultKeyAuthConfig is the default KeyAuth middleware config.
65 | var DefaultKeyAuthConfig = KeyAuthConfig{
66 | Skipper: DefaultSkipper,
67 | KeyLookup: "header:" + echo.HeaderAuthorization,
68 | AuthScheme: "Bearer",
69 | }
70 |
71 | // Error returns errors text
72 | func (e *ErrKeyAuthMissing) Error() string {
73 | return e.Err.Error()
74 | }
75 |
76 | // Unwrap unwraps error
77 | func (e *ErrKeyAuthMissing) Unwrap() error {
78 | return e.Err
79 | }
80 |
81 | // KeyAuth returns an KeyAuth middleware.
82 | //
83 | // For valid key it calls the next handler.
84 | // For invalid key, it sends "401 - Unauthorized" response.
85 | // For missing key, it sends "400 - Bad Request" response.
86 | func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
87 | c := DefaultKeyAuthConfig
88 | c.Validator = fn
89 | return KeyAuthWithConfig(c)
90 | }
91 |
92 | // KeyAuthWithConfig returns an KeyAuth middleware with config.
93 | // See `KeyAuth()`.
94 | func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
95 | // Defaults
96 | if config.Skipper == nil {
97 | config.Skipper = DefaultKeyAuthConfig.Skipper
98 | }
99 | // Defaults
100 | if config.AuthScheme == "" {
101 | config.AuthScheme = DefaultKeyAuthConfig.AuthScheme
102 | }
103 | if config.KeyLookup == "" {
104 | config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
105 | }
106 | if config.Validator == nil {
107 | panic("echo: key-auth middleware requires a validator function")
108 | }
109 |
110 | extractors, cErr := createExtractors(config.KeyLookup, config.AuthScheme)
111 | if cErr != nil {
112 | panic(cErr)
113 | }
114 |
115 | return func(next echo.HandlerFunc) echo.HandlerFunc {
116 | return func(c echo.Context) error {
117 | if config.Skipper(c) {
118 | return next(c)
119 | }
120 |
121 | var lastExtractorErr error
122 | var lastValidatorErr error
123 | for _, extractor := range extractors {
124 | keys, err := extractor(c)
125 | if err != nil {
126 | lastExtractorErr = err
127 | continue
128 | }
129 | for _, key := range keys {
130 | valid, err := config.Validator(key, c)
131 | if err != nil {
132 | lastValidatorErr = err
133 | continue
134 | }
135 | if valid {
136 | return next(c)
137 | }
138 | lastValidatorErr = errors.New("invalid key")
139 | }
140 | }
141 |
142 | // we are here only when we did not successfully extract and validate any of keys
143 | err := lastValidatorErr
144 | if err == nil { // prioritize validator errors over extracting errors
145 | // ugly part to preserve backwards compatible errors. someone could rely on them
146 | if lastExtractorErr == errQueryExtractorValueMissing {
147 | err = errors.New("missing key in the query string")
148 | } else if lastExtractorErr == errCookieExtractorValueMissing {
149 | err = errors.New("missing key in cookies")
150 | } else if lastExtractorErr == errFormExtractorValueMissing {
151 | err = errors.New("missing key in the form")
152 | } else if lastExtractorErr == errHeaderExtractorValueMissing {
153 | err = errors.New("missing key in request header")
154 | } else if lastExtractorErr == errHeaderExtractorValueInvalid {
155 | err = errors.New("invalid key in the request header")
156 | } else {
157 | err = lastExtractorErr
158 | }
159 | err = &ErrKeyAuthMissing{Err: err}
160 | }
161 |
162 | if config.ErrorHandler != nil {
163 | tmpErr := config.ErrorHandler(err, c)
164 | if config.ContinueOnIgnoredError && tmpErr == nil {
165 | return next(c)
166 | }
167 | return tmpErr
168 | }
169 | if lastValidatorErr != nil { // prioritize validator errors over extracting errors
170 | return &echo.HTTPError{
171 | Code: http.StatusUnauthorized,
172 | Message: "Unauthorized",
173 | Internal: lastValidatorErr,
174 | }
175 | }
176 | return echo.NewHTTPError(http.StatusBadRequest, err.Error())
177 | }
178 | }
179 | }
180 |
--------------------------------------------------------------------------------
/middleware/logger.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "bytes"
8 | "encoding/json"
9 | "io"
10 | "strconv"
11 | "strings"
12 | "sync"
13 | "time"
14 |
15 | "github.com/labstack/echo/v4"
16 | "github.com/labstack/gommon/color"
17 | "github.com/valyala/fasttemplate"
18 | )
19 |
20 | // LoggerConfig defines the config for Logger middleware.
21 | type LoggerConfig struct {
22 | // Skipper defines a function to skip middleware.
23 | Skipper Skipper
24 |
25 | // Tags to construct the logger format.
26 | //
27 | // - time_unix
28 | // - time_unix_milli
29 | // - time_unix_micro
30 | // - time_unix_nano
31 | // - time_rfc3339
32 | // - time_rfc3339_nano
33 | // - time_custom
34 | // - id (Request ID)
35 | // - remote_ip
36 | // - uri
37 | // - host
38 | // - method
39 | // - path
40 | // - route
41 | // - protocol
42 | // - referer
43 | // - user_agent
44 | // - status
45 | // - error
46 | // - latency (In nanoseconds)
47 | // - latency_human (Human readable)
48 | // - bytes_in (Bytes received)
49 | // - bytes_out (Bytes sent)
50 | // - header:
51 | // - query:
52 | // - form:
53 | // - custom (see CustomTagFunc field)
54 | //
55 | // Example "${remote_ip} ${status}"
56 | //
57 | // Optional. Default value DefaultLoggerConfig.Format.
58 | Format string `yaml:"format"`
59 |
60 | // Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
61 | CustomTimeFormat string `yaml:"custom_time_format"`
62 |
63 | // CustomTagFunc is function called for `${custom}` tag to output user implemented text by writing it to buf.
64 | // Make sure that outputted text creates valid JSON string with other logged tags.
65 | // Optional.
66 | CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error)
67 |
68 | // Output is a writer where logs in JSON format are written.
69 | // Optional. Default value os.Stdout.
70 | Output io.Writer
71 |
72 | template *fasttemplate.Template
73 | colorer *color.Color
74 | pool *sync.Pool
75 | }
76 |
77 | // DefaultLoggerConfig is the default Logger middleware config.
78 | var DefaultLoggerConfig = LoggerConfig{
79 | Skipper: DefaultSkipper,
80 | Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` +
81 | `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` +
82 | `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` +
83 | `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n",
84 | CustomTimeFormat: "2006-01-02 15:04:05.00000",
85 | colorer: color.New(),
86 | }
87 |
88 | // Logger returns a middleware that logs HTTP requests.
89 | func Logger() echo.MiddlewareFunc {
90 | return LoggerWithConfig(DefaultLoggerConfig)
91 | }
92 |
93 | // LoggerWithConfig returns a Logger middleware with config.
94 | // See: `Logger()`.
95 | func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
96 | // Defaults
97 | if config.Skipper == nil {
98 | config.Skipper = DefaultLoggerConfig.Skipper
99 | }
100 | if config.Format == "" {
101 | config.Format = DefaultLoggerConfig.Format
102 | }
103 | if config.Output == nil {
104 | config.Output = DefaultLoggerConfig.Output
105 | }
106 |
107 | config.template = fasttemplate.New(config.Format, "${", "}")
108 | config.colorer = color.New()
109 | config.colorer.SetOutput(config.Output)
110 | config.pool = &sync.Pool{
111 | New: func() interface{} {
112 | return bytes.NewBuffer(make([]byte, 256))
113 | },
114 | }
115 |
116 | return func(next echo.HandlerFunc) echo.HandlerFunc {
117 | return func(c echo.Context) (err error) {
118 | if config.Skipper(c) {
119 | return next(c)
120 | }
121 |
122 | req := c.Request()
123 | res := c.Response()
124 | start := time.Now()
125 | if err = next(c); err != nil {
126 | c.Error(err)
127 | }
128 | stop := time.Now()
129 | buf := config.pool.Get().(*bytes.Buffer)
130 | buf.Reset()
131 | defer config.pool.Put(buf)
132 |
133 | if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) {
134 | switch tag {
135 | case "custom":
136 | if config.CustomTagFunc == nil {
137 | return 0, nil
138 | }
139 | return config.CustomTagFunc(c, buf)
140 | case "time_unix":
141 | return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10))
142 | case "time_unix_milli":
143 | // go 1.17 or later, it supports time#UnixMilli()
144 | return buf.WriteString(strconv.FormatInt(time.Now().UnixNano()/1000000, 10))
145 | case "time_unix_micro":
146 | // go 1.17 or later, it supports time#UnixMicro()
147 | return buf.WriteString(strconv.FormatInt(time.Now().UnixNano()/1000, 10))
148 | case "time_unix_nano":
149 | return buf.WriteString(strconv.FormatInt(time.Now().UnixNano(), 10))
150 | case "time_rfc3339":
151 | return buf.WriteString(time.Now().Format(time.RFC3339))
152 | case "time_rfc3339_nano":
153 | return buf.WriteString(time.Now().Format(time.RFC3339Nano))
154 | case "time_custom":
155 | return buf.WriteString(time.Now().Format(config.CustomTimeFormat))
156 | case "id":
157 | id := req.Header.Get(echo.HeaderXRequestID)
158 | if id == "" {
159 | id = res.Header().Get(echo.HeaderXRequestID)
160 | }
161 | return buf.WriteString(id)
162 | case "remote_ip":
163 | return buf.WriteString(c.RealIP())
164 | case "host":
165 | return buf.WriteString(req.Host)
166 | case "uri":
167 | return buf.WriteString(req.RequestURI)
168 | case "method":
169 | return buf.WriteString(req.Method)
170 | case "path":
171 | p := req.URL.Path
172 | if p == "" {
173 | p = "/"
174 | }
175 | return buf.WriteString(p)
176 | case "route":
177 | return buf.WriteString(c.Path())
178 | case "protocol":
179 | return buf.WriteString(req.Proto)
180 | case "referer":
181 | return buf.WriteString(req.Referer())
182 | case "user_agent":
183 | return buf.WriteString(req.UserAgent())
184 | case "status":
185 | n := res.Status
186 | s := config.colorer.Green(n)
187 | switch {
188 | case n >= 500:
189 | s = config.colorer.Red(n)
190 | case n >= 400:
191 | s = config.colorer.Yellow(n)
192 | case n >= 300:
193 | s = config.colorer.Cyan(n)
194 | }
195 | return buf.WriteString(s)
196 | case "error":
197 | if err != nil {
198 | // Error may contain invalid JSON e.g. `"`
199 | b, _ := json.Marshal(err.Error())
200 | b = b[1 : len(b)-1]
201 | return buf.Write(b)
202 | }
203 | case "latency":
204 | l := stop.Sub(start)
205 | return buf.WriteString(strconv.FormatInt(int64(l), 10))
206 | case "latency_human":
207 | return buf.WriteString(stop.Sub(start).String())
208 | case "bytes_in":
209 | cl := req.Header.Get(echo.HeaderContentLength)
210 | if cl == "" {
211 | cl = "0"
212 | }
213 | return buf.WriteString(cl)
214 | case "bytes_out":
215 | return buf.WriteString(strconv.FormatInt(res.Size, 10))
216 | default:
217 | switch {
218 | case strings.HasPrefix(tag, "header:"):
219 | return buf.Write([]byte(c.Request().Header.Get(tag[7:])))
220 | case strings.HasPrefix(tag, "query:"):
221 | return buf.Write([]byte(c.QueryParam(tag[6:])))
222 | case strings.HasPrefix(tag, "form:"):
223 | return buf.Write([]byte(c.FormValue(tag[5:])))
224 | case strings.HasPrefix(tag, "cookie:"):
225 | cookie, err := c.Cookie(tag[7:])
226 | if err == nil {
227 | return buf.Write([]byte(cookie.Value))
228 | }
229 | }
230 | }
231 | return 0, nil
232 | }); err != nil {
233 | return
234 | }
235 |
236 | if config.Output == nil {
237 | _, err = c.Logger().Output().Write(buf.Bytes())
238 | return
239 | }
240 | _, err = config.Output.Write(buf.Bytes())
241 | return
242 | }
243 | }
244 | }
245 |
--------------------------------------------------------------------------------
/middleware/method_override.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "net/http"
8 |
9 | "github.com/labstack/echo/v4"
10 | )
11 |
12 | // MethodOverrideConfig defines the config for MethodOverride middleware.
13 | type MethodOverrideConfig struct {
14 | // Skipper defines a function to skip middleware.
15 | Skipper Skipper
16 |
17 | // Getter is a function that gets overridden method from the request.
18 | // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride).
19 | Getter MethodOverrideGetter
20 | }
21 |
22 | // MethodOverrideGetter is a function that gets overridden method from the request
23 | type MethodOverrideGetter func(echo.Context) string
24 |
25 | // DefaultMethodOverrideConfig is the default MethodOverride middleware config.
26 | var DefaultMethodOverrideConfig = MethodOverrideConfig{
27 | Skipper: DefaultSkipper,
28 | Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
29 | }
30 |
31 | // MethodOverride returns a MethodOverride middleware.
32 | // MethodOverride middleware checks for the overridden method from the request and
33 | // uses it instead of the original method.
34 | //
35 | // For security reasons, only `POST` method can be overridden.
36 | func MethodOverride() echo.MiddlewareFunc {
37 | return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
38 | }
39 |
40 | // MethodOverrideWithConfig returns a MethodOverride middleware with config.
41 | // See: `MethodOverride()`.
42 | func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
43 | // Defaults
44 | if config.Skipper == nil {
45 | config.Skipper = DefaultMethodOverrideConfig.Skipper
46 | }
47 | if config.Getter == nil {
48 | config.Getter = DefaultMethodOverrideConfig.Getter
49 | }
50 |
51 | return func(next echo.HandlerFunc) echo.HandlerFunc {
52 | return func(c echo.Context) error {
53 | if config.Skipper(c) {
54 | return next(c)
55 | }
56 |
57 | req := c.Request()
58 | if req.Method == http.MethodPost {
59 | m := config.Getter(c)
60 | if m != "" {
61 | req.Method = m
62 | }
63 | }
64 | return next(c)
65 | }
66 | }
67 | }
68 |
69 | // MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from
70 | // the request header.
71 | func MethodFromHeader(header string) MethodOverrideGetter {
72 | return func(c echo.Context) string {
73 | return c.Request().Header.Get(header)
74 | }
75 | }
76 |
77 | // MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the
78 | // form parameter.
79 | func MethodFromForm(param string) MethodOverrideGetter {
80 | return func(c echo.Context) string {
81 | return c.FormValue(param)
82 | }
83 | }
84 |
85 | // MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from
86 | // the query parameter.
87 | func MethodFromQuery(param string) MethodOverrideGetter {
88 | return func(c echo.Context) string {
89 | return c.QueryParam(param)
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/middleware/method_override_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "bytes"
8 | "net/http"
9 | "net/http/httptest"
10 | "testing"
11 |
12 | "github.com/labstack/echo/v4"
13 | "github.com/stretchr/testify/assert"
14 | )
15 |
16 | func TestMethodOverride(t *testing.T) {
17 | e := echo.New()
18 | m := MethodOverride()
19 | h := func(c echo.Context) error {
20 | return c.String(http.StatusOK, "test")
21 | }
22 |
23 | // Override with http header
24 | req := httptest.NewRequest(http.MethodPost, "/", nil)
25 | rec := httptest.NewRecorder()
26 | req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
27 | c := e.NewContext(req, rec)
28 | m(h)(c)
29 | assert.Equal(t, http.MethodDelete, req.Method)
30 |
31 | // Override with form parameter
32 | m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")})
33 | req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
34 | rec = httptest.NewRecorder()
35 | req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
36 | c = e.NewContext(req, rec)
37 | m(h)(c)
38 | assert.Equal(t, http.MethodDelete, req.Method)
39 |
40 | // Override with query parameter
41 | m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")})
42 | req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
43 | rec = httptest.NewRecorder()
44 | c = e.NewContext(req, rec)
45 | m(h)(c)
46 | assert.Equal(t, http.MethodDelete, req.Method)
47 |
48 | // Ignore `GET`
49 | req = httptest.NewRequest(http.MethodGet, "/", nil)
50 | req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
51 | assert.Equal(t, http.MethodGet, req.Method)
52 | }
53 |
--------------------------------------------------------------------------------
/middleware/middleware.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "net/http"
8 | "regexp"
9 | "strconv"
10 | "strings"
11 |
12 | "github.com/labstack/echo/v4"
13 | )
14 |
15 | // Skipper defines a function to skip middleware. Returning true skips processing
16 | // the middleware.
17 | type Skipper func(c echo.Context) bool
18 |
19 | // BeforeFunc defines a function which is executed just before the middleware.
20 | type BeforeFunc func(c echo.Context)
21 |
22 | func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
23 | groups := pattern.FindAllStringSubmatch(input, -1)
24 | if groups == nil {
25 | return nil
26 | }
27 | values := groups[0][1:]
28 | replace := make([]string, 2*len(values))
29 | for i, v := range values {
30 | j := 2 * i
31 | replace[j] = "$" + strconv.Itoa(i+1)
32 | replace[j+1] = v
33 | }
34 | return strings.NewReplacer(replace...)
35 | }
36 |
37 | func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
38 | // Initialize
39 | rulesRegex := map[*regexp.Regexp]string{}
40 | for k, v := range rewrite {
41 | k = regexp.QuoteMeta(k)
42 | k = strings.ReplaceAll(k, `\*`, "(.*?)")
43 | if strings.HasPrefix(k, `\^`) {
44 | k = strings.ReplaceAll(k, `\^`, "^")
45 | }
46 | k = k + "$"
47 | rulesRegex[regexp.MustCompile(k)] = v
48 | }
49 | return rulesRegex
50 | }
51 |
52 | func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error {
53 | if len(rewriteRegex) == 0 {
54 | return nil
55 | }
56 |
57 | // Depending on how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path.
58 | // We only want to use path part for rewriting and therefore trim prefix if it exists
59 | rawURI := req.RequestURI
60 | if rawURI != "" && rawURI[0] != '/' {
61 | prefix := ""
62 | if req.URL.Scheme != "" {
63 | prefix = req.URL.Scheme + "://"
64 | }
65 | if req.URL.Host != "" {
66 | prefix += req.URL.Host // host or host:port
67 | }
68 | if prefix != "" {
69 | rawURI = strings.TrimPrefix(rawURI, prefix)
70 | }
71 | }
72 |
73 | for k, v := range rewriteRegex {
74 | if replacer := captureTokens(k, rawURI); replacer != nil {
75 | url, err := req.URL.Parse(replacer.Replace(v))
76 | if err != nil {
77 | return err
78 | }
79 | req.URL = url
80 |
81 | return nil // rewrite only once
82 | }
83 | }
84 | return nil
85 | }
86 |
87 | // DefaultSkipper returns false which processes the middleware.
88 | func DefaultSkipper(echo.Context) bool {
89 | return false
90 | }
91 |
--------------------------------------------------------------------------------
/middleware/middleware_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "bufio"
8 | "errors"
9 | "github.com/stretchr/testify/assert"
10 | "net"
11 | "net/http"
12 | "net/http/httptest"
13 | "regexp"
14 | "testing"
15 | )
16 |
17 | func TestRewriteURL(t *testing.T) {
18 | var testCases = []struct {
19 | whenURL string
20 | expectPath string
21 | expectRawPath string
22 | expectQuery string
23 | expectErr string
24 | }{
25 | {
26 | whenURL: "http://localhost:8080/old",
27 | expectPath: "/new",
28 | expectRawPath: "",
29 | },
30 | { // encoded `ol%64` (decoded `old`) should not be rewritten to `/new`
31 | whenURL: "/ol%64", // `%64` is decoded `d`
32 | expectPath: "/old",
33 | expectRawPath: "/ol%64",
34 | },
35 | {
36 | whenURL: "http://localhost:8080/users/+_+/orders/___++++?test=1",
37 | expectPath: "/user/+_+/order/___++++",
38 | expectRawPath: "",
39 | expectQuery: "test=1",
40 | },
41 | {
42 | whenURL: "http://localhost:8080/users/%20a/orders/%20aa",
43 | expectPath: "/user/ a/order/ aa",
44 | expectRawPath: "",
45 | },
46 | {
47 | whenURL: "http://localhost:8080/%47%6f%2f?test=1",
48 | expectPath: "/Go/",
49 | expectRawPath: "/%47%6f%2f",
50 | expectQuery: "test=1",
51 | },
52 | {
53 | whenURL: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F",
54 | expectPath: "/user/jill/order/T/cO4lW/t/Vp/",
55 | expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
56 | },
57 | { // do nothing, replace nothing
58 | whenURL: "http://localhost:8080/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
59 | expectPath: "/user/jill/order/T/cO4lW/t/Vp/",
60 | expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
61 | },
62 | {
63 | whenURL: "http://localhost:8080/static",
64 | expectPath: "/static/path",
65 | expectRawPath: "",
66 | expectQuery: "role=AUTHOR&limit=1000",
67 | },
68 | {
69 | whenURL: "/static",
70 | expectPath: "/static/path",
71 | expectRawPath: "",
72 | expectQuery: "role=AUTHOR&limit=1000",
73 | },
74 | }
75 |
76 | rules := map[*regexp.Regexp]string{
77 | regexp.MustCompile("^/old$"): "/new",
78 | regexp.MustCompile("^/users/(.*?)/orders/(.*?)$"): "/user/$1/order/$2",
79 | regexp.MustCompile("^/static$"): "/static/path?role=AUTHOR&limit=1000",
80 | }
81 |
82 | for _, tc := range testCases {
83 | t.Run(tc.whenURL, func(t *testing.T) {
84 | req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
85 |
86 | err := rewriteURL(rules, req)
87 |
88 | if tc.expectErr != "" {
89 | assert.EqualError(t, err, tc.expectErr)
90 | } else {
91 | assert.NoError(t, err)
92 | }
93 | assert.Equal(t, tc.expectPath, req.URL.Path) // Path field is stored in decoded form: /%47%6f%2f becomes /Go/.
94 | assert.Equal(t, tc.expectRawPath, req.URL.RawPath) // RawPath, an optional field which only gets set if the default encoding is different from Path.
95 | assert.Equal(t, tc.expectQuery, req.URL.RawQuery)
96 | })
97 | }
98 | }
99 |
100 | type testResponseWriterNoFlushHijack struct {
101 | }
102 |
103 | func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) {
104 | }
105 |
106 | func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) {
107 | return 0, nil
108 | }
109 |
110 | func (w *testResponseWriterNoFlushHijack) Header() http.Header {
111 | return nil
112 | }
113 |
114 | type testResponseWriterUnwrapper struct {
115 | unwrapCalled int
116 | rw http.ResponseWriter
117 | }
118 |
119 | func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) {
120 | }
121 |
122 | func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) {
123 | return 0, nil
124 | }
125 |
126 | func (w *testResponseWriterUnwrapper) Header() http.Header {
127 | return nil
128 | }
129 |
130 | func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter {
131 | w.unwrapCalled++
132 | return w.rw
133 | }
134 |
135 | type testResponseWriterUnwrapperHijack struct {
136 | testResponseWriterUnwrapper
137 | }
138 |
139 | func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) {
140 | return nil, nil, errors.New("can hijack")
141 | }
142 |
--------------------------------------------------------------------------------
/middleware/recover.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "fmt"
8 | "net/http"
9 | "runtime"
10 |
11 | "github.com/labstack/echo/v4"
12 | "github.com/labstack/gommon/log"
13 | )
14 |
15 | // LogErrorFunc defines a function for custom logging in the middleware.
16 | type LogErrorFunc func(c echo.Context, err error, stack []byte) error
17 |
18 | // RecoverConfig defines the config for Recover middleware.
19 | type RecoverConfig struct {
20 | // Skipper defines a function to skip middleware.
21 | Skipper Skipper
22 |
23 | // Size of the stack to be printed.
24 | // Optional. Default value 4KB.
25 | StackSize int `yaml:"stack_size"`
26 |
27 | // DisableStackAll disables formatting stack traces of all other goroutines
28 | // into buffer after the trace for the current goroutine.
29 | // Optional. Default value false.
30 | DisableStackAll bool `yaml:"disable_stack_all"`
31 |
32 | // DisablePrintStack disables printing stack trace.
33 | // Optional. Default value as false.
34 | DisablePrintStack bool `yaml:"disable_print_stack"`
35 |
36 | // LogLevel is log level to printing stack trace.
37 | // Optional. Default value 0 (Print).
38 | LogLevel log.Lvl
39 |
40 | // LogErrorFunc defines a function for custom logging in the middleware.
41 | // If it's set you don't need to provide LogLevel for config.
42 | // If this function returns nil, the centralized HTTPErrorHandler will not be called.
43 | LogErrorFunc LogErrorFunc
44 |
45 | // DisableErrorHandler disables the call to centralized HTTPErrorHandler.
46 | // The recovered error is then passed back to upstream middleware, instead of swallowing the error.
47 | // Optional. Default value false.
48 | DisableErrorHandler bool `yaml:"disable_error_handler"`
49 | }
50 |
51 | // DefaultRecoverConfig is the default Recover middleware config.
52 | var DefaultRecoverConfig = RecoverConfig{
53 | Skipper: DefaultSkipper,
54 | StackSize: 4 << 10, // 4 KB
55 | DisableStackAll: false,
56 | DisablePrintStack: false,
57 | LogLevel: 0,
58 | LogErrorFunc: nil,
59 | DisableErrorHandler: false,
60 | }
61 |
62 | // Recover returns a middleware which recovers from panics anywhere in the chain
63 | // and handles the control to the centralized HTTPErrorHandler.
64 | func Recover() echo.MiddlewareFunc {
65 | return RecoverWithConfig(DefaultRecoverConfig)
66 | }
67 |
68 | // RecoverWithConfig returns a Recover middleware with config.
69 | // See: `Recover()`.
70 | func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
71 | // Defaults
72 | if config.Skipper == nil {
73 | config.Skipper = DefaultRecoverConfig.Skipper
74 | }
75 | if config.StackSize == 0 {
76 | config.StackSize = DefaultRecoverConfig.StackSize
77 | }
78 |
79 | return func(next echo.HandlerFunc) echo.HandlerFunc {
80 | return func(c echo.Context) (returnErr error) {
81 | if config.Skipper(c) {
82 | return next(c)
83 | }
84 |
85 | defer func() {
86 | if r := recover(); r != nil {
87 | if r == http.ErrAbortHandler {
88 | panic(r)
89 | }
90 | err, ok := r.(error)
91 | if !ok {
92 | err = fmt.Errorf("%v", r)
93 | }
94 | var stack []byte
95 | var length int
96 |
97 | if !config.DisablePrintStack {
98 | stack = make([]byte, config.StackSize)
99 | length = runtime.Stack(stack, !config.DisableStackAll)
100 | stack = stack[:length]
101 | }
102 |
103 | if config.LogErrorFunc != nil {
104 | err = config.LogErrorFunc(c, err, stack)
105 | } else if !config.DisablePrintStack {
106 | msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length])
107 | switch config.LogLevel {
108 | case log.DEBUG:
109 | c.Logger().Debug(msg)
110 | case log.INFO:
111 | c.Logger().Info(msg)
112 | case log.WARN:
113 | c.Logger().Warn(msg)
114 | case log.ERROR:
115 | c.Logger().Error(msg)
116 | case log.OFF:
117 | // None.
118 | default:
119 | c.Logger().Print(msg)
120 | }
121 | }
122 |
123 | if err != nil && !config.DisableErrorHandler {
124 | c.Error(err)
125 | } else {
126 | returnErr = err
127 | }
128 | }
129 | }()
130 | return next(c)
131 | }
132 | }
133 | }
134 |
--------------------------------------------------------------------------------
/middleware/recover_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "bytes"
8 | "errors"
9 | "fmt"
10 | "net/http"
11 | "net/http/httptest"
12 | "testing"
13 |
14 | "github.com/labstack/echo/v4"
15 | "github.com/labstack/gommon/log"
16 | "github.com/stretchr/testify/assert"
17 | )
18 |
19 | func TestRecover(t *testing.T) {
20 | e := echo.New()
21 | buf := new(bytes.Buffer)
22 | e.Logger.SetOutput(buf)
23 | req := httptest.NewRequest(http.MethodGet, "/", nil)
24 | rec := httptest.NewRecorder()
25 | c := e.NewContext(req, rec)
26 | h := Recover()(echo.HandlerFunc(func(c echo.Context) error {
27 | panic("test")
28 | }))
29 | err := h(c)
30 | assert.NoError(t, err)
31 | assert.Equal(t, http.StatusInternalServerError, rec.Code)
32 | assert.Contains(t, buf.String(), "PANIC RECOVER")
33 | }
34 |
35 | func TestRecoverErrAbortHandler(t *testing.T) {
36 | e := echo.New()
37 | buf := new(bytes.Buffer)
38 | e.Logger.SetOutput(buf)
39 | req := httptest.NewRequest(http.MethodGet, "/", nil)
40 | rec := httptest.NewRecorder()
41 | c := e.NewContext(req, rec)
42 | h := Recover()(echo.HandlerFunc(func(c echo.Context) error {
43 | panic(http.ErrAbortHandler)
44 | }))
45 | defer func() {
46 | r := recover()
47 | if r == nil {
48 | assert.Fail(t, "expecting `http.ErrAbortHandler`, got `nil`")
49 | } else {
50 | if err, ok := r.(error); ok {
51 | assert.ErrorIs(t, err, http.ErrAbortHandler)
52 | } else {
53 | assert.Fail(t, "not of error type")
54 | }
55 | }
56 | }()
57 |
58 | h(c)
59 |
60 | assert.Equal(t, http.StatusInternalServerError, rec.Code)
61 | assert.NotContains(t, buf.String(), "PANIC RECOVER")
62 | }
63 |
64 | func TestRecoverWithConfig_LogLevel(t *testing.T) {
65 | tests := []struct {
66 | logLevel log.Lvl
67 | levelName string
68 | }{{
69 | logLevel: log.DEBUG,
70 | levelName: "DEBUG",
71 | }, {
72 | logLevel: log.INFO,
73 | levelName: "INFO",
74 | }, {
75 | logLevel: log.WARN,
76 | levelName: "WARN",
77 | }, {
78 | logLevel: log.ERROR,
79 | levelName: "ERROR",
80 | }, {
81 | logLevel: log.OFF,
82 | levelName: "OFF",
83 | }}
84 |
85 | for _, tt := range tests {
86 | tt := tt
87 | t.Run(tt.levelName, func(t *testing.T) {
88 | e := echo.New()
89 | e.Logger.SetLevel(log.DEBUG)
90 |
91 | buf := new(bytes.Buffer)
92 | e.Logger.SetOutput(buf)
93 |
94 | req := httptest.NewRequest(http.MethodGet, "/", nil)
95 | rec := httptest.NewRecorder()
96 | c := e.NewContext(req, rec)
97 |
98 | config := DefaultRecoverConfig
99 | config.LogLevel = tt.logLevel
100 | h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
101 | panic("test")
102 | }))
103 |
104 | h(c)
105 |
106 | assert.Equal(t, http.StatusInternalServerError, rec.Code)
107 |
108 | output := buf.String()
109 | if tt.logLevel == log.OFF {
110 | assert.Empty(t, output)
111 | } else {
112 | assert.Contains(t, output, "PANIC RECOVER")
113 | assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName))
114 | }
115 | })
116 | }
117 | }
118 |
119 | func TestRecoverWithConfig_LogErrorFunc(t *testing.T) {
120 | e := echo.New()
121 | e.Logger.SetLevel(log.DEBUG)
122 |
123 | buf := new(bytes.Buffer)
124 | e.Logger.SetOutput(buf)
125 |
126 | req := httptest.NewRequest(http.MethodGet, "/", nil)
127 | rec := httptest.NewRecorder()
128 | c := e.NewContext(req, rec)
129 |
130 | testError := errors.New("test")
131 | config := DefaultRecoverConfig
132 | config.LogErrorFunc = func(c echo.Context, err error, stack []byte) error {
133 | msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack)
134 | if errors.Is(err, testError) {
135 | c.Logger().Debug(msg)
136 | } else {
137 | c.Logger().Error(msg)
138 | }
139 | return err
140 | }
141 |
142 | t.Run("first branch case for LogErrorFunc", func(t *testing.T) {
143 | buf.Reset()
144 | h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
145 | panic(testError)
146 | }))
147 |
148 | h(c)
149 | assert.Equal(t, http.StatusInternalServerError, rec.Code)
150 |
151 | output := buf.String()
152 | assert.Contains(t, output, "PANIC RECOVER")
153 | assert.Contains(t, output, `"level":"DEBUG"`)
154 | })
155 |
156 | t.Run("else branch case for LogErrorFunc", func(t *testing.T) {
157 | buf.Reset()
158 | h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
159 | panic("other")
160 | }))
161 |
162 | h(c)
163 | assert.Equal(t, http.StatusInternalServerError, rec.Code)
164 |
165 | output := buf.String()
166 | assert.Contains(t, output, "PANIC RECOVER")
167 | assert.Contains(t, output, `"level":"ERROR"`)
168 | })
169 | }
170 |
171 | func TestRecoverWithDisabled_ErrorHandler(t *testing.T) {
172 | e := echo.New()
173 | buf := new(bytes.Buffer)
174 | e.Logger.SetOutput(buf)
175 | req := httptest.NewRequest(http.MethodGet, "/", nil)
176 | rec := httptest.NewRecorder()
177 | c := e.NewContext(req, rec)
178 |
179 | config := DefaultRecoverConfig
180 | config.DisableErrorHandler = true
181 | h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
182 | panic("test")
183 | }))
184 | err := h(c)
185 |
186 | assert.Equal(t, http.StatusOK, rec.Code)
187 | assert.Contains(t, buf.String(), "PANIC RECOVER")
188 | assert.EqualError(t, err, "test")
189 | }
190 |
--------------------------------------------------------------------------------
/middleware/redirect.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "net/http"
8 | "strings"
9 |
10 | "github.com/labstack/echo/v4"
11 | )
12 |
13 | // RedirectConfig defines the config for Redirect middleware.
14 | type RedirectConfig struct {
15 | // Skipper defines a function to skip middleware.
16 | Skipper
17 |
18 | // Status code to be used when redirecting the request.
19 | // Optional. Default value http.StatusMovedPermanently.
20 | Code int `yaml:"code"`
21 | }
22 |
23 | // redirectLogic represents a function that given a scheme, host and uri
24 | // can both: 1) determine if redirect is needed (will set ok accordingly) and
25 | // 2) return the appropriate redirect url.
26 | type redirectLogic func(scheme, host, uri string) (ok bool, url string)
27 |
28 | const www = "www."
29 |
30 | // DefaultRedirectConfig is the default Redirect middleware config.
31 | var DefaultRedirectConfig = RedirectConfig{
32 | Skipper: DefaultSkipper,
33 | Code: http.StatusMovedPermanently,
34 | }
35 |
36 | // HTTPSRedirect redirects http requests to https.
37 | // For example, http://labstack.com will be redirect to https://labstack.com.
38 | //
39 | // Usage `Echo#Pre(HTTPSRedirect())`
40 | func HTTPSRedirect() echo.MiddlewareFunc {
41 | return HTTPSRedirectWithConfig(DefaultRedirectConfig)
42 | }
43 |
44 | // HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config.
45 | // See `HTTPSRedirect()`.
46 | func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
47 | return redirect(config, func(scheme, host, uri string) (bool, string) {
48 | if scheme != "https" {
49 | return true, "https://" + host + uri
50 | }
51 | return false, ""
52 | })
53 | }
54 |
55 | // HTTPSWWWRedirect redirects http requests to https www.
56 | // For example, http://labstack.com will be redirect to https://www.labstack.com.
57 | //
58 | // Usage `Echo#Pre(HTTPSWWWRedirect())`
59 | func HTTPSWWWRedirect() echo.MiddlewareFunc {
60 | return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig)
61 | }
62 |
63 | // HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
64 | // See `HTTPSWWWRedirect()`.
65 | func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
66 | return redirect(config, func(scheme, host, uri string) (bool, string) {
67 | if scheme != "https" && !strings.HasPrefix(host, www) {
68 | return true, "https://www." + host + uri
69 | }
70 | return false, ""
71 | })
72 | }
73 |
74 | // HTTPSNonWWWRedirect redirects http requests to https non www.
75 | // For example, http://www.labstack.com will be redirect to https://labstack.com.
76 | //
77 | // Usage `Echo#Pre(HTTPSNonWWWRedirect())`
78 | func HTTPSNonWWWRedirect() echo.MiddlewareFunc {
79 | return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig)
80 | }
81 |
82 | // HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
83 | // See `HTTPSNonWWWRedirect()`.
84 | func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
85 | return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
86 | if scheme != "https" {
87 | host = strings.TrimPrefix(host, www)
88 | return true, "https://" + host + uri
89 | }
90 | return false, ""
91 | })
92 | }
93 |
94 | // WWWRedirect redirects non www requests to www.
95 | // For example, http://labstack.com will be redirect to http://www.labstack.com.
96 | //
97 | // Usage `Echo#Pre(WWWRedirect())`
98 | func WWWRedirect() echo.MiddlewareFunc {
99 | return WWWRedirectWithConfig(DefaultRedirectConfig)
100 | }
101 |
102 | // WWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
103 | // See `WWWRedirect()`.
104 | func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
105 | return redirect(config, func(scheme, host, uri string) (bool, string) {
106 | if !strings.HasPrefix(host, www) {
107 | return true, scheme + "://www." + host + uri
108 | }
109 | return false, ""
110 | })
111 | }
112 |
113 | // NonWWWRedirect redirects www requests to non www.
114 | // For example, http://www.labstack.com will be redirect to http://labstack.com.
115 | //
116 | // Usage `Echo#Pre(NonWWWRedirect())`
117 | func NonWWWRedirect() echo.MiddlewareFunc {
118 | return NonWWWRedirectWithConfig(DefaultRedirectConfig)
119 | }
120 |
121 | // NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
122 | // See `NonWWWRedirect()`.
123 | func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
124 | return redirect(config, func(scheme, host, uri string) (bool, string) {
125 | if strings.HasPrefix(host, www) {
126 | return true, scheme + "://" + host[4:] + uri
127 | }
128 | return false, ""
129 | })
130 | }
131 |
132 | func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
133 | if config.Skipper == nil {
134 | config.Skipper = DefaultRedirectConfig.Skipper
135 | }
136 | if config.Code == 0 {
137 | config.Code = DefaultRedirectConfig.Code
138 | }
139 |
140 | return func(next echo.HandlerFunc) echo.HandlerFunc {
141 | return func(c echo.Context) error {
142 | if config.Skipper(c) {
143 | return next(c)
144 | }
145 |
146 | req, scheme := c.Request(), c.Scheme()
147 | host := req.Host
148 | if ok, url := cb(scheme, host, req.RequestURI); ok {
149 | return c.Redirect(config.Code, url)
150 | }
151 |
152 | return next(c)
153 | }
154 | }
155 | }
156 |
--------------------------------------------------------------------------------
/middleware/request_id.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "github.com/labstack/echo/v4"
8 | )
9 |
10 | // RequestIDConfig defines the config for RequestID middleware.
11 | type RequestIDConfig struct {
12 | // Skipper defines a function to skip middleware.
13 | Skipper Skipper
14 |
15 | // Generator defines a function to generate an ID.
16 | // Optional. Defaults to generator for random string of length 32.
17 | Generator func() string
18 |
19 | // RequestIDHandler defines a function which is executed for a request id.
20 | RequestIDHandler func(echo.Context, string)
21 |
22 | // TargetHeader defines what header to look for to populate the id
23 | TargetHeader string
24 | }
25 |
26 | // DefaultRequestIDConfig is the default RequestID middleware config.
27 | var DefaultRequestIDConfig = RequestIDConfig{
28 | Skipper: DefaultSkipper,
29 | Generator: generator,
30 | TargetHeader: echo.HeaderXRequestID,
31 | }
32 |
33 | // RequestID returns a X-Request-ID middleware.
34 | func RequestID() echo.MiddlewareFunc {
35 | return RequestIDWithConfig(DefaultRequestIDConfig)
36 | }
37 |
38 | // RequestIDWithConfig returns a X-Request-ID middleware with config.
39 | func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
40 | // Defaults
41 | if config.Skipper == nil {
42 | config.Skipper = DefaultRequestIDConfig.Skipper
43 | }
44 | if config.Generator == nil {
45 | config.Generator = generator
46 | }
47 | if config.TargetHeader == "" {
48 | config.TargetHeader = echo.HeaderXRequestID
49 | }
50 |
51 | return func(next echo.HandlerFunc) echo.HandlerFunc {
52 | return func(c echo.Context) error {
53 | if config.Skipper(c) {
54 | return next(c)
55 | }
56 |
57 | req := c.Request()
58 | res := c.Response()
59 | rid := req.Header.Get(config.TargetHeader)
60 | if rid == "" {
61 | rid = config.Generator()
62 | }
63 | res.Header().Set(config.TargetHeader, rid)
64 | if config.RequestIDHandler != nil {
65 | config.RequestIDHandler(c, rid)
66 | }
67 |
68 | return next(c)
69 | }
70 | }
71 | }
72 |
73 | func generator() string {
74 | return randomString(32)
75 | }
76 |
--------------------------------------------------------------------------------
/middleware/request_id_test.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "net/http"
8 | "net/http/httptest"
9 | "testing"
10 |
11 | "github.com/labstack/echo/v4"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | func TestRequestID(t *testing.T) {
16 | e := echo.New()
17 | req := httptest.NewRequest(http.MethodGet, "/", nil)
18 | rec := httptest.NewRecorder()
19 | c := e.NewContext(req, rec)
20 | handler := func(c echo.Context) error {
21 | return c.String(http.StatusOK, "test")
22 | }
23 |
24 | rid := RequestIDWithConfig(RequestIDConfig{})
25 | h := rid(handler)
26 | h(c)
27 | assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
28 |
29 | // Custom generator and handler
30 | customID := "customGenerator"
31 | calledHandler := false
32 | rid = RequestIDWithConfig(RequestIDConfig{
33 | Generator: func() string { return customID },
34 | RequestIDHandler: func(_ echo.Context, id string) {
35 | calledHandler = true
36 | assert.Equal(t, customID, id)
37 | },
38 | })
39 | h = rid(handler)
40 | h(c)
41 | assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
42 | assert.True(t, calledHandler)
43 | }
44 |
45 | func TestRequestID_IDNotAltered(t *testing.T) {
46 | e := echo.New()
47 | req := httptest.NewRequest(http.MethodGet, "/", nil)
48 | req.Header.Add(echo.HeaderXRequestID, "")
49 |
50 | rec := httptest.NewRecorder()
51 | c := e.NewContext(req, rec)
52 | handler := func(c echo.Context) error {
53 | return c.String(http.StatusOK, "test")
54 | }
55 |
56 | rid := RequestIDWithConfig(RequestIDConfig{})
57 | h := rid(handler)
58 | _ = h(c)
59 | assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "")
60 | }
61 |
62 | func TestRequestIDConfigDifferentHeader(t *testing.T) {
63 | e := echo.New()
64 | req := httptest.NewRequest(http.MethodGet, "/", nil)
65 | rec := httptest.NewRecorder()
66 | c := e.NewContext(req, rec)
67 | handler := func(c echo.Context) error {
68 | return c.String(http.StatusOK, "test")
69 | }
70 |
71 | rid := RequestIDWithConfig(RequestIDConfig{TargetHeader: echo.HeaderXCorrelationID})
72 | h := rid(handler)
73 | h(c)
74 | assert.Len(t, rec.Header().Get(echo.HeaderXCorrelationID), 32)
75 |
76 | // Custom generator and handler
77 | customID := "customGenerator"
78 | calledHandler := false
79 | rid = RequestIDWithConfig(RequestIDConfig{
80 | Generator: func() string { return customID },
81 | TargetHeader: echo.HeaderXCorrelationID,
82 | RequestIDHandler: func(_ echo.Context, id string) {
83 | calledHandler = true
84 | assert.Equal(t, customID, id)
85 | },
86 | })
87 | h = rid(handler)
88 | h(c)
89 | assert.Equal(t, rec.Header().Get(echo.HeaderXCorrelationID), "customGenerator")
90 | assert.True(t, calledHandler)
91 | }
92 |
--------------------------------------------------------------------------------
/middleware/rewrite.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "regexp"
8 |
9 | "github.com/labstack/echo/v4"
10 | )
11 |
12 | // RewriteConfig defines the config for Rewrite middleware.
13 | type RewriteConfig struct {
14 | // Skipper defines a function to skip middleware.
15 | Skipper Skipper
16 |
17 | // Rules defines the URL path rewrite rules. The values captured in asterisk can be
18 | // retrieved by index e.g. $1, $2 and so on.
19 | // Example:
20 | // "/old": "/new",
21 | // "/api/*": "/$1",
22 | // "/js/*": "/public/javascripts/$1",
23 | // "/users/*/orders/*": "/user/$1/order/$2",
24 | // Required.
25 | Rules map[string]string `yaml:"rules"`
26 |
27 | // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
28 | // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
29 | // Example:
30 | // "^/old/[0.9]+/": "/new",
31 | // "^/api/.+?/(.*)": "/v2/$1",
32 | RegexRules map[*regexp.Regexp]string `yaml:"-"`
33 | }
34 |
35 | // DefaultRewriteConfig is the default Rewrite middleware config.
36 | var DefaultRewriteConfig = RewriteConfig{
37 | Skipper: DefaultSkipper,
38 | }
39 |
40 | // Rewrite returns a Rewrite middleware.
41 | //
42 | // Rewrite middleware rewrites the URL path based on the provided rules.
43 | func Rewrite(rules map[string]string) echo.MiddlewareFunc {
44 | c := DefaultRewriteConfig
45 | c.Rules = rules
46 | return RewriteWithConfig(c)
47 | }
48 |
49 | // RewriteWithConfig returns a Rewrite middleware with config.
50 | // See: `Rewrite()`.
51 | func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
52 | // Defaults
53 | if config.Rules == nil && config.RegexRules == nil {
54 | panic("echo: rewrite middleware requires url path rewrite rules or regex rules")
55 | }
56 |
57 | if config.Skipper == nil {
58 | config.Skipper = DefaultBodyDumpConfig.Skipper
59 | }
60 |
61 | if config.RegexRules == nil {
62 | config.RegexRules = make(map[*regexp.Regexp]string)
63 | }
64 | for k, v := range rewriteRulesRegex(config.Rules) {
65 | config.RegexRules[k] = v
66 | }
67 |
68 | return func(next echo.HandlerFunc) echo.HandlerFunc {
69 | return func(c echo.Context) (err error) {
70 | if config.Skipper(c) {
71 | return next(c)
72 | }
73 |
74 | if err := rewriteURL(config.RegexRules, c.Request()); err != nil {
75 | return err
76 | }
77 | return next(c)
78 | }
79 | }
80 | }
81 |
--------------------------------------------------------------------------------
/middleware/secure.go:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | // SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3 |
4 | package middleware
5 |
6 | import (
7 | "fmt"
8 |
9 | "github.com/labstack/echo/v4"
10 | )
11 |
12 | // SecureConfig defines the config for Secure middleware.
13 | type SecureConfig struct {
14 | // Skipper defines a function to skip middleware.
15 | Skipper Skipper
16 |
17 | // XSSProtection provides protection against cross-site scripting attack (XSS)
18 | // by setting the `X-XSS-Protection` header.
19 | // Optional. Default value "1; mode=block".
20 | XSSProtection string `yaml:"xss_protection"`
21 |
22 | // ContentTypeNosniff provides protection against overriding Content-Type
23 | // header by setting the `X-Content-Type-Options` header.
24 | // Optional. Default value "nosniff".
25 | ContentTypeNosniff string `yaml:"content_type_nosniff"`
26 |
27 | // XFrameOptions can be used to indicate whether or not a browser should
28 | // be allowed to render a page in a ,