├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ ├── question.md │ └── refactor.md └── workflows │ ├── go-fmt.yml │ ├── go.yml │ ├── golangci-lint.yml │ ├── integration_test.yml │ ├── license.yml │ └── stale.yml ├── .gitignore ├── .golangci.yml ├── .license_header ├── .licenserc.json ├── .script ├── goimports.sh ├── integrate_test.sh ├── integration_test_compose.yml └── setup.sh ├── LICENSE ├── Makefile ├── README.md ├── error.go ├── gctx ├── context.go └── context_test.go ├── go.mod ├── go.sum ├── internal ├── crawlerdetect │ ├── baidu_strategy.go │ ├── baidu_strategy_test.go │ ├── bing_strategy.go │ ├── bing_strategy_test.go │ ├── crawler_detector.go │ ├── google_strategy.go │ ├── google_strategy_test.go │ ├── sogou_strategy.go │ └── sogou_strategy_test.go ├── e2e │ ├── activelimit_test.go │ ├── base_suite.go │ ├── dependency.go │ ├── gin_writer.go │ └── ratelimit_test.go ├── errs │ └── error.go ├── jwt │ ├── claims_option.go │ ├── claims_option_test.go │ ├── management.go │ ├── management_test.go │ └── types.go ├── mocks │ ├── pipeline.mock.go │ └── redis.mock.go └── ratelimit │ ├── mocks │ └── ratelimit.mock.go │ ├── redis_slide_window.go │ ├── redis_slide_window_test.go │ ├── slide_window.lua │ └── types.go ├── middlewares ├── accesslog │ ├── builder.go │ └── builder_test.go ├── activelimit │ ├── locallimit │ │ ├── builder.go │ │ └── builder_test.go │ └── redislimit │ │ ├── builder.go │ │ └── builder_test.go ├── crawlerdetect │ ├── builder.go │ └── builder_test.go └── ratelimit │ ├── builder.go │ ├── builder_test.go │ └── redis_slide_window.go ├── session ├── builder.go ├── builder_test.go ├── cookie │ ├── carrier.go │ └── carrier_test.go ├── global.go ├── global_test.go ├── header │ ├── carrier.go │ └── carrier_test.go ├── memory.go ├── memory_test.go ├── middleware_builder.go ├── mixin │ ├── carrier.go │ └── carrier_test.go ├── provider.mock_test.go ├── redis │ ├── provider.go │ ├── provider_test.go │ ├── session.go │ └── session_test.go └── types.go ├── types.go └── wrapper_func.go /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **仅限中文** 11 | 12 | 在提之前请先查找[已有 issues](https://github.com/ecodeclub/ginx/issues),避免重复上报。 13 | 14 | 并且确保自己已经: 15 | - [ ] 阅读过文档 16 | - [ ] 阅读过代码注释 17 | - [ ] 阅读过相关测试 18 | 19 | ### 问题简要描述 20 | 21 | ### 复现步骤 22 | > 通过编写单元、集成及e2e测试来复现Bug 23 | 24 | ### 错误日志或者截图 25 | 26 | ### 你期望的结果 27 | 28 | ### 你排查的结果,或者你觉得可行的修复方案 29 | > 可选。希望你能够尽量先排查问题,这对于你个人能力提升很有帮助。 30 | 31 | ### 你设置的的 Go 环境? 32 | > 上传 `go env` 的结果 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: feature 6 | assignees: '' 7 | 8 | --- 9 | 10 | **仅限中文** 11 | 12 | ### 使用场景 13 | 14 | ### 行业分析 15 | > 如果你知道有框架提供了类似功能,可以在这里描述,并且给出文档或者例子 16 | 17 | ### 可行方案 18 | > 如果你有设计思路或者解决方案,请在这里提供。你可以提供多个方案,并且给出自己的选择 19 | 20 | ### 其它 21 | > 任何你觉得有利于解决问题的补充说明 22 | 23 | ### 你设置的的 Go 环境? 24 | > 上传 `go env` 的结果 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Want to ask some questions 4 | title: '' 5 | labels: question 6 | --- 7 | 8 | **仅限中文** 9 | 10 | 在提问之前请先查找[已有 issues](https://github.com/ecodeclub/ginx/issues),避免重复提问。 11 | 12 | 并且确保自己已经: 13 | - [ ] 阅读过文档 14 | - [ ] 阅读过代码注释 15 | - [ ] 阅读过相关测试 16 | 17 | ### 你的问题 18 | 19 | ### 你设置的的 Go 环境? 20 | > 上传 `go env` 的结果 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/refactor.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Refactor request 3 | about: Refactor existing code 4 | title: '' 5 | labels: refactor 6 | assignees: '' 7 | 8 | --- 9 | 10 | **仅限中文** 11 | 12 | ### 当前实现缺陷 13 | 14 | ### 重构方案 15 | > 描述可以如何重构,以及重构之后带来的效果,如可读性、性能等方面的提升 16 | 17 | ### 其它 18 | > 任何你觉得有利于解决问题的补充说明 -------------------------------------------------------------------------------- /.github/workflows/go-fmt.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Format Go code 16 | 17 | on: 18 | push: 19 | branches: [ main,dev ] 20 | pull_request: 21 | branches: [ main,dev ] 22 | 23 | jobs: 24 | build: 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: actions/checkout@v4 28 | - name: Set up Go 29 | uses: actions/setup-go@v3 30 | with: 31 | go-version: '>=1.21.1' 32 | 33 | - name: Install goimports 34 | run: go install golang.org/x/tools/cmd/goimports@latest 35 | 36 | - name: Check 37 | run: | 38 | make check 39 | if [ -n "$(git status --porcelain)" ]; then 40 | echo >&2 "错误: 请在本地运行命令'make check'后再提交." 41 | exit 1 42 | fi -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Go 16 | 17 | on: 18 | push: 19 | branches: [ dev,main ] 20 | pull_request: 21 | branches: [ dev,main ] 22 | 23 | jobs: 24 | build: 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: actions/checkout@v4 28 | - name: Set up Go 29 | uses: actions/setup-go@v3 30 | with: 31 | go-version: '1.21.1' 32 | 33 | - name: Build 34 | run: go build -v ./... 35 | 36 | - name: Unit Test 37 | run: go test -race -v ./... -------------------------------------------------------------------------------- /.github/workflows/golangci-lint.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: golangci-lint 16 | on: 17 | push: 18 | branches: 19 | - dev 20 | - main 21 | pull_request: 22 | branches: 23 | - dev 24 | - main 25 | permissions: 26 | contents: read 27 | # Optional: allow read access to pull request. Use with `only-new-issues` option. 28 | pull-requests: read 29 | jobs: 30 | golangci: 31 | name: lint 32 | runs-on: ubuntu-latest 33 | steps: 34 | - uses: actions/setup-go@v3 35 | with: 36 | go-version: '1.21.1' 37 | - uses: actions/checkout@v4 38 | - name: golangci-lint 39 | uses: golangci/golangci-lint-action@v3 40 | with: 41 | # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version 42 | version: latest 43 | 44 | # Optional: working directory, useful for monorepos 45 | # working-directory: somedir 46 | 47 | # Optional: golangci-lint command line arguments. 48 | args: -c .golangci.yml 49 | 50 | # Optional: show only new issues if it's a pull request. The default value is `false`. 51 | only-new-issues: true 52 | 53 | # Optional: if set to true then the all caching functionality will be complete disabled, 54 | # takes precedence over all other caching options. 55 | # skip-cache: true 56 | 57 | # Optional: if set to true then the action don't cache or restore ~/go/pkg. 58 | # skip-pkg-cache: true 59 | 60 | # Optional: if set to true then the action don't cache or restore ~/.cache/go-build. 61 | # skip-build-cache: true -------------------------------------------------------------------------------- /.github/workflows/integration_test.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Integration Test 16 | 17 | on: 18 | push: 19 | branches: [ main, dev] 20 | pull_request: 21 | branches: [ main, dev] 22 | 23 | jobs: 24 | build: 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: actions/checkout@v4 28 | - name: Set up Go 29 | uses: actions/setup-go@v2 30 | with: 31 | go-version: '1.21.1' 32 | 33 | - name: Test 34 | run: make e2e 35 | 36 | - name: Post Coverage 37 | uses: codecov/codecov-action@v2 -------------------------------------------------------------------------------- /.github/workflows/license.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Check License Lines 16 | on: 17 | pull_request: 18 | types: [opened, synchronize, reopened, labeled, unlabeled] 19 | branches: 20 | - develop 21 | - main 22 | - dev 23 | jobs: 24 | check-license-lines: 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: denoland/setup-deno@v1 28 | with: 29 | deno-version: "1.40.4" 30 | - uses: actions/checkout@v4 31 | - name: Check license 32 | run: deno run --allow-read https://deno.land/x/license_checker@v3.2.3/main.ts 33 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Mark stale issues and pull requests 16 | 17 | on: 18 | schedule: 19 | - cron: "30 1 * * *" 20 | 21 | jobs: 22 | stale: 23 | 24 | runs-on: ubuntu-latest 25 | 26 | steps: 27 | - uses: actions/stale@v4 28 | with: 29 | repo-token: ${{ secrets.GITHUB_TOKEN }} 30 | stale-issue-message: 'This issue is inactive for a long time.' 31 | stale-pr-message: 'This PR is inactive for a long time' 32 | stale-issue-label: 'inactive-issue' 33 | stale-pr-label: 'inactive-pr' 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | 23 | .idea 24 | *.out 25 | .run -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | run: 16 | go: '1.21' 17 | skip-dirs: 18 | - .idea -------------------------------------------------------------------------------- /.license_header: -------------------------------------------------------------------------------- 1 | Copyright 2023 ecodeclub 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /.licenserc.json: -------------------------------------------------------------------------------- 1 | { 2 | "**/*.go": "// Copyright 2023 ecodeclub", 3 | "**/*.{yml,toml}": "# Copyright 2023 ecodeclub", 4 | "**/*.sh": "# Copyright 2023 ecodeclub", 5 | "ignore": [ 6 | "internal/mocks/", 7 | "session/provider.mock_test.go", 8 | "internal/ratelimit/mocks/ratelimit.mock.go" 9 | ] 10 | } -------------------------------------------------------------------------------- /.script/goimports.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | goimports -l -w $(find . -type f -name '*.go' -not -path "./.idea/*") -------------------------------------------------------------------------------- /.script/integrate_test.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/usr/bin/env bash 16 | 17 | set -e 18 | docker compose -f .script/integration_test_compose.yml down 19 | docker compose -f .script/integration_test_compose.yml up -d 20 | go test -race -coverprofile=cover.out -tags=e2e ./... 21 | docker compose -f .script/integration_test_compose.yml down -------------------------------------------------------------------------------- /.script/integration_test_compose.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http:#www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | version: '3.0' 16 | 17 | services: 18 | redis: 19 | image: redis:latest 20 | ports: 21 | - "16379:6379" 22 | -------------------------------------------------------------------------------- /.script/setup.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | SOURCE_COMMIT=.github/pre-commit 16 | TARGET_COMMIT=.git/hooks/pre-commit 17 | SOURCE_PUSH=.github/pre-push 18 | TARGET_PUSH=.git/hooks/pre-push 19 | 20 | # copy pre-commit file if not exist. 21 | echo "设置 git pre-commit hooks..." 22 | cp $SOURCE_COMMIT $TARGET_COMMIT 23 | 24 | # copy pre-push file if not exist. 25 | echo "设置 git pre-push hooks..." 26 | cp $SOURCE_PUSH $TARGET_PUSH 27 | 28 | # add permission to TARGET_PUSH and TARGET_COMMIT file. 29 | test -x $TARGET_PUSH || chmod +x $TARGET_PUSH 30 | test -x $TARGET_COMMIT || chmod +x $TARGET_COMMIT 31 | 32 | echo "安装 golangci-lint..." 33 | go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest 34 | 35 | echo "安装 goimports..." 36 | go install golang.org/x/tools/cmd/goimports@latest -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: bench 2 | bench: 3 | @go test -bench=. -benchmem ./... 4 | 5 | .PHONY: ut 6 | ut: 7 | @go test -tags=goexperiment.arenas -race ./... 8 | 9 | .PHONY: setup 10 | setup: 11 | @sh ./.script/setup.sh 12 | 13 | .PHONY: fmt 14 | fmt: 15 | @sh ./.script/goimports.sh 16 | 17 | .PHONY: lint 18 | lint: 19 | @golangci-lint run -c .golangci.yml 20 | 21 | .PHONY: tidy 22 | tidy: 23 | @go mod tidy -v 24 | 25 | .PHONY: check 26 | check: 27 | @$(MAKE) fmt 28 | @$(MAKE) tidy 29 | 30 | # e2e 测试 31 | .PHONY: e2e 32 | e2e: 33 | sh ./.script/integrate_test.sh 34 | 35 | .PHONY: e2e_up 36 | e2e_up: 37 | docker compose -f .script/integration_test_compose.yml up -d 38 | 39 | .PHONY: e2e_down 40 | e2e_down: 41 | docker compose -f .script/integration_test_compose.yml down 42 | mock: 43 | mockgen -copyright_file=.license_header -package=mocks -destination=internal/mocks/pipeline.mock.go github.com/redis/go-redis/v9 Pipeliner 44 | mockgen -copyright_file=.license_header -source=session/types.go -package=session -destination=session/provider.mock_test.go Provider -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ginx 2 | GIN 的加强版,提供了一些比较好用的特性,以及一些基于个人理解的新插件。 3 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package ginx 16 | 17 | import "github.com/ecodeclub/ginx/internal/errs" 18 | 19 | var ErrNoResponse = errs.ErrNoResponse 20 | var ErrUnauthorized = errs.ErrUnauthorized 21 | -------------------------------------------------------------------------------- /gctx/context.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gctx 16 | 17 | import ( 18 | "github.com/ecodeclub/ekit" 19 | "github.com/gin-gonic/gin" 20 | ) 21 | 22 | type Context struct { 23 | *gin.Context 24 | } 25 | 26 | func (c *Context) Param(key string) ekit.AnyValue { 27 | return ekit.AnyValue{ 28 | Val: c.Context.Param(key), 29 | } 30 | } 31 | 32 | func (c *Context) Query(key string) ekit.AnyValue { 33 | return ekit.AnyValue{ 34 | Val: c.Context.Query(key), 35 | } 36 | } 37 | 38 | func (c *Context) Cookie(key string) ekit.AnyValue { 39 | val, err := c.Context.Cookie(key) 40 | return ekit.AnyValue{ 41 | Val: val, 42 | Err: err, 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /gctx/context_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package gctx 16 | 17 | import ( 18 | "net/http" 19 | "net/http/httptest" 20 | "net/url" 21 | "testing" 22 | 23 | "github.com/gin-gonic/gin" 24 | "github.com/stretchr/testify/assert" 25 | "github.com/stretchr/testify/require" 26 | ) 27 | 28 | func TestContext_Query(t *testing.T) { 29 | testCases := []struct { 30 | name string 31 | req func(t *testing.T) *http.Request 32 | key string 33 | wantErr error 34 | wantVal any 35 | }{ 36 | { 37 | name: "获得数据", 38 | req: func(t *testing.T) *http.Request { 39 | req, err := http.NewRequest(http.MethodGet, "http://localhost/abc?name=123&age=18", nil) 40 | require.NoError(t, err) 41 | return req 42 | }, 43 | key: "name", 44 | wantVal: "123", 45 | }, 46 | { 47 | name: "没有数据", 48 | req: func(t *testing.T) *http.Request { 49 | req, err := http.NewRequest(http.MethodGet, "http://localhost/abc?name=123&age=18", nil) 50 | require.NoError(t, err) 51 | return req 52 | }, 53 | key: "nickname", 54 | wantVal: "", 55 | }, 56 | } 57 | 58 | for _, tc := range testCases { 59 | t.Run(tc.name, func(t *testing.T) { 60 | ctx := &Context{Context: &gin.Context{ 61 | Request: tc.req(t), 62 | }} 63 | name := ctx.Query(tc.key) 64 | val, err := name.String() 65 | assert.Equal(t, tc.wantErr, err) 66 | assert.Equal(t, tc.wantVal, val) 67 | }) 68 | } 69 | } 70 | 71 | func TestContext_Param(t *testing.T) { 72 | testCases := []struct { 73 | name string 74 | req func(t *testing.T) *http.Request 75 | key string 76 | wantErr error 77 | wantVal any 78 | }{ 79 | { 80 | name: "获得数据", 81 | req: func(t *testing.T) *http.Request { 82 | req, err := http.NewRequest(http.MethodGet, "http://localhost/hello?name=123&age=18", nil) 83 | req.Form = url.Values{} 84 | req.Form.Set("name", "world") 85 | require.NoError(t, err) 86 | return req 87 | }, 88 | key: "name", 89 | wantVal: "world", 90 | }, 91 | { 92 | name: "没有数据", 93 | req: func(t *testing.T) *http.Request { 94 | req, err := http.NewRequest(http.MethodPost, "http://localhost/hello?name=123&age=18", nil) 95 | require.NoError(t, err) 96 | return req 97 | }, 98 | key: "nickname", 99 | wantVal: "", 100 | }, 101 | } 102 | 103 | for _, tc := range testCases { 104 | t.Run(tc.name, func(t *testing.T) { 105 | server := gin.Default() 106 | server.POST("/hello", func(context *gin.Context) { 107 | ctx := &Context{Context: context} 108 | name := ctx.Param(tc.key) 109 | val, err := name.String() 110 | assert.Equal(t, tc.wantErr, err) 111 | assert.Equal(t, tc.wantVal, val) 112 | }) 113 | recorder := httptest.NewRecorder() 114 | server.ServeHTTP(recorder, tc.req(t)) 115 | }) 116 | } 117 | } 118 | 119 | func TestContext_Cookie(t *testing.T) { 120 | testCases := []struct { 121 | name string 122 | req func(t *testing.T) *http.Request 123 | key string 124 | wantErr error 125 | wantVal any 126 | }{ 127 | { 128 | name: "有cookie", 129 | req: func(t *testing.T) *http.Request { 130 | req, err := http.NewRequest(http.MethodGet, "http://localhost/hello?name=123&age=18", nil) 131 | req.AddCookie(&http.Cookie{ 132 | Name: "name", 133 | Value: "world", 134 | }) 135 | require.NoError(t, err) 136 | return req 137 | }, 138 | key: "name", 139 | wantVal: "world", 140 | }, 141 | { 142 | name: "没有 cookie", 143 | req: func(t *testing.T) *http.Request { 144 | req, err := http.NewRequest(http.MethodPost, "http://localhost/hello?name=123&age=18", nil) 145 | require.NoError(t, err) 146 | return req 147 | }, 148 | key: "nickname", 149 | wantVal: "", 150 | }, 151 | } 152 | for _, tc := range testCases { 153 | t.Run(tc.name, func(t *testing.T) { 154 | server := gin.Default() 155 | server.POST("/hello", func(context *gin.Context) { 156 | ctx := &Context{Context: context} 157 | name := ctx.Param(tc.key) 158 | val, err := name.String() 159 | assert.Equal(t, tc.wantErr, err) 160 | assert.Equal(t, tc.wantVal, val) 161 | }) 162 | recorder := httptest.NewRecorder() 163 | server.ServeHTTP(recorder, tc.req(t)) 164 | }) 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/ecodeclub/ginx 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/ecodeclub/ekit v0.0.8-0.20240211141809-d8a351a335b5 7 | github.com/gin-gonic/gin v1.9.1 8 | github.com/golang-jwt/jwt/v5 v5.0.0 9 | github.com/google/uuid v1.6.0 10 | github.com/redis/go-redis/v9 v9.2.1 11 | github.com/stretchr/testify v1.8.4 12 | go.uber.org/atomic v1.11.0 13 | go.uber.org/mock v0.3.0 14 | ) 15 | 16 | require ( 17 | github.com/bytedance/sonic v1.9.1 // indirect 18 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 19 | github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect 20 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 21 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 22 | github.com/gabriel-vasile/mimetype v1.4.2 // indirect 23 | github.com/gin-contrib/sse v0.1.0 // indirect 24 | github.com/go-playground/locales v0.14.1 // indirect 25 | github.com/go-playground/universal-translator v0.18.1 // indirect 26 | github.com/go-playground/validator/v10 v10.14.0 // indirect 27 | github.com/goccy/go-json v0.10.2 // indirect 28 | github.com/json-iterator/go v1.1.12 // indirect 29 | github.com/klauspost/cpuid/v2 v2.2.4 // indirect 30 | github.com/kr/text v0.2.0 // indirect 31 | github.com/leodido/go-urn v1.2.4 // indirect 32 | github.com/mattn/go-isatty v0.0.19 // indirect 33 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 34 | github.com/modern-go/reflect2 v1.0.2 // indirect 35 | github.com/pelletier/go-toml/v2 v2.0.8 // indirect 36 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 37 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect 38 | github.com/ugorji/go/codec v1.2.11 // indirect 39 | golang.org/x/arch v0.3.0 // indirect 40 | golang.org/x/crypto v0.9.0 // indirect 41 | golang.org/x/net v0.10.0 // indirect 42 | golang.org/x/sys v0.8.0 // indirect 43 | golang.org/x/text v0.9.0 // indirect 44 | google.golang.org/protobuf v1.30.0 // indirect 45 | gopkg.in/yaml.v3 v3.0.1 // indirect 46 | ) 47 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= 2 | github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= 3 | github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= 4 | github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= 5 | github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= 6 | github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= 7 | github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= 8 | github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= 9 | github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 10 | github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= 11 | github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= 12 | github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= 13 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 14 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 15 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 16 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= 17 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 18 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= 19 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= 20 | github.com/ecodeclub/ekit v0.0.8-0.20240211141809-d8a351a335b5 h1:beyGjdznTmvRRzEdxHF6SeYRLPAsCPr5OzR6Xb2bd1k= 21 | github.com/ecodeclub/ekit v0.0.8-0.20240211141809-d8a351a335b5/go.mod h1:rEGubThvxoIQT/qnbVBkZgSvYwgKrY/dtwEWKRTmgeY= 22 | github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= 23 | github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= 24 | github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= 25 | github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= 26 | github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= 27 | github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= 28 | github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= 29 | github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= 30 | github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= 31 | github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= 32 | github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= 33 | github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= 34 | github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= 35 | github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= 36 | github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= 37 | github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= 38 | github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= 39 | github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= 40 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 41 | github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= 42 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 43 | github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= 44 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 45 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 46 | github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= 47 | github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= 48 | github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= 49 | github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= 50 | github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= 51 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 52 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 53 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 54 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 55 | github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= 56 | github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= 57 | github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= 58 | github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 59 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 60 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= 61 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 62 | github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= 63 | github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= 64 | github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= 65 | github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= 66 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 67 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= 68 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 69 | github.com/redis/go-redis/v9 v9.2.1 h1:WlYJg71ODF0dVspZZCpYmoF1+U1Jjk9Rwd7pq6QmlCg= 70 | github.com/redis/go-redis/v9 v9.2.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= 71 | github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= 72 | github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= 73 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 74 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 75 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 76 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 77 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 78 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 79 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 80 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 81 | github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 82 | github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 83 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 84 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 85 | github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= 86 | github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= 87 | github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= 88 | github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= 89 | go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= 90 | go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= 91 | go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= 92 | go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= 93 | golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= 94 | golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= 95 | golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= 96 | golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= 97 | golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= 98 | golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= 99 | golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= 100 | golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 101 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 102 | golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= 103 | golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 104 | golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= 105 | golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 106 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= 107 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 108 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 109 | google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= 110 | google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= 111 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 112 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 113 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 114 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 115 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 116 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 117 | rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= 118 | -------------------------------------------------------------------------------- /internal/crawlerdetect/baidu_strategy.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package crawlerdetect 16 | 17 | type BaiduStrategy struct { 18 | *UniversalStrategy 19 | } 20 | 21 | func NewBaiduStrategy() *BaiduStrategy { 22 | return &BaiduStrategy{ 23 | UniversalStrategy: NewUniversalStrategy([]string{"baidu.com", "baidu.jp"}), 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /internal/crawlerdetect/baidu_strategy_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package crawlerdetect 16 | 17 | import ( 18 | "errors" 19 | "log" 20 | "net" 21 | "testing" 22 | 23 | "github.com/stretchr/testify/require" 24 | ) 25 | 26 | func TestBaiduStrategy(t *testing.T) { 27 | s := NewBaiduStrategy() 28 | require.NotNil(t, s) 29 | testCases := []struct { 30 | name string 31 | ip string 32 | matched bool 33 | errFunc require.ErrorAssertionFunc 34 | }{ 35 | { 36 | name: "无效 ip", 37 | ip: "256.0.0.0", 38 | matched: false, 39 | errFunc: func(t require.TestingT, err error, i ...interface{}) { 40 | var dnsError *net.DNSError 41 | if !errors.As(err, &dnsError) { 42 | log.Fatal(err) 43 | } 44 | }, 45 | }, 46 | { 47 | name: "非百度 ip", 48 | ip: "166.249.90.77", 49 | matched: false, 50 | }, 51 | { 52 | name: "百度 ip", 53 | ip: "111.206.198.69", 54 | matched: true, 55 | }, 56 | } 57 | 58 | for _, tc := range testCases { 59 | t.Run(tc.name, func(t *testing.T) { 60 | m, err := s.CheckCrawler(tc.ip) 61 | if err != nil { 62 | tc.errFunc(t, err) 63 | } 64 | require.Equal(t, tc.matched, m) 65 | }) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /internal/crawlerdetect/bing_strategy.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package crawlerdetect 16 | 17 | type BingStrategy struct { 18 | *UniversalStrategy 19 | } 20 | 21 | func NewBingStrategy() *BingStrategy { 22 | return &BingStrategy{ 23 | UniversalStrategy: NewUniversalStrategy([]string{"search.msn.com"}), 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /internal/crawlerdetect/bing_strategy_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package crawlerdetect 16 | 17 | import ( 18 | "errors" 19 | "log" 20 | "net" 21 | "testing" 22 | 23 | "github.com/stretchr/testify/require" 24 | ) 25 | 26 | func TestBingStrategy(t *testing.T) { 27 | s := NewBingStrategy() 28 | require.NotNil(t, s) 29 | testCases := []struct { 30 | name string 31 | ip string 32 | matched bool 33 | errFunc require.ErrorAssertionFunc 34 | }{ 35 | { 36 | name: "无效 ip", 37 | ip: "256.0.0.0", 38 | matched: false, 39 | errFunc: func(t require.TestingT, err error, i ...interface{}) { 40 | var dnsError *net.DNSError 41 | if !errors.As(err, &dnsError) { 42 | log.Fatal(err) 43 | } 44 | }, 45 | }, 46 | { 47 | name: "非必应 ip", 48 | ip: "166.249.90.77", 49 | matched: false, 50 | }, 51 | { 52 | name: "必应 ip", 53 | ip: "157.55.39.1", 54 | matched: true, 55 | }, 56 | } 57 | 58 | for _, tc := range testCases { 59 | t.Run(tc.name, func(t *testing.T) { 60 | m, err := s.CheckCrawler(tc.ip) 61 | if err != nil { 62 | tc.errFunc(t, err) 63 | } 64 | require.Equal(t, tc.matched, m) 65 | }) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /internal/crawlerdetect/crawler_detector.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package crawlerdetect 16 | 17 | import ( 18 | "net" 19 | "slices" 20 | "strings" 21 | ) 22 | 23 | const ( 24 | Baidu = "baidu" 25 | Bing = "bing" 26 | Google = "google" 27 | Sogou = "sogou" 28 | ) 29 | 30 | var strategyMap = map[string]Strategy{ 31 | Baidu: NewBaiduStrategy(), 32 | Bing: NewBingStrategy(), 33 | Google: NewGoogleStrategy(), 34 | Sogou: NewSoGouStrategy(), 35 | } 36 | 37 | type Strategy interface { 38 | CheckCrawler(ip string) (bool, error) 39 | } 40 | 41 | type UniversalStrategy struct { 42 | Hosts []string 43 | } 44 | 45 | func NewUniversalStrategy(hosts []string) *UniversalStrategy { 46 | return &UniversalStrategy{ 47 | Hosts: hosts, 48 | } 49 | } 50 | 51 | func (s *UniversalStrategy) CheckCrawler(ip string) (bool, error) { 52 | names, err := net.LookupAddr(ip) 53 | if err != nil { 54 | return false, err 55 | } 56 | if len(names) == 0 { 57 | return false, nil 58 | } 59 | 60 | name, matched := s.matchHost(names) 61 | if !matched { 62 | return false, nil 63 | } 64 | 65 | ips, err := net.LookupIP(name) 66 | if err != nil { 67 | return false, err 68 | } 69 | if slices.ContainsFunc(ips, func(netIp net.IP) bool { 70 | return netIp.String() == ip 71 | }) { 72 | return true, nil 73 | } 74 | 75 | return false, nil 76 | } 77 | 78 | func (s *UniversalStrategy) matchHost(names []string) (string, bool) { 79 | var matchedName string 80 | return matchedName, slices.ContainsFunc(s.Hosts, func(host string) bool { 81 | return slices.ContainsFunc(names, func(name string) bool { 82 | if strings.Contains(name, host) { 83 | matchedName = name 84 | return true 85 | } 86 | return false 87 | }) 88 | }) 89 | } 90 | 91 | func NewCrawlerDetector(crawler string) Strategy { 92 | return strategyMap[crawler] 93 | } 94 | -------------------------------------------------------------------------------- /internal/crawlerdetect/google_strategy.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package crawlerdetect 16 | 17 | type GoogleStrategy struct { 18 | *UniversalStrategy 19 | } 20 | 21 | func NewGoogleStrategy() *GoogleStrategy { 22 | return &GoogleStrategy{ 23 | UniversalStrategy: NewUniversalStrategy([]string{"googlebot.com", "google.com", "googleusercontent.com"}), 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /internal/crawlerdetect/google_strategy_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package crawlerdetect 16 | 17 | import ( 18 | "errors" 19 | "log" 20 | "net" 21 | "testing" 22 | 23 | "github.com/stretchr/testify/require" 24 | ) 25 | 26 | func TestGoogleStrategy(t *testing.T) { 27 | s := NewGoogleStrategy() 28 | require.NotNil(t, s) 29 | testCases := []struct { 30 | name string 31 | ip string 32 | matched bool 33 | errFunc require.ErrorAssertionFunc 34 | }{ 35 | { 36 | name: "无效 ip", 37 | ip: "256.0.0.0", 38 | matched: false, 39 | errFunc: func(t require.TestingT, err error, i ...interface{}) { 40 | var dnsError *net.DNSError 41 | if !errors.As(err, &dnsError) { 42 | log.Fatal(err) 43 | } 44 | }, 45 | }, 46 | { 47 | name: "非谷歌 ip", 48 | ip: "166.249.90.77", 49 | matched: false, 50 | }, 51 | { 52 | name: "谷歌 ip", 53 | ip: "66.249.90.77", 54 | matched: true, 55 | }, 56 | } 57 | 58 | for _, tc := range testCases { 59 | t.Run(tc.name, func(t *testing.T) { 60 | m, err := s.CheckCrawler(tc.ip) 61 | if err != nil { 62 | tc.errFunc(t, err) 63 | } 64 | require.Equal(t, tc.matched, m) 65 | }) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /internal/crawlerdetect/sogou_strategy.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package crawlerdetect 16 | 17 | import ( 18 | "net" 19 | "slices" 20 | "strings" 21 | ) 22 | 23 | type SoGouStrategy struct { 24 | Hosts []string 25 | } 26 | 27 | func NewSoGouStrategy() *SoGouStrategy { 28 | return &SoGouStrategy{ 29 | Hosts: []string{"sogou.com"}, 30 | } 31 | } 32 | 33 | func (s *SoGouStrategy) CheckCrawler(ip string) (bool, error) { 34 | names, err := net.LookupAddr(ip) 35 | if err != nil || len(names) == 0 { 36 | return false, err 37 | } 38 | return s.matchHost(names), nil 39 | } 40 | 41 | func (s *SoGouStrategy) matchHost(names []string) bool { 42 | return slices.ContainsFunc(s.Hosts, func(host string) bool { 43 | return slices.ContainsFunc(names, func(name string) bool { 44 | return strings.Contains(name, host) 45 | }) 46 | }) 47 | } 48 | -------------------------------------------------------------------------------- /internal/crawlerdetect/sogou_strategy_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package crawlerdetect 16 | 17 | import ( 18 | "errors" 19 | "log" 20 | "net" 21 | "testing" 22 | 23 | "github.com/stretchr/testify/require" 24 | ) 25 | 26 | func TestSoGouStrategy(t *testing.T) { 27 | s := NewSoGouStrategy() 28 | require.NotNil(t, s) 29 | testCases := []struct { 30 | name string 31 | ip string 32 | matched bool 33 | errFunc require.ErrorAssertionFunc 34 | }{ 35 | { 36 | name: "无效 ip", 37 | ip: "256.0.0.0", 38 | matched: false, 39 | errFunc: func(t require.TestingT, err error, i ...interface{}) { 40 | var dnsError *net.DNSError 41 | if !errors.As(err, &dnsError) { 42 | log.Fatal(err) 43 | } 44 | }, 45 | }, 46 | { 47 | name: "非搜狗 ip", 48 | ip: "166.249.90.77", 49 | matched: false, 50 | }, 51 | //{ 52 | // name: "搜狗 ip", 53 | // ip: "123.126.113.110", 54 | // matched: true, 55 | //}, 56 | } 57 | 58 | for _, tc := range testCases { 59 | t.Run(tc.name, func(t *testing.T) { 60 | m, err := s.CheckCrawler(tc.ip) 61 | if err != nil { 62 | tc.errFunc(t, err) 63 | } 64 | require.Equal(t, tc.matched, m) 65 | }) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /internal/e2e/activelimit_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //go:build e2e 16 | 17 | package e2e 18 | 19 | import ( 20 | "context" 21 | "fmt" 22 | "net/http" 23 | "net/http/httptest" 24 | "testing" 25 | "time" 26 | 27 | "github.com/ecodeclub/ginx/middlewares/activelimit/redislimit" 28 | "github.com/gin-gonic/gin" 29 | "github.com/stretchr/testify/assert" 30 | "github.com/stretchr/testify/require" 31 | ) 32 | 33 | func TestBuilder_e2e_ActiveRedisLimit(t *testing.T) { 34 | redisClient := newRedisTestClient() 35 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) 36 | defer cancel() 37 | err := redisClient.Ping(ctx).Err() 38 | if err != nil { 39 | panic("redislimit 连接失败") 40 | } 41 | defer func() { 42 | _ = redisClient.Close() 43 | }() 44 | 45 | testCases := []struct { 46 | name string 47 | maxCount int64 48 | key string 49 | getReq func() *http.Request 50 | createMiddleware func(maxActive int64, key string) gin.HandlerFunc 51 | before func(server *gin.Engine, key string) 52 | 53 | interval time.Duration 54 | after func(string2 string) (int64, error) 55 | 56 | //响应的code 57 | wantCode int 58 | 59 | //检查退出的时候redis 状态 60 | afterCount int64 61 | afterErr error 62 | }{ 63 | { 64 | name: "开启限流,RedisLimit正常操作", 65 | 66 | createMiddleware: func(maxActive int64, key string) gin.HandlerFunc { 67 | return redislimit.NewRedisActiveLimit(redisClient, maxActive, key).Build() 68 | }, 69 | getReq: func() *http.Request { 70 | req, err := http.NewRequest(http.MethodGet, "/activelimit", nil) 71 | require.NoError(t, err) 72 | return req 73 | }, 74 | before: func(server *gin.Engine, key string) { 75 | 76 | }, 77 | interval: time.Millisecond * 10, 78 | after: func(key string) (int64, error) { 79 | 80 | return redisClient.Get(context.Background(), key).Int64() 81 | }, 82 | 83 | maxCount: 1, 84 | key: "test", 85 | wantCode: http.StatusOK, 86 | 87 | afterCount: 0, 88 | afterErr: nil, 89 | }, 90 | { 91 | name: "开启限流,RedisLimit,有一个人长时间没退出,导致限流", 92 | 93 | createMiddleware: func(maxActive int64, key string) gin.HandlerFunc { 94 | return redislimit.NewRedisActiveLimit(redisClient, maxActive, key).Build() 95 | }, 96 | getReq: func() *http.Request { 97 | req, err := http.NewRequest(http.MethodGet, "/activelimit", nil) 98 | require.NoError(t, err) 99 | return req 100 | }, 101 | before: func(server *gin.Engine, key string) { 102 | 103 | req, err := http.NewRequest(http.MethodGet, "/activelimit3", nil) 104 | require.NoError(t, err) 105 | resp := httptest.NewRecorder() 106 | server.ServeHTTP(resp, req) 107 | assert.Equal(t, 200, resp.Code) 108 | }, 109 | 110 | interval: time.Millisecond * 50, 111 | after: func(key string) (int64, error) { 112 | 113 | return redisClient.Get(context.Background(), key).Int64() 114 | }, 115 | maxCount: 1, 116 | key: "test", 117 | wantCode: http.StatusTooManyRequests, 118 | afterCount: 1, 119 | afterErr: nil, 120 | }, 121 | { 122 | name: "开启限流,RedisLimit,有一个人长时间没退出,等待前面退出后,正常请求....", 123 | 124 | createMiddleware: func(maxActive int64, key string) gin.HandlerFunc { 125 | return redislimit.NewRedisActiveLimit(redisClient, maxActive, key).Build() 126 | }, 127 | getReq: func() *http.Request { 128 | req, err := http.NewRequest(http.MethodGet, "/activelimit", nil) 129 | require.NoError(t, err) 130 | return req 131 | }, 132 | before: func(server *gin.Engine, key string) { 133 | req, err := http.NewRequest(http.MethodGet, "/activelimit3", nil) 134 | require.NoError(t, err) 135 | resp := httptest.NewRecorder() 136 | server.ServeHTTP(resp, req) 137 | assert.Equal(t, 200, resp.Code) 138 | }, 139 | interval: time.Millisecond * 200, 140 | after: func(key string) (int64, error) { 141 | 142 | return redisClient.Get(context.Background(), key).Int64() 143 | }, 144 | maxCount: 1, 145 | key: "test", 146 | wantCode: http.StatusOK, 147 | afterCount: 0, 148 | afterErr: nil, 149 | }, 150 | } 151 | 152 | for _, tc := range testCases { 153 | //这里延时的原因是 保证builder 中的defer 延时操作不会导致测试的异常 154 | time.Sleep(time.Millisecond * 100) 155 | redisClient.Del(context.Background(), tc.key) 156 | fmt.Println(redisClient.Get(context.Background(), tc.key).Int64()) 157 | tc := tc 158 | t.Run(tc.name, func(t *testing.T) { 159 | 160 | server := gin.Default() 161 | server.Use(tc.createMiddleware(tc.maxCount, tc.key)) 162 | server.GET("/activelimit", func(ctx *gin.Context) { 163 | ctx.Status(http.StatusOK) 164 | }) 165 | server.GET("/activelimit3", func(ctx *gin.Context) { 166 | time.Sleep(time.Millisecond * 100) 167 | ctx.Status(http.StatusOK) 168 | }) 169 | resp := httptest.NewRecorder() 170 | go func() { 171 | tc.before(server, tc.key) 172 | }() 173 | time.Sleep(tc.interval) 174 | server.ServeHTTP(resp, tc.getReq()) 175 | assert.Equal(t, tc.wantCode, resp.Code) 176 | 177 | afterCount, err := tc.after(tc.key) 178 | 179 | assert.Equal(t, tc.afterCount, afterCount) 180 | assert.Equal(t, tc.afterErr, err) 181 | }) 182 | } 183 | } 184 | -------------------------------------------------------------------------------- /internal/e2e/base_suite.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //go:build e2e 16 | 17 | package e2e 18 | 19 | import ( 20 | "github.com/redis/go-redis/v9" 21 | "github.com/stretchr/testify/suite" 22 | ) 23 | 24 | type BaseSuite struct { 25 | suite.Suite 26 | RDB redis.Cmdable 27 | } 28 | 29 | func (s *BaseSuite) SetupSuite() { 30 | s.RDB = newRedisTestClient() 31 | } 32 | 33 | func (s *BaseSuite) TearDownSuite() { 34 | if s.RDB != nil { 35 | s.RDB.(*redis.Client).Close() 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /internal/e2e/dependency.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //go:build e2e 16 | 17 | package e2e 18 | 19 | import "github.com/redis/go-redis/v9" 20 | 21 | var redisCfg = &redis.Options{ 22 | Addr: "localhost:16379", 23 | Password: "", 24 | DB: 0, 25 | } 26 | 27 | func newRedisTestClient() *redis.Client { 28 | return redis.NewClient(redisCfg) 29 | } 30 | -------------------------------------------------------------------------------- /internal/e2e/gin_writer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package e2e 16 | 17 | import ( 18 | "bufio" 19 | "net" 20 | "net/http" 21 | ) 22 | 23 | type GinResponseWriter struct { 24 | http.ResponseWriter 25 | } 26 | 27 | func (g *GinResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 28 | panic("implement me") 29 | } 30 | 31 | func (g *GinResponseWriter) Flush() { 32 | panic("implement me") 33 | } 34 | 35 | func (g *GinResponseWriter) CloseNotify() <-chan bool { 36 | panic("implement me") 37 | } 38 | 39 | func (g *GinResponseWriter) Status() int { 40 | panic("implement me") 41 | } 42 | 43 | func (g *GinResponseWriter) Size() int { 44 | panic("implement me") 45 | } 46 | 47 | func (g *GinResponseWriter) WriteString(s string) (int, error) { 48 | panic("implement me") 49 | } 50 | 51 | func (g *GinResponseWriter) Written() bool { 52 | panic("implement me") 53 | } 54 | 55 | func (g *GinResponseWriter) WriteHeaderNow() { 56 | panic("implement me") 57 | } 58 | 59 | func (g *GinResponseWriter) Pusher() http.Pusher { 60 | panic("implement me") 61 | } 62 | -------------------------------------------------------------------------------- /internal/e2e/ratelimit_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //go:build e2e 16 | 17 | package e2e 18 | 19 | import ( 20 | "context" 21 | "fmt" 22 | "net/http" 23 | "net/http/httptest" 24 | "testing" 25 | "time" 26 | 27 | "github.com/gin-gonic/gin" 28 | "github.com/redis/go-redis/v9" 29 | "github.com/stretchr/testify/assert" 30 | 31 | limit "github.com/ecodeclub/ginx/middlewares/ratelimit" 32 | ) 33 | 34 | func TestBuilder_e2e_RateLimit(t *testing.T) { 35 | const ( 36 | ip = "127.0.0.1" 37 | limitURL = "/limit" 38 | ) 39 | rdb := initRedis() 40 | server := initWebServer(rdb) 41 | RegisterRoutes(server) 42 | 43 | tests := []struct { 44 | // 名字 45 | name string 46 | // 要提前准备数据 47 | before func(t *testing.T) 48 | // 验证并且删除数据 49 | after func(t *testing.T) 50 | 51 | // 预期响应 52 | wantCode int 53 | }{ 54 | { 55 | name: "不限流", 56 | before: func(t *testing.T) {}, 57 | after: func(t *testing.T) { 58 | rdb.Del(context.Background(), fmt.Sprintf("ip-limiter:%s", ip)) 59 | }, 60 | wantCode: http.StatusOK, 61 | }, 62 | { 63 | name: "限流", 64 | before: func(t *testing.T) { 65 | req, err := http.NewRequest(http.MethodGet, limitURL, nil) 66 | req.RemoteAddr = ip + ":80" 67 | assert.NoError(t, err) 68 | recorder := httptest.NewRecorder() 69 | server.ServeHTTP(recorder, req) 70 | }, 71 | after: func(t *testing.T) { 72 | rdb.Del(context.Background(), fmt.Sprintf("ip-limiter:%s", ip)) 73 | }, 74 | wantCode: http.StatusTooManyRequests, 75 | }, 76 | } 77 | for _, tt := range tests { 78 | t.Run(tt.name, func(t *testing.T) { 79 | defer tt.after(t) 80 | tt.before(t) 81 | req, err := http.NewRequest(http.MethodGet, limitURL, nil) 82 | req.RemoteAddr = ip + ":80" 83 | assert.NoError(t, err) 84 | 85 | recorder := httptest.NewRecorder() 86 | server.ServeHTTP(recorder, req) 87 | 88 | code := recorder.Code 89 | assert.Equal(t, tt.wantCode, code) 90 | }) 91 | } 92 | } 93 | 94 | func RegisterRoutes(server *gin.Engine) { 95 | server.GET("/limit", func(ctx *gin.Context) { 96 | ctx.Status(http.StatusOK) 97 | }) 98 | } 99 | 100 | func initRedis() redis.Cmdable { 101 | redisClient := redis.NewClient(&redis.Options{ 102 | Addr: "localhost:16379", 103 | }) 104 | return redisClient 105 | } 106 | 107 | func initWebServer(cmd redis.Cmdable) *gin.Engine { 108 | server := gin.Default() 109 | limiter := limit.NewRedisSlidingWindowLimiter(cmd, 500*time.Millisecond, 1) 110 | server.Use(limit.NewBuilder(limiter).Build()) 111 | return server 112 | } 113 | -------------------------------------------------------------------------------- /internal/errs/error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package errs 16 | 17 | import "errors" 18 | 19 | var ErrUnauthorized = errors.New("未授权") 20 | var ErrSessionKeyNotFound = errors.New("session 中没找到对应的 key") 21 | 22 | // ErrNoResponse 是一个 sentinel 错误。 23 | // 也就是说,你可以通过返回这个 ErrNoResponse 来告诉 ginx 不需要继续写响应。 24 | // 大多数情况下,这意味着你已经写回了响应。 25 | var ErrNoResponse = errors.New("不需要返回 response") 26 | -------------------------------------------------------------------------------- /internal/jwt/claims_option.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package jwt 16 | 17 | import ( 18 | "time" 19 | 20 | "github.com/ecodeclub/ekit/bean/option" 21 | "github.com/golang-jwt/jwt/v5" 22 | ) 23 | 24 | type Options struct { 25 | Expire time.Duration // 有效期 26 | EncryptionKey string // 加密密钥 27 | DecryptKey string // 解密密钥 28 | Method jwt.SigningMethod // 签名方式 29 | Issuer string // 签发人 30 | genIDFn func() string // 生成 JWT ID (jti) 的函数 31 | } 32 | 33 | // NewOptions 定义一个 JWT 配置. 34 | // DecryptKey: 默认与 EncryptionKey 相同. 35 | // Method: 默认使用 jwt.SigningMethodHS256 签名方式. 36 | func NewOptions(expire time.Duration, encryptionKey string, 37 | opts ...option.Option[Options]) Options { 38 | dOpts := Options{ 39 | Expire: expire, 40 | EncryptionKey: encryptionKey, 41 | DecryptKey: encryptionKey, 42 | Method: jwt.SigningMethodHS256, 43 | genIDFn: func() string { return "" }, 44 | } 45 | 46 | option.Apply[Options](&dOpts, opts...) 47 | 48 | return dOpts 49 | } 50 | 51 | // WithDecryptKey 设置解密密钥. 52 | func WithDecryptKey(decryptKey string) option.Option[Options] { 53 | return func(o *Options) { 54 | o.DecryptKey = decryptKey 55 | } 56 | } 57 | 58 | // WithMethod 设置 JWT 的签名方法. 59 | func WithMethod(method jwt.SigningMethod) option.Option[Options] { 60 | return func(o *Options) { 61 | o.Method = method 62 | } 63 | } 64 | 65 | // WithIssuer 设置签发人. 66 | func WithIssuer(issuer string) option.Option[Options] { 67 | return func(o *Options) { 68 | o.Issuer = issuer 69 | } 70 | } 71 | 72 | // WithGenIDFunc 设置生成 JWT ID 的函数. 73 | // 可以设置成 WithGenIDFunc(uuid.NewString). 74 | func WithGenIDFunc(fn func() string) option.Option[Options] { 75 | return func(o *Options) { 76 | o.genIDFn = fn 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /internal/jwt/claims_option_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package jwt 16 | 17 | import ( 18 | "testing" 19 | "time" 20 | 21 | "github.com/ecodeclub/ekit/bean/option" 22 | "github.com/golang-jwt/jwt/v5" 23 | "github.com/stretchr/testify/assert" 24 | ) 25 | 26 | func TestNewOptions(t *testing.T) { 27 | var genIDFn func() string 28 | tests := []struct { 29 | name string 30 | expire time.Duration 31 | encryptionKey string 32 | want Options 33 | }{ 34 | { 35 | name: "normal", 36 | expire: 10 * time.Minute, 37 | encryptionKey: "sign key", 38 | want: Options{ 39 | Expire: 10 * time.Minute, 40 | EncryptionKey: "sign key", 41 | DecryptKey: "sign key", 42 | Method: jwt.SigningMethodHS256, 43 | genIDFn: genIDFn, 44 | }, 45 | }, 46 | } 47 | for _, tt := range tests { 48 | t.Run(tt.name, func(t *testing.T) { 49 | got := NewOptions(tt.expire, tt.encryptionKey) 50 | got.genIDFn = genIDFn 51 | assert.Equal(t, tt.want, got) 52 | }) 53 | } 54 | } 55 | 56 | func TestWithDecryptKey(t *testing.T) { 57 | tests := []struct { 58 | name string 59 | fn func() option.Option[Options] 60 | want string 61 | }{ 62 | { 63 | name: "normal", 64 | fn: func() option.Option[Options] { 65 | return nil 66 | }, 67 | want: encryptionKey, 68 | }, 69 | { 70 | name: "set_another_key", 71 | fn: func() option.Option[Options] { 72 | return WithDecryptKey("another sign key") 73 | }, 74 | want: "another sign key", 75 | }, 76 | } 77 | for _, tt := range tests { 78 | t.Run(tt.name, func(t *testing.T) { 79 | var got string 80 | if tt.fn() == nil { 81 | got = NewOptions(defaultExpire, encryptionKey). 82 | DecryptKey 83 | } else { 84 | got = NewOptions(defaultExpire, encryptionKey, 85 | tt.fn()).DecryptKey 86 | } 87 | assert.Equal(t, tt.want, got) 88 | }) 89 | } 90 | } 91 | 92 | func TestWithMethod(t *testing.T) { 93 | tests := []struct { 94 | name string 95 | fn func() option.Option[Options] 96 | want jwt.SigningMethod 97 | }{ 98 | { 99 | name: "normal", 100 | fn: func() option.Option[Options] { 101 | return nil 102 | }, 103 | want: jwt.SigningMethodHS256, 104 | }, 105 | { 106 | name: "set_another_method", 107 | fn: func() option.Option[Options] { 108 | return WithMethod(jwt.SigningMethodHS384) 109 | }, 110 | want: jwt.SigningMethodHS384, 111 | }, 112 | } 113 | for _, tt := range tests { 114 | t.Run(tt.name, func(t *testing.T) { 115 | var got jwt.SigningMethod 116 | if tt.fn() == nil { 117 | got = NewOptions(defaultExpire, encryptionKey). 118 | Method 119 | } else { 120 | got = NewOptions(defaultExpire, encryptionKey, 121 | tt.fn()).Method 122 | } 123 | assert.Equal(t, tt.want, got) 124 | }) 125 | } 126 | } 127 | 128 | func TestWithIssuer(t *testing.T) { 129 | tests := []struct { 130 | name string 131 | fn func() option.Option[Options] 132 | want string 133 | }{ 134 | { 135 | name: "normal", 136 | fn: func() option.Option[Options] { 137 | return nil 138 | }, 139 | }, 140 | { 141 | name: "set_another_issuer", 142 | fn: func() option.Option[Options] { 143 | return WithIssuer("foo") 144 | }, 145 | want: "foo", 146 | }, 147 | } 148 | for _, tt := range tests { 149 | t.Run(tt.name, func(t *testing.T) { 150 | var got string 151 | if tt.fn() == nil { 152 | got = NewOptions(defaultExpire, encryptionKey). 153 | Issuer 154 | } else { 155 | got = NewOptions(defaultExpire, encryptionKey, 156 | tt.fn()).Issuer 157 | } 158 | assert.Equal(t, tt.want, got) 159 | }) 160 | } 161 | } 162 | 163 | func TestWithGenIDFunc(t *testing.T) { 164 | tests := []struct { 165 | name string 166 | fn func() option.Option[Options] 167 | want string 168 | }{ 169 | { 170 | name: "normal", 171 | fn: func() option.Option[Options] { 172 | return nil 173 | }, 174 | }, 175 | { 176 | name: "set_another_gen_id_func", 177 | fn: func() option.Option[Options] { 178 | return WithGenIDFunc(func() string { 179 | return "unique id" 180 | }) 181 | }, 182 | want: "unique id", 183 | }, 184 | } 185 | for _, tt := range tests { 186 | t.Run(tt.name, func(t *testing.T) { 187 | var got string 188 | if tt.fn() == nil { 189 | got = NewOptions(defaultExpire, encryptionKey). 190 | genIDFn() 191 | } else { 192 | got = NewOptions(defaultExpire, encryptionKey, 193 | tt.fn()).genIDFn() 194 | } 195 | assert.Equal(t, tt.want, got) 196 | }) 197 | } 198 | } 199 | -------------------------------------------------------------------------------- /internal/jwt/management.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package jwt 16 | 17 | import ( 18 | "fmt" 19 | "time" 20 | 21 | "github.com/ecodeclub/ekit/bean/option" 22 | "github.com/golang-jwt/jwt/v5" 23 | ) 24 | 25 | var _ Manager[int] = &Management[int]{} 26 | 27 | type Management[T any] struct { 28 | accessJWTOptions Options // 资源 token 选项 29 | nowFunc func() time.Time // 控制 jwt 的时间 30 | } 31 | 32 | // NewManagement 定义一个 Management. 33 | // allowTokenHeader: 默认使用 authorization 为认证请求头. 34 | // exposeAccessHeader: 默认使用 x-access-token 为暴露外部的资源请求头. 35 | // exposeRefreshHeader: 默认使用 x-refresh-token 为暴露外部的刷新请求头. 36 | // refreshJWTOptions: 默认使用 nil 为刷新 token 的配置, 37 | // 如要使用 refresh 相关功能则需要使用 WithRefreshJWTOptions 添加相关配置. 38 | // rotateRefreshToken: 默认不轮换刷新令牌. 39 | // 该配置需要设置 refreshJWTOptions 才有效. 40 | func NewManagement[T any](accessJWTOptions Options, 41 | opts ...option.Option[Management[T]]) *Management[T] { 42 | dOpts := defaultManagementOptions[T]() 43 | dOpts.accessJWTOptions = accessJWTOptions 44 | option.Apply[Management[T]](&dOpts, opts...) 45 | return &dOpts 46 | } 47 | 48 | func defaultManagementOptions[T any]() Management[T] { 49 | return Management[T]{ 50 | nowFunc: time.Now, 51 | } 52 | } 53 | 54 | // WithNowFunc 设置当前时间. 55 | // 一般用于测试固定 jwt. 56 | func WithNowFunc[T any](nowFunc func() time.Time) option.Option[Management[T]] { 57 | return func(m *Management[T]) { 58 | m.nowFunc = nowFunc 59 | } 60 | } 61 | 62 | // GenerateAccessToken 生成资源 token. 63 | func (m *Management[T]) GenerateAccessToken(data T) (string, error) { 64 | nowTime := m.nowFunc() 65 | claims := RegisteredClaims[T]{ 66 | Data: data, 67 | RegisteredClaims: jwt.RegisteredClaims{ 68 | Issuer: m.accessJWTOptions.Issuer, 69 | ExpiresAt: jwt.NewNumericDate(nowTime.Add(m.accessJWTOptions.Expire)), 70 | IssuedAt: jwt.NewNumericDate(nowTime), 71 | ID: m.accessJWTOptions.genIDFn(), 72 | }, 73 | } 74 | token := jwt.NewWithClaims(m.accessJWTOptions.Method, claims) 75 | return token.SignedString([]byte(m.accessJWTOptions.EncryptionKey)) 76 | } 77 | 78 | // VerifyAccessToken 校验资源 token. 79 | func (m *Management[T]) VerifyAccessToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) { 80 | t, err := jwt.ParseWithClaims(token, &RegisteredClaims[T]{}, 81 | func(*jwt.Token) (interface{}, error) { 82 | return []byte(m.accessJWTOptions.DecryptKey), nil 83 | }, 84 | opts..., 85 | ) 86 | if err != nil || !t.Valid { 87 | return RegisteredClaims[T]{}, fmt.Errorf("验证失败: %v", err) 88 | } 89 | clm, _ := t.Claims.(*RegisteredClaims[T]) 90 | return *clm, nil 91 | } 92 | -------------------------------------------------------------------------------- /internal/jwt/management_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package jwt 16 | 17 | import ( 18 | "fmt" 19 | "net/http" 20 | "testing" 21 | "time" 22 | 23 | "github.com/gin-gonic/gin" 24 | "github.com/golang-jwt/jwt/v5" 25 | "github.com/stretchr/testify/assert" 26 | ) 27 | 28 | type data struct { 29 | Foo string `json:"foo"` 30 | } 31 | 32 | var ( 33 | defaultExpire = 10 * time.Minute 34 | defaultClaims = RegisteredClaims[data]{ 35 | Data: data{Foo: "1"}, 36 | RegisteredClaims: jwt.RegisteredClaims{ 37 | ExpiresAt: jwt.NewNumericDate(nowTime.Add(defaultExpire)), 38 | IssuedAt: jwt.NewNumericDate(nowTime), 39 | }, 40 | } 41 | encryptionKey = "sign key" 42 | nowTime = time.UnixMilli(1695571200000) 43 | defaultOption = NewOptions(defaultExpire, encryptionKey) 44 | defaultManagement = NewManagement[data](defaultOption, 45 | WithNowFunc[data](func() time.Time { 46 | return nowTime 47 | }), 48 | ) 49 | ) 50 | 51 | func TestManagement_GenerateAccessToken(t *testing.T) { 52 | m := defaultManagement 53 | type testCase[T any] struct { 54 | name string 55 | data T 56 | want string 57 | wantErr error 58 | } 59 | tests := []testCase[data]{ 60 | { 61 | name: "normal", 62 | data: data{Foo: "1"}, 63 | want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ", 64 | }, 65 | } 66 | for _, tt := range tests { 67 | t.Run(tt.name, func(t *testing.T) { 68 | got, err := m.GenerateAccessToken(tt.data) 69 | assert.Equal(t, tt.wantErr, err) 70 | assert.Equal(t, tt.want, got) 71 | }) 72 | } 73 | } 74 | 75 | func TestManagement_VerifyAccessToken(t *testing.T) { 76 | type testCase[T any] struct { 77 | name string 78 | m *Management[T] 79 | token string 80 | want RegisteredClaims[T] 81 | wantErr error 82 | } 83 | tests := []testCase[data]{ 84 | { 85 | name: "normal", 86 | m: defaultManagement, 87 | token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ", 88 | want: defaultClaims, 89 | }, 90 | { 91 | // token 过期了 92 | name: "token_expired", 93 | m: NewManagement[data](defaultOption, 94 | WithNowFunc[data](func() time.Time { 95 | return time.UnixMilli(1695671200000) 96 | }), 97 | ), 98 | token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.RMpM5YNgxl9OtCy4lt_JRxv6k8s6plCkthnAV-vbXEQ", 99 | wantErr: fmt.Errorf("验证失败: %v", 100 | fmt.Errorf("%v: %v", jwt.ErrTokenInvalidClaims, jwt.ErrTokenExpired)), 101 | }, 102 | { 103 | // token 签名错误 104 | name: "bad_sign_key", 105 | m: NewManagement[data]( 106 | defaultOption, 107 | WithNowFunc[data](func() time.Time { 108 | return nowTime 109 | }), 110 | ), 111 | token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkYXRhIjp7ImZvbyI6IjEifSwiZXhwIjoxNjk1NTcxODAwLCJpYXQiOjE2OTU1NzEyMDB9.pnP991l48s_j4fkiZnmh48gjgDGult9Or_wLChHvYp0", 112 | wantErr: fmt.Errorf("验证失败: %v", 113 | fmt.Errorf("%v: %v", jwt.ErrTokenSignatureInvalid, jwt.ErrSignatureInvalid)), 114 | }, 115 | { 116 | // 错误的 token 117 | name: "bad_token", 118 | m: defaultManagement, 119 | token: "bad_token", 120 | wantErr: fmt.Errorf("验证失败: %v: token contains an invalid number of segments", 121 | jwt.ErrTokenMalformed), 122 | }, 123 | } 124 | for _, tt := range tests { 125 | t.Run(tt.name, func(t *testing.T) { 126 | got, err := tt.m.VerifyAccessToken(tt.token, 127 | jwt.WithTimeFunc(tt.m.nowFunc)) 128 | assert.Equal(t, tt.wantErr, err) 129 | assert.Equal(t, tt.want, got) 130 | }) 131 | } 132 | } 133 | 134 | func TestNewManagement(t *testing.T) { 135 | type testCase[T any] struct { 136 | name string 137 | accessJWTOptions Options 138 | wantPanic bool 139 | } 140 | tests := []testCase[data]{ 141 | { 142 | name: "normal", 143 | accessJWTOptions: defaultOption, 144 | wantPanic: false, 145 | }, 146 | } 147 | for _, tt := range tests { 148 | t.Run(tt.name, func(t *testing.T) { 149 | defer func() { 150 | if err := recover(); err != nil { 151 | if !tt.wantPanic { 152 | t.Errorf("期望出现 painc ,但没有") 153 | } 154 | } 155 | }() 156 | NewManagement[data](tt.accessJWTOptions) 157 | }) 158 | } 159 | } 160 | 161 | func (m *Management[T]) registerRoutes(server *gin.Engine) { 162 | server.GET("/", func(ctx *gin.Context) { 163 | ctx.Status(http.StatusOK) 164 | }) 165 | server.GET("/login", func(ctx *gin.Context) { 166 | ctx.Status(http.StatusOK) 167 | }) 168 | } 169 | -------------------------------------------------------------------------------- /internal/jwt/types.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package jwt 16 | 17 | import ( 18 | "github.com/golang-jwt/jwt/v5" 19 | ) 20 | 21 | // Manager jwt 管理器. 22 | type Manager[T any] interface { 23 | // GenerateAccessToken 生成资源 token. 24 | GenerateAccessToken(data T) (string, error) 25 | 26 | // VerifyAccessToken 校验资源 token. 27 | VerifyAccessToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) 28 | } 29 | 30 | type RegisteredClaims[T any] struct { 31 | Data T `json:"data"` 32 | jwt.RegisteredClaims 33 | } 34 | -------------------------------------------------------------------------------- /internal/ratelimit/mocks/ratelimit.mock.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: types.go 3 | // 4 | // Generated by this command: 5 | // 6 | // mockgen -source=types.go -package=limitmocks -destination=./mocks/ratelimit.mock.go 7 | // 8 | // Package limitmocks is a generated GoMock package. 9 | package limitmocks 10 | 11 | import ( 12 | context "context" 13 | reflect "reflect" 14 | 15 | gomock "go.uber.org/mock/gomock" 16 | ) 17 | 18 | // MockLimiter is a mock of Limiter interface. 19 | type MockLimiter struct { 20 | ctrl *gomock.Controller 21 | recorder *MockLimiterMockRecorder 22 | } 23 | 24 | // MockLimiterMockRecorder is the mock recorder for MockLimiter. 25 | type MockLimiterMockRecorder struct { 26 | mock *MockLimiter 27 | } 28 | 29 | // NewMockLimiter creates a new mock instance. 30 | func NewMockLimiter(ctrl *gomock.Controller) *MockLimiter { 31 | mock := &MockLimiter{ctrl: ctrl} 32 | mock.recorder = &MockLimiterMockRecorder{mock} 33 | return mock 34 | } 35 | 36 | // EXPECT returns an object that allows the caller to indicate expected use. 37 | func (m *MockLimiter) EXPECT() *MockLimiterMockRecorder { 38 | return m.recorder 39 | } 40 | 41 | // Limit mocks base method. 42 | func (m *MockLimiter) Limit(ctx context.Context, key string) (bool, error) { 43 | m.ctrl.T.Helper() 44 | ret := m.ctrl.Call(m, "Limit", ctx, key) 45 | ret0, _ := ret[0].(bool) 46 | ret1, _ := ret[1].(error) 47 | return ret0, ret1 48 | } 49 | 50 | // Limit indicates an expected call of Limit. 51 | func (mr *MockLimiterMockRecorder) Limit(ctx, key any) *gomock.Call { 52 | mr.mock.ctrl.T.Helper() 53 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Limit", reflect.TypeOf((*MockLimiter)(nil).Limit), ctx, key) 54 | } 55 | -------------------------------------------------------------------------------- /internal/ratelimit/redis_slide_window.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package ratelimit 16 | 17 | import ( 18 | "context" 19 | _ "embed" 20 | "fmt" 21 | "time" 22 | 23 | "github.com/google/uuid" 24 | "github.com/redis/go-redis/v9" 25 | ) 26 | 27 | //go:embed slide_window.lua 28 | var luaSlideWindow string 29 | 30 | // RedisSlidingWindowLimiter Redis 上的滑动窗口算法限流器实现 31 | type RedisSlidingWindowLimiter struct { 32 | Cmd redis.Cmdable 33 | 34 | // 窗口大小 35 | Interval time.Duration 36 | // 阈值 37 | Rate int 38 | // Interval 内允许 Rate 个请求 39 | // 1s 内允许 3000 个请求 40 | } 41 | 42 | func (r *RedisSlidingWindowLimiter) Limit(ctx context.Context, key string) (bool, error) { 43 | uid, err := uuid.NewUUID() 44 | if err != nil { 45 | return false, fmt.Errorf("generate uuid failed: %w", err) 46 | } 47 | return r.Cmd.Eval(ctx, luaSlideWindow, []string{key}, 48 | r.Interval.Milliseconds(), r.Rate, time.Now().UnixMilli(), uid.String()).Bool() 49 | } 50 | -------------------------------------------------------------------------------- /internal/ratelimit/redis_slide_window_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //go:build e2e 16 | 17 | package ratelimit 18 | 19 | import ( 20 | "context" 21 | "testing" 22 | "time" 23 | 24 | "github.com/redis/go-redis/v9" 25 | "github.com/stretchr/testify/assert" 26 | ) 27 | 28 | func TestRedisSlidingWindowLimiter_Limit(t *testing.T) { 29 | r := &RedisSlidingWindowLimiter{ 30 | Cmd: initRedis(), 31 | Interval: 500 * time.Millisecond, 32 | Rate: 1, 33 | } 34 | tests := []struct { 35 | name string 36 | ctx context.Context 37 | key string 38 | interval time.Duration 39 | want bool 40 | wantErr error 41 | }{ 42 | { 43 | name: "正常通过", 44 | ctx: context.Background(), 45 | key: "foo", 46 | want: false, 47 | }, 48 | { 49 | name: "另外一个key正常通过", 50 | ctx: context.Background(), 51 | key: "bar", 52 | want: false, 53 | }, 54 | { 55 | name: "限流", 56 | ctx: context.Background(), 57 | key: "foo", 58 | interval: 300 * time.Millisecond, 59 | want: true, 60 | }, 61 | { 62 | name: "窗口有空余正常通过", 63 | ctx: context.Background(), 64 | key: "foo", 65 | interval: 510 * time.Millisecond, 66 | want: false, 67 | }, 68 | } 69 | for _, tt := range tests { 70 | t.Run(tt.name, func(t *testing.T) { 71 | <-time.After(tt.interval) 72 | got, err := r.Limit(tt.ctx, tt.key) 73 | assert.Equal(t, tt.wantErr, err) 74 | assert.Equal(t, tt.want, got) 75 | }) 76 | } 77 | } 78 | 79 | func initRedis() redis.Cmdable { 80 | redisClient := redis.NewClient(&redis.Options{ 81 | Addr: "localhost:16379", 82 | }) 83 | return redisClient 84 | } 85 | 86 | func TestRedisSlidingWindowLimiter(t *testing.T) { 87 | r := &RedisSlidingWindowLimiter{ 88 | Cmd: initRedis(), 89 | Interval: time.Second, 90 | Rate: 1200, 91 | } 92 | var ( 93 | total = 1500 // 总请求数 94 | succCount int // 成功请求数 95 | limitCount int // 被限流的请求数 96 | ) 97 | start := time.Now() 98 | for i := 0; i < total; i++ { 99 | limit, err := r.Limit(context.Background(), "TestRedisSlidingWindowLimiter") 100 | if err != nil { 101 | t.Fatalf("limit error: %v", err) 102 | return 103 | } 104 | if limit { 105 | limitCount++ 106 | continue 107 | } 108 | succCount++ 109 | } 110 | end := time.Now() 111 | t.Logf("开始时间: %v", start.Format(time.StampMilli)) 112 | t.Logf("结束时间: %v", end.Format(time.StampMilli)) 113 | t.Logf("total: %d, succ: %d, limited: %d", total, succCount, limitCount) 114 | } 115 | -------------------------------------------------------------------------------- /internal/ratelimit/slide_window.lua: -------------------------------------------------------------------------------- 1 | -- 限流对象 2 | local key = KEYS[1] 3 | -- 窗口大小 4 | local window = tonumber(ARGV[1]) 5 | -- 阈值 6 | local threshold = tonumber(ARGV[2]) 7 | local now = tonumber(ARGV[3]) 8 | -- 唯一ID, 用于解决同一时间内多个请求只统计一次的问题 9 | -- SEE: issue #27 10 | local uid = ARGV[4] 11 | -- 窗口的起始时间 12 | local min = now - window 13 | 14 | redis.call('ZREMRANGEBYSCORE', key, '-inf', min) 15 | local cnt = redis.call('ZCOUNT', key, '-inf', '+inf') 16 | -- local cnt = redis.call('ZCOUNT', key, min, '+inf') 17 | if cnt >= threshold then 18 | -- 执行限流 19 | return "true" 20 | else 21 | -- score 设置为当前时间, member 设置为唯一id 22 | redis.call('ZADD', key, now, uid) 23 | redis.call('PEXPIRE', key, window) 24 | return "false" 25 | end -------------------------------------------------------------------------------- /internal/ratelimit/types.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package ratelimit 16 | 17 | import "context" 18 | 19 | //go:generate mockgen -source=types.go -package=limitmocks -destination=./mocks/ratelimit.mock.go 20 | type Limiter interface { 21 | // Limit 有没有触发限流。key 就是限流对象 22 | // bool 代表是否限流,true 就是要限流 23 | // err 限流器本身有没有错误 24 | Limit(ctx context.Context, key string) (bool, error) 25 | } 26 | -------------------------------------------------------------------------------- /middlewares/accesslog/builder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package accesslog 16 | 17 | import ( 18 | "bytes" 19 | "context" 20 | "io" 21 | "time" 22 | 23 | "github.com/gin-gonic/gin" 24 | "go.uber.org/atomic" 25 | ) 26 | 27 | type AccessLog struct { 28 | //http 请求类型 29 | Method string 30 | //url 整个请求的url 31 | Url string 32 | //请求体 33 | ReqBody string 34 | //响应体 35 | RespBody string 36 | //处理时间 37 | Duration string 38 | //状态码 39 | Status int 40 | } 41 | 42 | type Builder struct { 43 | allowReqBody *atomic.Bool 44 | allowRespBody *atomic.Bool 45 | //logger logger.LoggerV1 //这里要自己确认用什么日志级别 46 | // 47 | loggerFunc func(ctx context.Context, al *AccessLog) 48 | maxLength *atomic.Int64 49 | } 50 | 51 | func NewBuilder(fn func(ctx context.Context, al *AccessLog)) *Builder { 52 | return &Builder{ 53 | allowReqBody: atomic.NewBool(false), 54 | allowRespBody: atomic.NewBool(false), 55 | loggerFunc: fn, 56 | maxLength: atomic.NewInt64(1024), 57 | } 58 | } 59 | 60 | // AllowReqBody 是否打印请求体 61 | func (b *Builder) AllowReqBody() *Builder { 62 | b.allowReqBody.Store(true) 63 | return b 64 | } 65 | 66 | // AllowRespBody 是否打印响应体 67 | func (b *Builder) AllowRespBody() *Builder { 68 | b.allowRespBody.Store(true) 69 | return b 70 | } 71 | 72 | // MaxLength 打印的最大长度 73 | func (b *Builder) MaxLength(maxLength int64) *Builder { 74 | b.maxLength.Store(maxLength) 75 | return b 76 | } 77 | 78 | func (b *Builder) Builder() gin.HandlerFunc { 79 | return func(ctx *gin.Context) { 80 | var ( 81 | //请求处理开始时间 82 | start = time.Now() 83 | //url 84 | url = ctx.Request.URL.String() 85 | //url 长度 86 | curLen = int64(len(url)) 87 | //运行打印的最大长度 88 | maxLength = b.maxLength.Load() 89 | //是否打印请求体 90 | allowReqBody = b.allowReqBody.Load() 91 | //是否打印响应体 92 | allowRespBody = b.allowRespBody.Load() 93 | ) 94 | 95 | if curLen >= maxLength { 96 | url = url[:maxLength] 97 | } 98 | 99 | accessLog := &AccessLog{ 100 | Method: ctx.Request.Method, 101 | Url: url, 102 | } 103 | if ctx.Request.Body != nil && allowReqBody { 104 | body, _ := ctx.GetRawData() 105 | ctx.Request.Body = io.NopCloser(bytes.NewReader(body)) 106 | if int64(len(body)) >= maxLength { 107 | body = body[:maxLength] 108 | } 109 | //注意资源的消耗 110 | accessLog.ReqBody = string(body) 111 | } 112 | 113 | if allowRespBody { 114 | ctx.Writer = responseWriter{ 115 | ResponseWriter: ctx.Writer, 116 | al: accessLog, 117 | maxLength: maxLength, 118 | } 119 | } 120 | 121 | defer func() { 122 | accessLog.Duration = time.Since(start).String() 123 | //日志打印 124 | b.loggerFunc(ctx, accessLog) 125 | }() 126 | ctx.Next() 127 | } 128 | } 129 | 130 | type responseWriter struct { 131 | gin.ResponseWriter 132 | al *AccessLog 133 | maxLength int64 134 | } 135 | 136 | func (r responseWriter) WriteHeader(statusCode int) { 137 | 138 | r.al.Status = statusCode 139 | r.ResponseWriter.WriteHeader(statusCode) 140 | } 141 | 142 | func (r responseWriter) Write(data []byte) (int, error) { 143 | curLen := int64(len(data)) 144 | if curLen >= r.maxLength { 145 | data = data[:r.maxLength] 146 | } 147 | r.al.RespBody = string(data) 148 | return r.ResponseWriter.Write(data) 149 | } 150 | -------------------------------------------------------------------------------- /middlewares/accesslog/builder_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package accesslog 16 | 17 | import ( 18 | "context" 19 | "fmt" 20 | "net/http" 21 | "net/http/httptest" 22 | "strings" 23 | "testing" 24 | "time" 25 | 26 | "github.com/gin-gonic/gin" 27 | "github.com/stretchr/testify/assert" 28 | "github.com/stretchr/testify/require" 29 | ) 30 | 31 | func TestBuilder_Builder(t *testing.T) { 32 | testCases := []struct { 33 | name string 34 | getReq func() *http.Request 35 | accesslog *AccessLog 36 | logfunc func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) 37 | middleWarebuilder func(func(ctx context.Context, al *AccessLog)) gin.HandlerFunc 38 | setStatus int 39 | setRsp string 40 | resultAccessLog *AccessLog 41 | }{ 42 | { 43 | name: "不打印请求体,响应体", 44 | getReq: func() *http.Request { 45 | req, err := http.NewRequest(http.MethodGet, "/accesslog", nil) 46 | require.NoError(t, err) 47 | return req 48 | }, 49 | accesslog: &AccessLog{}, 50 | logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) { 51 | 52 | return func(ctx context.Context, al *AccessLog) { 53 | // 54 | //accesslog.Status = al.Status 55 | //accesslog.Method = al.Method 56 | ////url 整个请求的u 57 | //accesslog.Url = al.Url 58 | ////请求体 59 | //accesslog.ReqBody = al.ReqBody 60 | ////响应体 61 | //accesslog.RespBody = al.RespBody 62 | ////处理时间 63 | //accesslog.Duration = al.Duration 64 | ////状态码 65 | //accesslog.Status = al.Status 66 | copy(accesslog, al) 67 | fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration) 68 | } 69 | }, 70 | middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc { 71 | return NewBuilder(f).Builder() 72 | }, 73 | resultAccessLog: &AccessLog{ 74 | Method: "GET", 75 | Url: "/accesslog", 76 | }, 77 | }, 78 | { 79 | name: "不打印请求体,打印响应体", 80 | getReq: func() *http.Request { 81 | req, err := http.NewRequest(http.MethodGet, "/accesslog", nil) 82 | require.NoError(t, err) 83 | return req 84 | }, 85 | accesslog: &AccessLog{}, 86 | logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) { 87 | 88 | return func(ctx context.Context, al *AccessLog) { 89 | 90 | copy(accesslog, al) 91 | 92 | fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration) 93 | } 94 | }, 95 | middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc { 96 | return NewBuilder(f).AllowRespBody().Builder() 97 | }, 98 | resultAccessLog: &AccessLog{ 99 | Method: "GET", 100 | Url: "/accesslog", 101 | RespBody: `{"msg":"aa22"}`, 102 | Status: http.StatusOK, 103 | }, 104 | }, 105 | { 106 | name: "打印请求体,不打印响应体", 107 | getReq: func() *http.Request { 108 | read := strings.NewReader(`{"msg":"aa11"}`) 109 | 110 | req, err := http.NewRequest(http.MethodGet, "/accesslog", read) 111 | require.NoError(t, err) 112 | return req 113 | }, 114 | accesslog: &AccessLog{}, 115 | logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) { 116 | 117 | return func(ctx context.Context, al *AccessLog) { 118 | 119 | copy(accesslog, al) 120 | 121 | fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration) 122 | } 123 | }, 124 | middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc { 125 | return NewBuilder(f).AllowReqBody().Builder() 126 | }, 127 | resultAccessLog: &AccessLog{ 128 | Method: "GET", 129 | Url: "/accesslog", 130 | ReqBody: `{"msg":"aa11"}`, 131 | }, 132 | }, 133 | { 134 | name: "打印请求体,打印响应体", 135 | getReq: func() *http.Request { 136 | read := strings.NewReader(`{"msg":"aa11"}`) 137 | 138 | req, err := http.NewRequest(http.MethodGet, "/accesslog", read) 139 | require.NoError(t, err) 140 | return req 141 | }, 142 | accesslog: &AccessLog{}, 143 | logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) { 144 | 145 | return func(ctx context.Context, al *AccessLog) { 146 | 147 | copy(accesslog, al) 148 | 149 | fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration) 150 | } 151 | }, 152 | middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc { 153 | return NewBuilder(f).AllowReqBody().AllowRespBody().Builder() 154 | }, 155 | resultAccessLog: &AccessLog{ 156 | Method: "GET", 157 | Url: "/accesslog", 158 | ReqBody: `{"msg":"aa11"}`, 159 | RespBody: `{"msg":"aa22"}`, 160 | Status: http.StatusOK, 161 | }, 162 | }, 163 | { 164 | name: "打印请求体超标,不打印响应体,限制长度为10", 165 | getReq: func() *http.Request { 166 | read := strings.NewReader(`{"msg":"aa11"}`) 167 | 168 | req, err := http.NewRequest(http.MethodGet, "/accesslog", read) 169 | require.NoError(t, err) 170 | return req 171 | }, 172 | accesslog: &AccessLog{}, 173 | logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) { 174 | 175 | return func(ctx context.Context, al *AccessLog) { 176 | 177 | copy(accesslog, al) 178 | 179 | fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration) 180 | } 181 | }, 182 | middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc { 183 | return NewBuilder(f).AllowReqBody().MaxLength(10).Builder() 184 | }, 185 | resultAccessLog: &AccessLog{ 186 | Method: "GET", 187 | Url: "/accesslog", 188 | ReqBody: `{"msg":"aa`, 189 | }, 190 | }, 191 | { 192 | name: "不打印请求体,打印响应体超标,限制长度为10", 193 | getReq: func() *http.Request { 194 | read := strings.NewReader(`{"msg":"aa11"}`) 195 | 196 | req, err := http.NewRequest(http.MethodGet, "/accesslog", read) 197 | require.NoError(t, err) 198 | return req 199 | }, 200 | accesslog: &AccessLog{}, 201 | logfunc: func(accesslog *AccessLog) func(ctx context.Context, al *AccessLog) { 202 | 203 | return func(ctx context.Context, al *AccessLog) { 204 | 205 | copy(accesslog, al) 206 | 207 | fmt.Printf("请求类型: %s \n请求url:%s \n请求体:%s \n响应体:%s \n状态码:%d \n消耗时间:%s \n", al.Method, al.Url, al.ReqBody, al.RespBody, al.Status, al.Duration) 208 | } 209 | }, 210 | middleWarebuilder: func(f func(ctx context.Context, al *AccessLog)) gin.HandlerFunc { 211 | return NewBuilder(f).AllowRespBody().MaxLength(10).Builder() 212 | }, 213 | resultAccessLog: &AccessLog{ 214 | Method: "GET", 215 | Url: "/accesslog", 216 | RespBody: `{"msg":"aa`, 217 | Status: http.StatusOK, 218 | }, 219 | }, 220 | } 221 | 222 | for _, tc := range testCases { 223 | t.Run(tc.name, func(t *testing.T) { 224 | server := gin.Default() 225 | server.Use(tc.middleWarebuilder(tc.logfunc(tc.accesslog))) 226 | server.GET("/accesslog", func(ctx *gin.Context) { 227 | ctx.JSON(http.StatusOK, map[string]any{ 228 | "msg": "aa22", 229 | }) 230 | }) 231 | resp := httptest.NewRecorder() 232 | 233 | server.ServeHTTP(resp, tc.getReq()) 234 | //中间件使用的defer 所有这里要给点时间 235 | time.Sleep(time.Millisecond * 100) 236 | assert.Equal(t, tc.accesslog.Method, tc.resultAccessLog.Method) 237 | assert.Equal(t, tc.accesslog.Url, tc.resultAccessLog.Url) 238 | assert.Equal(t, tc.accesslog.ReqBody, tc.resultAccessLog.ReqBody) 239 | assert.Equal(t, tc.accesslog.RespBody, tc.resultAccessLog.RespBody) 240 | //时间不好判断 241 | //assert.Equal(t, tc.accesslog.Duration, tc.resultAccessLog.Duration) 242 | 243 | assert.Equal(t, tc.accesslog.Status, tc.resultAccessLog.Status) 244 | 245 | }) 246 | 247 | } 248 | } 249 | 250 | func copy(source, target *AccessLog) { 251 | source.Status = target.Status 252 | source.Method = target.Method 253 | //url 整个请求的u 254 | source.Url = target.Url 255 | //请求体 256 | source.ReqBody = target.ReqBody 257 | //响应体 258 | source.RespBody = target.RespBody 259 | //处理时间 260 | source.Duration = target.Duration 261 | //状态码 262 | source.Status = target.Status 263 | } 264 | -------------------------------------------------------------------------------- /middlewares/activelimit/locallimit/builder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package locallimit 16 | 17 | import ( 18 | "net/http" 19 | 20 | "github.com/gin-gonic/gin" 21 | "go.uber.org/atomic" 22 | ) 23 | 24 | type LocalActiveLimit struct { 25 | // 最大限制个数 26 | maxActive *atomic.Int64 27 | // 当前活跃个数 28 | countActive *atomic.Int64 29 | } 30 | 31 | // NewLocalActiveLimit 全局限流 32 | func NewLocalActiveLimit(maxActive int64) *LocalActiveLimit { 33 | return &LocalActiveLimit{ 34 | maxActive: atomic.NewInt64(maxActive), 35 | countActive: atomic.NewInt64(0), 36 | } 37 | } 38 | 39 | func (a *LocalActiveLimit) Build() gin.HandlerFunc { 40 | return func(ctx *gin.Context) { 41 | current := a.countActive.Add(1) 42 | defer func() { 43 | a.countActive.Sub(1) 44 | }() 45 | if current <= a.maxActive.Load() { 46 | ctx.Next() 47 | } else { 48 | ctx.AbortWithStatus(http.StatusTooManyRequests) 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /middlewares/activelimit/locallimit/builder_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package locallimit 16 | 17 | import ( 18 | "net/http" 19 | "net/http/httptest" 20 | "testing" 21 | 22 | "github.com/gin-gonic/gin" 23 | "github.com/stretchr/testify/assert" 24 | "github.com/stretchr/testify/require" 25 | ) 26 | 27 | func TestLocalActiveLimit_Build(t *testing.T) { 28 | const ( 29 | url = "/" 30 | ) 31 | tests := []struct { 32 | name string 33 | countActive int64 34 | reqBuilder func(t *testing.T) *http.Request 35 | 36 | // 预期响应 37 | wantCode int 38 | }{ 39 | { 40 | name: "正常通过", 41 | countActive: 0, 42 | reqBuilder: func(t *testing.T) *http.Request { 43 | req, err := http.NewRequest(http.MethodGet, url, nil) 44 | require.NoError(t, err) 45 | return req 46 | }, 47 | wantCode: http.StatusNoContent, 48 | }, 49 | { 50 | name: "限流中", 51 | countActive: 1, 52 | reqBuilder: func(t *testing.T) *http.Request { 53 | req, err := http.NewRequest(http.MethodGet, url, nil) 54 | require.NoError(t, err) 55 | return req 56 | }, 57 | wantCode: http.StatusTooManyRequests, 58 | }, 59 | } 60 | for _, tt := range tests { 61 | t.Run(tt.name, func(t *testing.T) { 62 | limit := NewLocalActiveLimit(1) 63 | limit.countActive.Store(tt.countActive) 64 | server := gin.Default() 65 | server.Use(limit.Build()) 66 | server.GET(url, func(c *gin.Context) { 67 | c.Status(http.StatusNoContent) 68 | }) 69 | 70 | req := tt.reqBuilder(t) 71 | recorder := httptest.NewRecorder() 72 | 73 | server.ServeHTTP(recorder, req) 74 | 75 | assert.Equal(t, tt.wantCode, recorder.Code) 76 | }) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /middlewares/activelimit/redislimit/builder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package redislimit 16 | 17 | import ( 18 | "fmt" 19 | "net/http" 20 | 21 | "github.com/gin-gonic/gin" 22 | "github.com/redis/go-redis/v9" 23 | "go.uber.org/atomic" 24 | ) 25 | 26 | type RedisActiveLimit struct { 27 | // 最大限制个数 28 | maxActive *atomic.Int64 29 | 30 | // 用来记录连接数的key 31 | key string 32 | cmd redis.Cmdable 33 | logFn func(msg any, args ...any) 34 | } 35 | 36 | // NewRedisActiveLimit 全局限流 37 | func NewRedisActiveLimit(cmd redis.Cmdable, maxActive int64, key string) *RedisActiveLimit { 38 | return &RedisActiveLimit{ 39 | maxActive: atomic.NewInt64(maxActive), 40 | key: key, 41 | cmd: cmd, 42 | logFn: func(msg any, args ...any) { 43 | fmt.Printf("%v 详细信息: %v \n", msg, args) 44 | }, 45 | } 46 | } 47 | 48 | func (a *RedisActiveLimit) SetLogFunc(fun func(msg any, args ...any)) *RedisActiveLimit { 49 | a.logFn = fun 50 | return a 51 | } 52 | 53 | func (a *RedisActiveLimit) Build() gin.HandlerFunc { 54 | return func(ctx *gin.Context) { 55 | currentCount, err := a.cmd.Incr(ctx, a.key).Result() 56 | if err != nil { 57 | // 为了安全性 直接返回异常 58 | a.logFn("redis 加一操作", err) 59 | ctx.AbortWithStatus(http.StatusInternalServerError) 60 | return 61 | } 62 | defer func() { 63 | if err = a.cmd.Decr(ctx, a.key).Err(); err != nil { 64 | a.logFn("redis 减一操作", err) 65 | } 66 | }() 67 | if currentCount <= a.maxActive.Load() { 68 | ctx.Next() 69 | } else { 70 | a.logFn("web server ", "限流中..") 71 | ctx.AbortWithStatus(http.StatusTooManyRequests) 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /middlewares/activelimit/redislimit/builder_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package redislimit 16 | 17 | import ( 18 | "context" 19 | "errors" 20 | "net/http" 21 | "net/http/httptest" 22 | "testing" 23 | 24 | "github.com/gin-gonic/gin" 25 | "github.com/redis/go-redis/v9" 26 | "github.com/stretchr/testify/assert" 27 | "github.com/stretchr/testify/require" 28 | "go.uber.org/mock/gomock" 29 | 30 | "github.com/ecodeclub/ginx/internal/mocks" 31 | ) 32 | 33 | func TestRedisActiveLimit_Build(t *testing.T) { 34 | const ( 35 | url = "/" 36 | key = "limit" 37 | ) 38 | tests := []struct { 39 | name string 40 | mock func(ctrl *gomock.Controller) redis.Cmdable 41 | reqBuilder func(t *testing.T) *http.Request 42 | 43 | // 预期响应 44 | wantCode int 45 | }{ 46 | { 47 | name: "正常通过", 48 | mock: func(ctrl *gomock.Controller) redis.Cmdable { 49 | cmd := mocks.NewMockCmdable(ctrl) 50 | res := redis.NewIntCmd(context.Background()) 51 | res.SetVal(int64(1)) 52 | cmd.EXPECT().Incr(gomock.Any(), key).Return(res) 53 | res = redis.NewIntCmd(context.Background()) 54 | cmd.EXPECT().Decr(gomock.Any(), key).Return(res) 55 | return cmd 56 | }, 57 | reqBuilder: func(t *testing.T) *http.Request { 58 | req, err := http.NewRequest(http.MethodGet, url, nil) 59 | require.NoError(t, err) 60 | return req 61 | }, 62 | wantCode: http.StatusNoContent, 63 | }, 64 | { 65 | name: "限流中", 66 | mock: func(ctrl *gomock.Controller) redis.Cmdable { 67 | ctx := context.Background() 68 | cmd := mocks.NewMockCmdable(ctrl) 69 | res := redis.NewIntCmd(ctx) 70 | res.SetVal(int64(2)) 71 | cmd.EXPECT().Incr(gomock.Any(), key).Return(res) 72 | res = redis.NewIntCmd(ctx) 73 | cmd.EXPECT().Decr(gomock.Any(), key).Return(res) 74 | return cmd 75 | }, 76 | reqBuilder: func(t *testing.T) *http.Request { 77 | req, err := http.NewRequest(http.MethodGet, url, nil) 78 | require.NoError(t, err) 79 | return req 80 | }, 81 | wantCode: http.StatusTooManyRequests, 82 | }, 83 | { 84 | name: "defer 的减1操作失败", 85 | mock: func(ctrl *gomock.Controller) redis.Cmdable { 86 | ctx := context.Background() 87 | cmd := mocks.NewMockCmdable(ctrl) 88 | res := redis.NewIntCmd(ctx) 89 | res.SetVal(int64(1)) 90 | cmd.EXPECT().Incr(gomock.Any(), key).Return(res) 91 | res = redis.NewIntCmd(ctx) 92 | res.SetErr(errors.New("模拟 redis 操作失败")) 93 | cmd.EXPECT().Decr(gomock.Any(), key).Return(res) 94 | return cmd 95 | }, 96 | reqBuilder: func(t *testing.T) *http.Request { 97 | req, err := http.NewRequest(http.MethodGet, url, nil) 98 | require.NoError(t, err) 99 | return req 100 | }, 101 | wantCode: http.StatusNoContent, 102 | }, 103 | { 104 | name: "刚进入中间件的加1操作失败", 105 | mock: func(ctrl *gomock.Controller) redis.Cmdable { 106 | cmd := mocks.NewMockCmdable(ctrl) 107 | res := redis.NewIntCmd(context.Background()) 108 | res.SetErr(errors.New("模拟 redis 操作失败")) 109 | cmd.EXPECT().Incr(gomock.Any(), key).Return(res) 110 | return cmd 111 | }, 112 | reqBuilder: func(t *testing.T) *http.Request { 113 | req, err := http.NewRequest(http.MethodGet, url, nil) 114 | require.NoError(t, err) 115 | return req 116 | }, 117 | wantCode: http.StatusInternalServerError, 118 | }, 119 | } 120 | for _, tt := range tests { 121 | t.Run(tt.name, func(t *testing.T) { 122 | ctrl := gomock.NewController(t) 123 | defer ctrl.Finish() 124 | limit := NewRedisActiveLimit(tt.mock(ctrl), 1, key) 125 | server := gin.Default() 126 | server.Use(limit.Build()) 127 | server.GET(url, func(c *gin.Context) { 128 | c.Status(http.StatusNoContent) 129 | }) 130 | 131 | req := tt.reqBuilder(t) 132 | recorder := httptest.NewRecorder() 133 | 134 | server.ServeHTTP(recorder, req) 135 | 136 | assert.Equal(t, tt.wantCode, recorder.Code) 137 | }) 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /middlewares/crawlerdetect/builder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package crawlerdetect 16 | 17 | import ( 18 | "log/slog" 19 | "net/http" 20 | "strings" 21 | 22 | "github.com/ecodeclub/ginx/internal/crawlerdetect" 23 | "github.com/gin-gonic/gin" 24 | ) 25 | 26 | const Baidu = crawlerdetect.Baidu 27 | const Bing = crawlerdetect.Bing 28 | const Google = crawlerdetect.Google 29 | const Sogou = crawlerdetect.Sogou 30 | 31 | type Builder struct { 32 | crawlersMap map[string]string 33 | } 34 | 35 | func NewBuilder() *Builder { 36 | return &Builder{ 37 | // 常用的 User-Agent 映射 38 | crawlersMap: map[string]string{ 39 | "Baiduspider": Baidu, 40 | "Baiduspider-render": Baidu, 41 | 42 | "bingbot": Bing, 43 | "adidxbot": Bing, 44 | "MicrosoftPreview": Bing, 45 | 46 | "Googlebot": Google, 47 | "Googlebot-Image": Google, 48 | "Googlebot-News": Google, 49 | "Googlebot-Video": Google, 50 | "Storebot-Google": Google, 51 | "Google-InspectionTool": Google, 52 | "GoogleOther": Google, 53 | "Google-Extended": Google, 54 | 55 | "Sogou web spider": Sogou, 56 | }, 57 | } 58 | } 59 | 60 | // AddUserAgent 添加 user-agent 映射 61 | // 例如: 62 | // 63 | // map[string][]string{ 64 | // crawlerdetect.Baidu: []string{"NewBaiduUserAgent"}, 65 | // crawlerdetect.Bing: []string{"NewBingUserAgent"}, 66 | // } 67 | func (b *Builder) AddUserAgent(userAgents map[string][]string) *Builder { 68 | for crawler, values := range userAgents { 69 | for _, userAgent := range values { 70 | b.crawlersMap[userAgent] = crawler 71 | } 72 | } 73 | return b 74 | } 75 | 76 | func (b *Builder) RemoveUserAgent(userAgents ...string) *Builder { 77 | for _, userAgent := range userAgents { 78 | delete(b.crawlersMap, userAgent) 79 | } 80 | return b 81 | } 82 | 83 | func (b *Builder) Build() gin.HandlerFunc { 84 | return func(ctx *gin.Context) { 85 | userAgent := ctx.GetHeader("User-Agent") 86 | ip := ctx.ClientIP() 87 | if ip == "" { 88 | slog.ErrorContext(ctx, "crawlerdetect", "error", "ip is empty.") 89 | ctx.AbortWithStatus(http.StatusForbidden) 90 | return 91 | } 92 | crawlerDetector := b.getCrawlerDetector(userAgent) 93 | if crawlerDetector == nil { 94 | ctx.AbortWithStatus(http.StatusForbidden) 95 | return 96 | } 97 | pass, err := crawlerDetector.CheckCrawler(ip) 98 | if err != nil { 99 | slog.ErrorContext(ctx, "crawlerdetect", "error", err.Error()) 100 | ctx.AbortWithStatus(http.StatusInternalServerError) 101 | return 102 | } 103 | if !pass { 104 | ctx.AbortWithStatus(http.StatusForbidden) 105 | return 106 | } 107 | ctx.Next() 108 | } 109 | } 110 | 111 | func (b *Builder) getCrawlerDetector(userAgent string) crawlerdetect.Strategy { 112 | for key, value := range b.crawlersMap { 113 | if strings.Contains(userAgent, key) { 114 | return crawlerdetect.NewCrawlerDetector(value) 115 | } 116 | } 117 | return nil 118 | } 119 | -------------------------------------------------------------------------------- /middlewares/crawlerdetect/builder_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package crawlerdetect 16 | 17 | import ( 18 | "net/http" 19 | "net/http/httptest" 20 | "testing" 21 | 22 | "github.com/gin-gonic/gin" 23 | "github.com/stretchr/testify/require" 24 | ) 25 | 26 | func Test_Builder(t *testing.T) { 27 | testCases := []struct { 28 | name string 29 | 30 | reqBuilder func(t *testing.T) *http.Request 31 | 32 | wantCode int 33 | }{ 34 | { 35 | name: "空 ip", 36 | reqBuilder: func(t *testing.T) *http.Request { 37 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | return req 42 | }, 43 | wantCode: 403, 44 | }, 45 | { 46 | name: "无效 ip", 47 | reqBuilder: func(t *testing.T) *http.Request { 48 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | req.Header.Set("X-Forwarded-For", "256.0.0.0") 53 | req.Header.Set("User-Agent", "Mozilla/5.0 (compatible; Baiduspider/2.0; +http://www.baidu.com/search/spider.html)") 54 | return req 55 | }, 56 | wantCode: 500, 57 | }, 58 | { 59 | name: "用户", 60 | reqBuilder: func(t *testing.T) *http.Request { 61 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 62 | if err != nil { 63 | t.Fatal(err) 64 | } 65 | req.Header.Set("X-Forwarded-For", "155.206.198.69") 66 | req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/89.0.4389.82 Safari/537.36") 67 | return req 68 | }, 69 | wantCode: 403, 70 | }, 71 | { 72 | name: "百度 - Baiduspider", 73 | reqBuilder: func(t *testing.T) *http.Request { 74 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 75 | if err != nil { 76 | t.Fatal(err) 77 | } 78 | req.Header.Set("User-Agent", "Mozilla/5.0 (compatible; Baiduspider/2.0; +http://www.baidu.com/search/spider.html)") 79 | req.Header.Set("X-Forwarded-For", "111.206.198.69") 80 | return req 81 | }, 82 | wantCode: 200, 83 | }, 84 | { 85 | name: "百度 - Baiduspider-render", 86 | reqBuilder: func(t *testing.T) *http.Request { 87 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 88 | if err != nil { 89 | t.Fatal(err) 90 | } 91 | req.Header.Set("User-Agent", "Mozilla/5.0 (iPhone;CPU iPhone OS 9_1 like Mac OS X) AppleWebKit/601.1.46 (KHTML, like Gecko)Version/9.0 Mobile/13B143 Safari/601.1 (compatible; Baiduspider-render/2.0;Smartapp; +http://www.baidu.com/search/spider.html)") 92 | req.Header.Set("X-Forwarded-For", "111.206.198.69") 93 | return req 94 | }, 95 | wantCode: 200, 96 | }, 97 | { 98 | name: "必应 - bingbot", 99 | reqBuilder: func(t *testing.T) *http.Request { 100 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 101 | if err != nil { 102 | t.Fatal(err) 103 | } 104 | req.Header.Set("User-Agent", "Mozilla/5.0 AppleWebKit/537.36 (KHTML, like Gecko; compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm) Chrome/") 105 | req.Header.Set("X-Forwarded-For", "157.55.39.1") 106 | return req 107 | }, 108 | wantCode: 200, 109 | }, 110 | { 111 | name: "必应 - adidxbot", 112 | reqBuilder: func(t *testing.T) *http.Request { 113 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 114 | if err != nil { 115 | t.Fatal(err) 116 | } 117 | req.Header.Set("User-Agent", "Mozilla/5.0 (Windows Phone 8.1; ARM; Trident/7.0; Touch; rv:11.0; IEMobile/11.0; NOKIA; Lumia 530) like Gecko (compatible; adidxbot/2.0; +http://www.bing.com/bingbot.htm)") 118 | req.Header.Set("X-Forwarded-For", "157.55.39.1") 119 | return req 120 | }, 121 | wantCode: 200, 122 | }, 123 | { 124 | name: "必应 - MicrosoftPreview", 125 | reqBuilder: func(t *testing.T) *http.Request { 126 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 127 | if err != nil { 128 | t.Fatal(err) 129 | } 130 | req.Header.Set("User-Agent", "Mozilla/5.0 AppleWebKit/537.36 (KHTML, like Gecko; compatible; MicrosoftPreview/2.0; +https://aka.ms/MicrosoftPreview) Chrome/W.X.Y.Z Safari/537.36") 131 | req.Header.Set("X-Forwarded-For", "157.55.39.1") 132 | return req 133 | }, 134 | wantCode: 200, 135 | }, 136 | { 137 | name: "谷歌 - Googlebot", 138 | reqBuilder: func(t *testing.T) *http.Request { 139 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 140 | if err != nil { 141 | t.Fatal(err) 142 | } 143 | req.Header.Set("User-Agent", "Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Mobile Safari/537.36 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)") 144 | req.Header.Set("X-Forwarded-For", "66.249.66.1") 145 | return req 146 | }, 147 | wantCode: 200, 148 | }, 149 | { 150 | name: "谷歌 - Googlebot-Image", 151 | reqBuilder: func(t *testing.T) *http.Request { 152 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 153 | if err != nil { 154 | t.Fatal(err) 155 | } 156 | req.Header.Set("User-Agent", "Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Mobile Safari/537.36 (compatible; Googlebot-Image/1.0; +http://www.google.com/bot.html)") 157 | req.Header.Set("X-Forwarded-For", "35.247.243.240") 158 | return req 159 | }, 160 | wantCode: 200, 161 | }, 162 | { 163 | name: "谷歌 - Googlebot-News", 164 | reqBuilder: func(t *testing.T) *http.Request { 165 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 166 | if err != nil { 167 | t.Fatal(err) 168 | } 169 | req.Header.Set("User-Agent", "Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Mobile Safari/537.36 (compatible; Googlebot-News/1.0; +http://www.google.com/bot.html)") 170 | req.Header.Set("X-Forwarded-For", "66.249.90.77") 171 | return req 172 | }, 173 | wantCode: 200, 174 | }, 175 | { 176 | name: "谷歌 - Storebot-Google", 177 | reqBuilder: func(t *testing.T) *http.Request { 178 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 179 | if err != nil { 180 | t.Fatal(err) 181 | } 182 | req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64; Storebot-Google/1.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Safari/537.36") 183 | req.Header.Set("X-Forwarded-For", "66.249.66.1") 184 | return req 185 | }, 186 | wantCode: 200, 187 | }, 188 | { 189 | name: "谷歌 - Google-InspectionTool", 190 | reqBuilder: func(t *testing.T) *http.Request { 191 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 192 | if err != nil { 193 | t.Fatal(err) 194 | } 195 | req.Header.Set("User-Agent", "Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Mobile Safari/537.36 (compatible; Google-InspectionTool/1.0;)") 196 | req.Header.Set("X-Forwarded-For", "66.249.66.1") 197 | return req 198 | }, 199 | wantCode: 200, 200 | }, 201 | { 202 | name: "谷歌 - GoogleOther", 203 | reqBuilder: func(t *testing.T) *http.Request { 204 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 205 | if err != nil { 206 | t.Fatal(err) 207 | } 208 | req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64; GoogleOther/1.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Safari/537.36") 209 | req.Header.Set("X-Forwarded-For", "66.249.66.1") 210 | return req 211 | }, 212 | wantCode: 200, 213 | }, 214 | { 215 | name: "谷歌 - Google-Extended", 216 | reqBuilder: func(t *testing.T) *http.Request { 217 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 218 | if err != nil { 219 | t.Fatal(err) 220 | } 221 | req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64; Google-Extended/1.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Safari/537.36") 222 | req.Header.Set("X-Forwarded-For", "66.249.66.1") 223 | return req 224 | }, 225 | wantCode: 200, 226 | }, 227 | { 228 | name: "搜狗 - Sogou web spider", 229 | reqBuilder: func(t *testing.T) *http.Request { 230 | req, err := http.NewRequest(http.MethodGet, "/test", nil) 231 | if err != nil { 232 | t.Fatal(err) 233 | } 234 | req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64; Google-Extended/1.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/W.X.Y.Z Safari/537.36") 235 | req.Header.Set("X-Forwarded-For", "66.249.66.1") 236 | return req 237 | }, 238 | wantCode: 200, 239 | }, 240 | } 241 | for _, tc := range testCases { 242 | t.Run(tc.name, func(t *testing.T) { 243 | server := gin.Default() 244 | server.TrustedPlatform = "X-Forwarded-For" 245 | server.Use(NewBuilder().Build()) 246 | server.GET("/test", func(ctx *gin.Context) { 247 | ctx.JSON(200, nil) 248 | }) 249 | 250 | recorder := httptest.NewRecorder() 251 | req := tc.reqBuilder(t) 252 | 253 | server.ServeHTTP(recorder, req) 254 | 255 | require.Equal(t, tc.wantCode, recorder.Code) 256 | }) 257 | } 258 | } 259 | 260 | func TestBuilder_AddUserAgent(t *testing.T) { 261 | b := NewBuilder().AddUserAgent(map[string][]string{ 262 | Baidu: {"test-new-baidu-user-agent"}, 263 | }) 264 | v, exist := b.crawlersMap["test-new-baidu-user-agent"] 265 | require.Equal(t, Baidu, v) 266 | require.True(t, exist) 267 | } 268 | 269 | func TestBuilder_RemoveUserAgent(t *testing.T) { 270 | b := NewBuilder().RemoveUserAgent("Baiduspider") 271 | v, exist := b.crawlersMap["Baiduspider"] 272 | require.Equal(t, "", v) 273 | require.False(t, exist) 274 | } 275 | -------------------------------------------------------------------------------- /middlewares/ratelimit/builder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package ratelimit 16 | 17 | import ( 18 | "log" 19 | "net/http" 20 | "strings" 21 | 22 | "github.com/gin-gonic/gin" 23 | 24 | "github.com/ecodeclub/ginx/internal/ratelimit" 25 | ) 26 | 27 | type Builder struct { 28 | limiter ratelimit.Limiter 29 | genKeyFn func(ctx *gin.Context) string 30 | logFn func(msg any, args ...any) 31 | } 32 | 33 | // NewBuilder 34 | // genKeyFn: 默认使用 IP 限流. 35 | // logFn: 默认使用 log.Println(). 36 | func NewBuilder(limiter ratelimit.Limiter) *Builder { 37 | return &Builder{ 38 | limiter: limiter, 39 | genKeyFn: func(ctx *gin.Context) string { 40 | var b strings.Builder 41 | b.WriteString("ip-limiter") 42 | b.WriteString(":") 43 | b.WriteString(ctx.ClientIP()) 44 | return b.String() 45 | }, 46 | logFn: func(msg any, args ...any) { 47 | v := make([]any, 0, len(args)+1) 48 | v = append(v, msg) 49 | v = append(v, args...) 50 | log.Println(v...) 51 | }, 52 | } 53 | } 54 | 55 | func (b *Builder) SetKeyGenFunc(fn func(*gin.Context) string) *Builder { 56 | b.genKeyFn = fn 57 | return b 58 | } 59 | 60 | func (b *Builder) SetLogFunc(fn func(msg any, args ...any)) *Builder { 61 | b.logFn = fn 62 | return b 63 | } 64 | 65 | func (b *Builder) Build() gin.HandlerFunc { 66 | return func(ctx *gin.Context) { 67 | limited, err := b.limit(ctx) 68 | if err != nil { 69 | b.logFn(err) 70 | ctx.AbortWithStatus(http.StatusInternalServerError) 71 | return 72 | } 73 | if limited { 74 | ctx.AbortWithStatus(http.StatusTooManyRequests) 75 | return 76 | } 77 | ctx.Next() 78 | } 79 | } 80 | 81 | func (b *Builder) limit(ctx *gin.Context) (bool, error) { 82 | return b.limiter.Limit(ctx, b.genKeyFn(ctx)) 83 | } 84 | -------------------------------------------------------------------------------- /middlewares/ratelimit/builder_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package ratelimit 16 | 17 | import ( 18 | "errors" 19 | "net/http" 20 | "net/http/httptest" 21 | "testing" 22 | 23 | "github.com/gin-gonic/gin" 24 | "github.com/stretchr/testify/assert" 25 | "go.uber.org/mock/gomock" 26 | 27 | "github.com/ecodeclub/ginx/internal/ratelimit" 28 | limitmocks "github.com/ecodeclub/ginx/internal/ratelimit/mocks" 29 | ) 30 | 31 | func TestBuilder_SetKeyGenFunc(t *testing.T) { 32 | tests := []struct { 33 | name string 34 | reqBuilder func(t *testing.T) *http.Request 35 | fn func(*gin.Context) string 36 | want string 37 | }{ 38 | { 39 | name: "设置key成功", 40 | reqBuilder: func(t *testing.T) *http.Request { 41 | req, err := http.NewRequest(http.MethodGet, "", nil) 42 | if err != nil { 43 | t.Fatal(err) 44 | } 45 | req.RemoteAddr = "127.0.0.1:80" 46 | return req 47 | }, 48 | fn: func(ctx *gin.Context) string { 49 | return "test" 50 | }, 51 | want: "test", 52 | }, 53 | { 54 | name: "默认key", 55 | reqBuilder: func(t *testing.T) *http.Request { 56 | req, err := http.NewRequest(http.MethodGet, "", nil) 57 | if err != nil { 58 | t.Fatal(err) 59 | } 60 | req.RemoteAddr = "127.0.0.1:80" 61 | return req 62 | }, 63 | want: "ip-limiter:127.0.0.1", 64 | }, 65 | } 66 | for _, tt := range tests { 67 | t.Run(tt.name, func(t *testing.T) { 68 | b := NewBuilder(nil) 69 | if tt.fn != nil { 70 | b.SetKeyGenFunc(tt.fn) 71 | } 72 | 73 | recorder := httptest.NewRecorder() 74 | ctx, _ := gin.CreateTestContext(recorder) 75 | req := tt.reqBuilder(t) 76 | ctx.Request = req 77 | 78 | assert.Equal(t, tt.want, b.genKeyFn(ctx)) 79 | }) 80 | } 81 | } 82 | 83 | func TestBuilder_Build(t *testing.T) { 84 | const limitURL = "/limit" 85 | tests := []struct { 86 | name string 87 | 88 | mock func(ctrl *gomock.Controller) ratelimit.Limiter 89 | reqBuilder func(t *testing.T) *http.Request 90 | 91 | // 预期响应 92 | wantCode int 93 | }{ 94 | { 95 | name: "不限流", 96 | mock: func(ctrl *gomock.Controller) ratelimit.Limiter { 97 | limiter := limitmocks.NewMockLimiter(ctrl) 98 | limiter.EXPECT().Limit(gomock.Any(), gomock.Any()). 99 | Return(false, nil) 100 | return limiter 101 | }, 102 | reqBuilder: func(t *testing.T) *http.Request { 103 | req, err := http.NewRequest(http.MethodGet, limitURL, nil) 104 | if err != nil { 105 | t.Fatal(err) 106 | } 107 | return req 108 | }, 109 | wantCode: http.StatusOK, 110 | }, 111 | { 112 | name: "限流", 113 | mock: func(ctrl *gomock.Controller) ratelimit.Limiter { 114 | limiter := limitmocks.NewMockLimiter(ctrl) 115 | limiter.EXPECT().Limit(gomock.Any(), gomock.Any()). 116 | Return(true, nil) 117 | return limiter 118 | }, 119 | reqBuilder: func(t *testing.T) *http.Request { 120 | req, err := http.NewRequest(http.MethodGet, limitURL, nil) 121 | if err != nil { 122 | t.Fatal(err) 123 | } 124 | return req 125 | }, 126 | wantCode: http.StatusTooManyRequests, 127 | }, 128 | { 129 | name: "系统错误", 130 | mock: func(ctrl *gomock.Controller) ratelimit.Limiter { 131 | limiter := limitmocks.NewMockLimiter(ctrl) 132 | limiter.EXPECT().Limit(gomock.Any(), gomock.Any()). 133 | Return(false, errors.New("模拟系统错误")) 134 | return limiter 135 | }, 136 | reqBuilder: func(t *testing.T) *http.Request { 137 | req, err := http.NewRequest(http.MethodGet, limitURL, nil) 138 | if err != nil { 139 | t.Fatal(err) 140 | } 141 | return req 142 | }, 143 | wantCode: http.StatusInternalServerError, 144 | }, 145 | } 146 | for _, tt := range tests { 147 | t.Run(tt.name, func(t *testing.T) { 148 | ctrl := gomock.NewController(t) 149 | defer ctrl.Finish() 150 | svc := NewBuilder(tt.mock(ctrl)) 151 | 152 | server := gin.Default() 153 | server.Use(svc.Build()) 154 | svc.RegisterRoutes(server) 155 | 156 | req := tt.reqBuilder(t) 157 | recorder := httptest.NewRecorder() 158 | 159 | server.ServeHTTP(recorder, req) 160 | 161 | assert.Equal(t, tt.wantCode, recorder.Code) 162 | }) 163 | } 164 | } 165 | 166 | func TestBuilder_limit(t *testing.T) { 167 | tests := []struct { 168 | name string 169 | 170 | mock func(ctrl *gomock.Controller) ratelimit.Limiter 171 | reqBuilder func(t *testing.T) *http.Request 172 | 173 | // 预期响应 174 | want bool 175 | wantErr error 176 | }{ 177 | { 178 | name: "不限流", 179 | mock: func(ctrl *gomock.Controller) ratelimit.Limiter { 180 | limiter := limitmocks.NewMockLimiter(ctrl) 181 | limiter.EXPECT().Limit(gomock.Any(), gomock.Any()). 182 | Return(false, nil) 183 | return limiter 184 | }, 185 | reqBuilder: func(t *testing.T) *http.Request { 186 | req, err := http.NewRequest(http.MethodGet, "", nil) 187 | if err != nil { 188 | t.Fatal(err) 189 | } 190 | req.RemoteAddr = "127.0.0.1:80" 191 | return req 192 | }, 193 | want: false, 194 | }, 195 | { 196 | name: "限流", 197 | mock: func(ctrl *gomock.Controller) ratelimit.Limiter { 198 | limiter := limitmocks.NewMockLimiter(ctrl) 199 | limiter.EXPECT().Limit(gomock.Any(), gomock.Any()). 200 | Return(true, nil) 201 | return limiter 202 | }, 203 | reqBuilder: func(t *testing.T) *http.Request { 204 | req, err := http.NewRequest(http.MethodGet, "", nil) 205 | if err != nil { 206 | t.Fatal(err) 207 | } 208 | req.RemoteAddr = "127.0.0.1:80" 209 | return req 210 | }, 211 | want: true, 212 | }, 213 | { 214 | name: "限流代码出错", 215 | mock: func(ctrl *gomock.Controller) ratelimit.Limiter { 216 | limiter := limitmocks.NewMockLimiter(ctrl) 217 | limiter.EXPECT().Limit(gomock.Any(), gomock.Any()). 218 | Return(false, errors.New("模拟系统错误")) 219 | return limiter 220 | }, 221 | reqBuilder: func(t *testing.T) *http.Request { 222 | req, err := http.NewRequest(http.MethodGet, "", nil) 223 | if err != nil { 224 | t.Fatal(err) 225 | } 226 | req.RemoteAddr = "127.0.0.1:80" 227 | return req 228 | }, 229 | want: false, 230 | wantErr: errors.New("模拟系统错误"), 231 | }, 232 | } 233 | for _, tt := range tests { 234 | t.Run(tt.name, func(t *testing.T) { 235 | ctrl := gomock.NewController(t) 236 | defer ctrl.Finish() 237 | limiter := tt.mock(ctrl) 238 | b := NewBuilder(limiter) 239 | 240 | recorder := httptest.NewRecorder() 241 | ctx, _ := gin.CreateTestContext(recorder) 242 | req := tt.reqBuilder(t) 243 | ctx.Request = req 244 | 245 | got, err := b.limit(ctx) 246 | assert.Equal(t, tt.wantErr, err) 247 | assert.Equal(t, tt.want, got) 248 | }) 249 | } 250 | } 251 | 252 | func (b *Builder) RegisterRoutes(server *gin.Engine) { 253 | server.GET("/limit", func(ctx *gin.Context) { 254 | ctx.Status(http.StatusOK) 255 | }) 256 | } 257 | -------------------------------------------------------------------------------- /middlewares/ratelimit/redis_slide_window.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package ratelimit 16 | 17 | import ( 18 | "time" 19 | 20 | "github.com/redis/go-redis/v9" 21 | 22 | "github.com/ecodeclub/ginx/internal/ratelimit" 23 | ) 24 | 25 | // NewRedisSlidingWindowLimiter 创建一个基于 redis 的滑动窗口限流器. 26 | // cmd: 可传入 redis 的客户端 27 | // interval: 窗口大小 28 | // rate: 阈值 29 | // 表示: 在 interval 内允许 rate 个请求 30 | // 示例: 1s 内允许 3000 个请求 31 | func NewRedisSlidingWindowLimiter(cmd redis.Cmdable, 32 | interval time.Duration, rate int) ratelimit.Limiter { 33 | return &ratelimit.RedisSlidingWindowLimiter{ 34 | Cmd: cmd, 35 | Interval: interval, 36 | Rate: rate, 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /session/builder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package session 16 | 17 | import "github.com/ecodeclub/ginx/gctx" 18 | 19 | // Builder 是一个辅助接口,便于构造 Session 20 | type Builder struct { 21 | ctx *gctx.Context 22 | uid int64 23 | jwtData map[string]string 24 | sessData map[string]any 25 | sp Provider 26 | } 27 | 28 | // NewSessionBuilder 创建一个 Builder 用于构造 Session 29 | // 默认使用 defaultProvider 30 | func NewSessionBuilder(ctx *gctx.Context, uid int64) *Builder { 31 | return &Builder{ 32 | ctx: ctx, 33 | uid: uid, 34 | sp: defaultProvider, 35 | } 36 | } 37 | 38 | func (b *Builder) SetProvider(p Provider) *Builder { 39 | b.sp = p 40 | return b 41 | } 42 | 43 | func (b *Builder) SetJwtData(data map[string]string) *Builder { 44 | b.jwtData = data 45 | return b 46 | } 47 | 48 | func (b *Builder) SetSessData(data map[string]any) *Builder { 49 | b.sessData = data 50 | return b 51 | } 52 | 53 | func (b *Builder) Build() (Session, error) { 54 | return b.sp.NewSession(b.ctx, b.uid, b.jwtData, b.sessData) 55 | } 56 | -------------------------------------------------------------------------------- /session/builder_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package session 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/ecodeclub/ginx/gctx" 21 | "github.com/stretchr/testify/assert" 22 | "github.com/stretchr/testify/require" 23 | "go.uber.org/mock/gomock" 24 | ) 25 | 26 | func TestBuilder(t *testing.T) { 27 | ctrl := gomock.NewController(t) 28 | defer ctrl.Finish() 29 | p := NewMockProvider(ctrl) 30 | p.EXPECT().NewSession(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). 31 | DoAndReturn(func(ctx *gctx.Context, uid int64, jwtData map[string]string, 32 | sessData map[string]any) (Session, error) { 33 | return &MemorySession{data: sessData, 34 | claims: Claims{Uid: uid, Data: jwtData}}, nil 35 | }) 36 | sess, err := NewSessionBuilder(new(gctx.Context), 123). 37 | SetProvider(p). 38 | SetJwtData(map[string]string{"jwt": "true"}). 39 | SetSessData(map[string]any{"session": "true"}). 40 | Build() 41 | require.NoError(t, err) 42 | assert.Equal(t, &MemorySession{ 43 | data: map[string]any{"session": "true"}, 44 | claims: Claims{ 45 | Uid: 123, 46 | Data: map[string]string{"jwt": "true"}, 47 | }, 48 | }, sess) 49 | } 50 | -------------------------------------------------------------------------------- /session/cookie/carrier.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package cookie 16 | 17 | import ( 18 | "github.com/ecodeclub/ginx/gctx" 19 | "github.com/ecodeclub/ginx/session" 20 | ) 21 | 22 | var _ session.TokenCarrier = &TokenCarrier{} 23 | 24 | type TokenCarrier struct { 25 | MaxAge int 26 | Name string 27 | Path string 28 | Domain string 29 | Secure bool 30 | HttpOnly bool 31 | } 32 | 33 | func (t *TokenCarrier) Clear(ctx *gctx.Context) { 34 | // 当 MaxAge 等于 -1 的时候,等价于清除 cookie 35 | ctx.SetCookie(t.Name, "", -1, t.Path, t.Domain, t.Secure, t.HttpOnly) 36 | } 37 | 38 | func (t *TokenCarrier) Inject(ctx *gctx.Context, value string) { 39 | ctx.SetCookie(t.Name, value, t.MaxAge, t.Path, t.Domain, t.Secure, t.HttpOnly) 40 | } 41 | 42 | func (t *TokenCarrier) Extract(ctx *gctx.Context) string { 43 | return ctx.Cookie(t.Name).StringOrDefault("") 44 | } 45 | -------------------------------------------------------------------------------- /session/cookie/carrier_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package cookie 16 | 17 | import ( 18 | "net/http" 19 | "net/http/httptest" 20 | "strings" 21 | "testing" 22 | 23 | "github.com/ecodeclub/ginx" 24 | "github.com/gin-gonic/gin" 25 | "github.com/stretchr/testify/assert" 26 | "github.com/stretchr/testify/suite" 27 | ) 28 | 29 | type CarrierTestSuite struct { 30 | suite.Suite 31 | } 32 | 33 | func (s *CarrierTestSuite) TestInject() { 34 | instance := &TokenCarrier{ 35 | Name: "ssid", 36 | } 37 | val := "this is token" 38 | recorder := httptest.NewRecorder() 39 | ctx, _ := gin.CreateTestContext(recorder) 40 | instance.Inject(&ginx.Context{ 41 | Context: ctx, 42 | }, val) 43 | // 没有仔细检测 Cookie 的值,但是我们认为有值就可以了 44 | ck := recorder.Header().Get("Set-Cookie") 45 | assert.NotEmpty(s.T(), ck) 46 | } 47 | 48 | func (s *CarrierTestSuite) TestExtract() { 49 | instance := &TokenCarrier{ 50 | Name: "ssid", 51 | } 52 | recorder := httptest.NewRecorder() 53 | ctx, _ := gin.CreateTestContext(recorder) 54 | val := "this is token" 55 | ctx.Request = &http.Request{ 56 | Header: http.Header{}, 57 | } 58 | ctx.Request.AddCookie(&http.Cookie{ 59 | Name: "ssid", 60 | Value: val, 61 | }) 62 | res := instance.Extract(&ginx.Context{ 63 | Context: ctx, 64 | }) 65 | assert.Equal(s.T(), val, res) 66 | } 67 | 68 | func (s *CarrierTestSuite) TestClear() { 69 | instance := &TokenCarrier{ 70 | Name: "ssid", 71 | } 72 | recorder := httptest.NewRecorder() 73 | ctx, _ := gin.CreateTestContext(recorder) 74 | instance.Clear(&ginx.Context{ 75 | Context: ctx, 76 | }) 77 | ck := recorder.Header().Get("Set-Cookie") 78 | strings.Contains(ck, "Max-Age=-1") 79 | assert.NotEmpty(s.T(), ck) 80 | } 81 | 82 | func TestCarrier(t *testing.T) { 83 | suite.Run(t, new(CarrierTestSuite)) 84 | } 85 | -------------------------------------------------------------------------------- /session/global.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package session 16 | 17 | import ( 18 | "time" 19 | 20 | "github.com/ecodeclub/ginx/gctx" 21 | "github.com/gin-gonic/gin" 22 | ) 23 | 24 | const CtxSessionKey = "_session" 25 | 26 | var defaultProvider Provider 27 | 28 | func NewSession(ctx *gctx.Context, uid int64, 29 | jwtData map[string]string, 30 | sessData map[string]any) (Session, error) { 31 | return defaultProvider.NewSession( 32 | ctx, 33 | uid, 34 | jwtData, 35 | sessData) 36 | } 37 | 38 | // Get 参考 defaultProvider.Get 的说明 39 | func Get(ctx *gctx.Context) (Session, error) { 40 | return defaultProvider.Get(ctx) 41 | } 42 | 43 | func SetDefaultProvider(sp Provider) { 44 | defaultProvider = sp 45 | } 46 | 47 | func DefaultProvider() Provider { 48 | return defaultProvider 49 | } 50 | 51 | func CheckLoginMiddleware() gin.HandlerFunc { 52 | return (&MiddlewareBuilder{sp: defaultProvider, Threshold: time.Minute * 30}).Build() 53 | } 54 | 55 | func RenewAccessToken(ctx *gctx.Context) error { 56 | return defaultProvider.RenewAccessToken(ctx) 57 | } 58 | 59 | func UpdateClaims(ctx *gctx.Context, claims Claims) error { 60 | return defaultProvider.UpdateClaims(ctx, claims) 61 | } 62 | -------------------------------------------------------------------------------- /session/global_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package session 16 | 17 | import ( 18 | "net/http" 19 | "net/http/httptest" 20 | "testing" 21 | 22 | "github.com/ecodeclub/ginx/gctx" 23 | "github.com/ecodeclub/ginx/internal/errs" 24 | "github.com/gin-gonic/gin" 25 | "github.com/stretchr/testify/assert" 26 | "github.com/stretchr/testify/require" 27 | "go.uber.org/mock/gomock" 28 | ) 29 | 30 | func TestNewSession(t *testing.T) { 31 | ctrl := gomock.NewController(t) 32 | defer ctrl.Finish() 33 | p := NewMockProvider(ctrl) 34 | // 包变量的垃圾之处 35 | SetDefaultProvider(p) 36 | defer SetDefaultProvider(nil) 37 | p.EXPECT().NewSession(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). 38 | DoAndReturn(func(ctx *gctx.Context, uid int64, jwtData map[string]string, 39 | sessData map[string]any) (Session, error) { 40 | return &MemorySession{data: sessData, 41 | claims: Claims{Uid: uid, Data: jwtData}}, nil 42 | }) 43 | sess, err := NewSession(new(gctx.Context), 123, 44 | map[string]string{"jwt": "true"}, 45 | map[string]any{"session": "true"}) 46 | require.NoError(t, err) 47 | assert.Equal(t, &MemorySession{ 48 | data: map[string]any{"session": "true"}, 49 | claims: Claims{ 50 | Uid: 123, 51 | Data: map[string]string{"jwt": "true"}, 52 | }, 53 | }, sess) 54 | } 55 | 56 | func TestCheckLoginMiddleware(t *testing.T) { 57 | ctrl := gomock.NewController(t) 58 | defer ctrl.Finish() 59 | p := NewMockProvider(ctrl) 60 | // 包变量的垃圾之处 61 | SetDefaultProvider(p) 62 | p.EXPECT().RenewAccessToken(gomock.Any()).AnyTimes().Return(nil) 63 | defer SetDefaultProvider(nil) 64 | server := gin.Default() 65 | server.Use(CheckLoginMiddleware()) 66 | server.GET("/hello", func(ctx *gin.Context) { 67 | ctx.String(http.StatusOK, "OK") 68 | }) 69 | 70 | // 第一个请求,被拒绝 71 | p.EXPECT().Get(gomock.Any()).Return(nil, errs.ErrUnauthorized) 72 | recorder := httptest.NewRecorder() 73 | req, err := http.NewRequest(http.MethodGet, "http://localhost/hello", nil) 74 | require.NoError(t, err) 75 | server.ServeHTTP(recorder, req) 76 | assert.Equal(t, 401, recorder.Code) 77 | 78 | // 第二个请求,被处理了 79 | 80 | p.EXPECT().Get(gomock.Any()).Return(NewMemorySession(Claims{}), nil) 81 | recorder = httptest.NewRecorder() 82 | req, err = http.NewRequest(http.MethodGet, "http://localhost/hello", nil) 83 | require.NoError(t, err) 84 | server.ServeHTTP(recorder, req) 85 | assert.Equal(t, 200, recorder.Code) 86 | } 87 | -------------------------------------------------------------------------------- /session/header/carrier.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package header 16 | 17 | import ( 18 | "strings" 19 | 20 | "github.com/ecodeclub/ginx/gctx" 21 | "github.com/ecodeclub/ginx/session" 22 | ) 23 | 24 | type TokenCarrier struct { 25 | // 写入到 resp 中的名字 26 | // 固定从请求的 Authorization 字段中读取 token,并且假定使用的是 Bearer 27 | Name string 28 | } 29 | 30 | func (t *TokenCarrier) Clear(ctx *gctx.Context) { 31 | // 设置一个空的 token 就等价于清除了 token 32 | ctx.Writer.Header().Set(t.Name, "") 33 | } 34 | 35 | func (t *TokenCarrier) Inject(ctx *gctx.Context, value string) { 36 | ctx.Writer.Header().Set(t.Name, value) 37 | } 38 | 39 | // Extract 固定从 Authorization 中提取 40 | func (t *TokenCarrier) Extract(ctx *gctx.Context) string { 41 | token := ctx.Request.Header.Get("Authorization") 42 | const bearerPrefix = "Bearer " 43 | return strings.TrimPrefix(token, bearerPrefix) 44 | } 45 | 46 | var _ session.TokenCarrier = &TokenCarrier{} 47 | 48 | func NewTokenCarrier() *TokenCarrier { 49 | return &TokenCarrier{ 50 | Name: "X-Access-Token", 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /session/header/carrier_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package header 16 | 17 | import ( 18 | "fmt" 19 | "net/http" 20 | "net/http/httptest" 21 | "testing" 22 | 23 | "github.com/ecodeclub/ginx" 24 | "github.com/gin-gonic/gin" 25 | "github.com/stretchr/testify/assert" 26 | "github.com/stretchr/testify/suite" 27 | ) 28 | 29 | type CarrierTestSuite struct { 30 | suite.Suite 31 | } 32 | 33 | func (s *CarrierTestSuite) TestInject() { 34 | instance := NewTokenCarrier() 35 | val := "this is token" 36 | recorder := httptest.NewRecorder() 37 | ctx, _ := gin.CreateTestContext(recorder) 38 | instance.Inject(&ginx.Context{ 39 | Context: ctx, 40 | }, val) 41 | ck := recorder.Header().Get(instance.Name) 42 | assert.NotEmpty(s.T(), ck) 43 | } 44 | 45 | func (s *CarrierTestSuite) TestExtract() { 46 | instance := NewTokenCarrier() 47 | recorder := httptest.NewRecorder() 48 | ctx, _ := gin.CreateTestContext(recorder) 49 | val := "this is token" 50 | ctx.Request = &http.Request{ 51 | Header: http.Header{ 52 | "Authorization": []string{fmt.Sprintf("Bearer %s", val)}, 53 | }, 54 | } 55 | res := instance.Extract(&ginx.Context{ 56 | Context: ctx, 57 | }) 58 | assert.Equal(s.T(), val, res) 59 | } 60 | 61 | func (s *CarrierTestSuite) TestClear() { 62 | instance := &TokenCarrier{ 63 | Name: "ssid", 64 | } 65 | recorder := httptest.NewRecorder() 66 | ctx, _ := gin.CreateTestContext(recorder) 67 | instance.Clear(&ginx.Context{ 68 | Context: ctx, 69 | }) 70 | ck := recorder.Header().Get(instance.Name) 71 | assert.Equal(s.T(), "", ck) 72 | } 73 | 74 | func TestCarrier(t *testing.T) { 75 | suite.Run(t, new(CarrierTestSuite)) 76 | } 77 | -------------------------------------------------------------------------------- /session/memory.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package session 16 | 17 | import ( 18 | "context" 19 | 20 | "github.com/ecodeclub/ginx/gctx" 21 | 22 | "github.com/ecodeclub/ekit" 23 | "github.com/ecodeclub/ginx/internal/errs" 24 | ) 25 | 26 | var _ Session = &MemorySession{} 27 | 28 | // MemorySession 一般用于测试 29 | type MemorySession struct { 30 | data map[string]any 31 | claims Claims 32 | } 33 | 34 | func (m *MemorySession) Destroy(ctx context.Context) error { 35 | return nil 36 | } 37 | 38 | func (m *MemorySession) UpdateClaims(ctx *gctx.Context, claims Claims) error { 39 | return nil 40 | } 41 | 42 | func (m *MemorySession) Del(ctx context.Context, key string) error { 43 | delete(m.data, key) 44 | return nil 45 | } 46 | 47 | func NewMemorySession(cl Claims) *MemorySession { 48 | return &MemorySession{ 49 | data: map[string]any{}, 50 | claims: cl, 51 | } 52 | } 53 | 54 | func (m *MemorySession) Set(ctx context.Context, key string, val any) error { 55 | m.data[key] = val 56 | return nil 57 | } 58 | 59 | func (m *MemorySession) Get(ctx context.Context, key string) ekit.AnyValue { 60 | val, ok := m.data[key] 61 | if !ok { 62 | return ekit.AnyValue{Err: errs.ErrSessionKeyNotFound} 63 | } 64 | return ekit.AnyValue{Val: val} 65 | } 66 | 67 | func (m *MemorySession) Claims() Claims { 68 | return m.claims 69 | } 70 | -------------------------------------------------------------------------------- /session/memory_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package session 16 | 17 | import ( 18 | "context" 19 | "testing" 20 | 21 | "github.com/ecodeclub/ekit" 22 | "github.com/ecodeclub/ginx/internal/errs" 23 | "github.com/stretchr/testify/assert" 24 | "github.com/stretchr/testify/require" 25 | ) 26 | 27 | func TestMemorySession_GetSet(t *testing.T) { 28 | testCases := []struct { 29 | name string 30 | 31 | // 插入数据 32 | key string 33 | val string 34 | 35 | getKey string 36 | 37 | wantVal ekit.AnyValue 38 | }{ 39 | { 40 | name: "成功获取", 41 | key: "key1", 42 | val: "value1", 43 | getKey: "key1", 44 | wantVal: ekit.AnyValue{Val: "value1"}, 45 | }, 46 | { 47 | name: "没有数据", 48 | key: "key1", 49 | val: "value1", 50 | getKey: "key2", 51 | wantVal: ekit.AnyValue{Err: errs.ErrSessionKeyNotFound}, 52 | }, 53 | } 54 | 55 | for _, tc := range testCases { 56 | t.Run(tc.name, func(t *testing.T) { 57 | ms := NewMemorySession(Claims{}) 58 | ctx := context.Background() 59 | err := ms.Set(ctx, tc.key, tc.val) 60 | require.NoError(t, err) 61 | val := ms.Get(ctx, tc.getKey) 62 | assert.Equal(t, tc.wantVal, val) 63 | }) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /session/middleware_builder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package session 16 | 17 | import ( 18 | "log/slog" 19 | "net/http" 20 | "time" 21 | 22 | "github.com/ecodeclub/ginx/gctx" 23 | "github.com/gin-gonic/gin" 24 | ) 25 | 26 | // MiddlewareBuilder 登录校验 27 | type MiddlewareBuilder struct { 28 | sp Provider 29 | // 当 token 的有效时间少于这个值的时候,就会刷新一下 token 30 | Threshold time.Duration 31 | } 32 | 33 | func (b *MiddlewareBuilder) Build() gin.HandlerFunc { 34 | threshold := b.Threshold.Milliseconds() 35 | return func(ctx *gin.Context) { 36 | ctxx := &gctx.Context{Context: ctx} 37 | sess, err := b.sp.Get(ctxx) 38 | if err != nil { 39 | slog.Debug("未授权", slog.Any("err", err)) 40 | ctx.AbortWithStatus(http.StatusUnauthorized) 41 | return 42 | } 43 | expiration := sess.Claims().Expiration 44 | if expiration-time.Now().UnixMilli() < threshold { 45 | // 刷新一个token 46 | err = b.sp.RenewAccessToken(ctxx) 47 | if err != nil { 48 | slog.Warn("刷新 token 失败", slog.String("err", err.Error())) 49 | } 50 | } 51 | ctx.Set(CtxSessionKey, sess) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /session/mixin/carrier.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package mixin 16 | 17 | import ( 18 | "github.com/ecodeclub/ginx/gctx" 19 | "github.com/ecodeclub/ginx/session" 20 | ) 21 | 22 | type TokenCarrier struct { 23 | carriers []session.TokenCarrier 24 | } 25 | 26 | func NewTokenCarrier(carriers ...session.TokenCarrier) *TokenCarrier { 27 | return &TokenCarrier{carriers: carriers} 28 | } 29 | 30 | func (t *TokenCarrier) Inject(ctx *gctx.Context, value string) { 31 | for _, carrier := range t.carriers { 32 | carrier.Inject(ctx, value) 33 | } 34 | } 35 | 36 | func (t *TokenCarrier) Extract(ctx *gctx.Context) string { 37 | for _, carrier := range t.carriers { 38 | val := carrier.Extract(ctx) 39 | if val != "" { 40 | return val 41 | } 42 | } 43 | return "" 44 | } 45 | 46 | func (t *TokenCarrier) Clear(ctx *gctx.Context) { 47 | for _, carrier := range t.carriers { 48 | carrier.Clear(ctx) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /session/mixin/carrier_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package mixin 16 | 17 | import ( 18 | "net/http" 19 | "net/http/httptest" 20 | "strings" 21 | "testing" 22 | 23 | "github.com/ecodeclub/ginx" 24 | "github.com/ecodeclub/ginx/session/cookie" 25 | "github.com/ecodeclub/ginx/session/header" 26 | "github.com/gin-gonic/gin" 27 | "github.com/stretchr/testify/assert" 28 | "github.com/stretchr/testify/suite" 29 | ) 30 | 31 | type CarrierTestSuite struct { 32 | suite.Suite 33 | carrier *TokenCarrier 34 | } 35 | 36 | func (s *CarrierTestSuite) SetupSuite() { 37 | hc := header.NewTokenCarrier() 38 | ck := &cookie.TokenCarrier{ 39 | MaxAge: 1000, 40 | Name: "ssid", 41 | } 42 | s.carrier = NewTokenCarrier(hc, ck) 43 | } 44 | 45 | func (s *CarrierTestSuite) TestInject() { 46 | val := "this is token" 47 | recorder := httptest.NewRecorder() 48 | ctx, _ := gin.CreateTestContext(recorder) 49 | s.carrier.Inject(&ginx.Context{ 50 | Context: ctx, 51 | }, val) 52 | // 没有仔细检测 Cookie 的值,但是我们认为有值就可以了 53 | ck := recorder.Header().Get("Set-Cookie") 54 | assert.NotEmpty(s.T(), ck) 55 | 56 | ck = recorder.Header().Get("X-Access-Token") 57 | assert.NotEmpty(s.T(), ck) 58 | } 59 | 60 | func (s *CarrierTestSuite) TestExtract() { 61 | testCases := []struct { 62 | name string 63 | ctxBuilder func() *ginx.Context 64 | wantVal string 65 | }{ 66 | { 67 | name: "从 header 中取出", 68 | ctxBuilder: func() *ginx.Context { 69 | recorder := httptest.NewRecorder() 70 | ctx, _ := gin.CreateTestContext(recorder) 71 | val := "this is token" 72 | ctx.Request = &http.Request{ 73 | Header: http.Header{}, 74 | } 75 | ctx.Request.AddCookie(&http.Cookie{ 76 | Name: "ssid", 77 | Value: val, 78 | }) 79 | return &ginx.Context{Context: ctx} 80 | }, 81 | wantVal: "this is token", 82 | }, 83 | { 84 | name: "从 cookie 中取出", 85 | ctxBuilder: func() *ginx.Context { 86 | recorder := httptest.NewRecorder() 87 | ctx, _ := gin.CreateTestContext(recorder) 88 | val := "this is token" 89 | ctx.Request = &http.Request{ 90 | Header: http.Header{}, 91 | } 92 | ctx.Request.AddCookie(&http.Cookie{ 93 | Name: "ssid", 94 | Value: val, 95 | }) 96 | return &ginx.Context{Context: ctx} 97 | }, 98 | wantVal: "this is token", 99 | }, 100 | { 101 | name: "都没有", 102 | ctxBuilder: func() *ginx.Context { 103 | recorder := httptest.NewRecorder() 104 | ctx, _ := gin.CreateTestContext(recorder) 105 | ctx.Request = &http.Request{ 106 | Header: http.Header{}, 107 | } 108 | return &ginx.Context{Context: ctx} 109 | }, 110 | wantVal: "", 111 | }, 112 | } 113 | 114 | for _, tc := range testCases { 115 | s.T().Run(tc.name, func(t *testing.T) { 116 | val := s.carrier.Extract(tc.ctxBuilder()) 117 | assert.Equal(t, tc.wantVal, val) 118 | }) 119 | } 120 | } 121 | 122 | func (s *CarrierTestSuite) TestClear() { 123 | recorder := httptest.NewRecorder() 124 | ctx, _ := gin.CreateTestContext(recorder) 125 | s.carrier.Clear(&ginx.Context{ 126 | Context: ctx, 127 | }) 128 | ck := recorder.Header().Get("Set-Cookie") 129 | strings.Contains(ck, "Max-Age=-1") 130 | assert.NotEmpty(s.T(), ck) 131 | 132 | ck = recorder.Header().Get("X-Access-Token") 133 | assert.Equal(s.T(), "", ck) 134 | } 135 | 136 | func TestCarrier(t *testing.T) { 137 | suite.Run(t, new(CarrierTestSuite)) 138 | } 139 | -------------------------------------------------------------------------------- /session/provider.mock_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Code generated by MockGen. DO NOT EDIT. 16 | // Source: session/types.go 17 | // 18 | // Generated by this command: 19 | // 20 | // mockgen -copyright_file=.license_header -source=session/types.go -package=session -destination=session/provider.mock_test.go Provider 21 | // 22 | // Package session is a generated GoMock package. 23 | package session 24 | 25 | import ( 26 | context "context" 27 | reflect "reflect" 28 | 29 | ekit "github.com/ecodeclub/ekit" 30 | gctx "github.com/ecodeclub/ginx/gctx" 31 | gomock "go.uber.org/mock/gomock" 32 | ) 33 | 34 | // MockSession is a mock of Session interface. 35 | type MockSession struct { 36 | ctrl *gomock.Controller 37 | recorder *MockSessionMockRecorder 38 | } 39 | 40 | // MockSessionMockRecorder is the mock recorder for MockSession. 41 | type MockSessionMockRecorder struct { 42 | mock *MockSession 43 | } 44 | 45 | // NewMockSession creates a new mock instance. 46 | func NewMockSession(ctrl *gomock.Controller) *MockSession { 47 | mock := &MockSession{ctrl: ctrl} 48 | mock.recorder = &MockSessionMockRecorder{mock} 49 | return mock 50 | } 51 | 52 | // EXPECT returns an object that allows the caller to indicate expected use. 53 | func (m *MockSession) EXPECT() *MockSessionMockRecorder { 54 | return m.recorder 55 | } 56 | 57 | // Claims mocks base method. 58 | func (m *MockSession) Claims() Claims { 59 | m.ctrl.T.Helper() 60 | ret := m.ctrl.Call(m, "Claims") 61 | ret0, _ := ret[0].(Claims) 62 | return ret0 63 | } 64 | 65 | // Claims indicates an expected call of Claims. 66 | func (mr *MockSessionMockRecorder) Claims() *gomock.Call { 67 | mr.mock.ctrl.T.Helper() 68 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Claims", reflect.TypeOf((*MockSession)(nil).Claims)) 69 | } 70 | 71 | // Del mocks base method. 72 | func (m *MockSession) Del(ctx context.Context, key string) error { 73 | m.ctrl.T.Helper() 74 | ret := m.ctrl.Call(m, "Del", ctx, key) 75 | ret0, _ := ret[0].(error) 76 | return ret0 77 | } 78 | 79 | // Del indicates an expected call of Del. 80 | func (mr *MockSessionMockRecorder) Del(ctx, key any) *gomock.Call { 81 | mr.mock.ctrl.T.Helper() 82 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Del", reflect.TypeOf((*MockSession)(nil).Del), ctx, key) 83 | } 84 | 85 | // Destroy mocks base method. 86 | func (m *MockSession) Destroy(ctx context.Context) error { 87 | m.ctrl.T.Helper() 88 | ret := m.ctrl.Call(m, "Destroy", ctx) 89 | ret0, _ := ret[0].(error) 90 | return ret0 91 | } 92 | 93 | // Destroy indicates an expected call of Destroy. 94 | func (mr *MockSessionMockRecorder) Destroy(ctx any) *gomock.Call { 95 | mr.mock.ctrl.T.Helper() 96 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockSession)(nil).Destroy), ctx) 97 | } 98 | 99 | // Get mocks base method. 100 | func (m *MockSession) Get(ctx context.Context, key string) ekit.AnyValue { 101 | m.ctrl.T.Helper() 102 | ret := m.ctrl.Call(m, "Get", ctx, key) 103 | ret0, _ := ret[0].(ekit.AnyValue) 104 | return ret0 105 | } 106 | 107 | // Get indicates an expected call of Get. 108 | func (mr *MockSessionMockRecorder) Get(ctx, key any) *gomock.Call { 109 | mr.mock.ctrl.T.Helper() 110 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSession)(nil).Get), ctx, key) 111 | } 112 | 113 | // Set mocks base method. 114 | func (m *MockSession) Set(ctx context.Context, key string, val any) error { 115 | m.ctrl.T.Helper() 116 | ret := m.ctrl.Call(m, "Set", ctx, key, val) 117 | ret0, _ := ret[0].(error) 118 | return ret0 119 | } 120 | 121 | // Set indicates an expected call of Set. 122 | func (mr *MockSessionMockRecorder) Set(ctx, key, val any) *gomock.Call { 123 | mr.mock.ctrl.T.Helper() 124 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockSession)(nil).Set), ctx, key, val) 125 | } 126 | 127 | // MockProvider is a mock of Provider interface. 128 | type MockProvider struct { 129 | ctrl *gomock.Controller 130 | recorder *MockProviderMockRecorder 131 | } 132 | 133 | // MockProviderMockRecorder is the mock recorder for MockProvider. 134 | type MockProviderMockRecorder struct { 135 | mock *MockProvider 136 | } 137 | 138 | // NewMockProvider creates a new mock instance. 139 | func NewMockProvider(ctrl *gomock.Controller) *MockProvider { 140 | mock := &MockProvider{ctrl: ctrl} 141 | mock.recorder = &MockProviderMockRecorder{mock} 142 | return mock 143 | } 144 | 145 | // EXPECT returns an object that allows the caller to indicate expected use. 146 | func (m *MockProvider) EXPECT() *MockProviderMockRecorder { 147 | return m.recorder 148 | } 149 | 150 | // Destroy mocks base method. 151 | func (m *MockProvider) Destroy(ctx *gctx.Context) error { 152 | m.ctrl.T.Helper() 153 | ret := m.ctrl.Call(m, "Destroy", ctx) 154 | ret0, _ := ret[0].(error) 155 | return ret0 156 | } 157 | 158 | // Destroy indicates an expected call of Destroy. 159 | func (mr *MockProviderMockRecorder) Destroy(ctx any) *gomock.Call { 160 | mr.mock.ctrl.T.Helper() 161 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockProvider)(nil).Destroy), ctx) 162 | } 163 | 164 | // Get mocks base method. 165 | func (m *MockProvider) Get(ctx *gctx.Context) (Session, error) { 166 | m.ctrl.T.Helper() 167 | ret := m.ctrl.Call(m, "Get", ctx) 168 | ret0, _ := ret[0].(Session) 169 | ret1, _ := ret[1].(error) 170 | return ret0, ret1 171 | } 172 | 173 | // Get indicates an expected call of Get. 174 | func (mr *MockProviderMockRecorder) Get(ctx any) *gomock.Call { 175 | mr.mock.ctrl.T.Helper() 176 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockProvider)(nil).Get), ctx) 177 | } 178 | 179 | // NewSession mocks base method. 180 | func (m *MockProvider) NewSession(ctx *gctx.Context, uid int64, jwtData map[string]string, sessData map[string]any) (Session, error) { 181 | m.ctrl.T.Helper() 182 | ret := m.ctrl.Call(m, "NewSession", ctx, uid, jwtData, sessData) 183 | ret0, _ := ret[0].(Session) 184 | ret1, _ := ret[1].(error) 185 | return ret0, ret1 186 | } 187 | 188 | // NewSession indicates an expected call of NewSession. 189 | func (mr *MockProviderMockRecorder) NewSession(ctx, uid, jwtData, sessData any) *gomock.Call { 190 | mr.mock.ctrl.T.Helper() 191 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSession", reflect.TypeOf((*MockProvider)(nil).NewSession), ctx, uid, jwtData, sessData) 192 | } 193 | 194 | // RenewAccessToken mocks base method. 195 | func (m *MockProvider) RenewAccessToken(ctx *gctx.Context) error { 196 | m.ctrl.T.Helper() 197 | ret := m.ctrl.Call(m, "RenewAccessToken", ctx) 198 | ret0, _ := ret[0].(error) 199 | return ret0 200 | } 201 | 202 | // RenewAccessToken indicates an expected call of RenewAccessToken. 203 | func (mr *MockProviderMockRecorder) RenewAccessToken(ctx any) *gomock.Call { 204 | mr.mock.ctrl.T.Helper() 205 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewAccessToken", reflect.TypeOf((*MockProvider)(nil).RenewAccessToken), ctx) 206 | } 207 | 208 | // UpdateClaims mocks base method. 209 | func (m *MockProvider) UpdateClaims(ctx *gctx.Context, claims Claims) error { 210 | m.ctrl.T.Helper() 211 | ret := m.ctrl.Call(m, "UpdateClaims", ctx, claims) 212 | ret0, _ := ret[0].(error) 213 | return ret0 214 | } 215 | 216 | // UpdateClaims indicates an expected call of UpdateClaims. 217 | func (mr *MockProviderMockRecorder) UpdateClaims(ctx, claims any) *gomock.Call { 218 | mr.mock.ctrl.T.Helper() 219 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateClaims", reflect.TypeOf((*MockProvider)(nil).UpdateClaims), ctx, claims) 220 | } 221 | 222 | // MockTokenCarrier is a mock of TokenCarrier interface. 223 | type MockTokenCarrier struct { 224 | ctrl *gomock.Controller 225 | recorder *MockTokenCarrierMockRecorder 226 | } 227 | 228 | // MockTokenCarrierMockRecorder is the mock recorder for MockTokenCarrier. 229 | type MockTokenCarrierMockRecorder struct { 230 | mock *MockTokenCarrier 231 | } 232 | 233 | // NewMockTokenCarrier creates a new mock instance. 234 | func NewMockTokenCarrier(ctrl *gomock.Controller) *MockTokenCarrier { 235 | mock := &MockTokenCarrier{ctrl: ctrl} 236 | mock.recorder = &MockTokenCarrierMockRecorder{mock} 237 | return mock 238 | } 239 | 240 | // EXPECT returns an object that allows the caller to indicate expected use. 241 | func (m *MockTokenCarrier) EXPECT() *MockTokenCarrierMockRecorder { 242 | return m.recorder 243 | } 244 | 245 | // Clear mocks base method. 246 | func (m *MockTokenCarrier) Clear(ctx *gctx.Context) { 247 | m.ctrl.T.Helper() 248 | m.ctrl.Call(m, "Clear", ctx) 249 | } 250 | 251 | // Clear indicates an expected call of Clear. 252 | func (mr *MockTokenCarrierMockRecorder) Clear(ctx any) *gomock.Call { 253 | mr.mock.ctrl.T.Helper() 254 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockTokenCarrier)(nil).Clear), ctx) 255 | } 256 | 257 | // Extract mocks base method. 258 | func (m *MockTokenCarrier) Extract(ctx *gctx.Context) string { 259 | m.ctrl.T.Helper() 260 | ret := m.ctrl.Call(m, "Extract", ctx) 261 | ret0, _ := ret[0].(string) 262 | return ret0 263 | } 264 | 265 | // Extract indicates an expected call of Extract. 266 | func (mr *MockTokenCarrierMockRecorder) Extract(ctx any) *gomock.Call { 267 | mr.mock.ctrl.T.Helper() 268 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Extract", reflect.TypeOf((*MockTokenCarrier)(nil).Extract), ctx) 269 | } 270 | 271 | // Inject mocks base method. 272 | func (m *MockTokenCarrier) Inject(ctx *gctx.Context, value string) { 273 | m.ctrl.T.Helper() 274 | m.ctrl.Call(m, "Inject", ctx, value) 275 | } 276 | 277 | // Inject indicates an expected call of Inject. 278 | func (mr *MockTokenCarrierMockRecorder) Inject(ctx, value any) *gomock.Call { 279 | mr.mock.ctrl.T.Helper() 280 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Inject", reflect.TypeOf((*MockTokenCarrier)(nil).Inject), ctx, value) 281 | } 282 | -------------------------------------------------------------------------------- /session/redis/provider.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package redis 16 | 17 | import ( 18 | "time" 19 | 20 | "github.com/ecodeclub/ginx/session/header" 21 | 22 | "github.com/ecodeclub/ginx" 23 | 24 | "github.com/ecodeclub/ginx/gctx" 25 | ijwt "github.com/ecodeclub/ginx/internal/jwt" 26 | "github.com/ecodeclub/ginx/session" 27 | "github.com/google/uuid" 28 | "github.com/redis/go-redis/v9" 29 | ) 30 | 31 | var _ session.Provider = &SessionProvider{} 32 | 33 | // SessionProvider 默认情况下,产生的 Session 一个 token, 34 | // 而如何返回,以及如何携带,取决于具体的 TokenCarrier 实现 35 | // 很多字段并没有暴露,如果你需要自定义,可以发 issue 36 | type SessionProvider struct { 37 | client redis.Cmdable 38 | m ijwt.Manager[session.Claims] 39 | TokenCarrier session.TokenCarrier 40 | expiration time.Duration 41 | } 42 | 43 | func (rsp *SessionProvider) Destroy(ctx *gctx.Context) error { 44 | sess, err := rsp.Get(ctx) 45 | if err != nil { 46 | return err 47 | } 48 | // 清除 token 49 | rsp.TokenCarrier.Clear(ctx) 50 | return sess.Destroy(ctx) 51 | } 52 | 53 | // UpdateClaims 在这个实现里面,claims 同时写进去了 54 | func (rsp *SessionProvider) UpdateClaims(ctx *gctx.Context, claims session.Claims) error { 55 | accessToken, err := rsp.m.GenerateAccessToken(claims) 56 | if err != nil { 57 | return err 58 | } 59 | rsp.TokenCarrier.Inject(ctx, accessToken) 60 | return nil 61 | } 62 | 63 | func (rsp *SessionProvider) RenewAccessToken(ctx *ginx.Context) error { 64 | // 此时这里应该放着 RefreshToken 65 | rt := rsp.TokenCarrier.Extract(ctx) 66 | jwtClaims, err := rsp.m.VerifyAccessToken(rt) 67 | if err != nil { 68 | return err 69 | } 70 | claims := jwtClaims.Data 71 | accessToken, err := rsp.m.GenerateAccessToken(claims) 72 | rsp.TokenCarrier.Inject(ctx, accessToken) 73 | return err 74 | } 75 | 76 | // NewSession 的时候,要先把这个 data 写入到对应的 token 里面 77 | func (rsp *SessionProvider) NewSession(ctx *gctx.Context, 78 | uid int64, 79 | jwtData map[string]string, 80 | sessData map[string]any) (session.Session, error) { 81 | ssid := uuid.New().String() 82 | claims := session.Claims{Uid: uid, 83 | SSID: ssid, 84 | Expiration: time.Now().Add(rsp.expiration).UnixMilli(), 85 | Data: jwtData} 86 | accessToken, err := rsp.m.GenerateAccessToken(claims) 87 | if err != nil { 88 | return nil, err 89 | } 90 | rsp.TokenCarrier.Inject(ctx, accessToken) 91 | res := newRedisSession(ssid, rsp.expiration, rsp.client, claims) 92 | if sessData == nil { 93 | sessData = make(map[string]any, 1) 94 | } 95 | sessData["uid"] = uid 96 | err = res.init(ctx, sessData) 97 | return res, err 98 | } 99 | 100 | // Get 返回 Session,如果没有拿到 session 或者 session 已经过期,会返回 error 101 | func (rsp *SessionProvider) Get(ctx *gctx.Context) (session.Session, error) { 102 | val, _ := ctx.Get(session.CtxSessionKey) 103 | // 对接口断言,而不是对实现断言 104 | res, ok := val.(session.Session) 105 | if ok { 106 | return res, nil 107 | } 108 | token := rsp.TokenCarrier.Extract(ctx) 109 | claims, err := rsp.m.VerifyAccessToken(token) 110 | if err != nil { 111 | return nil, err 112 | } 113 | res = newRedisSession(claims.Data.SSID, rsp.expiration, rsp.client, claims.Data) 114 | return res, nil 115 | } 116 | 117 | // NewSessionProvider 用于管理 Session 118 | func NewSessionProvider(client redis.Cmdable, jwtKey string, 119 | expiration time.Duration) *SessionProvider { 120 | // 长 token 过期时间,被看做是 Session 的过期时间 121 | m := ijwt.NewManagement[session.Claims](ijwt.NewOptions(expiration, jwtKey)) 122 | return &SessionProvider{ 123 | client: client, 124 | TokenCarrier: header.NewTokenCarrier(), 125 | m: m, 126 | expiration: expiration, 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /session/redis/provider_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //go:build e2e 16 | 17 | package redis 18 | 19 | import ( 20 | "net/http/httptest" 21 | "testing" 22 | "time" 23 | 24 | "github.com/ecodeclub/ginx/internal/e2e" 25 | "github.com/stretchr/testify/suite" 26 | 27 | "github.com/ecodeclub/ginx/gctx" 28 | "github.com/ecodeclub/ginx/internal/mocks" 29 | "github.com/ecodeclub/ginx/session" 30 | "github.com/gin-gonic/gin" 31 | "github.com/redis/go-redis/v9" 32 | "github.com/stretchr/testify/assert" 33 | "github.com/stretchr/testify/require" 34 | "go.uber.org/mock/gomock" 35 | ) 36 | 37 | type ProviderTestSuite struct { 38 | e2e.BaseSuite 39 | } 40 | 41 | func TestSessionProvider_UpdateClaims(t *testing.T) { 42 | testCases := []struct { 43 | name string 44 | mock func(ctrl *gomock.Controller) redis.Cmdable 45 | wantErr error 46 | }{ 47 | { 48 | name: "更新成功", 49 | mock: func(ctrl *gomock.Controller) redis.Cmdable { 50 | cmd := mocks.NewMockCmdable(ctrl) 51 | pip := mocks.NewMockPipeliner(ctrl) 52 | pip.EXPECT().HMSet(gomock.Any(), gomock.Any(), gomock.Any()). 53 | AnyTimes().Return(nil) 54 | pip.EXPECT().Expire(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) 55 | pip.EXPECT().Exec(gomock.Any()).Return(nil, nil) 56 | cmd.EXPECT().Pipeline().Return(pip) 57 | return cmd 58 | }, 59 | }, 60 | } 61 | for _, tc := range testCases { 62 | t.Run(tc.name, func(t *testing.T) { 63 | ctrl := gomock.NewController(t) 64 | defer ctrl.Finish() 65 | client := tc.mock(ctrl) 66 | sp := NewSessionProvider(client, "123", time.Minute) 67 | recorder := httptest.NewRecorder() 68 | 69 | ctx, _ := gin.CreateTestContext(recorder) 70 | // 先创建一个 71 | _, err := sp.NewSession(&gctx.Context{ 72 | Context: ctx, 73 | }, 123, map[string]string{"hello": "world"}, map[string]any{}) 74 | 75 | gtx := &gctx.Context{ 76 | Context: ctx, 77 | } 78 | newCl := session.Claims{ 79 | Uid: 234, 80 | SSID: "ssid_123", 81 | Expiration: 123, 82 | Data: map[string]string{"hello": "nihao"}} 83 | 84 | err = sp.UpdateClaims(gtx, newCl) 85 | assert.Equal(t, tc.wantErr, err) 86 | if err != nil { 87 | return 88 | } 89 | token := ctx.Writer.Header().Get("X-Access-Token") 90 | rc, err := sp.m.VerifyAccessToken(token) 91 | require.NoError(t, err) 92 | cl := rc.Data 93 | assert.Equal(t, newCl, cl) 94 | }) 95 | } 96 | } 97 | 98 | func TestProvider(t *testing.T) { 99 | suite.Run(t, new(ProviderTestSuite)) 100 | } 101 | 102 | // 历史测试,后面考虑删了 103 | func TestSessionProvider_NewSession(t *testing.T) { 104 | testCases := []struct { 105 | name string 106 | mock func(ctrl *gomock.Controller) redis.Cmdable 107 | key string 108 | wantErr error 109 | }{ 110 | { 111 | name: "成功", 112 | mock: func(ctrl *gomock.Controller) redis.Cmdable { 113 | cmd := mocks.NewMockCmdable(ctrl) 114 | pip := mocks.NewMockPipeliner(ctrl) 115 | pip.EXPECT().HMSet(gomock.Any(), gomock.Any(), gomock.Any()). 116 | AnyTimes().Return(nil) 117 | pip.EXPECT().Expire(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) 118 | pip.EXPECT().Exec(gomock.Any()).Return(nil, nil) 119 | cmd.EXPECT().Pipeline().Return(pip) 120 | return cmd 121 | }, 122 | key: "key1", 123 | }, 124 | } 125 | for _, tc := range testCases { 126 | t.Run(tc.name, func(t *testing.T) { 127 | ctrl := gomock.NewController(t) 128 | defer ctrl.Finish() 129 | client := tc.mock(ctrl) 130 | sp := NewSessionProvider(client, tc.key, time.Minute) 131 | recorder := httptest.NewRecorder() 132 | ctx, _ := gin.CreateTestContext(recorder) 133 | sess, err := sp.NewSession(&gctx.Context{ 134 | Context: ctx, 135 | }, 123, map[string]string{"hello": "world"}, map[string]any{}) 136 | assert.Equal(t, tc.wantErr, err) 137 | if err != nil { 138 | return 139 | } 140 | rs, ok := sess.(*Session) 141 | require.True(t, ok) 142 | cl := rs.Claims() 143 | assert.True(t, len(cl.SSID) > 0) 144 | cl.SSID = "" 145 | assert.Greater(t, cl.Expiration, int64(0)) 146 | cl.Expiration = 0 147 | assert.Equal(t, session.Claims{ 148 | Uid: 123, 149 | Data: map[string]string{"hello": "world"}, 150 | }, cl) 151 | }) 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /session/redis/session.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package redis 16 | 17 | import ( 18 | "context" 19 | "time" 20 | 21 | "github.com/ecodeclub/ekit" 22 | "github.com/ecodeclub/ginx/session" 23 | "github.com/redis/go-redis/v9" 24 | ) 25 | 26 | var _ session.Session = &Session{} 27 | 28 | // Session 生命周期应该和 http 请求保持一致 29 | type Session struct { 30 | client redis.Cmdable 31 | // key 是 ssid 拼接而成。注意,它不是 access token 32 | key string 33 | claims session.Claims 34 | expiration time.Duration 35 | } 36 | 37 | func (sess *Session) Destroy(ctx context.Context) error { 38 | return sess.client.Del(ctx, sess.key).Err() 39 | } 40 | 41 | func (sess *Session) Del(ctx context.Context, key string) error { 42 | return sess.client.Del(ctx, sess.key, key).Err() 43 | } 44 | 45 | func (sess *Session) Set(ctx context.Context, key string, val any) error { 46 | return sess.client.HSet(ctx, sess.key, key, val).Err() 47 | } 48 | 49 | func (sess *Session) init(ctx context.Context, kvs map[string]any) error { 50 | pip := sess.client.Pipeline() 51 | for k, v := range kvs { 52 | pip.HMSet(ctx, sess.key, k, v) 53 | } 54 | pip.Expire(ctx, sess.key, sess.expiration) 55 | _, err := pip.Exec(ctx) 56 | return err 57 | } 58 | 59 | func (sess *Session) Get(ctx context.Context, key string) ekit.AnyValue { 60 | res, err := sess.client.HGet(ctx, sess.key, key).Result() 61 | if err != nil { 62 | return ekit.AnyValue{Err: err} 63 | } 64 | return ekit.AnyValue{ 65 | Val: res, 66 | } 67 | } 68 | 69 | func (sess *Session) Claims() session.Claims { 70 | return sess.claims 71 | } 72 | 73 | func newRedisSession( 74 | ssid string, 75 | expiration time.Duration, 76 | client redis.Cmdable, cl session.Claims) *Session { 77 | return &Session{ 78 | client: client, 79 | key: "session:" + ssid, 80 | expiration: expiration, 81 | claims: cl, 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /session/redis/session_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //go:build e2e 16 | 17 | package redis 18 | 19 | import ( 20 | "context" 21 | "testing" 22 | "time" 23 | 24 | "github.com/ecodeclub/ginx/internal/e2e" 25 | "github.com/ecodeclub/ginx/session" 26 | "github.com/stretchr/testify/assert" 27 | "github.com/stretchr/testify/require" 28 | "github.com/stretchr/testify/suite" 29 | ) 30 | 31 | type SessionE2ETestSuite struct { 32 | e2e.BaseSuite 33 | } 34 | 35 | func (s *SessionE2ETestSuite) TestGetSetDel() { 36 | ssid := "test_ssid" 37 | sess := newRedisSession(ssid, time.Minute, s.RDB, session.Claims{ 38 | Uid: 123, 39 | SSID: ssid, 40 | Data: map[string]string{ 41 | "key1": "value1", 42 | }, 43 | }) 44 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 45 | defer cancel() 46 | defer sess.Destroy(ctx) 47 | ssKey1, ssVal1 := "ss_key1", "ss_val1" 48 | err := sess.Set(ctx, ssKey1, ssVal1) 49 | require.NoError(s.T(), err) 50 | val, err := sess.Get(ctx, ssKey1).AsString() 51 | require.NoError(s.T(), err) 52 | assert.Equal(s.T(), ssVal1, val) 53 | } 54 | 55 | func TestSession(t *testing.T) { 56 | suite.Run(t, new(SessionE2ETestSuite)) 57 | } 58 | -------------------------------------------------------------------------------- /session/types.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package session 16 | 17 | import ( 18 | "context" 19 | 20 | "github.com/ecodeclub/ekit" 21 | "github.com/ecodeclub/ginx/gctx" 22 | "github.com/ecodeclub/ginx/internal/errs" 23 | ) 24 | 25 | // Session 混合了 JWT 的设计。 26 | type Session interface { 27 | // Set 将数据写入到 Session 里面 28 | Set(ctx context.Context, key string, val any) error 29 | // Get 从 Session 中获取数据,注意,这个方法不会从 JWT 里面获取数据 30 | Get(ctx context.Context, key string) ekit.AnyValue 31 | // Del 删除对应的数据 32 | Del(ctx context.Context, key string) error 33 | // Destroy 销毁整个 Session 34 | Destroy(ctx context.Context) error 35 | // Claims 编码进去了 JWT 里面的数据 36 | Claims() Claims 37 | } 38 | 39 | // Provider 定义了 Session 的整个管理机制。 40 | // 所有的 Session 都必须支持 jwt 41 | // 42 | //go:generate mockgen -source=./types.go -destination=./provider.mock_test.go -package=session Provider 43 | type Provider interface { 44 | // NewSession 将会初始化 Session 45 | // 其中 jwtData 将编码进去 jwt 中 46 | // sessData 将被放进去 Session 中 47 | NewSession(ctx *gctx.Context, uid int64, jwtData map[string]string, 48 | sessData map[string]any) (Session, error) 49 | // Get 尝试拿到 Session,如果没有,返回 error 50 | // Get 必须校验 Session 的合法性。 51 | // 也就是,用户可以预期拿到的 Session 永远是没有过期,直接可用的 52 | Get(ctx *gctx.Context) (Session, error) 53 | 54 | // Destroy 销毁 Session,一般用在退出登录这种地方 55 | Destroy(ctx *gctx.Context) error 56 | 57 | // UpdateClaims 修改 claims 的数据 58 | // 但是因为 jwt 本身是不可变的,所以实际上这里是重新生成了一个 jwt 的 token 59 | // 必须传入正确的 SSID 60 | UpdateClaims(ctx *gctx.Context, claims Claims) error 61 | 62 | // RenewAccessToken 刷新并且返回一个新的 access token 63 | // 注意,必须是之前的 AccessToken 快要过期但是还没过期的时候 64 | RenewAccessToken(ctx *gctx.Context) error 65 | } 66 | 67 | type Claims struct { 68 | Uid int64 69 | SSID string 70 | Data map[string]string 71 | // 过期时间。毫秒数 72 | Expiration int64 73 | } 74 | 75 | func (c Claims) Get(key string) ekit.AnyValue { 76 | val, ok := c.Data[key] 77 | if !ok { 78 | return ekit.AnyValue{Err: errs.ErrSessionKeyNotFound} 79 | } 80 | return ekit.AnyValue{Val: val} 81 | } 82 | 83 | // TokenCarrier 用于决定是使用 Header 还是使用 Cookie 来作为 84 | type TokenCarrier interface { 85 | Inject(ctx *gctx.Context, value string) 86 | Extract(ctx *gctx.Context) string 87 | Clear(ctx *gctx.Context) 88 | } 89 | -------------------------------------------------------------------------------- /types.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package ginx 16 | 17 | import ( 18 | "github.com/ecodeclub/ginx/gctx" 19 | "github.com/gin-gonic/gin" 20 | ) 21 | 22 | type Handler interface { 23 | PrivateRoutes(server *gin.Engine) 24 | PublicRoutes(server *gin.Engine) 25 | } 26 | 27 | type Result struct { 28 | Code int `json:"code"` 29 | Msg string `json:"msg"` 30 | Data any `json:"data"` 31 | } 32 | 33 | type Context = gctx.Context 34 | -------------------------------------------------------------------------------- /wrapper_func.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package ginx 16 | 17 | import ( 18 | "errors" 19 | "log/slog" 20 | "net/http" 21 | 22 | "github.com/ecodeclub/ginx/session" 23 | "github.com/gin-gonic/gin" 24 | ) 25 | 26 | func W(fn func(ctx *Context) (Result, error)) gin.HandlerFunc { 27 | return func(ctx *gin.Context) { 28 | res, err := fn(&Context{Context: ctx}) 29 | if errors.Is(err, ErrNoResponse) { 30 | slog.Debug("不需要响应", slog.Any("err", err)) 31 | return 32 | } 33 | if errors.Is(err, ErrUnauthorized) { 34 | slog.Debug("未授权", slog.Any("err", err)) 35 | ctx.AbortWithStatus(http.StatusUnauthorized) 36 | return 37 | } 38 | if err != nil { 39 | slog.Error("执行业务逻辑失败", slog.Any("err", err)) 40 | ctx.PureJSON(http.StatusInternalServerError, res) 41 | return 42 | } 43 | ctx.PureJSON(http.StatusOK, res) 44 | } 45 | } 46 | 47 | func B[Req any](fn func(ctx *Context, req Req) (Result, error)) gin.HandlerFunc { 48 | return func(ctx *gin.Context) { 49 | var req Req 50 | if err := ctx.Bind(&req); err != nil { 51 | slog.Debug("绑定参数失败", slog.Any("err", err)) 52 | return 53 | } 54 | res, err := fn(&Context{Context: ctx}, req) 55 | if errors.Is(err, ErrNoResponse) { 56 | slog.Debug("不需要响应", slog.Any("err", err)) 57 | return 58 | } 59 | if errors.Is(err, ErrUnauthorized) { 60 | slog.Debug("未授权", slog.Any("err", err)) 61 | ctx.AbortWithStatus(http.StatusUnauthorized) 62 | return 63 | } 64 | if err != nil { 65 | slog.Error("执行业务逻辑失败", slog.Any("err", err)) 66 | ctx.PureJSON(http.StatusInternalServerError, res) 67 | return 68 | } 69 | ctx.PureJSON(http.StatusOK, res) 70 | } 71 | } 72 | 73 | // BS 的意思是,传入的业务逻辑方法可以接受 req 和 sess 两个参数 74 | func BS[Req any](fn func(ctx *Context, req Req, sess session.Session) (Result, error)) gin.HandlerFunc { 75 | return func(ctx *gin.Context) { 76 | gtx := &Context{Context: ctx} 77 | sess, err := session.Get(gtx) 78 | if err != nil { 79 | slog.Debug("获取 Session 失败", slog.Any("err", err)) 80 | ctx.AbortWithStatus(http.StatusUnauthorized) 81 | return 82 | } 83 | var req Req 84 | // Bind 方法本身会返回 400 的错误 85 | if err := ctx.Bind(&req); err != nil { 86 | slog.Debug("绑定参数失败", slog.Any("err", err)) 87 | return 88 | } 89 | res, err := fn(gtx, req, sess) 90 | if errors.Is(err, ErrNoResponse) { 91 | slog.Debug("不需要响应", slog.Any("err", err)) 92 | return 93 | } 94 | // 如果里面有权限校验,那么会返回 401 错误(目前来看,主要是登录态校验) 95 | if errors.Is(err, ErrUnauthorized) { 96 | slog.Debug("未授权", slog.Any("err", err)) 97 | ctx.AbortWithStatus(http.StatusUnauthorized) 98 | return 99 | } 100 | if err != nil { 101 | slog.Error("执行业务逻辑失败", slog.Any("err", err)) 102 | ctx.PureJSON(http.StatusInternalServerError, res) 103 | return 104 | } 105 | ctx.PureJSON(http.StatusOK, res) 106 | } 107 | } 108 | 109 | // S 的意思是,传入的业务逻辑方法可以接受 Session 参数 110 | func S(fn func(ctx *Context, sess session.Session) (Result, error)) gin.HandlerFunc { 111 | return func(ctx *gin.Context) { 112 | gtx := &Context{Context: ctx} 113 | sess, err := session.Get(gtx) 114 | if err != nil { 115 | slog.Debug("获取 Session 失败", slog.Any("err", err)) 116 | ctx.AbortWithStatus(http.StatusUnauthorized) 117 | return 118 | } 119 | res, err := fn(gtx, sess) 120 | if errors.Is(err, ErrNoResponse) { 121 | slog.Debug("不需要响应", slog.Any("err", err)) 122 | return 123 | } 124 | // 如果里面有权限校验,那么会返回 401 错误(目前来看,主要是登录态校验) 125 | if errors.Is(err, ErrUnauthorized) { 126 | slog.Debug("未授权", slog.Any("err", err)) 127 | ctx.AbortWithStatus(http.StatusUnauthorized) 128 | return 129 | } 130 | if err != nil { 131 | slog.Error("执行业务逻辑失败", slog.Any("err", err)) 132 | ctx.PureJSON(http.StatusInternalServerError, res) 133 | return 134 | } 135 | ctx.PureJSON(http.StatusOK, res) 136 | } 137 | } 138 | --------------------------------------------------------------------------------