├── .editorconfig ├── .gitattributes ├── .github └── workflows │ ├── checks.yml │ └── echo.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── codecov.yml ├── extractors.go ├── extractors_test.go ├── go.mod ├── go.sum ├── jwt.go ├── jwt_benchmark_test.go ├── jwt_external_test.go ├── jwt_integration_test.go └── jwt_test.go /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig coding styles definitions. For more information about the 2 | # properties used in this file, please see the EditorConfig documentation: 3 | # http://editorconfig.org/ 4 | 5 | # indicate this is the root of the project 6 | root = true 7 | 8 | [*] 9 | charset = utf-8 10 | 11 | end_of_line = LF 12 | insert_final_newline = true 13 | trim_trailing_whitespace = true 14 | 15 | indent_style = space 16 | indent_size = 2 17 | 18 | [Makefile] 19 | indent_style = tab 20 | 21 | [*.md] 22 | trim_trailing_whitespace = false 23 | 24 | [*.go] 25 | indent_style = tab 26 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Automatically normalize line endings for all text-based files 2 | # http://git-scm.com/docs/gitattributes#_end_of_line_conversion 3 | * text=auto 4 | 5 | # For the following file types, normalize line endings to LF on checking and 6 | # prevent conversion to CRLF when they are checked out (this is required in 7 | # order to prevent newline related issues) 8 | .* text eol=lf 9 | *.go text eol=lf 10 | *.yml text eol=lf 11 | *.html text eol=lf 12 | *.css text eol=lf 13 | *.js text eol=lf 14 | *.json text eol=lf 15 | LICENSE text eol=lf 16 | 17 | -------------------------------------------------------------------------------- /.github/workflows/checks.yml: -------------------------------------------------------------------------------- 1 | name: Run checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | workflow_dispatch: 11 | 12 | permissions: 13 | contents: read # to fetch code (actions/checkout) 14 | 15 | env: 16 | # run static analysis only with the latest Go version 17 | LATEST_GO_VERSION: "1.24" 18 | 19 | jobs: 20 | check: 21 | runs-on: ubuntu-latest 22 | steps: 23 | - name: Checkout Code 24 | uses: actions/checkout@v4 25 | 26 | - name: Set up Go ${{ matrix.go }} 27 | uses: actions/setup-go@v5 28 | with: 29 | go-version: ${{ env.LATEST_GO_VERSION }} 30 | check-latest: true 31 | 32 | - name: Run golint 33 | run: | 34 | go install golang.org/x/lint/golint@latest 35 | golint -set_exit_status ./... 36 | 37 | - name: Run staticcheck 38 | run: | 39 | go install honnef.co/go/tools/cmd/staticcheck@latest 40 | staticcheck ./... 41 | 42 | - name: Run govulncheck 43 | run: | 44 | go version 45 | go install golang.org/x/vuln/cmd/govulncheck@latest 46 | govulncheck ./... 47 | -------------------------------------------------------------------------------- /.github/workflows/echo.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | workflow_dispatch: 11 | 12 | permissions: 13 | contents: read # to fetch code (actions/checkout) 14 | 15 | env: 16 | # run coverage and benchmarks only with the latest Go version 17 | LATEST_GO_VERSION: "1.24" 18 | 19 | jobs: 20 | test: 21 | strategy: 22 | matrix: 23 | os: [ubuntu-latest, macos-latest, windows-latest] 24 | # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy 25 | # Echo tests with last four major releases (unless there are pressing vulnerabilities) 26 | # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when 27 | # we derive from last four major releases promise. 28 | go: ["1.23", "1.24"] 29 | name: ${{ matrix.os }} @ Go ${{ matrix.go }} 30 | runs-on: ${{ matrix.os }} 31 | steps: 32 | - name: Checkout Code 33 | uses: actions/checkout@v4 34 | 35 | - name: Set up Go ${{ matrix.go }} 36 | uses: actions/setup-go@v5 37 | with: 38 | go-version: ${{ matrix.go }} 39 | 40 | - name: Run Tests 41 | run: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... 42 | 43 | - name: Upload coverage to Codecov 44 | if: success() && matrix.go == env.LATEST_GO_VERSION && matrix.os == 'ubuntu-latest' 45 | uses: codecov/codecov-action@v3 46 | with: 47 | token: 48 | fail_ci_if_error: false 49 | 50 | benchmark: 51 | needs: test 52 | name: Benchmark comparison 53 | runs-on: ubuntu-latest 54 | steps: 55 | - name: Checkout Code (Previous) 56 | uses: actions/checkout@v4 57 | with: 58 | ref: ${{ github.base_ref }} 59 | path: previous 60 | 61 | - name: Checkout Code (New) 62 | uses: actions/checkout@v4 63 | with: 64 | path: new 65 | 66 | - name: Set up Go ${{ matrix.go }} 67 | uses: actions/setup-go@v5 68 | with: 69 | go-version: ${{ env.LATEST_GO_VERSION }} 70 | 71 | - name: Install Dependencies 72 | run: go install golang.org/x/perf/cmd/benchstat@latest 73 | 74 | - name: Run Benchmark (Previous) 75 | run: | 76 | cd previous 77 | go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt 78 | 79 | - name: Run Benchmark (New) 80 | run: | 81 | cd new 82 | go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt 83 | 84 | - name: Run Benchstat 85 | run: | 86 | benchstat previous/benchmark.txt new/benchmark.txt 87 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | coverage.txt 3 | _test 4 | vendor 5 | .idea 6 | *.iml 7 | *.out 8 | .vscode 9 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## v4.3.1 - 2025-03-22 4 | 5 | **Security** 6 | 7 | * update JWT dependencies (https://github.com/advisories/GHSA-mh63-6h87-95cp) by @aldas in [#31](https://github.com/labstack/echo-jwt/pull/31) 8 | 9 | 10 | ## v4.3.0 - 2024-12-04 11 | 12 | **Enhancements** 13 | 14 | * Update Echo dependency to v4.13.0 by @aldas in [#28](https://github.com/labstack/echo-jwt/pull/28) 15 | 16 | 17 | ## v4.2.1 - 2024-12-04 18 | 19 | **Enhancements** 20 | 21 | * Return HTTP status 400 if missing JWT by @kitloong in [#13](https://github.com/labstack/echo-jwt/pull/13) 22 | * Update dependencies and CI flow by @aldas in [#21](https://github.com/labstack/echo-jwt/pull/21), [#24](https://github.com/labstack/echo-jwt/pull/24), [#27](https://github.com/labstack/echo-jwt/pull/27) 23 | * Improve readme formatting by @aldas in [#25](https://github.com/labstack/echo-jwt/pull/25) 24 | 25 | 26 | ## v4.2.0 - 2023-01-26 27 | 28 | **Breaking change:** [JWT](github.com/golang-jwt/jwt) has been upgraded to `v5`. Check/test all your code involved with JWT tokens/claims. Search for `github.com/golang-jwt/jwt/v4` 29 | and replace it with `github.com/golang-jwt/jwt/v5` 30 | 31 | **Enhancements** 32 | 33 | * Upgrade `golang-jwt/jwt` library to `v5` [#9](https://github.com/labstack/echo-jwt/pull/9) 34 | 35 | 36 | ## v4.1.0 - 2023-01-26 37 | 38 | **Enhancements** 39 | 40 | * Add TokenExtractionError and TokenParsingError types to help distinguishing error source in ErrorHandler [#6](https://github.com/labstack/echo-jwt/pull/6) 41 | 42 | 43 | ## v4.0.1 - 2023-01-24 44 | 45 | **Fixes** 46 | 47 | * Fix data race in error path [#4](https://github.com/labstack/echo-jwt/pull/4) 48 | 49 | 50 | **Enhancements** 51 | 52 | * add TokenError as error returned when parsing fails [#3](https://github.com/labstack/echo-jwt/pull/3) 53 | 54 | 55 | ## v4.0.0 - 2022-12-27 56 | 57 | * First release 58 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 LabStack and Echo contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PKG := "github.com/labstack/echo-jwt" 2 | PKG_LIST := $(shell go list ${PKG}/...) 3 | 4 | .DEFAULT_GOAL := check 5 | check: lint vet race ## Check project 6 | 7 | 8 | init: 9 | @go install golang.org/x/lint/golint@latest 10 | @go install honnef.co/go/tools/cmd/staticcheck@latest 11 | 12 | lint: ## Lint the files 13 | @staticcheck ${PKG_LIST} 14 | @golint -set_exit_status ${PKG_LIST} 15 | 16 | vet: ## Vet the files 17 | @go vet ${PKG_LIST} 18 | 19 | test: ## Run tests 20 | @go test -short ${PKG_LIST} 21 | 22 | race: ## Run tests with data race detector 23 | @go test -race ${PKG_LIST} 24 | 25 | benchmark: ## Run benchmarks 26 | @go test -run="-" -bench=".*" ${PKG_LIST} 27 | 28 | format: ## Format the source code 29 | @find ./ -type f -name "*.go" -exec gofmt -w {} \; 30 | 31 | help: ## Display this help screen 32 | @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 33 | 34 | goversion ?= "1.20" 35 | test_version: ## Run tests inside Docker with given version (defaults to 1.20 oldest supported). Example: make test_version goversion=1.20 36 | @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make race" 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo-jwt/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo-jwt?badge) 2 | [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo-jwt/v4) 3 | [![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo-jwt?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo-jwt) 4 | [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo-jwt.svg?style=flat-square)](https://codecov.io/gh/labstack/echo-jwt) 5 | [![License](http://img.shields.io/badge/license-mit-blue.svg?style=flat-square)](https://raw.githubusercontent.com/labstack/echo-jwt/master/LICENSE) 6 | 7 | # Echo JWT middleware 8 | 9 | JWT middleware for [Echo](https://github.com/labstack/echo) framework. This middleware uses by default [golang-jwt/jwt/v5](https://github.com/golang-jwt/jwt) 10 | as JWT implementation. 11 | 12 | ## Versioning 13 | 14 | This repository does not use semantic versioning. MAJOR version tracks which Echo version should be used. MINOR version 15 | tracks API changes (possibly backwards incompatible) and PATCH version is incremented for fixes. 16 | 17 | NB: When `golang-jwt` MAJOR version changes this library will release MINOR version with **breaking change**. Always 18 | add at least one integration test in your project. 19 | 20 | For Echo `v4` use `v4.x.y` releases. 21 | Minimal needed Echo versions: 22 | * `v4.0.0` needs Echo `v4.7.0+` 23 | 24 | `main` branch is compatible with the latest Echo version. 25 | 26 | ## Usage 27 | 28 | Add JWT middleware dependency with go modules 29 | ```bash 30 | go get github.com/labstack/echo-jwt/v4 31 | ``` 32 | 33 | Use as import statement 34 | ```go 35 | import "github.com/labstack/echo-jwt/v4" 36 | ``` 37 | 38 | Add middleware in simplified form, by providing only the secret key 39 | ```go 40 | e.Use(echojwt.JWT([]byte("secret"))) 41 | ``` 42 | 43 | Add middleware with configuration options 44 | ```go 45 | e.Use(echojwt.WithConfig(echojwt.Config{ 46 | // ... 47 | SigningKey: []byte("secret"), 48 | // ... 49 | })) 50 | ``` 51 | 52 | Extract token in handler 53 | ```go 54 | import "github.com/golang-jwt/jwt/v5" 55 | 56 | // ... 57 | 58 | e.GET("/", func(c echo.Context) error { 59 | token, ok := c.Get("user").(*jwt.Token) // by default token is stored under `user` key 60 | if !ok { 61 | return errors.New("JWT token missing or invalid") 62 | } 63 | claims, ok := token.Claims.(jwt.MapClaims) // by default claims is of type `jwt.MapClaims` 64 | if !ok { 65 | return errors.New("failed to cast claims as jwt.MapClaims") 66 | } 67 | return c.JSON(http.StatusOK, claims) 68 | }) 69 | ``` 70 | 71 | ## IMPORTANT: Integration Testing with JWT Library 72 | 73 | Ensure that your project includes at least one integration test to detect changes in major versions of the `golang-jwt/jwt` library early. 74 | This is crucial because type assertions like `token := c.Get("user").(*jwt.Token)` may fail silently if the imported version of the JWT library (e.g., `import "github.com/golang-jwt/jwt/v5"`) differs from the version used internally by dependencies (e.g., echo-jwt may now use `v6`). Such discrepancies can lead to invalid casts, causing your handlers to panic or throw errors. Integration tests help safeguard against these version mismatches. 75 | 76 | ```go 77 | func TestIntegrationMiddlewareWithHandler(t *testing.T) { 78 | e := echo.New() 79 | e.Use(echojwt.WithConfig(echojwt.Config{ 80 | SigningKey: []byte("secret"), 81 | })) 82 | 83 | // use handler that gets token from context to fail your CI flow when `golang-jwt/jwt` library version changes 84 | // a) `token, ok := c.Get("user").(*jwt.Token)` 85 | // b) `token := c.Get("user").(*jwt.Token)` 86 | e.GET("/example", exampleHandler) 87 | 88 | req := httptest.NewRequest(http.MethodGet, "/example", nil) 89 | req.Header.Set(echo.HeaderAuthorization, "Bearer ") 90 | res := httptest.NewRecorder() 91 | 92 | e.ServeHTTP(res, req) 93 | 94 | if res.Code != 200 { 95 | t.Failed() 96 | } 97 | } 98 | ``` 99 | 100 | 101 | ## Full example 102 | 103 | ```go 104 | package main 105 | 106 | import ( 107 | "errors" 108 | "github.com/golang-jwt/jwt/v5" 109 | "github.com/labstack/echo-jwt/v4" 110 | "github.com/labstack/echo/v4" 111 | "github.com/labstack/echo/v4/middleware" 112 | "log" 113 | "net/http" 114 | ) 115 | 116 | func main() { 117 | e := echo.New() 118 | e.Use(middleware.Logger()) 119 | e.Use(middleware.Recover()) 120 | 121 | e.Use(echojwt.WithConfig(echojwt.Config{ 122 | SigningKey: []byte("secret"), 123 | })) 124 | 125 | e.GET("/", func(c echo.Context) error { 126 | token, ok := c.Get("user").(*jwt.Token) // by default token is stored under `user` key 127 | if !ok { 128 | return errors.New("JWT token missing or invalid") 129 | } 130 | claims, ok := token.Claims.(jwt.MapClaims) // by default claims is of type `jwt.MapClaims` 131 | if !ok { 132 | return errors.New("failed to cast claims as jwt.MapClaims") 133 | } 134 | return c.JSON(http.StatusOK, claims) 135 | }) 136 | 137 | if err := e.Start(":8080"); err != http.ErrServerClosed { 138 | log.Fatal(err) 139 | } 140 | } 141 | ``` 142 | 143 | Test with 144 | ```bash 145 | curl -v -H "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" http://localhost:8080 146 | ``` 147 | 148 | Output should be 149 | ```bash 150 | * Trying 127.0.0.1:8080... 151 | * Connected to localhost (127.0.0.1) port 8080 (#0) 152 | > GET / HTTP/1.1 153 | > Host: localhost:8080 154 | > User-Agent: curl/7.81.0 155 | > Accept: */* 156 | > Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ 157 | > 158 | * Mark bundle as not supporting multiuse 159 | < HTTP/1.1 200 OK 160 | < Content-Type: application/json; charset=UTF-8 161 | < Date: Sun, 27 Nov 2022 21:34:17 GMT 162 | < Content-Length: 52 163 | < 164 | {"admin":true,"name":"John Doe","sub":"1234567890"} 165 | ``` 166 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | threshold: 1.0% 6 | patch: 7 | default: 8 | threshold: 1.0% 9 | 10 | comment: 11 | require_changes: true 12 | -------------------------------------------------------------------------------- /extractors.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2016 LabStack and Echo contributors 3 | 4 | package echojwt 5 | 6 | import ( 7 | "errors" 8 | "fmt" 9 | "github.com/labstack/echo/v4" 10 | "github.com/labstack/echo/v4/middleware" 11 | "net/textproto" 12 | "strings" 13 | ) 14 | 15 | const ( 16 | // extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion 17 | // attack vector 18 | extractorLimit = 20 19 | ) 20 | 21 | // TokenExtractionError is catch all type for all errors that occur when the token is extracted from the request. This 22 | // helps to distinguish extractor errors from token parsing errors even if custom extractors or token parsing functions 23 | // are being used that have their own custom errors. 24 | type TokenExtractionError struct { 25 | Err error 26 | } 27 | 28 | // Is checks if target error is same as TokenExtractionError 29 | func (e TokenExtractionError) Is(target error) bool { return target == ErrJWTMissing } // to provide some compatibility with older error handling logic 30 | 31 | func (e *TokenExtractionError) Error() string { return e.Err.Error() } 32 | func (e *TokenExtractionError) Unwrap() error { return e.Err } 33 | 34 | var errHeaderExtractorValueMissing = errors.New("missing value in request header") 35 | var errHeaderExtractorValueInvalid = errors.New("invalid value in request header") 36 | var errQueryExtractorValueMissing = errors.New("missing value in the query string") 37 | var errParamExtractorValueMissing = errors.New("missing value in path params") 38 | var errCookieExtractorValueMissing = errors.New("missing value in cookies") 39 | var errFormExtractorValueMissing = errors.New("missing value in the form") 40 | 41 | // CreateExtractors creates ValuesExtractors from given lookups. 42 | // Lookups is a string in the form of ":" or ":,:" that is used 43 | // to extract key from the request. 44 | // Possible values: 45 | // - "header:" or "header::" 46 | // `` is argument value to cut/trim prefix of the extracted value. This is useful if header 47 | // value has static prefix like `Authorization: ` where part that we 48 | // want to cut is ` ` note the space at the end. 49 | // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. 50 | // - "query:" 51 | // - "param:" 52 | // - "form:" 53 | // - "cookie:" 54 | // 55 | // Multiple sources example: 56 | // - "header:Authorization,header:X-Api-Key" 57 | func CreateExtractors(lookups string) ([]middleware.ValuesExtractor, error) { 58 | if lookups == "" { 59 | return nil, nil 60 | } 61 | sources := strings.Split(lookups, ",") 62 | var extractors = make([]middleware.ValuesExtractor, 0) 63 | for _, source := range sources { 64 | parts := strings.Split(source, ":") 65 | if len(parts) < 2 { 66 | return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source) 67 | } 68 | 69 | switch parts[0] { 70 | case "query": 71 | extractors = append(extractors, valuesFromQuery(parts[1])) 72 | case "param": 73 | extractors = append(extractors, valuesFromParam(parts[1])) 74 | case "cookie": 75 | extractors = append(extractors, valuesFromCookie(parts[1])) 76 | case "form": 77 | extractors = append(extractors, valuesFromForm(parts[1])) 78 | case "header": 79 | prefix := "" 80 | if len(parts) > 2 { 81 | prefix = parts[2] 82 | } 83 | extractors = append(extractors, valuesFromHeader(parts[1], prefix)) 84 | } 85 | } 86 | return extractors, nil 87 | } 88 | 89 | // valuesFromHeader returns a functions that extracts values from the request header. 90 | // valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static 91 | // prefix like `Authorization: ` where part that we want to remove is ` ` 92 | // note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove 93 | // is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `. 94 | // If prefix is left empty the whole value is returned. 95 | func valuesFromHeader(header string, valuePrefix string) middleware.ValuesExtractor { 96 | prefixLen := len(valuePrefix) 97 | // standard library parses http.Request header keys in canonical form but we may provide something else so fix this 98 | header = textproto.CanonicalMIMEHeaderKey(header) 99 | return func(c echo.Context) ([]string, error) { 100 | values := c.Request().Header.Values(header) 101 | if len(values) == 0 { 102 | return nil, errHeaderExtractorValueMissing 103 | } 104 | 105 | result := make([]string, 0) 106 | for i, value := range values { 107 | if prefixLen == 0 { 108 | result = append(result, value) 109 | if i >= extractorLimit-1 { 110 | break 111 | } 112 | continue 113 | } 114 | if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { 115 | result = append(result, value[prefixLen:]) 116 | if i >= extractorLimit-1 { 117 | break 118 | } 119 | } 120 | } 121 | 122 | if len(result) == 0 { 123 | if prefixLen > 0 { 124 | return nil, errHeaderExtractorValueInvalid 125 | } 126 | return nil, errHeaderExtractorValueMissing 127 | } 128 | return result, nil 129 | } 130 | } 131 | 132 | // valuesFromQuery returns a function that extracts values from the query string. 133 | func valuesFromQuery(param string) middleware.ValuesExtractor { 134 | return func(c echo.Context) ([]string, error) { 135 | result := c.QueryParams()[param] 136 | if len(result) == 0 { 137 | return nil, errQueryExtractorValueMissing 138 | } else if len(result) > extractorLimit-1 { 139 | result = result[:extractorLimit] 140 | } 141 | return result, nil 142 | } 143 | } 144 | 145 | // valuesFromParam returns a function that extracts values from the url param string. 146 | func valuesFromParam(param string) middleware.ValuesExtractor { 147 | return func(c echo.Context) ([]string, error) { 148 | result := make([]string, 0) 149 | paramVales := c.ParamValues() 150 | for i, p := range c.ParamNames() { 151 | if param == p { 152 | result = append(result, paramVales[i]) 153 | if i >= extractorLimit-1 { 154 | break 155 | } 156 | } 157 | } 158 | if len(result) == 0 { 159 | return nil, errParamExtractorValueMissing 160 | } 161 | return result, nil 162 | } 163 | } 164 | 165 | // valuesFromCookie returns a function that extracts values from the named cookie. 166 | func valuesFromCookie(name string) middleware.ValuesExtractor { 167 | return func(c echo.Context) ([]string, error) { 168 | cookies := c.Cookies() 169 | if len(cookies) == 0 { 170 | return nil, errCookieExtractorValueMissing 171 | } 172 | 173 | result := make([]string, 0) 174 | for i, cookie := range cookies { 175 | if name == cookie.Name { 176 | result = append(result, cookie.Value) 177 | if i >= extractorLimit-1 { 178 | break 179 | } 180 | } 181 | } 182 | if len(result) == 0 { 183 | return nil, errCookieExtractorValueMissing 184 | } 185 | return result, nil 186 | } 187 | } 188 | 189 | // valuesFromForm returns a function that extracts values from the form field. 190 | func valuesFromForm(name string) middleware.ValuesExtractor { 191 | return func(c echo.Context) ([]string, error) { 192 | if c.Request().Form == nil { 193 | _ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does 194 | } 195 | values := c.Request().Form[name] 196 | if len(values) == 0 { 197 | return nil, errFormExtractorValueMissing 198 | } 199 | if len(values) > extractorLimit-1 { 200 | values = values[:extractorLimit] 201 | } 202 | result := append([]string{}, values...) 203 | return result, nil 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /extractors_test.go: -------------------------------------------------------------------------------- 1 | package echojwt 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "github.com/labstack/echo/v4" 8 | "github.com/stretchr/testify/assert" 9 | "mime/multipart" 10 | "net/http" 11 | "net/http/httptest" 12 | "net/url" 13 | "strings" 14 | "testing" 15 | ) 16 | 17 | func TestTokenExtractionError_Is(t *testing.T) { 18 | given := echo.ErrUnauthorized.SetInternal(&TokenExtractionError{Err: errCookieExtractorValueMissing}) 19 | 20 | assert.True(t, errors.Is(given, ErrJWTMissing)) 21 | assert.True(t, errors.Is(given, errCookieExtractorValueMissing)) 22 | } 23 | 24 | func TestTokenExtractionError_Error(t *testing.T) { 25 | given := &TokenExtractionError{Err: errCookieExtractorValueMissing} 26 | assert.Equal(t, "missing value in cookies", given.Error()) 27 | } 28 | 29 | func TestTokenExtractionError_Unwrap(t *testing.T) { 30 | given := &TokenExtractionError{Err: errCookieExtractorValueMissing} 31 | assert.Equal(t, errCookieExtractorValueMissing, given.Unwrap()) 32 | } 33 | 34 | type pathParam struct { 35 | name string 36 | value string 37 | } 38 | 39 | func setPathParams(c echo.Context, params []pathParam) { 40 | names := make([]string, 0, len(params)) 41 | values := make([]string, 0, len(params)) 42 | for _, pp := range params { 43 | names = append(names, pp.name) 44 | values = append(values, pp.value) 45 | } 46 | c.SetParamNames(names...) 47 | c.SetParamValues(values...) 48 | } 49 | 50 | func TestCreateExtractors(t *testing.T) { 51 | var testCases = []struct { 52 | name string 53 | givenRequest func() *http.Request 54 | givenPathParams []pathParam 55 | whenLoopups string 56 | expectValues []string 57 | expectCreateError string 58 | expectError string 59 | }{ 60 | { 61 | name: "ok, header", 62 | givenRequest: func() *http.Request { 63 | req := httptest.NewRequest(http.MethodGet, "/", nil) 64 | req.Header.Set(echo.HeaderAuthorization, "Bearer token") 65 | return req 66 | }, 67 | whenLoopups: "header:Authorization:Bearer ", 68 | expectValues: []string{"token"}, 69 | }, 70 | { 71 | name: "ok, form", 72 | givenRequest: func() *http.Request { 73 | f := make(url.Values) 74 | f.Set("name", "Jon Snow") 75 | 76 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) 77 | req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) 78 | return req 79 | }, 80 | whenLoopups: "form:name", 81 | expectValues: []string{"Jon Snow"}, 82 | }, 83 | { 84 | name: "ok, cookie", 85 | givenRequest: func() *http.Request { 86 | req := httptest.NewRequest(http.MethodGet, "/", nil) 87 | req.Header.Set(echo.HeaderCookie, "_csrf=token") 88 | return req 89 | }, 90 | whenLoopups: "cookie:_csrf", 91 | expectValues: []string{"token"}, 92 | }, 93 | { 94 | name: "ok, param", 95 | givenPathParams: []pathParam{ 96 | {name: "id", value: "123"}, 97 | }, 98 | whenLoopups: "param:id", 99 | expectValues: []string{"123"}, 100 | }, 101 | { 102 | name: "ok, query", 103 | givenRequest: func() *http.Request { 104 | req := httptest.NewRequest(http.MethodGet, "/?id=999", nil) 105 | return req 106 | }, 107 | whenLoopups: "query:id", 108 | expectValues: []string{"999"}, 109 | }, 110 | { 111 | name: "nok, invalid lookup", 112 | whenLoopups: "query", 113 | expectCreateError: "extractor source for lookup could not be split into needed parts: query", 114 | }, 115 | } 116 | 117 | for _, tc := range testCases { 118 | t.Run(tc.name, func(t *testing.T) { 119 | e := echo.New() 120 | 121 | req := httptest.NewRequest(http.MethodGet, "/", nil) 122 | if tc.givenRequest != nil { 123 | req = tc.givenRequest() 124 | } 125 | rec := httptest.NewRecorder() 126 | c := e.NewContext(req, rec) 127 | if tc.givenPathParams != nil { 128 | setPathParams(c, tc.givenPathParams) 129 | } 130 | 131 | extractors, err := CreateExtractors(tc.whenLoopups) 132 | if tc.expectCreateError != "" { 133 | assert.EqualError(t, err, tc.expectCreateError) 134 | return 135 | } 136 | assert.NoError(t, err) 137 | 138 | for _, e := range extractors { 139 | values, eErr := e(c) 140 | assert.Equal(t, tc.expectValues, values) 141 | if tc.expectError != "" { 142 | assert.EqualError(t, eErr, tc.expectError) 143 | return 144 | } 145 | assert.NoError(t, eErr) 146 | } 147 | }) 148 | } 149 | } 150 | 151 | func TestValuesFromHeader(t *testing.T) { 152 | exampleRequest := func(req *http.Request) { 153 | req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") 154 | } 155 | 156 | var testCases = []struct { 157 | name string 158 | givenRequest func(req *http.Request) 159 | whenName string 160 | whenValuePrefix string 161 | expectValues []string 162 | expectError string 163 | }{ 164 | { 165 | name: "ok, single value", 166 | givenRequest: exampleRequest, 167 | whenName: echo.HeaderAuthorization, 168 | whenValuePrefix: "basic ", 169 | expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, 170 | }, 171 | { 172 | name: "ok, single value, case insensitive", 173 | givenRequest: exampleRequest, 174 | whenName: echo.HeaderAuthorization, 175 | whenValuePrefix: "Basic ", 176 | expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, 177 | }, 178 | { 179 | name: "ok, multiple value", 180 | givenRequest: func(req *http.Request) { 181 | req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") 182 | req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") 183 | }, 184 | whenName: echo.HeaderAuthorization, 185 | whenValuePrefix: "basic ", 186 | expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"}, 187 | }, 188 | { 189 | name: "ok, empty prefix", 190 | givenRequest: exampleRequest, 191 | whenName: echo.HeaderAuthorization, 192 | whenValuePrefix: "", 193 | expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="}, 194 | }, 195 | { 196 | name: "nok, no matching due different prefix", 197 | givenRequest: func(req *http.Request) { 198 | req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") 199 | req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") 200 | }, 201 | whenName: echo.HeaderAuthorization, 202 | whenValuePrefix: "Bearer ", 203 | expectError: errHeaderExtractorValueInvalid.Error(), 204 | }, 205 | { 206 | name: "nok, no matching due different prefix", 207 | givenRequest: func(req *http.Request) { 208 | req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") 209 | req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") 210 | }, 211 | whenName: echo.HeaderWWWAuthenticate, 212 | whenValuePrefix: "", 213 | expectError: errHeaderExtractorValueMissing.Error(), 214 | }, 215 | { 216 | name: "nok, no headers", 217 | givenRequest: nil, 218 | whenName: echo.HeaderAuthorization, 219 | whenValuePrefix: "basic ", 220 | expectError: errHeaderExtractorValueMissing.Error(), 221 | }, 222 | { 223 | name: "ok, prefix, cut values over extractorLimit", 224 | givenRequest: func(req *http.Request) { 225 | for i := 1; i <= 25; i++ { 226 | req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i)) 227 | } 228 | }, 229 | whenName: echo.HeaderAuthorization, 230 | whenValuePrefix: "basic ", 231 | expectValues: []string{ 232 | "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", 233 | "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", 234 | }, 235 | }, 236 | { 237 | name: "ok, cut values over extractorLimit", 238 | givenRequest: func(req *http.Request) { 239 | for i := 1; i <= 25; i++ { 240 | req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i)) 241 | } 242 | }, 243 | whenName: echo.HeaderAuthorization, 244 | whenValuePrefix: "", 245 | expectValues: []string{ 246 | "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", 247 | "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", 248 | }, 249 | }, 250 | } 251 | 252 | for _, tc := range testCases { 253 | t.Run(tc.name, func(t *testing.T) { 254 | e := echo.New() 255 | 256 | req := httptest.NewRequest(http.MethodGet, "/", nil) 257 | if tc.givenRequest != nil { 258 | tc.givenRequest(req) 259 | } 260 | rec := httptest.NewRecorder() 261 | c := e.NewContext(req, rec) 262 | 263 | extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix) 264 | 265 | values, err := extractor(c) 266 | assert.Equal(t, tc.expectValues, values) 267 | if tc.expectError != "" { 268 | assert.EqualError(t, err, tc.expectError) 269 | } else { 270 | assert.NoError(t, err) 271 | } 272 | }) 273 | } 274 | } 275 | 276 | func TestValuesFromQuery(t *testing.T) { 277 | var testCases = []struct { 278 | name string 279 | givenQueryPart string 280 | whenName string 281 | expectValues []string 282 | expectError string 283 | }{ 284 | { 285 | name: "ok, single value", 286 | givenQueryPart: "?id=123&name=test", 287 | whenName: "id", 288 | expectValues: []string{"123"}, 289 | }, 290 | { 291 | name: "ok, multiple value", 292 | givenQueryPart: "?id=123&id=456&name=test", 293 | whenName: "id", 294 | expectValues: []string{"123", "456"}, 295 | }, 296 | { 297 | name: "nok, missing value", 298 | givenQueryPart: "?id=123&name=test", 299 | whenName: "nope", 300 | expectError: errQueryExtractorValueMissing.Error(), 301 | }, 302 | { 303 | name: "ok, cut values over extractorLimit", 304 | givenQueryPart: "?name=test" + 305 | "&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" + 306 | "&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" + 307 | "&id=21&id=22&id=23&id=24&id=25", 308 | whenName: "id", 309 | expectValues: []string{ 310 | "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", 311 | "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", 312 | }, 313 | }, 314 | } 315 | 316 | for _, tc := range testCases { 317 | t.Run(tc.name, func(t *testing.T) { 318 | e := echo.New() 319 | 320 | req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil) 321 | rec := httptest.NewRecorder() 322 | c := e.NewContext(req, rec) 323 | 324 | extractor := valuesFromQuery(tc.whenName) 325 | 326 | values, err := extractor(c) 327 | assert.Equal(t, tc.expectValues, values) 328 | if tc.expectError != "" { 329 | assert.EqualError(t, err, tc.expectError) 330 | } else { 331 | assert.NoError(t, err) 332 | } 333 | }) 334 | } 335 | } 336 | 337 | func TestValuesFromParam(t *testing.T) { 338 | examplePathParams := []pathParam{ 339 | {name: "id", value: "123"}, 340 | {name: "gid", value: "456"}, 341 | {name: "gid", value: "789"}, 342 | } 343 | examplePathParams20 := make([]pathParam, 0) 344 | for i := 1; i < 25; i++ { 345 | examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)}) 346 | } 347 | 348 | var testCases = []struct { 349 | name string 350 | givenPathParams []pathParam 351 | whenName string 352 | expectValues []string 353 | expectError string 354 | }{ 355 | { 356 | name: "ok, single value", 357 | givenPathParams: examplePathParams, 358 | whenName: "id", 359 | expectValues: []string{"123"}, 360 | }, 361 | { 362 | name: "ok, multiple value", 363 | givenPathParams: examplePathParams, 364 | whenName: "gid", 365 | expectValues: []string{"456", "789"}, 366 | }, 367 | { 368 | name: "nok, no values", 369 | givenPathParams: nil, 370 | whenName: "nope", 371 | expectValues: nil, 372 | expectError: errParamExtractorValueMissing.Error(), 373 | }, 374 | { 375 | name: "nok, no matching value", 376 | givenPathParams: examplePathParams, 377 | whenName: "nope", 378 | expectValues: nil, 379 | expectError: errParamExtractorValueMissing.Error(), 380 | }, 381 | { 382 | name: "ok, cut values over extractorLimit", 383 | givenPathParams: examplePathParams20, 384 | whenName: "id", 385 | expectValues: []string{ 386 | "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", 387 | "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", 388 | }, 389 | }, 390 | } 391 | 392 | for _, tc := range testCases { 393 | t.Run(tc.name, func(t *testing.T) { 394 | e := echo.New() 395 | 396 | req := httptest.NewRequest(http.MethodGet, "/", nil) 397 | rec := httptest.NewRecorder() 398 | c := e.NewContext(req, rec) 399 | if tc.givenPathParams != nil { 400 | setPathParams(c, tc.givenPathParams) 401 | } 402 | 403 | extractor := valuesFromParam(tc.whenName) 404 | 405 | values, err := extractor(c) 406 | assert.Equal(t, tc.expectValues, values) 407 | if tc.expectError != "" { 408 | assert.EqualError(t, err, tc.expectError) 409 | } else { 410 | assert.NoError(t, err) 411 | } 412 | }) 413 | } 414 | } 415 | 416 | func TestValuesFromCookie(t *testing.T) { 417 | exampleRequest := func(req *http.Request) { 418 | req.Header.Set(echo.HeaderCookie, "_csrf=token") 419 | } 420 | 421 | var testCases = []struct { 422 | name string 423 | givenRequest func(req *http.Request) 424 | whenName string 425 | expectValues []string 426 | expectError string 427 | }{ 428 | { 429 | name: "ok, single value", 430 | givenRequest: exampleRequest, 431 | whenName: "_csrf", 432 | expectValues: []string{"token"}, 433 | }, 434 | { 435 | name: "ok, multiple value", 436 | givenRequest: func(req *http.Request) { 437 | req.Header.Add(echo.HeaderCookie, "_csrf=token") 438 | req.Header.Add(echo.HeaderCookie, "_csrf=token2") 439 | }, 440 | whenName: "_csrf", 441 | expectValues: []string{"token", "token2"}, 442 | }, 443 | { 444 | name: "nok, no matching cookie", 445 | givenRequest: exampleRequest, 446 | whenName: "xxx", 447 | expectValues: nil, 448 | expectError: errCookieExtractorValueMissing.Error(), 449 | }, 450 | { 451 | name: "nok, no cookies at all", 452 | givenRequest: nil, 453 | whenName: "xxx", 454 | expectValues: nil, 455 | expectError: errCookieExtractorValueMissing.Error(), 456 | }, 457 | { 458 | name: "ok, cut values over extractorLimit", 459 | givenRequest: func(req *http.Request) { 460 | for i := 1; i < 25; i++ { 461 | req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i)) 462 | } 463 | }, 464 | whenName: "_csrf", 465 | expectValues: []string{ 466 | "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", 467 | "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", 468 | }, 469 | }, 470 | } 471 | 472 | for _, tc := range testCases { 473 | t.Run(tc.name, func(t *testing.T) { 474 | e := echo.New() 475 | 476 | req := httptest.NewRequest(http.MethodGet, "/", nil) 477 | if tc.givenRequest != nil { 478 | tc.givenRequest(req) 479 | } 480 | rec := httptest.NewRecorder() 481 | c := e.NewContext(req, rec) 482 | 483 | extractor := valuesFromCookie(tc.whenName) 484 | 485 | values, err := extractor(c) 486 | assert.Equal(t, tc.expectValues, values) 487 | if tc.expectError != "" { 488 | assert.EqualError(t, err, tc.expectError) 489 | } else { 490 | assert.NoError(t, err) 491 | } 492 | }) 493 | } 494 | } 495 | 496 | func TestValuesFromForm(t *testing.T) { 497 | examplePostFormRequest := func(mod func(v *url.Values)) *http.Request { 498 | f := make(url.Values) 499 | f.Set("name", "Jon Snow") 500 | f.Set("emails[]", "jon@labstack.com") 501 | if mod != nil { 502 | mod(&f) 503 | } 504 | 505 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) 506 | req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) 507 | 508 | return req 509 | } 510 | exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request { 511 | f := make(url.Values) 512 | f.Set("name", "Jon Snow") 513 | f.Set("emails[]", "jon@labstack.com") 514 | if mod != nil { 515 | mod(&f) 516 | } 517 | 518 | req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil) 519 | return req 520 | } 521 | 522 | exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request { 523 | var b bytes.Buffer 524 | w := multipart.NewWriter(&b) 525 | w.WriteField("name", "Jon Snow") 526 | w.WriteField("emails[]", "jon@labstack.com") 527 | if mod != nil { 528 | mod(w) 529 | } 530 | 531 | fw, _ := w.CreateFormFile("upload", "my.file") 532 | fw.Write([]byte(`
hi
`)) 533 | w.Close() 534 | 535 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(b.String())) 536 | req.Header.Add(echo.HeaderContentType, w.FormDataContentType()) 537 | 538 | return req 539 | } 540 | 541 | var testCases = []struct { 542 | name string 543 | givenRequest *http.Request 544 | whenName string 545 | expectValues []string 546 | expectError string 547 | }{ 548 | { 549 | name: "ok, POST form, single value", 550 | givenRequest: examplePostFormRequest(nil), 551 | whenName: "emails[]", 552 | expectValues: []string{"jon@labstack.com"}, 553 | }, 554 | { 555 | name: "ok, POST form, multiple value", 556 | givenRequest: examplePostFormRequest(func(v *url.Values) { 557 | v.Add("emails[]", "snow@labstack.com") 558 | }), 559 | whenName: "emails[]", 560 | expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, 561 | }, 562 | { 563 | name: "ok, POST multipart/form, multiple value", 564 | givenRequest: exampleMultiPartFormRequest(func(w *multipart.Writer) { 565 | w.WriteField("emails[]", "snow@labstack.com") 566 | }), 567 | whenName: "emails[]", 568 | expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, 569 | }, 570 | { 571 | name: "ok, GET form, single value", 572 | givenRequest: exampleGetFormRequest(nil), 573 | whenName: "emails[]", 574 | expectValues: []string{"jon@labstack.com"}, 575 | }, 576 | { 577 | name: "ok, GET form, multiple value", 578 | givenRequest: examplePostFormRequest(func(v *url.Values) { 579 | v.Add("emails[]", "snow@labstack.com") 580 | }), 581 | whenName: "emails[]", 582 | expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, 583 | }, 584 | { 585 | name: "nok, POST form, value missing", 586 | givenRequest: examplePostFormRequest(nil), 587 | whenName: "nope", 588 | expectError: errFormExtractorValueMissing.Error(), 589 | }, 590 | { 591 | name: "ok, cut values over extractorLimit", 592 | givenRequest: examplePostFormRequest(func(v *url.Values) { 593 | for i := 1; i < 25; i++ { 594 | v.Add("id[]", fmt.Sprintf("%v", i)) 595 | } 596 | }), 597 | whenName: "id[]", 598 | expectValues: []string{ 599 | "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", 600 | "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", 601 | }, 602 | }, 603 | } 604 | 605 | for _, tc := range testCases { 606 | t.Run(tc.name, func(t *testing.T) { 607 | e := echo.New() 608 | 609 | req := tc.givenRequest 610 | rec := httptest.NewRecorder() 611 | c := e.NewContext(req, rec) 612 | 613 | extractor := valuesFromForm(tc.whenName) 614 | 615 | values, err := extractor(c) 616 | assert.Equal(t, tc.expectValues, values) 617 | if tc.expectError != "" { 618 | assert.EqualError(t, err, tc.expectError) 619 | } else { 620 | assert.NoError(t, err) 621 | } 622 | }) 623 | } 624 | } 625 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/labstack/echo-jwt/v4 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/golang-jwt/jwt/v5 v5.2.2 7 | github.com/labstack/echo/v4 v4.13.3 8 | github.com/stretchr/testify v1.10.0 9 | ) 10 | 11 | require ( 12 | github.com/davecgh/go-spew v1.1.1 // indirect 13 | github.com/labstack/gommon v0.4.2 // indirect 14 | github.com/mattn/go-colorable v0.1.14 // indirect 15 | github.com/mattn/go-isatty v0.0.20 // indirect 16 | github.com/pmezard/go-difflib v1.0.0 // indirect 17 | github.com/valyala/bytebufferpool v1.0.0 // indirect 18 | github.com/valyala/fasttemplate v1.2.2 // indirect 19 | golang.org/x/crypto v0.36.0 // indirect 20 | golang.org/x/net v0.37.0 // indirect 21 | golang.org/x/sys v0.31.0 // indirect 22 | golang.org/x/text v0.23.0 // indirect 23 | golang.org/x/time v0.11.0 // indirect 24 | gopkg.in/yaml.v3 v3.0.1 // indirect 25 | ) 26 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= 4 | github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= 5 | github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaafY= 6 | github.com/labstack/echo/v4 v4.13.3/go.mod h1:o90YNEeQWjDozo584l7AwhJMHN0bOC4tAfg+Xox9q5g= 7 | github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= 8 | github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= 9 | github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= 10 | github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= 11 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= 12 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 13 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 14 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 15 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 16 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 17 | github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= 18 | github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= 19 | github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= 20 | github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= 21 | golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= 22 | golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= 23 | golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= 24 | golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= 25 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 26 | golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= 27 | golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 28 | golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= 29 | golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= 30 | golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= 31 | golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= 32 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 33 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 34 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 35 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 36 | -------------------------------------------------------------------------------- /jwt.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2016 LabStack and Echo contributors 3 | 4 | package echojwt 5 | 6 | import ( 7 | "errors" 8 | "fmt" 9 | "net/http" 10 | 11 | "github.com/golang-jwt/jwt/v5" 12 | "github.com/labstack/echo/v4" 13 | "github.com/labstack/echo/v4/middleware" 14 | ) 15 | 16 | // Config defines the config for JWT middleware. 17 | type Config struct { 18 | // Skipper defines a function to skip middleware. 19 | Skipper middleware.Skipper 20 | 21 | // BeforeFunc defines a function which is executed just before the middleware. 22 | BeforeFunc middleware.BeforeFunc 23 | 24 | // SuccessHandler defines a function which is executed for a valid token. 25 | SuccessHandler func(c echo.Context) 26 | 27 | // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator 28 | // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. 29 | // It may be used to define a custom JWT error. 30 | // 31 | // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. 32 | // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users 33 | // In that case you can use ErrorHandler to set default public JWT token value to request and continue with handler chain. 34 | ErrorHandler func(c echo.Context, err error) error 35 | 36 | // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to 37 | // ignore the error (by returning `nil`). 38 | // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. 39 | // In that case you can use ErrorHandler to set a default public JWT token value in the request context 40 | // and continue. Some logic down the remaining execution chain needs to check that (public) token value then. 41 | ContinueOnIgnoredError bool 42 | 43 | // Context key to store user information from the token into context. 44 | // Optional. Default value "user". 45 | ContextKey string 46 | 47 | // Signing key to validate token. 48 | // This is one of the three options to provide a token validation key. 49 | // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. 50 | // Required if neither user-defined KeyFunc nor SigningKeys is provided. 51 | SigningKey interface{} 52 | 53 | // Map of signing keys to validate token with kid field usage. 54 | // This is one of the three options to provide a token validation key. 55 | // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. 56 | // Required if neither user-defined KeyFunc nor SigningKey is provided. 57 | SigningKeys map[string]interface{} 58 | 59 | // Signing method used to check the token's signing algorithm. 60 | // Optional. Default value HS256. 61 | SigningMethod string 62 | 63 | // KeyFunc defines a user-defined function that supplies the public key for a token validation. 64 | // The function shall take care of verifying the signing algorithm and selecting the proper key. 65 | // A user-defined KeyFunc can be useful if tokens are issued by an external party. 66 | // Used by default ParseTokenFunc implementation. 67 | // 68 | // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. 69 | // This is one of the three options to provide a token validation key. 70 | // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. 71 | // Required if neither SigningKeys nor SigningKey is provided. 72 | // Not used if custom ParseTokenFunc is set. 73 | // Default to an internal implementation verifying the signing algorithm and selecting the proper key. 74 | KeyFunc jwt.Keyfunc 75 | 76 | // TokenLookup is a string in the form of ":" or ":,:" that is used 77 | // to extract token from the request. 78 | // Optional. Default value "header:Authorization". 79 | // Possible values: 80 | // - "header:" or "header::" 81 | // `` is argument value to cut/trim prefix of the extracted value. This is useful if header 82 | // value has static prefix like `Authorization: ` where part that we 83 | // want to cut is ` ` note the space at the end. 84 | // In case of JWT tokens `Authorization: Bearer ` prefix we cut is `Bearer `. 85 | // If prefix is left empty the whole value is returned. 86 | // - "query:" 87 | // - "param:" 88 | // - "cookie:" 89 | // - "form:" 90 | // Multiple sources example: 91 | // - "header:Authorization:Bearer ,cookie:myowncookie" 92 | TokenLookup string 93 | 94 | // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. 95 | // This is one of the two options to provide a token extractor. 96 | // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. 97 | // You can also provide both if you want. 98 | TokenLookupFuncs []middleware.ValuesExtractor 99 | 100 | // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token 101 | // parsing fails or parsed token is invalid. 102 | // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library 103 | ParseTokenFunc func(c echo.Context, auth string) (interface{}, error) 104 | 105 | // Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. 106 | // Not used if custom ParseTokenFunc is set. 107 | // Optional. Defaults to function returning jwt.MapClaims 108 | NewClaimsFunc func(c echo.Context) jwt.Claims 109 | } 110 | 111 | const ( 112 | // AlgorithmHS256 is token signing algorithm 113 | AlgorithmHS256 = "HS256" 114 | ) 115 | 116 | // ErrJWTMissing denotes an error raised when JWT token value could not be extracted from request 117 | var ErrJWTMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing or malformed jwt") 118 | 119 | // ErrJWTInvalid denotes an error raised when JWT token value is invalid or expired 120 | var ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") 121 | 122 | // TokenParsingError is catch all type for all errors that occur when token is parsed. In case of library default 123 | // token parsing functions are being used this error instance wraps TokenError. This helps to distinguish extractor 124 | // errors from token parsing errors even if custom extractors or token parsing functions are being used that have 125 | // their own custom errors. 126 | type TokenParsingError struct { 127 | Err error 128 | } 129 | 130 | // Is checks if target error is same as TokenParsingError 131 | func (e TokenParsingError) Is(target error) bool { return target == ErrJWTInvalid } // to provide some compatibility with older error handling logic 132 | 133 | func (e *TokenParsingError) Error() string { return e.Err.Error() } 134 | func (e *TokenParsingError) Unwrap() error { return e.Err } 135 | 136 | // TokenError is used to return error with error occurred JWT token when processing JWT token 137 | type TokenError struct { 138 | Token *jwt.Token 139 | Err error 140 | } 141 | 142 | func (e *TokenError) Error() string { return e.Err.Error() } 143 | 144 | func (e *TokenError) Unwrap() error { return e.Err } 145 | 146 | // JWT returns a JSON Web Token (JWT) auth middleware. 147 | // 148 | // For valid token, it sets the user in context and calls next handler. 149 | // For invalid token, it returns "401 - Unauthorized" error. 150 | // For missing token, it returns "400 - Bad Request" error. 151 | // 152 | // See: https://jwt.io/introduction 153 | func JWT(signingKey interface{}) echo.MiddlewareFunc { 154 | return WithConfig(Config{SigningKey: signingKey}) 155 | } 156 | 157 | // WithConfig returns a JSON Web Token (JWT) auth middleware or panics if configuration is invalid. 158 | // 159 | // For valid token, it sets the user in context and calls next handler. 160 | // For invalid token, it returns "401 - Unauthorized" error. 161 | // For missing token, it returns "400 - Bad Request" error. 162 | // 163 | // See: https://jwt.io/introduction 164 | func WithConfig(config Config) echo.MiddlewareFunc { 165 | mw, err := config.ToMiddleware() 166 | if err != nil { 167 | panic(err) 168 | } 169 | return mw 170 | } 171 | 172 | // ToMiddleware converts Config to middleware or returns an error for invalid configuration 173 | func (config Config) ToMiddleware() (echo.MiddlewareFunc, error) { 174 | if config.Skipper == nil { 175 | config.Skipper = middleware.DefaultSkipper 176 | } 177 | if config.ContextKey == "" { 178 | config.ContextKey = "user" 179 | } 180 | if config.TokenLookup == "" && len(config.TokenLookupFuncs) == 0 { 181 | config.TokenLookup = "header:Authorization:Bearer " 182 | } 183 | if config.SigningMethod == "" { 184 | config.SigningMethod = AlgorithmHS256 185 | } 186 | 187 | if config.NewClaimsFunc == nil { 188 | config.NewClaimsFunc = func(c echo.Context) jwt.Claims { 189 | return jwt.MapClaims{} 190 | } 191 | } 192 | if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil { 193 | return nil, errors.New("jwt middleware requires signing key") 194 | } 195 | if config.KeyFunc == nil { 196 | config.KeyFunc = config.defaultKeyFunc 197 | } 198 | if config.ParseTokenFunc == nil { 199 | config.ParseTokenFunc = config.defaultParseTokenFunc 200 | } 201 | extractors, ceErr := CreateExtractors(config.TokenLookup) 202 | if ceErr != nil { 203 | return nil, ceErr 204 | } 205 | if len(config.TokenLookupFuncs) > 0 { 206 | extractors = append(config.TokenLookupFuncs, extractors...) 207 | } 208 | 209 | return func(next echo.HandlerFunc) echo.HandlerFunc { 210 | return func(c echo.Context) error { 211 | if config.Skipper(c) { 212 | return next(c) 213 | } 214 | 215 | if config.BeforeFunc != nil { 216 | config.BeforeFunc(c) 217 | } 218 | var lastExtractorErr error 219 | var lastTokenErr error 220 | for _, extractor := range extractors { 221 | auths, extrErr := extractor(c) 222 | if extrErr != nil { 223 | lastExtractorErr = extrErr 224 | continue 225 | } 226 | for _, auth := range auths { 227 | token, err := config.ParseTokenFunc(c, auth) 228 | if err != nil { 229 | lastTokenErr = err 230 | continue 231 | } 232 | // Store user information from token into context. 233 | c.Set(config.ContextKey, token) 234 | if config.SuccessHandler != nil { 235 | config.SuccessHandler(c) 236 | } 237 | return next(c) 238 | } 239 | } 240 | 241 | // prioritize token errors over extracting errors as parsing is occurs further in process, meaning we managed to 242 | // extract at least one token and failed to parse it 243 | var err error 244 | if lastTokenErr != nil { 245 | err = &TokenParsingError{Err: lastTokenErr} 246 | } else if lastExtractorErr != nil { 247 | err = &TokenExtractionError{Err: lastExtractorErr} 248 | } 249 | if config.ErrorHandler != nil { 250 | tmpErr := config.ErrorHandler(c, err) 251 | if config.ContinueOnIgnoredError && tmpErr == nil { 252 | return next(c) 253 | } 254 | return tmpErr 255 | } 256 | 257 | if lastTokenErr == nil { 258 | return echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt").SetInternal(err) 259 | } 260 | 261 | return echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt").SetInternal(err) 262 | } 263 | }, nil 264 | } 265 | 266 | // defaultKeyFunc creates JWTGo implementation for KeyFunc. 267 | // 268 | // error returns TokenError. 269 | func (config Config) defaultKeyFunc(token *jwt.Token) (interface{}, error) { 270 | if token.Method.Alg() != config.SigningMethod { 271 | return nil, &TokenError{Token: token, Err: fmt.Errorf("unexpected jwt signing method=%v", token.Header["alg"])} 272 | } 273 | if len(config.SigningKeys) == 0 { 274 | return config.SigningKey, nil 275 | } 276 | 277 | if kid, ok := token.Header["kid"].(string); ok { 278 | if key, ok := config.SigningKeys[kid]; ok { 279 | return key, nil 280 | } 281 | } 282 | return nil, &TokenError{Token: token, Err: fmt.Errorf("unexpected jwt key id=%v", token.Header["kid"])} 283 | } 284 | 285 | // defaultParseTokenFunc creates JWTGo implementation for ParseTokenFunc. 286 | // 287 | // error returns TokenError. 288 | func (config Config) defaultParseTokenFunc(c echo.Context, auth string) (interface{}, error) { 289 | token, err := jwt.ParseWithClaims(auth, config.NewClaimsFunc(c), config.KeyFunc) 290 | if err != nil { 291 | return nil, &TokenError{Token: token, Err: err} 292 | } 293 | if !token.Valid { 294 | return nil, &TokenError{Token: token, Err: errors.New("invalid token")} 295 | } 296 | return token, nil 297 | } 298 | -------------------------------------------------------------------------------- /jwt_benchmark_test.go: -------------------------------------------------------------------------------- 1 | package echojwt 2 | 3 | import ( 4 | "github.com/golang-jwt/jwt/v5" 5 | "github.com/labstack/echo/v4" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | ) 10 | 11 | func BenchmarkJWTSuccessPath(b *testing.B) { 12 | e := echo.New() 13 | 14 | e.GET("/", func(c echo.Context) error { 15 | token := c.Get("user").(*jwt.Token) 16 | return c.JSON(http.StatusTeapot, token.Claims) 17 | }) 18 | 19 | b.ReportAllocs() 20 | mw, err := Config{SigningKey: []byte("secret")}.ToMiddleware() 21 | if err != nil { 22 | b.Fatal(err) 23 | } 24 | e.Use(mw) 25 | 26 | b.ResetTimer() 27 | for i := 0; i < b.N; i++ { 28 | req := httptest.NewRequest(http.MethodGet, "/", nil) 29 | req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") 30 | res := httptest.NewRecorder() 31 | 32 | e.ServeHTTP(res, req) 33 | 34 | if res.Code != http.StatusUnauthorized { 35 | b.Failed() 36 | } 37 | } 38 | } 39 | 40 | func BenchmarkJWTErrorPath(b *testing.B) { 41 | e := echo.New() 42 | 43 | e.GET("/", func(c echo.Context) error { 44 | token := c.Get("user").(*jwt.Token) 45 | return c.JSON(http.StatusTeapot, token.Claims) 46 | }) 47 | 48 | b.ReportAllocs() 49 | mw, err := Config{SigningKey: []byte("secret")}.ToMiddleware() 50 | if err != nil { 51 | b.Fatal(err) 52 | } 53 | e.Use(mw) 54 | 55 | b.ResetTimer() 56 | for i := 0; i < b.N; i++ { 57 | req := httptest.NewRequest(http.MethodGet, "/", nil) 58 | req.Header.Set(echo.HeaderAuthorization, "Bearer x.x.x") 59 | res := httptest.NewRecorder() 60 | 61 | e.ServeHTTP(res, req) 62 | 63 | if res.Code != http.StatusUnauthorized { 64 | b.Failed() 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /jwt_external_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2016 LabStack and Echo contributors 3 | 4 | package echojwt_test 5 | 6 | import ( 7 | "errors" 8 | "fmt" 9 | "github.com/golang-jwt/jwt/v5" 10 | echojwt "github.com/labstack/echo-jwt/v4" 11 | "github.com/labstack/echo/v4" 12 | "io" 13 | "log" 14 | "net" 15 | "net/http" 16 | "time" 17 | ) 18 | 19 | func ExampleWithConfig_usage() { 20 | e := echo.New() 21 | 22 | e.Use(echojwt.WithConfig(echojwt.Config{ 23 | SigningKey: []byte("secret"), 24 | })) 25 | 26 | e.GET("/", func(c echo.Context) error { 27 | // make sure that your imports are correct versions. for example if you use `"github.com/golang-jwt/jwt"` as 28 | // import this cast will fail and `"github.com/golang-jwt/jwt/v5"` will succeed. 29 | // Although `.(*jwt.Token)` looks exactly the same for both packages but this struct is still different 30 | token, ok := c.Get("user").(*jwt.Token) // by default token is stored under `user` key 31 | if !ok { 32 | return errors.New("JWT token missing or invalid") 33 | } 34 | claims, ok := token.Claims.(jwt.MapClaims) // by default claims is of type `jwt.MapClaims` 35 | if !ok { 36 | return errors.New("failed to cast claims as jwt.MapClaims") 37 | } 38 | return c.JSON(http.StatusOK, claims) 39 | }) 40 | 41 | // ----------------------- start server on random port ----------------------- 42 | l, err := net.Listen("tcp", ":0") 43 | if err != nil { 44 | log.Fatal(err) 45 | } 46 | go func(e *echo.Echo, l net.Listener) { 47 | s := http.Server{Handler: e} 48 | if err := s.Serve(l); err != http.ErrServerClosed { 49 | log.Fatal(err) 50 | } 51 | }(e, l) 52 | time.Sleep(100 * time.Millisecond) 53 | 54 | // ----------------------- execute HTTP request with valid token and check the response ----------------------- 55 | requestURL := fmt.Sprintf("http://%v", l.Addr().String()) 56 | req, err := http.NewRequest(http.MethodGet, requestURL, nil) 57 | if err != nil { 58 | log.Fatal(err) 59 | } 60 | req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") 61 | 62 | res, err := http.DefaultClient.Do(req) 63 | if err != nil { 64 | log.Fatal(err) 65 | } 66 | 67 | body, err := io.ReadAll(res.Body) 68 | if err != nil { 69 | log.Fatal(err) 70 | } 71 | 72 | fmt.Printf("Response: status code: %d, body: %s\n", res.StatusCode, body) 73 | 74 | // Output: Response: status code: 200, body: {"admin":true,"name":"John Doe","sub":"1234567890"} 75 | } 76 | -------------------------------------------------------------------------------- /jwt_integration_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2016 LabStack and Echo contributors 3 | 4 | package echojwt_test 5 | 6 | import ( 7 | "errors" 8 | "github.com/golang-jwt/jwt/v5" 9 | echojwt "github.com/labstack/echo-jwt/v4" 10 | "github.com/labstack/echo/v4" 11 | "net/http" 12 | "net/http/httptest" 13 | "testing" 14 | ) 15 | 16 | func TestIntegrationMiddlewareWithHandler(t *testing.T) { 17 | e := echo.New() 18 | e.Use(echojwt.WithConfig(echojwt.Config{ 19 | SigningKey: []byte("secret"), 20 | })) 21 | 22 | e.GET("/example", exampleHandler) 23 | 24 | req := httptest.NewRequest(http.MethodGet, "/example", nil) 25 | req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") 26 | res := httptest.NewRecorder() 27 | 28 | e.ServeHTTP(res, req) 29 | 30 | if res.Code != 200 { 31 | t.Failed() 32 | } 33 | } 34 | 35 | func exampleHandler(c echo.Context) error { 36 | // make sure that your imports are correct versions. for example if you use `"github.com/golang-jwt/jwt"` as 37 | // import this cast will fail and `"github.com/golang-jwt/jwt/v5"` will succeed. 38 | // Although `.(*jwt.Token)` looks exactly the same for both packages but this struct is still different 39 | token, ok := c.Get("user").(*jwt.Token) 40 | if !ok { 41 | return errors.New("JWT token missing or invalid") 42 | } 43 | 44 | claims, ok := token.Claims.(jwt.MapClaims) // by default claims is of type `jwt.MapClaims` 45 | if !ok { 46 | return errors.New("failed to cast claims as jwt.MapClaims") 47 | } 48 | return c.JSON(http.StatusOK, claims) 49 | } 50 | -------------------------------------------------------------------------------- /jwt_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: © 2016 LabStack and Echo contributors 3 | 4 | package echojwt 5 | 6 | import ( 7 | "errors" 8 | "fmt" 9 | "net/http" 10 | "net/http/httptest" 11 | "net/url" 12 | "strings" 13 | "testing" 14 | 15 | "github.com/golang-jwt/jwt/v5" 16 | "github.com/labstack/echo/v4" 17 | "github.com/labstack/echo/v4/middleware" 18 | "github.com/stretchr/testify/assert" 19 | ) 20 | 21 | func TestTokenParsingError_Is(t *testing.T) { 22 | err := errors.New("parsing error") 23 | given := echo.ErrUnauthorized.SetInternal(&TokenParsingError{Err: err}) 24 | 25 | assert.True(t, errors.Is(given, ErrJWTInvalid)) 26 | assert.True(t, errors.Is(given, err)) 27 | } 28 | 29 | func TestTokenParsingError_Error(t *testing.T) { 30 | given := &TokenParsingError{Err: errors.New("parsing error")} 31 | assert.Equal(t, "parsing error", given.Error()) 32 | } 33 | 34 | func TestTokenParsingError_Unwrap(t *testing.T) { 35 | inner := errors.New("parsing error") 36 | given := &TokenParsingError{Err: inner} 37 | assert.Equal(t, inner, given.Unwrap()) 38 | } 39 | 40 | // jwtCustomInfo defines some custom types we're going to use within our tokens. 41 | type jwtCustomInfo struct { 42 | Name string `json:"name"` 43 | Admin bool `json:"admin"` 44 | } 45 | 46 | // jwtCustomClaims are custom claims expanding default ones. 47 | type jwtCustomClaims struct { 48 | jwt.RegisteredClaims 49 | jwtCustomInfo 50 | } 51 | 52 | func TestJWT(t *testing.T) { 53 | e := echo.New() 54 | 55 | e.GET("/", func(c echo.Context) error { 56 | token := c.Get("user").(*jwt.Token) 57 | return c.JSON(http.StatusOK, token.Claims) 58 | }) 59 | 60 | e.Use(JWT([]byte("secret"))) 61 | 62 | req := httptest.NewRequest(http.MethodGet, "/", nil) 63 | req.Header.Set(echo.HeaderAuthorization, "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") 64 | res := httptest.NewRecorder() 65 | 66 | e.ServeHTTP(res, req) 67 | 68 | assert.Equal(t, http.StatusOK, res.Code) 69 | assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) 70 | } 71 | 72 | func TestJWT_combinations(t *testing.T) { 73 | e := echo.New() 74 | handler := func(c echo.Context) error { 75 | return c.String(http.StatusOK, "test") 76 | } 77 | token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" 78 | validKey := []byte("secret") 79 | invalidKey := []byte("invalid-key") 80 | validAuth := "Bearer " + token 81 | 82 | var testCases = []struct { 83 | name string 84 | config Config 85 | reqURL string // "/" if empty 86 | hdrAuth string 87 | hdrCookie string // test.Request doesn't provide SetCookie(); use name=val 88 | formValues map[string]string 89 | expectPanic bool 90 | expectToMiddlewareError string 91 | expectError string 92 | }{ 93 | { 94 | name: "No signing key provided", 95 | expectToMiddlewareError: "jwt middleware requires signing key", 96 | }, 97 | { 98 | name: "invalid TokenLookup", 99 | config: Config{ 100 | SigningKey: validKey, 101 | SigningMethod: "RS256", 102 | TokenLookup: "q", 103 | }, 104 | expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q", 105 | }, 106 | { 107 | name: "Unexpected signing method", 108 | hdrAuth: validAuth, 109 | config: Config{ 110 | SigningKey: validKey, 111 | SigningMethod: "RS256", 112 | }, 113 | expectError: "code=401, message=invalid or expired jwt, internal=token is unverifiable: error while executing keyfunc: unexpected jwt signing method=HS256", 114 | }, 115 | { 116 | name: "Invalid key", 117 | hdrAuth: validAuth, 118 | config: Config{ 119 | SigningKey: invalidKey, 120 | }, 121 | expectError: "code=401, message=invalid or expired jwt, internal=token signature is invalid: signature is invalid", 122 | }, 123 | { 124 | name: "Valid JWT", 125 | hdrAuth: validAuth, 126 | config: Config{ 127 | SigningKey: validKey, 128 | }, 129 | }, 130 | { 131 | name: "Valid JWT with custom AuthScheme", 132 | hdrAuth: "Token" + " " + token, 133 | config: Config{ 134 | TokenLookup: "header:" + echo.HeaderAuthorization + ":Token ", 135 | SigningKey: validKey, 136 | }, 137 | }, 138 | { 139 | name: "Valid JWT with custom claims", 140 | hdrAuth: validAuth, 141 | config: Config{ 142 | SigningKey: []byte("secret"), 143 | NewClaimsFunc: func(c echo.Context) jwt.Claims { 144 | return &jwtCustomClaims{ // this needs to be pointer to json unmarshalling to work 145 | jwtCustomInfo: jwtCustomInfo{ 146 | Name: "John Doe", 147 | Admin: true, 148 | }, 149 | } 150 | }, 151 | }, 152 | }, 153 | { 154 | name: "Invalid Authorization header", 155 | hdrAuth: "invalid-auth", 156 | config: Config{ 157 | SigningKey: validKey, 158 | }, 159 | expectError: "code=400, message=missing or malformed jwt, internal=invalid value in request header", 160 | }, 161 | { 162 | name: "Empty header auth field", 163 | config: Config{ 164 | SigningKey: validKey, 165 | }, 166 | expectError: "code=400, message=missing or malformed jwt, internal=invalid value in request header", 167 | }, 168 | { 169 | name: "Valid query method", 170 | config: Config{ 171 | SigningKey: validKey, 172 | TokenLookup: "query:jwt", 173 | }, 174 | reqURL: "/?a=b&jwt=" + token, 175 | }, 176 | { 177 | name: "Invalid query param name", 178 | config: Config{ 179 | SigningKey: validKey, 180 | TokenLookup: "query:jwt", 181 | }, 182 | reqURL: "/?a=b&jwtxyz=" + token, 183 | expectError: "code=400, message=missing or malformed jwt, internal=missing value in the query string", 184 | }, 185 | { 186 | name: "Invalid query param value", 187 | config: Config{ 188 | SigningKey: validKey, 189 | TokenLookup: "query:jwt", 190 | }, 191 | reqURL: "/?a=b&jwt=invalid-token", 192 | expectError: "code=401, message=invalid or expired jwt, internal=token is malformed: token contains an invalid number of segments", 193 | }, 194 | { 195 | name: "Empty query", 196 | config: Config{ 197 | SigningKey: validKey, 198 | TokenLookup: "query:jwt", 199 | }, 200 | reqURL: "/?a=b", 201 | expectError: "code=400, message=missing or malformed jwt, internal=missing value in the query string", 202 | }, 203 | { 204 | config: Config{ 205 | SigningKey: validKey, 206 | TokenLookup: "param:jwt", 207 | }, 208 | reqURL: "/" + token, 209 | name: "Valid param method", 210 | }, 211 | { 212 | config: Config{ 213 | SigningKey: validKey, 214 | TokenLookup: "cookie:jwt", 215 | }, 216 | hdrCookie: "jwt=" + token, 217 | name: "Valid cookie method", 218 | }, 219 | { 220 | config: Config{ 221 | SigningKey: validKey, 222 | TokenLookup: "query:jwt,cookie:jwt", 223 | }, 224 | hdrCookie: "jwt=" + token, 225 | name: "Multiple jwt lookuop", 226 | }, 227 | { 228 | name: "Invalid token with cookie method", 229 | config: Config{ 230 | SigningKey: validKey, 231 | TokenLookup: "cookie:jwt", 232 | }, 233 | hdrCookie: "jwt=invalid", 234 | expectError: "code=401, message=invalid or expired jwt, internal=token is malformed: token contains an invalid number of segments", 235 | }, 236 | { 237 | name: "Empty cookie", 238 | config: Config{ 239 | SigningKey: validKey, 240 | TokenLookup: "cookie:jwt", 241 | }, 242 | expectError: "code=400, message=missing or malformed jwt, internal=missing value in cookies", 243 | }, 244 | { 245 | name: "Valid form method", 246 | config: Config{ 247 | SigningKey: validKey, 248 | TokenLookup: "form:jwt", 249 | }, 250 | formValues: map[string]string{"jwt": token}, 251 | }, 252 | { 253 | name: "Invalid token with form method", 254 | config: Config{ 255 | SigningKey: validKey, 256 | TokenLookup: "form:jwt", 257 | }, 258 | formValues: map[string]string{"jwt": "invalid"}, 259 | expectError: "code=401, message=invalid or expired jwt, internal=token is malformed: token contains an invalid number of segments", 260 | }, 261 | { 262 | name: "Empty form field", 263 | config: Config{ 264 | SigningKey: validKey, 265 | TokenLookup: "form:jwt", 266 | }, 267 | expectError: "code=400, message=missing or malformed jwt, internal=missing value in the form", 268 | }, 269 | } 270 | 271 | for _, tc := range testCases { 272 | t.Run(tc.name, func(t *testing.T) { 273 | if tc.reqURL == "" { 274 | tc.reqURL = "/" 275 | } 276 | 277 | var req *http.Request 278 | if len(tc.formValues) > 0 { 279 | form := url.Values{} 280 | for k, v := range tc.formValues { 281 | form.Set(k, v) 282 | } 283 | req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode())) 284 | req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded") 285 | req.ParseForm() 286 | } else { 287 | req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil) 288 | } 289 | res := httptest.NewRecorder() 290 | req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) 291 | req.Header.Set(echo.HeaderCookie, tc.hdrCookie) 292 | c := e.NewContext(req, res) 293 | 294 | if tc.reqURL == "/"+token { 295 | c.SetParamNames("jwt") 296 | c.SetParamValues(token) 297 | } 298 | 299 | mw, err := tc.config.ToMiddleware() 300 | if tc.expectToMiddlewareError != "" { 301 | assert.EqualError(t, err, tc.expectToMiddlewareError) 302 | return 303 | } 304 | 305 | hErr := mw(handler)(c) 306 | if tc.expectError != "" { 307 | assert.EqualError(t, hErr, tc.expectError) 308 | return 309 | } 310 | if !assert.NoError(t, hErr) { 311 | return 312 | } 313 | 314 | user := c.Get("user").(*jwt.Token) 315 | switch claims := user.Claims.(type) { 316 | case jwt.MapClaims: 317 | assert.Equal(t, claims["name"], "John Doe") 318 | case *jwtCustomClaims: 319 | assert.Equal(t, claims.Name, "John Doe") 320 | assert.Equal(t, claims.Admin, true) 321 | default: 322 | panic("unexpected type of claims") 323 | } 324 | }) 325 | } 326 | } 327 | 328 | func TestJWTwithKID(t *testing.T) { 329 | e := echo.New() 330 | handler := func(c echo.Context) error { 331 | return c.String(http.StatusOK, "test") 332 | } 333 | 334 | firstToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImZpcnN0T25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.w5VGpHOe0jlNgf7jMVLHzIYH_XULmpUlreJnilwSkWk" 335 | secondToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.sdghDYQ85jdh0hgQ6bKbMguLI_NSPYWjkhVJkee-yZM" 336 | wrongToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.RyhLybtVLpoewF6nz9YN79oXo32kAtgUxp8FNwTkb90" 337 | staticToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.1_-XFYUPpJfgsaGwYhgZEt7hfySMg-a3GN-nfZmbW7o" 338 | validKeys := map[string]interface{}{"firstOne": []byte("first_secret"), "secondOne": []byte("second_secret")} 339 | invalidKeys := map[string]interface{}{"thirdOne": []byte("third_secret")} 340 | staticSecret := []byte("static_secret") 341 | invalidStaticSecret := []byte("invalid_secret") 342 | 343 | var testCases = []struct { 344 | expErrCode int // 0 for Success 345 | config Config 346 | hdrAuth string 347 | name string 348 | }{ 349 | { 350 | name: "First token valid", 351 | hdrAuth: "Bearer " + firstToken, 352 | config: Config{SigningKeys: validKeys}, 353 | }, 354 | { 355 | name: "Second token valid", 356 | hdrAuth: "Bearer " + secondToken, 357 | config: Config{SigningKeys: validKeys}, 358 | }, 359 | { 360 | name: "Wrong key id token", 361 | expErrCode: http.StatusUnauthorized, 362 | hdrAuth: "Bearer " + wrongToken, 363 | config: Config{SigningKeys: validKeys}, 364 | }, 365 | { 366 | name: "Valid static secret token", 367 | hdrAuth: "Bearer " + staticToken, 368 | config: Config{SigningKey: staticSecret}, 369 | }, 370 | { 371 | name: "Invalid static secret", 372 | expErrCode: http.StatusUnauthorized, 373 | hdrAuth: "Bearer " + staticToken, 374 | config: Config{SigningKey: invalidStaticSecret}, 375 | }, 376 | { 377 | name: "Invalid keys first token", 378 | expErrCode: http.StatusUnauthorized, 379 | hdrAuth: "Bearer " + firstToken, 380 | config: Config{SigningKeys: invalidKeys}, 381 | }, 382 | { 383 | name: "Invalid keys second token", 384 | expErrCode: http.StatusUnauthorized, 385 | hdrAuth: "Bearer " + secondToken, 386 | config: Config{SigningKeys: invalidKeys}, 387 | }, 388 | } 389 | 390 | for _, tc := range testCases { 391 | t.Run(tc.name, func(t *testing.T) { 392 | req := httptest.NewRequest(http.MethodGet, "/", nil) 393 | res := httptest.NewRecorder() 394 | req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) 395 | c := e.NewContext(req, res) 396 | 397 | if tc.expErrCode != 0 { 398 | h := WithConfig(tc.config)(handler) 399 | he := h(c).(*echo.HTTPError) 400 | assert.Equal(t, tc.expErrCode, he.Code) 401 | return 402 | } 403 | 404 | h := WithConfig(tc.config)(handler) 405 | if assert.NoError(t, h(c), tc.name) { 406 | user := c.Get("user").(*jwt.Token) 407 | switch claims := user.Claims.(type) { 408 | case jwt.MapClaims: 409 | assert.Equal(t, claims["name"], "John Doe") 410 | case *jwtCustomClaims: 411 | assert.Equal(t, claims.Name, "John Doe") 412 | assert.Equal(t, claims.Admin, true) 413 | default: 414 | panic("unexpected type of claims") 415 | } 416 | } 417 | }) 418 | } 419 | } 420 | 421 | func TestConfig_skipper(t *testing.T) { 422 | e := echo.New() 423 | 424 | e.Use(WithConfig(Config{ 425 | Skipper: func(context echo.Context) bool { 426 | return true // skip everything 427 | }, 428 | SigningKey: []byte("secret"), 429 | })) 430 | 431 | isCalled := false 432 | e.GET("/", func(c echo.Context) error { 433 | isCalled = true 434 | return c.String(http.StatusTeapot, "test") 435 | }) 436 | 437 | req := httptest.NewRequest(http.MethodGet, "/", nil) 438 | res := httptest.NewRecorder() 439 | e.ServeHTTP(res, req) 440 | 441 | assert.Equal(t, http.StatusTeapot, res.Code) 442 | assert.True(t, isCalled) 443 | } 444 | 445 | func TestConfig_BeforeFunc(t *testing.T) { 446 | e := echo.New() 447 | e.GET("/", func(c echo.Context) error { 448 | return c.String(http.StatusTeapot, "test") 449 | }) 450 | 451 | isCalled := false 452 | e.Use(WithConfig(Config{ 453 | BeforeFunc: func(context echo.Context) { 454 | isCalled = true 455 | }, 456 | SigningKey: []byte("secret"), 457 | })) 458 | 459 | req := httptest.NewRequest(http.MethodGet, "/", nil) 460 | req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") 461 | res := httptest.NewRecorder() 462 | e.ServeHTTP(res, req) 463 | 464 | assert.Equal(t, http.StatusTeapot, res.Code) 465 | assert.True(t, isCalled) 466 | } 467 | 468 | func TestConfig_ErrorHandling(t *testing.T) { 469 | var testCases = []struct { 470 | name string 471 | given Config 472 | whenAuthHeader string 473 | expectError string 474 | }{ 475 | { 476 | name: "ok, ErrorHandler is executed", 477 | given: Config{ 478 | SigningKey: []byte("secret"), 479 | ErrorHandler: func(c echo.Context, err error) error { 480 | return echo.NewHTTPError(http.StatusTeapot, "custom_error") 481 | }, 482 | }, 483 | expectError: "code=418, message=custom_error", 484 | }, 485 | { 486 | name: "ok, extractor errors are distinguishable as TokenExtractionError", 487 | given: Config{ 488 | SigningKey: []byte("secret"), 489 | ErrorHandler: func(c echo.Context, err error) error { 490 | var extratorErr *TokenExtractionError 491 | if !errors.As(err, &extratorErr) { 492 | panic("must get TokenExtractionError") 493 | } 494 | return err 495 | }, 496 | }, 497 | expectError: "missing value in request header", 498 | }, 499 | { 500 | name: "ok, token parsing errors are distinguishable as TokenParsingError", 501 | given: Config{ 502 | SigningKey: []byte("secret"), 503 | ErrorHandler: func(c echo.Context, err error) error { 504 | var tpErr *TokenParsingError 505 | if !errors.As(err, &tpErr) { 506 | panic("must get TokenParsingError") 507 | } 508 | var tErr *TokenError 509 | if !errors.As(err, &tErr) { 510 | panic("must get TokenError") 511 | } 512 | return err 513 | }, 514 | }, 515 | whenAuthHeader: "Bearer x.x.x", 516 | expectError: "token is malformed: could not base64 decode header: illegal base64 data at input byte 0", 517 | }, 518 | } 519 | 520 | for _, tc := range testCases { 521 | t.Run(tc.name, func(t *testing.T) { 522 | e := echo.New() 523 | h := func(c echo.Context) error { 524 | return c.String(http.StatusNotImplemented, "should not end up here") 525 | } 526 | 527 | jwtMiddlewareFunc := WithConfig(tc.given) 528 | 529 | req := httptest.NewRequest(http.MethodGet, "/", nil) 530 | if tc.whenAuthHeader != "" { 531 | req.Header.Set(echo.HeaderAuthorization, tc.whenAuthHeader) 532 | 533 | } 534 | res := httptest.NewRecorder() 535 | c := e.NewContext(req, res) 536 | 537 | err := jwtMiddlewareFunc(h)(c) 538 | 539 | assert.EqualError(t, err, tc.expectError) 540 | }) 541 | } 542 | } 543 | 544 | func TestConfig_parseTokenErrorHandling(t *testing.T) { 545 | var testCases = []struct { 546 | name string 547 | given Config 548 | expectErr string 549 | }{ 550 | { 551 | name: "ok, ErrorHandler is executed", 552 | given: Config{ 553 | ErrorHandler: func(c echo.Context, err error) error { 554 | return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error()) 555 | }, 556 | }, 557 | expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n", 558 | }, 559 | } 560 | 561 | for _, tc := range testCases { 562 | t.Run(tc.name, func(t *testing.T) { 563 | e := echo.New() 564 | //e.Debug = true 565 | e.GET("/", func(c echo.Context) error { 566 | return c.String(http.StatusNotImplemented, "should not end up here") 567 | }) 568 | 569 | config := tc.given 570 | parseTokenCalled := false 571 | config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) { 572 | parseTokenCalled = true 573 | return nil, errors.New("parsing failed") 574 | } 575 | e.Use(WithConfig(config)) 576 | 577 | req := httptest.NewRequest(http.MethodGet, "/", nil) 578 | req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") 579 | res := httptest.NewRecorder() 580 | 581 | e.ServeHTTP(res, req) 582 | 583 | assert.Equal(t, http.StatusTeapot, res.Code) 584 | assert.Equal(t, tc.expectErr, res.Body.String()) 585 | assert.True(t, parseTokenCalled) 586 | }) 587 | } 588 | } 589 | 590 | func TestConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { 591 | e := echo.New() 592 | e.GET("/", func(c echo.Context) error { 593 | return c.String(http.StatusTeapot, "test") 594 | }) 595 | 596 | // example of minimal custom ParseTokenFunc implementation. Allows you to use different versions of `github.com/golang-jwt/jwt` 597 | // with current JWT middleware 598 | signingKey := []byte("secret") 599 | 600 | config := Config{ 601 | ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { 602 | keyFunc := func(t *jwt.Token) (interface{}, error) { 603 | if t.Method.Alg() != "HS256" { 604 | return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) 605 | } 606 | return signingKey, nil 607 | } 608 | 609 | // claims are of type `jwt.MapClaims` when token is created with `jwt.Parse` 610 | token, err := jwt.Parse(auth, keyFunc) 611 | if err != nil { 612 | return nil, err 613 | } 614 | if !token.Valid { 615 | return nil, errors.New("invalid token") 616 | } 617 | return token, nil 618 | }, 619 | } 620 | 621 | e.Use(WithConfig(config)) 622 | 623 | req := httptest.NewRequest(http.MethodGet, "/", nil) 624 | req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") 625 | res := httptest.NewRecorder() 626 | e.ServeHTTP(res, req) 627 | 628 | assert.Equal(t, http.StatusTeapot, res.Code) 629 | } 630 | 631 | func TestMustJWTWithConfig_SuccessHandler(t *testing.T) { 632 | e := echo.New() 633 | 634 | e.GET("/", func(c echo.Context) error { 635 | success := c.Get("success").(string) 636 | user := c.Get("user").(string) 637 | return c.String(http.StatusTeapot, fmt.Sprintf("%v:%v", success, user)) 638 | }) 639 | 640 | mw, err := Config{ 641 | ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { 642 | return auth, nil 643 | }, 644 | SuccessHandler: func(c echo.Context) { 645 | c.Set("success", "yes") 646 | }, 647 | }.ToMiddleware() 648 | assert.NoError(t, err) 649 | e.Use(mw) 650 | 651 | req := httptest.NewRequest(http.MethodGet, "/", nil) 652 | req.Header.Add(echo.HeaderAuthorization, "Bearer valid_token_base64") 653 | res := httptest.NewRecorder() 654 | e.ServeHTTP(res, req) 655 | 656 | assert.Equal(t, "yes:valid_token_base64", res.Body.String()) 657 | assert.Equal(t, http.StatusTeapot, res.Code) 658 | } 659 | 660 | func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { 661 | var testCases = []struct { 662 | name string 663 | givenContinueOnIgnoredError bool 664 | givenErrorHandler func(c echo.Context, err error) error 665 | givenTokenLookup string 666 | whenAuthHeaders []string 667 | whenCookies []string 668 | whenParseReturn string 669 | whenParseError error 670 | expectHandlerCalled bool 671 | expect string 672 | expectCode int 673 | }{ 674 | { 675 | name: "ok, with valid JWT from auth header", 676 | givenContinueOnIgnoredError: true, 677 | givenErrorHandler: func(c echo.Context, err error) error { 678 | return nil 679 | }, 680 | whenAuthHeaders: []string{"Bearer valid_token_base64"}, 681 | whenParseReturn: "valid_token", 682 | expectCode: http.StatusTeapot, 683 | expect: "valid_token", 684 | }, 685 | { 686 | name: "ok, missing header, callNext and set public_token from error handler", 687 | givenContinueOnIgnoredError: true, 688 | givenErrorHandler: func(c echo.Context, err error) error { 689 | var extratorErr *TokenExtractionError 690 | if !errors.As(err, &extratorErr) { 691 | panic("must get TokenExtractionError") 692 | } 693 | c.Set("user", "public_token") 694 | return nil 695 | }, 696 | whenAuthHeaders: []string{}, // no JWT header 697 | expectCode: http.StatusTeapot, 698 | expect: "public_token", 699 | }, 700 | { 701 | name: "ok, invalid token, callNext and set public_token from error handler", 702 | givenContinueOnIgnoredError: true, 703 | givenErrorHandler: func(c echo.Context, err error) error { 704 | // this is probably not realistic usecase. on parse error you probably want to return error 705 | if err.Error() != "parser_error" { 706 | panic("must get parser_error") 707 | } 708 | c.Set("user", "public_token") 709 | return nil 710 | }, 711 | whenAuthHeaders: []string{"Bearer invalid_header"}, 712 | whenParseError: errors.New("parser_error"), 713 | expectCode: http.StatusTeapot, 714 | expect: "public_token", 715 | }, 716 | { 717 | name: "nok, invalid token, return error from error handler", 718 | givenContinueOnIgnoredError: true, 719 | givenErrorHandler: func(c echo.Context, err error) error { 720 | if err.Error() != "parser_error" { 721 | panic("must get parser_error") 722 | } 723 | return err 724 | }, 725 | whenAuthHeaders: []string{"Bearer invalid_header"}, 726 | whenParseError: errors.New("parser_error"), 727 | expectCode: http.StatusInternalServerError, 728 | expect: "{\"message\":\"Internal Server Error\"}\n", 729 | }, 730 | { 731 | name: "nok, ContinueOnIgnoredError but return error from error handler", 732 | givenContinueOnIgnoredError: true, 733 | givenErrorHandler: func(c echo.Context, err error) error { 734 | return echo.ErrUnauthorized 735 | }, 736 | whenAuthHeaders: []string{}, // no JWT header 737 | expectCode: http.StatusUnauthorized, 738 | expect: "{\"message\":\"Unauthorized\"}\n", 739 | }, 740 | { 741 | name: "nok, ContinueOnIgnoredError=false", 742 | givenContinueOnIgnoredError: false, 743 | givenErrorHandler: func(c echo.Context, err error) error { 744 | return echo.ErrUnauthorized 745 | }, 746 | whenAuthHeaders: []string{}, // no JWT header 747 | expectCode: http.StatusUnauthorized, 748 | expect: "{\"message\":\"Unauthorized\"}\n", 749 | }, 750 | } 751 | 752 | for _, tc := range testCases { 753 | t.Run(tc.name, func(t *testing.T) { 754 | e := echo.New() 755 | 756 | e.GET("/", func(c echo.Context) error { 757 | token := c.Get("user").(string) 758 | return c.String(http.StatusTeapot, token) 759 | }) 760 | 761 | mw, err := Config{ 762 | ContinueOnIgnoredError: tc.givenContinueOnIgnoredError, 763 | TokenLookup: tc.givenTokenLookup, 764 | ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { 765 | return tc.whenParseReturn, tc.whenParseError 766 | }, 767 | ErrorHandler: tc.givenErrorHandler, 768 | }.ToMiddleware() 769 | assert.NoError(t, err) 770 | e.Use(mw) 771 | 772 | req := httptest.NewRequest(http.MethodGet, "/", nil) 773 | for _, a := range tc.whenAuthHeaders { 774 | req.Header.Add(echo.HeaderAuthorization, a) 775 | } 776 | res := httptest.NewRecorder() 777 | e.ServeHTTP(res, req) 778 | 779 | assert.Equal(t, tc.expect, res.Body.String()) 780 | assert.Equal(t, tc.expectCode, res.Code) 781 | }) 782 | } 783 | } 784 | 785 | func TestConfig_TokenLookupFuncs(t *testing.T) { 786 | e := echo.New() 787 | 788 | e.GET("/", func(c echo.Context) error { 789 | token := c.Get("user").(*jwt.Token) 790 | return c.JSON(http.StatusOK, token.Claims) 791 | }) 792 | 793 | e.Use(WithConfig(Config{ 794 | SigningKey: []byte("secret"), 795 | TokenLookupFuncs: []middleware.ValuesExtractor{ 796 | func(c echo.Context) ([]string, error) { 797 | return []string{c.Request().Header.Get("X-API-Key")}, nil 798 | }, 799 | }, 800 | })) 801 | 802 | req := httptest.NewRequest(http.MethodGet, "/", nil) 803 | req.Header.Set("X-API-Key", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") 804 | res := httptest.NewRecorder() 805 | e.ServeHTTP(res, req) 806 | 807 | assert.Equal(t, http.StatusOK, res.Code) 808 | assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) 809 | } 810 | 811 | func TestWithConfig_panic(t *testing.T) { 812 | assert.PanicsWithError(t, 813 | "jwt middleware requires signing key", 814 | func() { 815 | WithConfig(Config{}) 816 | }, 817 | ) 818 | } 819 | 820 | func TestDataRacesOnParallelExecution(t *testing.T) { 821 | var testCases = []struct { 822 | name string 823 | whenHeader string 824 | expectCode int 825 | }{ // run multiple cases in parallel to catch data races 826 | { 827 | name: "ok", 828 | whenHeader: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ", 829 | expectCode: http.StatusTeapot, 830 | }, 831 | { 832 | name: "nok", 833 | whenHeader: "Bearer x.x.x", 834 | expectCode: http.StatusUnauthorized, 835 | }, 836 | { 837 | name: "nok, simulatenous error", 838 | whenHeader: "Bearer x.x.x", 839 | expectCode: http.StatusUnauthorized, 840 | }, 841 | } 842 | 843 | e := echo.New() 844 | e.GET("/", func(c echo.Context) error { 845 | token := c.Get("user").(*jwt.Token) 846 | return c.JSON(http.StatusTeapot, token.Claims) 847 | }) 848 | 849 | mw, err := Config{SigningKey: []byte("secret")}.ToMiddleware() 850 | if err != nil { 851 | t.Fatal(err) 852 | } 853 | e.Use(mw) 854 | 855 | for _, tc := range testCases { 856 | tc := tc 857 | t.Run(tc.name, func(t *testing.T) { 858 | t.Parallel() 859 | 860 | req := httptest.NewRequest(http.MethodGet, "/", nil) 861 | req.Header.Set(echo.HeaderAuthorization, tc.whenHeader) 862 | res := httptest.NewRecorder() 863 | 864 | e.ServeHTTP(res, req) 865 | 866 | if res.Code != tc.expectCode { 867 | t.Failed() 868 | } 869 | }) 870 | } 871 | } 872 | --------------------------------------------------------------------------------