├── .CHANGELOG.md ├── .LICENSE_FILE_HEADER ├── .deepsource.toml ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── question.md ├── pre-commit ├── pre-push └── workflows │ ├── changelog.yml │ ├── go-fmt.yml │ ├── go.yml │ ├── golangci-lint.yml │ ├── integration_test.yml │ ├── license.yml │ └── stale.yml ├── .gitignore ├── .licenserc.json ├── LICENSE ├── Makefile ├── README.md ├── aggregate.go ├── aggregate_test.go ├── assignment.go ├── assignment_test.go ├── builder.go ├── builder_test.go ├── column.go ├── column_test.go ├── core.go ├── db.go ├── db_test.go ├── delete.go ├── delete_test.go ├── error.go ├── expression.go ├── expression_test.go ├── go.mod ├── go.sum ├── insert.go ├── insert_builder.go ├── insert_test.go ├── internal ├── datasource │ ├── cluster │ │ ├── cluster_db.go │ │ └── cluster_db_test.go │ ├── masterslave │ │ ├── master_slave_db.go │ │ ├── master_slave_db_test.go │ │ └── slaves │ │ │ ├── dns │ │ │ ├── dns.go │ │ │ ├── dns_test.go │ │ │ └── mysql │ │ │ │ ├── dsn.go │ │ │ │ └── dsn_test.go │ │ │ ├── roundrobin │ │ │ ├── roundrobin.go │ │ │ └── roundrobin_test.go │ │ │ └── type.go │ ├── shardingsource │ │ ├── sharding_datasource.go │ │ └── sharding_datasource_test.go │ ├── single │ │ ├── db.go │ │ └── db_test.go │ ├── transaction │ │ ├── delay_transaction.go │ │ ├── delay_transaction_test.go │ │ ├── single_transaction.go │ │ ├── single_transaction_test.go │ │ ├── transaction.go │ │ ├── transaction_suite_test.go │ │ ├── transaction_test.go │ │ └── types.go │ └── types.go ├── dialect │ ├── dialect.go │ └── dialect_test.go ├── errs │ └── error.go ├── integration │ ├── base_test.go │ ├── delete_composition_test.go │ ├── delete_masterslave_test.go │ ├── delete_test.go │ ├── insert_combination_test.go │ ├── insert_masterslave_test.go │ ├── insert_test.go │ ├── select_combination_test.go │ ├── select_masterslave_test.go │ ├── select_test.go │ ├── sharding_delay_transaction_test.go │ ├── sharding_insert_test.go │ ├── sharding_select_test.go │ ├── sharding_single_transaction_test.go │ ├── sharding_suite_test.go │ ├── sharding_update_test.go │ ├── update_combination_test.go │ ├── update_masterslave_test.go │ └── update_test.go ├── merger │ ├── factory │ │ ├── factory.go │ │ └── factory_test.go │ ├── internal │ │ ├── aggregatemerger │ │ │ ├── aggregator │ │ │ │ ├── avg.go │ │ │ │ ├── avg_test.go │ │ │ │ ├── count.go │ │ │ │ ├── count_test.go │ │ │ │ ├── max.go │ │ │ │ ├── max_test.go │ │ │ │ ├── min.go │ │ │ │ ├── min_test.go │ │ │ │ ├── sum.go │ │ │ │ ├── sum_test.go │ │ │ │ └── type.go │ │ │ ├── merger.go │ │ │ └── merger_test.go │ │ ├── batchmerger │ │ │ ├── merger.go │ │ │ └── merger_test.go │ │ ├── distinctmerger │ │ │ ├── merger.go │ │ │ └── merger_test.go │ │ ├── errs │ │ │ └── error.go │ │ ├── groupbymerger │ │ │ ├── aggregator_merger.go │ │ │ └── aggregator_merger_test.go │ │ ├── pagedmerger │ │ │ ├── merger.go │ │ │ └── merger_test.go │ │ └── sortmerger │ │ │ ├── heap │ │ │ ├── heap.go │ │ │ └── heap_test.go │ │ │ ├── merger.go │ │ │ └── merger_test.go │ ├── type.go │ └── type_test.go ├── model │ ├── model.go │ └── model_test.go ├── operator │ └── operator.go ├── query │ └── query.go ├── rows │ ├── convert_assign.go │ ├── convert_assign_test.go │ ├── data_rows.go │ ├── data_rows_test.go │ └── types.go ├── sharding │ ├── compare.go │ ├── hash │ │ ├── hash.go │ │ └── shadow_hash.go │ ├── result.go │ └── types.go ├── test │ ├── types.go │ └── types_test.go └── valuer │ ├── primitive.go │ ├── primitive_test.go │ ├── reflect.go │ ├── reflect_test.go │ ├── unsafe.go │ ├── unsafe_test.go │ └── value.go ├── middleware.go ├── middleware └── querylog │ ├── querylog.go │ └── querylog_test.go ├── middleware_test.go ├── predicate.go ├── predicate_test.go ├── result.go ├── result_test.go ├── script ├── fmt.sh ├── integrate_test.sh ├── integration_test_compose.yml ├── mysql │ ├── init.sql │ ├── master │ │ ├── init.sql │ │ └── master.sh │ └── slave │ │ ├── init.sql │ │ └── slave.sh └── setup.sh ├── select.go ├── select_builder.go ├── select_test.go ├── session.go ├── sharding_builder.go ├── sharding_insert.go ├── sharding_insert_test.go ├── sharding_select.go ├── sharding_select_test.go ├── sharding_update.go ├── sharding_update_test.go ├── table.go ├── transaction.go ├── transaction_test.go ├── types.go ├── update.go ├── update_builder.go └── update_test.go /.LICENSE_FILE_HEADER: -------------------------------------------------------------------------------- 1 | Copyright 2021 ecodeclub 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /.deepsource.toml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | version = 1 16 | 17 | [[analyzers]] 18 | name = "go" 19 | enabled = true 20 | 21 | [analyzers.meta] 22 | import_root = "github.com/ecodeclub/eorm" 23 | dependencies_vendored = false 24 | cyclomatic_complexity_threshold = "high" 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **仅限中文** 11 | 12 | 在提之前请先查找[已有 issues](https://github.com/ecodeclub/eorm/issues),避免重复上报。 13 | 14 | 并且确保自己已经: 15 | - [ ] 阅读过文档 16 | - [ ] 阅读过注释 17 | - [ ] 阅读过例子 18 | 19 | ### 问题简要描述 20 | 21 | ### 复现步骤 22 | > 如果涉及到数据库表,你必须提供模型定义和表结构定义 23 | 24 | ### 错误日志或者截图 25 | 26 | ### 你期望的结果 27 | 28 | ### 你排查的结果,或者你觉得可行的修复方案 29 | > 可选。我们希望你能够尽量先排查问题,帮助我们减轻维护负担。这对于你个人能力提升同样是有帮助的。 30 | 31 | ### 你使用的是 eorm 哪个版本? 32 | 33 | ### 你设置的的 Go 环境? 34 | > 上传 `go env` 的结果 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: feature 6 | assignees: '' 7 | 8 | --- 9 | 10 | **仅限中文** 11 | 12 | ### 使用场景 13 | 14 | ### 行业分析 15 | > 如果你知道有框架提供了类似功能,可以在这里描述,并且给出文档或者例子 16 | 17 | ### 可行方案 18 | > 如果你有设计思路或者解决方案,请在这里提供。你可以提供多个方案,并且给出自己的选择 19 | 20 | ### 其它 21 | > 任何你觉得有利于解决问题的补充说明 22 | 23 | ### 你使用的是 eorm 哪个版本? 24 | 25 | ### 你设置的的 Go 环境? 26 | > 上传 `go env` 的结果 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Want to ask some questions 4 | title: '' 5 | labels: question 6 | --- 7 | 8 | **仅限中文** 9 | 10 | 在提之前请先查找[已有 issues](https://github.com/ecodeclub/eorm/issues),避免重复上报。 11 | 12 | 并且确保自己已经: 13 | - [ ] 阅读过文档 14 | - [ ] 阅读过注释 15 | - [ ] 阅读过例子 16 | 17 | ### 你的问题 18 | 19 | ### 你使用的是 eorm 哪个版本? 20 | 21 | ### 你设置的的 Go 环境? 22 | > 上传 `go env` 的结果 23 | -------------------------------------------------------------------------------- /.github/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Copyright 2021 ecodeclub 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # To use, store as .git/hooks/pre-commit inside your repository and make sure 17 | # it has execute permissions. 18 | # 19 | # This script does not handle file names that contain spaces. 20 | 21 | # Pre-commit configuration 22 | 23 | RESULT=$(make check) 24 | printf "执行检查中...\n" 25 | 26 | if [ -n "$RESULT" ]; then 27 | echo >&2 "[ERROR]: 有文件发生变更,请将变更文件添加到本次提交中" 28 | exit 1 29 | fi 30 | 31 | exit 0 32 | -------------------------------------------------------------------------------- /.github/pre-push: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Copyright 2021 ecodeclub 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # git test pre-push hook 16 | # 17 | # To use, store as .git/hooks/pre-push inside your repository and make sure 18 | # it has execute permissions. 19 | # 20 | # This script does not handle file names that contain spaces. 21 | 22 | # Pre-push configuration 23 | remote=$1 24 | url=$2 25 | echo >&2 "Try pushing $2 to $1" 26 | 27 | TEST="go test ./... -race -cover -failfast" 28 | LINTER="golangci-lint run" 29 | 30 | # Run test and return if failed 31 | printf "Running go test..." 32 | $TEST 33 | RESULT=$? 34 | if [ $RESULT -ne 0 ]; then 35 | echo >&2 "$TEST" 36 | echo >&2 "Check code to pass test." 37 | exit 1 38 | fi 39 | 40 | # Run linter and return if failed 41 | printf "Running go linter..." 42 | $LINTER 43 | RESULT=$? 44 | if [ $RESULT -ne 0 ]; then 45 | echo >&2 "$LINTER" 46 | echo >&2 "Check code to pass linter." 47 | exit 1 48 | fi 49 | 50 | exit 0 -------------------------------------------------------------------------------- /.github/workflows/changelog.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: changelog 16 | 17 | on: 18 | pull_request: 19 | types: [opened, synchronize, reopened, labeled, unlabeled] 20 | branches: 21 | - dev 22 | - main 23 | 24 | jobs: 25 | changelog: 26 | runs-on: ubuntu-latest 27 | if: "!contains(github.event.pull_request.labels.*.name, 'Skip Changelog')" 28 | 29 | steps: 30 | - uses: actions/checkout@v2 31 | 32 | - name: Check for CHANGELOG changes 33 | run: | 34 | # Only the latest commit of the feature branch is available 35 | # automatically. To diff with the base branch, we need to 36 | # fetch that too (and we only need its latest commit). 37 | git fetch origin ${{ github.base_ref }} --depth=1 38 | if [[ $(git diff --name-only FETCH_HEAD | grep CHANGELOG) ]] 39 | then 40 | echo "A CHANGELOG was modified. Looks good!" 41 | else 42 | echo "No CHANGELOG was modified." 43 | echo "Please add a CHANGELOG entry, or add the \"Skip Changelog\" label if not required." 44 | false 45 | fi -------------------------------------------------------------------------------- /.github/workflows/go-fmt.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Format Go code 16 | 17 | on: 18 | push: 19 | branches: [ main, dev] 20 | pull_request: 21 | branches: [ main, dev] 22 | 23 | jobs: 24 | build: 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: actions/checkout@v3 28 | - name: Set up Go 29 | uses: actions/setup-go@v3 30 | with: 31 | go-version: ">=1.20.0" 32 | 33 | - name: Install goimports 34 | run: go install golang.org/x/tools/cmd/goimports@latest 35 | 36 | - name: Check 37 | run: | 38 | make check 39 | if [ -n "$(git status --porcelain)" ]; then 40 | echo >&2 "错误: 请在本地运行命令'make check'后再提交." 41 | exit 1 42 | fi -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Go 16 | 17 | on: 18 | push: 19 | branches: [ main, dev] 20 | pull_request: 21 | branches: [ main, dev] 22 | 23 | jobs: 24 | build: 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: actions/checkout@v2 28 | - name: Set up Go 29 | uses: actions/setup-go@v2 30 | with: 31 | go-version: '1.20' 32 | 33 | - name: Build 34 | run: go build -v ./... 35 | 36 | - name: Test 37 | run: go test -race -coverprofile=cover.out -v ./... 38 | 39 | - name: Post Coverage 40 | uses: codecov/codecov-action@v2 -------------------------------------------------------------------------------- /.github/workflows/golangci-lint.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: golangci-lint 16 | on: 17 | push: 18 | tags: 19 | - v* 20 | branches: 21 | - master 22 | - main 23 | - dev 24 | pull_request: 25 | permissions: 26 | contents: read 27 | # Optional: allow read access to pull request. Use with `only-new-issues` option. 28 | # pull-requests: read 29 | jobs: 30 | golangci: 31 | name: lint 32 | runs-on: ubuntu-latest 33 | steps: 34 | - uses: actions/setup-go@v3 35 | with: 36 | go-version: '1.20' 37 | - uses: actions/checkout@v3 38 | - name: golangci-lint 39 | uses: golangci/golangci-lint-action@v3 40 | with: 41 | # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version 42 | version: "v1.52.2" 43 | 44 | # Optional: working directory, useful for monorepos 45 | # working-directory: somedir 46 | 47 | # Optional: golangci-lint command line arguments. 48 | args: --timeout=10m # --issues-exit-code=0 49 | 50 | # Optional: show only new issues if it's a pull request. The default value is `false`. 51 | only-new-issues: true 52 | 53 | # Optional: if set to true then the all caching functionality will be complete disabled, 54 | # takes precedence over all other caching options. 55 | # skip-cache: true 56 | 57 | # Optional: if set to true then the action don't cache or restore ~/go/pkg. 58 | # skip-pkg-cache: true 59 | 60 | # Optional: if set to true then the action don't cache or restore ~/.cache/go-build. 61 | # skip-build-cache: true -------------------------------------------------------------------------------- /.github/workflows/integration_test.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Integration Test 16 | 17 | on: 18 | push: 19 | branches: [ main, dev] 20 | pull_request: 21 | branches: [ main, dev] 22 | 23 | jobs: 24 | build: 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: actions/checkout@v2 28 | - name: Set up Go 29 | uses: actions/setup-go@v2 30 | with: 31 | go-version: '1.20' 32 | 33 | - name: Test 34 | run: sudo sh ./script/integrate_test.sh -------------------------------------------------------------------------------- /.github/workflows/license.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Check License Lines 16 | on: 17 | pull_request: 18 | types: [opened, synchronize, reopened, labeled, unlabeled] 19 | branches: 20 | - dev 21 | - main 22 | jobs: 23 | check-license-lines: 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@master 27 | - name: Check License Lines 28 | uses: kt3k/license_checker@v1.0.6 -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Mark stale issues and pull requests 16 | 17 | on: 18 | schedule: 19 | - cron: "30 1 * * *" 20 | 21 | jobs: 22 | stale: 23 | 24 | runs-on: ubuntu-latest 25 | 26 | steps: 27 | - uses: actions/stale@v4 28 | with: 29 | repo-token: ${{ secrets.GITHUB_TOKEN }} 30 | stale-issue-message: 'This issue is inactive for a long time.' 31 | stale-pr-message: 'This PR is inactive for a long time' 32 | stale-issue-label: 'inactive-issue' 33 | stale-pr-label: 'inactive-pr' 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | .idea 18 | *.iml 19 | fuzz 20 | go.work -------------------------------------------------------------------------------- /.licenserc.json: -------------------------------------------------------------------------------- 1 | { 2 | "**/*.go": "// Copyright 2021 ecodeclub", 3 | "**/*.{yml,toml}": "# Copyright 2021 ecodeclub" 4 | } -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # 单元测试 2 | .PHONY: ut 3 | ut: 4 | @go test -race ./... 5 | 6 | .PHONY: setup 7 | setup: 8 | @sh ./script/setup.sh 9 | 10 | .PHONY: lint 11 | lint: 12 | golangci-lint run 13 | 14 | .PHONY: fmt 15 | fmt: 16 | @sh ./script/fmt.sh 17 | 18 | .PHONY: tidy 19 | tidy: 20 | @go mod tidy -v 21 | 22 | .PHONY: check 23 | check: 24 | @$(MAKE) --no-print-directory fmt 25 | @$(MAKE) --no-print-directory tidy 26 | 27 | # e2e 测试 28 | .PHONY: e2e 29 | e2e: 30 | sh ./script/integrate_test.sh 31 | 32 | .PHONY: e2e_up 33 | e2e_up: 34 | docker compose -f script/integration_test_compose.yml up -d 35 | 36 | .PHONY: e2e_down 37 | e2e_down: 38 | docker compose -f script/integration_test_compose.yml down -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EORM 2 | 3 | [![DeepSource](https://deepsource.io/gh/ecodeclub/eorm.svg/?label=active+issues&show_trend=true&token=pKgxd-FmZ5F3l0M2iXQRpBRy)](https://deepsource.io/gh/ecodeclub/eorm/?ref=repository-badge) 4 | 5 | 简单的 ORM 框架。 6 | 7 | ## 使用注意事项(技术选型前必读) 8 | 9 | 丑话说前头。 10 | 11 | - 目前来说,它还没有达到一个稳定的状态,在正式发布 v1.0.0 之前,你可以认为它都处于一种实验和学习状态。在这个状态下,我们并不会对 API 的稳定性做任何承诺。从这个角度来说,我们也不建议你将它用在一些核心项目里面 12 | 13 | ### Go 版本 14 | 15 | 请使用 Go 1.20 以上版本。 16 | 17 | ### SQL 2003 标准 18 | 理论上来说,我们计划支持 [SQL 2003 standard](https://ronsavage.github.io/SQL/sql-2003-2.bnf.html#query%20specification). 不过据我们所知,并不是所有的数据库都支持全部的 SQL 2003 标准,所以用户还是需要进一步检查目标数据库的语法。 19 | 20 | ### 全中文仓库 21 | 22 | 这是一个全中文的仓库。这意味着注释、文档和错误信息,都会是中文。介意的用户可以选择别的 ORM 仓库。 23 | 24 | 但是不必来反馈说希望提供英文版本,我们是不会提供的,因为: 25 | - 这个仓库目前以及可预测的将来,都不会有外国人使用。很多国内开发者的开源仓库提供了英文选项但是实际上英文用户数量感人,我就不想做这种性价比低的事情; 26 | - 双语言会导致英文用户有一些不切实际的期望,比如说在 issues 和 discussions 里面都使用英文交流。但是这对于我的目标来说,也是不现实的。因为很多国内开发者不具备这种英文能力,而且这些开发者非常倔强,即便我们一再强调请使用英文,他们也会固执使用中文。因此我索性将这个搞成全中文仓库,一了百了; 27 | - 中文终究是我的母语,所以使用中文的表达能力更加准确; 28 | - 翻译软件非常发达,真有英语用户要用,可以自己找翻译软件; 29 | 30 | ### 社区组织和讨论 31 | 32 | 短时间内,我不会组建任何的微信群,QQ 群或者钉钉群之类的即时通讯群。 33 | 34 | 我注意到很多开源仓库都会组织类似的群,但是这种群有利有弊,而且对于项目本身来说是弊大于利的。组建了不同的群之后会导致问题讨论被切割到不同的群里面。例如某个群讨论了 A 问题,其它群完全看不到。 35 | 36 | 另外一个原因就是,因为即时通讯过于便捷,会给维护者带来庞大的维护压力。用户可能期望自己的所有答案都能从群里得到解答,因此不愿意自己花时间去读文档,读注释,读例子。在这种情况下,他们会频繁艾特维护者,并希望维护者能够实时给出详细回答。而实际情况是,一般的小项目维护者可能只有两三个人,所以没有足够的精力来维护这种群。 37 | 38 | 我想要的社区是大家统一在 github 下,利用 issue 和 discussion 来讨论问题。这样别的用户都可以搜索到所有的讨论。 39 | 40 | ## 加入我们 41 | - [贡献者指南](https://doc.meoying.com/) 42 | -------------------------------------------------------------------------------- /aggregate.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | // Aggregate represents aggregate expression, including AVG, MAX, MIN... 18 | type Aggregate struct { 19 | table TableReference 20 | fn string 21 | arg string 22 | alias string 23 | distinct bool 24 | } 25 | 26 | // As 指定别名。一般情况下,这个别名应该等同于列名,我们会将这个列名映射过去对应的字段名。 27 | // 例如说 alias= avg_age,默认情况下,我们会找 AvgAge 这个字段来接收值。 28 | func (a Aggregate) As(alias string) Selectable { 29 | return Aggregate{ 30 | fn: a.fn, 31 | arg: a.arg, 32 | alias: alias, 33 | table: a.table, 34 | } 35 | } 36 | 37 | // Avg represents AVG 38 | func Avg(c string) Aggregate { 39 | return Aggregate{ 40 | fn: "AVG", 41 | arg: c, 42 | } 43 | } 44 | 45 | // Max represents MAX 46 | func Max(c string) Aggregate { 47 | return Aggregate{ 48 | fn: "MAX", 49 | arg: c, 50 | } 51 | } 52 | 53 | // Min represents MIN 54 | func Min(c string) Aggregate { 55 | return Aggregate{ 56 | fn: "MIN", 57 | arg: c, 58 | } 59 | } 60 | 61 | // Count represents COUNT 62 | func Count(c string) Aggregate { 63 | return Aggregate{ 64 | fn: "COUNT", 65 | arg: c, 66 | } 67 | } 68 | 69 | // Sum represents SUM 70 | func Sum(c string) Aggregate { 71 | return Aggregate{ 72 | fn: "SUM", 73 | arg: c, 74 | } 75 | } 76 | 77 | // CountDistinct represents COUNT(DISTINCT XXX) 78 | func CountDistinct(col string) Aggregate { 79 | a := Count(col) 80 | a.distinct = true 81 | return a 82 | } 83 | 84 | // AvgDistinct represents AVG(DISTINCT XXX) 85 | func AvgDistinct(col string) Aggregate { 86 | a := Avg(col) 87 | a.distinct = true 88 | return a 89 | } 90 | 91 | // SumDistinct represents SUM(DISTINCT XXX) 92 | func SumDistinct(col string) Aggregate { 93 | a := Sum(col) 94 | a.distinct = true 95 | return a 96 | } 97 | 98 | func (Aggregate) selected() {} 99 | 100 | func (a Aggregate) EQ(val interface{}) Predicate { 101 | return Predicate{ 102 | left: a, 103 | op: opEQ, 104 | right: valueOf(val), 105 | } 106 | } 107 | func (a Aggregate) NEQ(val interface{}) Predicate { 108 | return Predicate{ 109 | left: a, 110 | op: opNEQ, 111 | right: valueOf(val), 112 | } 113 | } 114 | 115 | // LT < 116 | func (a Aggregate) LT(val interface{}) Predicate { 117 | return Predicate{ 118 | left: a, 119 | op: opLT, 120 | right: valueOf(val), 121 | } 122 | } 123 | 124 | // LTEQ <= 125 | func (a Aggregate) LTEQ(val interface{}) Predicate { 126 | return Predicate{ 127 | left: a, 128 | op: opLTEQ, 129 | right: valueOf(val), 130 | } 131 | } 132 | 133 | // GT > 134 | func (a Aggregate) GT(val interface{}) Predicate { 135 | return Predicate{ 136 | left: a, 137 | op: opGT, 138 | right: valueOf(val), 139 | } 140 | } 141 | 142 | // GTEQ >= 143 | func (a Aggregate) GTEQ(val interface{}) Predicate { 144 | return Predicate{ 145 | left: a, 146 | op: opGTEQ, 147 | right: valueOf(val), 148 | } 149 | } 150 | 151 | func (Aggregate) expr() (string, error) { 152 | return "", nil 153 | } 154 | -------------------------------------------------------------------------------- /aggregate_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import ( 18 | "fmt" 19 | "testing" 20 | 21 | "github.com/stretchr/testify/assert" 22 | ) 23 | 24 | func TestAggregate(t *testing.T) { 25 | db := memoryDB() 26 | testCases := []CommonTestCase{ 27 | { 28 | name: "avg", 29 | builder: NewSelector[TestModel](db).Select(Avg("Age")), 30 | wantSql: "SELECT AVG(`age`) FROM `test_model`;", 31 | }, 32 | { 33 | name: "max", 34 | builder: NewSelector[TestModel](db).Select(Max("Age")), 35 | wantSql: "SELECT MAX(`age`) FROM `test_model`;", 36 | }, 37 | { 38 | name: "min", 39 | builder: NewSelector[TestModel](db).Select(Min("Age").As("min_age")), 40 | wantSql: "SELECT MIN(`age`) AS `min_age` FROM `test_model`;", 41 | }, 42 | { 43 | name: "sum", 44 | builder: NewSelector[TestModel](db).Select(Sum("Age")), 45 | wantSql: "SELECT SUM(`age`) FROM `test_model`;", 46 | }, 47 | { 48 | name: "count", 49 | builder: NewSelector[TestModel](db).Select(Count("Age")), 50 | wantSql: "SELECT COUNT(`age`) FROM `test_model`;", 51 | }, 52 | { 53 | name: "count distinct", 54 | builder: NewSelector[TestModel](db).Select(CountDistinct("FirstName")), 55 | wantSql: "SELECT COUNT(DISTINCT `first_name`) FROM `test_model`;", 56 | }, 57 | { 58 | name: "avg distinct", 59 | builder: NewSelector[TestModel](db).Select(AvgDistinct("FirstName")), 60 | wantSql: "SELECT AVG(DISTINCT `first_name`) FROM `test_model`;", 61 | }, 62 | { 63 | name: "SUM distinct", 64 | builder: NewSelector[TestModel](db).Select(SumDistinct("FirstName")), 65 | wantSql: "SELECT SUM(DISTINCT `first_name`) FROM `test_model`;", 66 | }, 67 | } 68 | 69 | for _, tc := range testCases { 70 | c := tc 71 | t.Run(c.name, func(t *testing.T) { 72 | query, err := c.builder.Build() 73 | assert.Equal(t, c.wantErr, err) 74 | if err != nil { 75 | return 76 | } 77 | assert.Equal(t, c.wantSql, query.SQL) 78 | assert.Equal(t, c.wantArgs, query.Args) 79 | }) 80 | } 81 | } 82 | 83 | func ExampleAggregate_As() { 84 | db := memoryDB() 85 | query, _ := NewSelector[TestModel](db).Select(Avg("Age").As("avg_age")).Build() 86 | fmt.Println(query.SQL) 87 | // Output: SELECT AVG(`age`) AS `avg_age` FROM `test_model`; 88 | } 89 | 90 | func ExampleAvg() { 91 | db := memoryDB() 92 | query, _ := NewSelector[TestModel](db).Select(Avg("Age").As("avg_age")).Build() 93 | fmt.Println(query.SQL) 94 | // Output: SELECT AVG(`age`) AS `avg_age` FROM `test_model`; 95 | } 96 | 97 | func ExampleCount() { 98 | db := memoryDB() 99 | query, _ := NewSelector[TestModel](db).Select(Count("Age")).Build() 100 | fmt.Println(query.SQL) 101 | // Output: SELECT COUNT(`age`) FROM `test_model`; 102 | } 103 | 104 | func ExampleMax() { 105 | db := memoryDB() 106 | query, _ := NewSelector[TestModel](db).Select(Max("Age")).Build() 107 | fmt.Println(query.SQL) 108 | // Output: SELECT MAX(`age`) FROM `test_model`; 109 | } 110 | 111 | func ExampleMin() { 112 | db := memoryDB() 113 | query, _ := NewSelector[TestModel](db).Select(Min("Age")).Build() 114 | fmt.Println(query.SQL) 115 | // Output: SELECT MIN(`age`) FROM `test_model`; 116 | } 117 | 118 | func ExampleSum() { 119 | db := memoryDB() 120 | query, _ := NewSelector[TestModel](db).Select(Sum("Age")).Build() 121 | fmt.Println(query.SQL) 122 | // Output: SELECT SUM(`age`) FROM `test_model`; 123 | } 124 | -------------------------------------------------------------------------------- /assignment.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | // Assignable represents that something could be used alias "assignment" statement 18 | type Assignable interface { 19 | assign() 20 | } 21 | 22 | // Assignment represents assignment statement 23 | type Assignment binaryExpr 24 | 25 | func Assign(column string, value interface{}) Assignment { 26 | var expr Expr 27 | switch v := value.(type) { 28 | case Expr: 29 | expr = v 30 | default: 31 | expr = valueExpr{val: v} 32 | } 33 | return Assignment{left: C(column), op: opEQ, right: expr} 34 | } 35 | 36 | func (Assignment) assign() { 37 | panic("implement me") 38 | } 39 | 40 | type valueExpr struct { 41 | val interface{} 42 | } 43 | 44 | func (valueExpr) expr() (string, error) { 45 | return "", nil 46 | } 47 | -------------------------------------------------------------------------------- /assignment_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import "fmt" 18 | 19 | func ExampleAssign() { 20 | db := memoryDB() 21 | tm := &TestModel{} 22 | examples := []struct { 23 | assign Assignment 24 | assignStr string 25 | wantSQL string 26 | wantArgs []interface{} 27 | }{ 28 | { 29 | assign: Assign("Age", 18), 30 | assignStr: `Assign("Age", 18)`, 31 | wantSQL: "UPDATE `test_model` SET `age`=?;", 32 | wantArgs: []interface{}{18}, 33 | }, 34 | { 35 | assign: Assign("Age", C("Id")), 36 | assignStr: `Assign("Age", C("Id"))`, 37 | wantSQL: "UPDATE `test_model` SET `age`=`id`;", 38 | }, 39 | { 40 | assign: Assign("Age", C("Age").Add(1)), 41 | assignStr: `Assign("Age", C("Age").Add(1))`, 42 | wantSQL: "UPDATE `test_model` SET `age`=`age`+?;", 43 | wantArgs: []interface{}{1}, 44 | }, 45 | { 46 | assign: Assign("Age", Raw("`age`+`id`+1")), 47 | assignStr: "Assign(\"Age\", Raw(\"`age`+`id`+1\"))", 48 | wantSQL: "UPDATE `test_model` SET `age`=`age`+`id`+1;", 49 | }, 50 | } 51 | for _, exp := range examples { 52 | query, _ := NewUpdater[TestModel](db).Update(tm).Set(exp.assign).Build() 53 | fmt.Printf(` 54 | Assignment: %s 55 | SQL: %s 56 | Args: %v 57 | `, exp.assignStr, query.SQL, query.Args) 58 | } 59 | // Output: 60 | // 61 | // Assignment: Assign("Age", 18) 62 | // SQL: UPDATE `test_model` SET `age`=?; 63 | // Args: [18] 64 | // 65 | // Assignment: Assign("Age", C("Id")) 66 | // SQL: UPDATE `test_model` SET `age`=`id`; 67 | // Args: [] 68 | // 69 | // Assignment: Assign("Age", C("Age").Add(1)) 70 | // SQL: UPDATE `test_model` SET `age`=(`age`+?); 71 | // Args: [1] 72 | // 73 | // Assignment: Assign("Age", Raw("`age`+`id`+1")) 74 | // SQL: UPDATE `test_model` SET `age`=`age`+`id`+1; 75 | // Args: [] 76 | } 77 | -------------------------------------------------------------------------------- /core.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import ( 18 | "context" 19 | "reflect" 20 | 21 | "github.com/ecodeclub/eorm/internal/dialect" 22 | "github.com/ecodeclub/eorm/internal/errs" 23 | "github.com/ecodeclub/eorm/internal/model" 24 | "github.com/ecodeclub/eorm/internal/valuer" 25 | ) 26 | 27 | type core struct { 28 | metaRegistry model.MetaRegistry 29 | dialect dialect.Dialect 30 | valCreator valuer.PrimitiveCreator 31 | ms []Middleware 32 | } 33 | 34 | func getHandler[T any](ctx context.Context, sess Session, c core, qc *QueryContext) *QueryResult { 35 | rows, err := sess.queryContext(ctx, qc.q) 36 | if err != nil { 37 | return &QueryResult{Err: err} 38 | } 39 | defer func() { 40 | _ = rows.Close() 41 | }() 42 | if !rows.Next() { 43 | return &QueryResult{Err: errs.ErrNoRows} 44 | } 45 | 46 | tp := new(T) 47 | meta := qc.meta 48 | if meta == nil && reflect.TypeOf(tp).Elem().Kind() == reflect.Struct { 49 | // 当通过 RawQuery 方法调用 Get ,如果 T 是 time.Time, sql.Scanner 的实现, 50 | // 内置类型或者基本类型时, 在这里都会报错,但是这种情况我们认为是可以接受的 51 | // 所以在此将报错忽略,因为基本类型取值用不到 meta 里的数据 52 | meta, _ = c.metaRegistry.Get(tp) 53 | } 54 | 55 | val := c.valCreator.NewPrimitiveValue(tp, meta) 56 | if err = val.SetColumns(rows); err != nil { 57 | return &QueryResult{Err: err} 58 | } 59 | 60 | return &QueryResult{Result: tp} 61 | } 62 | 63 | func get[T any](ctx context.Context, sess Session, core core, qc *QueryContext) *QueryResult { 64 | var handler HandleFunc = func(ctx context.Context, queryContext *QueryContext) *QueryResult { 65 | return getHandler[T](ctx, sess, core, queryContext) 66 | } 67 | ms := core.ms 68 | for i := len(ms) - 1; i >= 0; i-- { 69 | handler = ms[i](handler) 70 | } 71 | return handler(ctx, qc) 72 | } 73 | 74 | func getMultiHandler[T any](ctx context.Context, sess Session, c core, qc *QueryContext) *QueryResult { 75 | rows, err := sess.queryContext(ctx, qc.q) 76 | if err != nil { 77 | return &QueryResult{Err: err} 78 | } 79 | defer func() { 80 | _ = rows.Close() 81 | }() 82 | res := make([]*T, 0, 16) 83 | meta := qc.meta 84 | if meta == nil { 85 | t := new(T) 86 | if reflect.TypeOf(t).Elem().Kind() == reflect.Struct { 87 | // 当通过 RawQuery 方法调用 Get ,如果 T 是 time.Time, sql.Scanner 的实现, 88 | // 内置类型或者基本类型时, 在这里都会报错,但是这种情况我们认为是可以接受的 89 | // 所以在此将报错忽略,因为基本类型取值用不到 meta 里的数据 90 | meta, _ = c.metaRegistry.Get(t) 91 | } 92 | } 93 | for rows.Next() { 94 | tp := new(T) 95 | val := c.valCreator.NewPrimitiveValue(tp, meta) 96 | if err = val.SetColumns(rows); err != nil { 97 | return &QueryResult{Err: err} 98 | } 99 | res = append(res, tp) 100 | } 101 | 102 | return &QueryResult{Result: res} 103 | } 104 | 105 | func getMulti[T any](ctx context.Context, sess Session, core core, qc *QueryContext) *QueryResult { 106 | var handler HandleFunc = func(ctx context.Context, queryContext *QueryContext) *QueryResult { 107 | return getMultiHandler[T](ctx, sess, core, queryContext) 108 | } 109 | ms := core.ms 110 | for i := len(ms) - 1; i >= 0; i-- { 111 | handler = ms[i](handler) 112 | } 113 | return handler(ctx, qc) 114 | } 115 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import ( 18 | "context" 19 | "database/sql" 20 | 21 | "github.com/ecodeclub/eorm/internal/datasource" 22 | "github.com/ecodeclub/eorm/internal/datasource/single" 23 | "github.com/ecodeclub/eorm/internal/dialect" 24 | "github.com/ecodeclub/eorm/internal/errs" 25 | "github.com/ecodeclub/eorm/internal/model" 26 | "github.com/ecodeclub/eorm/internal/valuer" 27 | ) 28 | 29 | const ( 30 | SELECT = "SELECT" 31 | DELETE = "DELETE" 32 | UPDATE = "UPDATE" 33 | INSERT = "INSERT" 34 | RAW = "RAW" 35 | ) 36 | 37 | // DBOption configure DB 38 | type DBOption func(db *DB) 39 | 40 | // DB represents a database 41 | type DB struct { 42 | baseSession 43 | ds datasource.DataSource 44 | } 45 | 46 | // DBWithMiddlewares 为 db 配置 Middleware 47 | func DBWithMiddlewares(ms ...Middleware) DBOption { 48 | return func(db *DB) { 49 | db.ms = ms 50 | } 51 | } 52 | 53 | func DBWithMetaRegistry(r model.MetaRegistry) DBOption { 54 | return func(db *DB) { 55 | db.metaRegistry = r 56 | } 57 | } 58 | 59 | func UseReflection() DBOption { 60 | return func(db *DB) { 61 | db.valCreator = valuer.PrimitiveCreator{Creator: valuer.NewUnsafeValue} 62 | } 63 | } 64 | 65 | // Open 创建一个 ORM 实例 66 | // 注意该实例是一个无状态的对象,你应该尽可能复用它 67 | func Open(driver string, dsn string, opts ...DBOption) (*DB, error) { 68 | db, err := single.OpenDB(driver, dsn) 69 | if err != nil { 70 | return nil, err 71 | } 72 | return OpenDS(driver, db, opts...) 73 | } 74 | 75 | func OpenDS(driver string, ds datasource.DataSource, opts ...DBOption) (*DB, error) { 76 | dl, err := dialect.Of(driver) 77 | if err != nil { 78 | return nil, err 79 | } 80 | orm := &DB{ 81 | baseSession: baseSession{ 82 | executor: ds, 83 | core: core{ 84 | metaRegistry: model.NewMetaRegistry(), 85 | dialect: dl, 86 | // 可以设为默认,因为原本这里也有默认 87 | valCreator: valuer.PrimitiveCreator{ 88 | Creator: valuer.NewUnsafeValue, 89 | }, 90 | }, 91 | }, 92 | ds: ds, 93 | } 94 | for _, o := range opts { 95 | o(orm) 96 | } 97 | return orm, nil 98 | } 99 | 100 | func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { 101 | inst, ok := db.ds.(datasource.TxBeginner) 102 | if !ok { 103 | return nil, errs.ErrNotCompleteTxBeginner 104 | } 105 | tx, err := inst.BeginTx(ctx, opts) 106 | if err != nil { 107 | return nil, err 108 | } 109 | return &Tx{tx: tx, baseSession: baseSession{ 110 | executor: tx, 111 | core: db.core, 112 | }}, nil 113 | } 114 | 115 | func (db *DB) Close() error { 116 | return db.ds.Close() 117 | } 118 | -------------------------------------------------------------------------------- /delete.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import ( 18 | "context" 19 | 20 | "github.com/valyala/bytebufferpool" 21 | ) 22 | 23 | var _ QueryBuilder = &Deleter[any]{} 24 | 25 | // Deleter builds DELETE query 26 | type Deleter[T any] struct { 27 | builder 28 | Session 29 | table interface{} 30 | where []Predicate 31 | } 32 | 33 | // NewDeleter 开始构建一个 DELETE 查询 34 | func NewDeleter[T any](sess Session) *Deleter[T] { 35 | return &Deleter[T]{ 36 | builder: builder{ 37 | core: sess.getCore(), 38 | buffer: bytebufferpool.Get(), 39 | }, 40 | Session: sess, 41 | } 42 | } 43 | 44 | // Build returns DELETE query 45 | func (d *Deleter[T]) Build() (Query, error) { 46 | defer bytebufferpool.Put(d.buffer) 47 | _, _ = d.buffer.WriteString("DELETE FROM ") 48 | var err error 49 | if d.table == nil { 50 | d.table = new(T) 51 | } 52 | d.meta, err = d.metaRegistry.Get(d.table) 53 | if err != nil { 54 | return EmptyQuery, err 55 | } 56 | 57 | d.quote(d.meta.TableName) 58 | if len(d.where) > 0 { 59 | d.writeString(" WHERE ") 60 | err = d.buildPredicates(d.where) 61 | if err != nil { 62 | return EmptyQuery, err 63 | } 64 | } 65 | d.end() 66 | return Query{SQL: d.buffer.String(), Args: d.args}, nil 67 | } 68 | 69 | // From accepts model definition 70 | func (d *Deleter[T]) From(table interface{}) *Deleter[T] { 71 | d.table = table 72 | return d 73 | } 74 | 75 | // Where accepts predicates 76 | func (d *Deleter[T]) Where(predicates ...Predicate) *Deleter[T] { 77 | d.where = predicates 78 | return d 79 | } 80 | 81 | // Exec sql 82 | func (d *Deleter[T]) Exec(ctx context.Context) Result { 83 | query, err := d.Build() 84 | if err != nil { 85 | return Result{err: err} 86 | } 87 | return newQuerier[T](d.Session, query, d.meta, DELETE).Exec(ctx) 88 | } 89 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import "github.com/ecodeclub/eorm/internal/errs" 18 | 19 | // 哨兵错误,或者说预定义错误,谨慎添加 20 | var ( 21 | // ErrNoRows 代表没有找到数据 22 | ErrNoRows = errs.ErrNoRows 23 | ) 24 | -------------------------------------------------------------------------------- /expression.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import operator "github.com/ecodeclub/eorm/internal/operator" 18 | 19 | // Expr is the top interface. It represents everything. 20 | type Expr interface { 21 | expr() (string, error) 22 | } 23 | 24 | // RawExpr uses string alias Expr 25 | type RawExpr struct { 26 | raw string 27 | args []interface{} 28 | } 29 | 30 | // Raw just take expr alias Expr 31 | func Raw(expr string, args ...interface{}) RawExpr { 32 | return RawExpr{ 33 | raw: expr, 34 | args: args, 35 | } 36 | } 37 | 38 | func (r RawExpr) expr() (string, error) { 39 | return r.raw, nil 40 | } 41 | 42 | // AsPredicate 将会返回一个 Predicate,RawExpr 将会作为这个 Predicate 的左边部分 43 | // eorm 将不会校验任何从 RawExpr 生成的 Predicate 44 | func (r RawExpr) AsPredicate() Predicate { 45 | return Predicate{ 46 | left: r, 47 | } 48 | } 49 | 50 | func (RawExpr) selected() {} 51 | 52 | type binaryExpr struct { 53 | left Expr 54 | op operator.Op 55 | right Expr 56 | } 57 | 58 | func (binaryExpr) expr() (string, error) { 59 | return "", nil 60 | } 61 | 62 | type MathExpr binaryExpr 63 | 64 | func (m MathExpr) Add(val interface{}) Expr { 65 | return MathExpr{ 66 | left: m, 67 | op: opAdd, 68 | right: valueOf(val), 69 | } 70 | } 71 | 72 | func (m MathExpr) Multi(val interface{}) MathExpr { 73 | return MathExpr{ 74 | left: m, 75 | op: opMulti, 76 | right: valueOf(val), 77 | } 78 | } 79 | 80 | func (MathExpr) expr() (string, error) { 81 | return "", nil 82 | } 83 | 84 | func valueOf(val interface{}) Expr { 85 | switch v := val.(type) { 86 | case Expr: 87 | return v 88 | default: 89 | return valueExpr{val: val} 90 | } 91 | } 92 | 93 | type SubqueryExpr struct { 94 | s Subquery 95 | // 謂詞: ALL、ANY、SOME 96 | pred string 97 | } 98 | 99 | func (SubqueryExpr) expr() (string, error) { 100 | panic("implement me") 101 | } 102 | 103 | func Any(sub Subquery) SubqueryExpr { 104 | return SubqueryExpr{ 105 | s: sub, 106 | pred: "ANY", 107 | } 108 | } 109 | 110 | func All(sub Subquery) SubqueryExpr { 111 | return SubqueryExpr{ 112 | s: sub, 113 | pred: "ALL", 114 | } 115 | } 116 | 117 | func Some(sub Subquery) SubqueryExpr { 118 | return SubqueryExpr{ 119 | s: sub, 120 | pred: "SOME", 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /expression_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import ( 18 | "fmt" 19 | "testing" 20 | 21 | "github.com/stretchr/testify/assert" 22 | ) 23 | 24 | func TestRawExpr_AsPredicate(t *testing.T) { 25 | db := memoryDB() 26 | testCases := []CommonTestCase{ 27 | { 28 | name: "simple", 29 | builder: NewSelector[TestModel](db).Where(Raw("`id`= len(col) { 40 | return nil, errs.ErrMergerInvalidAggregateColumnIndex 41 | } 42 | return s.countNullableAggregator, nil 43 | } 44 | 45 | func (s *Count) ColumnInfo() merger.ColumnInfo { 46 | return s.countInfo 47 | } 48 | 49 | func (s *Count) Name() string { 50 | return s.name 51 | } 52 | 53 | func NewCount(info merger.ColumnInfo) *Count { 54 | return &Count{ 55 | name: "COUNT", 56 | countInfo: info, 57 | } 58 | } 59 | 60 | func countAggregate[T AggregateElement](cols [][]any, countIndex int) (any, error) { 61 | var count T 62 | for _, col := range cols { 63 | count += col[countIndex].(T) 64 | } 65 | return count, nil 66 | } 67 | func (*Count) countNullableAggregator(colsData [][]any, countIndex int) (any, error) { 68 | notNullCols, kind := nullableAggregator(colsData, countIndex) 69 | // 说明几个数据库里查出来的数据都为null,返回第一个null值即可 70 | if len(notNullCols) == 0 { 71 | return colsData[0][countIndex], nil 72 | } 73 | countFunc, ok := countAggregateFuncMapping[kind] 74 | if !ok { 75 | return nil, errs.ErrMergerAggregateFuncNotFound 76 | } 77 | return countFunc(notNullCols, countIndex) 78 | } 79 | 80 | var countAggregateFuncMapping = map[reflect.Kind]func([][]any, int) (any, error){ 81 | reflect.Int: countAggregate[int], 82 | reflect.Int8: countAggregate[int8], 83 | reflect.Int16: countAggregate[int16], 84 | reflect.Int32: countAggregate[int32], 85 | reflect.Int64: countAggregate[int64], 86 | reflect.Uint8: countAggregate[uint8], 87 | reflect.Uint16: countAggregate[uint16], 88 | reflect.Uint32: countAggregate[uint32], 89 | reflect.Uint64: countAggregate[uint64], 90 | reflect.Float32: countAggregate[float32], 91 | reflect.Float64: countAggregate[float64], 92 | reflect.Uint: countAggregate[uint], 93 | } 94 | -------------------------------------------------------------------------------- /internal/merger/internal/aggregatemerger/aggregator/count_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package aggregator 16 | 17 | import ( 18 | "database/sql" 19 | "testing" 20 | 21 | "github.com/ecodeclub/eorm/internal/merger" 22 | 23 | "github.com/ecodeclub/eorm/internal/merger/internal/errs" 24 | 25 | "github.com/stretchr/testify/assert" 26 | ) 27 | 28 | func TestCount_Aggregate(t *testing.T) { 29 | testcases := []struct { 30 | name string 31 | input [][]any 32 | wantVal any 33 | wantErr error 34 | countIndex int 35 | }{ 36 | { 37 | name: "count正常合并", 38 | input: [][]any{ 39 | { 40 | int64(10), 41 | }, 42 | { 43 | int64(20), 44 | }, 45 | { 46 | int64(30), 47 | }, 48 | }, 49 | wantVal: int64(60), 50 | countIndex: 0, 51 | }, 52 | { 53 | name: "传入的参数非AggregateElement类型", 54 | input: [][]any{ 55 | { 56 | "1", 57 | }, 58 | { 59 | "3", 60 | }, 61 | }, 62 | wantErr: errs.ErrMergerAggregateFuncNotFound, 63 | countIndex: 0, 64 | }, 65 | { 66 | name: "columnInfo的index不合法", 67 | input: [][]any{ 68 | { 69 | int64(10), 70 | }, 71 | { 72 | int64(20), 73 | }, 74 | { 75 | int64(30), 76 | }, 77 | }, 78 | countIndex: 20, 79 | wantErr: errs.ErrMergerInvalidAggregateColumnIndex, 80 | }, 81 | { 82 | name: "columnInfo为nullable类型", 83 | input: [][]any{ 84 | { 85 | sql.NullInt64{ 86 | Int64: 4, 87 | Valid: true, 88 | }, 89 | }, 90 | { 91 | sql.NullFloat64{ 92 | Valid: false, 93 | }, 94 | }, 95 | { 96 | sql.NullInt64{ 97 | Valid: true, 98 | Int64: 7, 99 | }, 100 | }, 101 | }, 102 | countIndex: 0, 103 | wantVal: int64(11), 104 | }, 105 | { 106 | name: "所有列查出来的都为null", 107 | input: [][]any{ 108 | { 109 | sql.NullInt64{ 110 | Valid: false, 111 | }, 112 | }, 113 | { 114 | sql.NullInt64{ 115 | Valid: false, 116 | }, 117 | }, 118 | { 119 | sql.NullInt64{ 120 | Valid: false, 121 | }, 122 | }, 123 | }, 124 | countIndex: 0, 125 | wantVal: sql.NullInt64{ 126 | Valid: false, 127 | }, 128 | }, 129 | { 130 | name: "所有列查出来的都不是null", 131 | input: [][]any{ 132 | { 133 | sql.NullInt64{ 134 | Int64: 8, 135 | Valid: true, 136 | }, 137 | }, 138 | { 139 | sql.NullInt64{ 140 | Int64: 9, 141 | Valid: true, 142 | }, 143 | }, 144 | { 145 | sql.NullInt64{ 146 | Valid: true, 147 | Int64: 8, 148 | }, 149 | }, 150 | }, 151 | countIndex: 0, 152 | wantVal: int64(25), 153 | }, 154 | { 155 | name: "表示 三者混合情况", 156 | input: [][]any{ 157 | { 158 | sql.NullInt64{ 159 | Int64: 8, 160 | Valid: true, 161 | }, 162 | }, 163 | { 164 | sql.NullInt64{Valid: false}, 165 | }, 166 | { 167 | int64(8), 168 | }, 169 | }, 170 | countIndex: 0, 171 | wantVal: int64(16), 172 | }, 173 | } 174 | for _, tc := range testcases { 175 | t.Run(tc.name, func(t *testing.T) { 176 | info := merger.ColumnInfo{Index: tc.countIndex, Name: "id", AggregateFunc: "COUNT"} 177 | count := NewCount(info) 178 | val, err := count.Aggregate(tc.input) 179 | assert.Equal(t, tc.wantErr, err) 180 | if err != nil { 181 | return 182 | } 183 | assert.Equal(t, tc.wantVal, val) 184 | assert.Equal(t, info, count.ColumnInfo()) 185 | }) 186 | } 187 | 188 | } 189 | -------------------------------------------------------------------------------- /internal/merger/internal/aggregatemerger/aggregator/max.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package aggregator 16 | 17 | import ( 18 | "reflect" 19 | 20 | "github.com/ecodeclub/eorm/internal/merger" 21 | 22 | "github.com/ecodeclub/eorm/internal/merger/internal/errs" 23 | ) 24 | 25 | type Max struct { 26 | name string 27 | maxColumnInfo merger.ColumnInfo 28 | } 29 | 30 | func (m *Max) Aggregate(cols [][]any) (any, error) { 31 | maxFunc, err := m.findMaxFunc(cols[0]) 32 | if err != nil { 33 | return nil, err 34 | } 35 | return maxFunc(cols, m.maxColumnInfo.Index) 36 | } 37 | func (m *Max) findMaxFunc(col []any) (func([][]any, int) (any, error), error) { 38 | maxIndex := m.maxColumnInfo.Index 39 | if maxIndex < 0 || maxIndex >= len(col) { 40 | return nil, errs.ErrMergerInvalidAggregateColumnIndex 41 | } 42 | return m.maxNullableAggregator, nil 43 | } 44 | 45 | func (m *Max) ColumnInfo() merger.ColumnInfo { 46 | return m.maxColumnInfo 47 | } 48 | 49 | func (m *Max) Name() string { 50 | return m.name 51 | } 52 | 53 | func NewMax(info merger.ColumnInfo) *Max { 54 | return &Max{ 55 | name: "MAX", 56 | maxColumnInfo: info, 57 | } 58 | } 59 | 60 | func maxAggregator[T AggregateElement](colsData [][]any, maxIndex int) (any, error) { 61 | return findExtremeValue[T](colsData, isMaxValue[T], maxIndex) 62 | } 63 | func (*Max) maxNullableAggregator(colsData [][]any, maxIndex int) (any, error) { 64 | notNullCols, kind := nullableAggregator(colsData, maxIndex) 65 | // 说明几个数据库里查出来的数据都为null,返回第一个null值即可 66 | if len(notNullCols) == 0 { 67 | return colsData[0][maxIndex], nil 68 | } 69 | maxFunc, ok := maxFuncMapping[kind] 70 | if !ok { 71 | return nil, errs.ErrMergerAggregateFuncNotFound 72 | } 73 | return maxFunc(notNullCols, maxIndex) 74 | } 75 | 76 | var maxFuncMapping = map[reflect.Kind]func([][]any, int) (any, error){ 77 | reflect.Int: maxAggregator[int], 78 | reflect.Int8: maxAggregator[int8], 79 | reflect.Int16: maxAggregator[int16], 80 | reflect.Int32: maxAggregator[int32], 81 | reflect.Int64: maxAggregator[int64], 82 | reflect.Uint8: maxAggregator[uint8], 83 | reflect.Uint16: maxAggregator[uint16], 84 | reflect.Uint32: maxAggregator[uint32], 85 | reflect.Uint64: maxAggregator[uint64], 86 | reflect.Float32: maxAggregator[float32], 87 | reflect.Float64: maxAggregator[float64], 88 | reflect.Uint: maxAggregator[uint], 89 | } 90 | 91 | type extremeValueFunc[T AggregateElement] func(T, T) bool 92 | 93 | func findExtremeValue[T AggregateElement](colsData [][]any, isExtremeValue extremeValueFunc[T], index int) (any, error) { 94 | var ans T 95 | for idx, colData := range colsData { 96 | data := colData[index].(T) 97 | if idx == 0 { 98 | ans = data 99 | continue 100 | } 101 | if isExtremeValue(ans, data) { 102 | ans = data 103 | } 104 | } 105 | return ans, nil 106 | } 107 | 108 | func isMaxValue[T AggregateElement](maxData T, data T) bool { 109 | return maxData < data 110 | } 111 | -------------------------------------------------------------------------------- /internal/merger/internal/aggregatemerger/aggregator/max_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package aggregator 16 | 17 | import ( 18 | "database/sql" 19 | "testing" 20 | 21 | "github.com/ecodeclub/eorm/internal/merger" 22 | 23 | "github.com/ecodeclub/eorm/internal/merger/internal/errs" 24 | "github.com/stretchr/testify/assert" 25 | ) 26 | 27 | func TestMax_Aggregate(t *testing.T) { 28 | testcases := []struct { 29 | name string 30 | input [][]any 31 | wantVal any 32 | wantErr error 33 | maxIndex int 34 | }{ 35 | { 36 | name: "MAX正常合并", 37 | input: [][]any{ 38 | { 39 | int64(10), 40 | }, 41 | { 42 | int64(20), 43 | }, 44 | { 45 | int64(30), 46 | }, 47 | }, 48 | wantVal: int64(30), 49 | maxIndex: 0, 50 | }, 51 | { 52 | name: "传入的参数非AggregateElement类型", 53 | input: [][]any{ 54 | { 55 | "1", 56 | }, 57 | { 58 | "3", 59 | }, 60 | }, 61 | wantErr: errs.ErrMergerAggregateFuncNotFound, 62 | maxIndex: 0, 63 | }, 64 | { 65 | name: "columnInfo的index不合法", 66 | input: [][]any{ 67 | { 68 | int64(10), 69 | }, 70 | { 71 | int64(20), 72 | }, 73 | { 74 | int64(30), 75 | }, 76 | }, 77 | maxIndex: 20, 78 | wantErr: errs.ErrMergerInvalidAggregateColumnIndex, 79 | }, 80 | { 81 | name: "columnInfo为nullable类型", 82 | input: [][]any{ 83 | { 84 | sql.NullFloat64{ 85 | Float64: 2.2, 86 | Valid: true, 87 | }, 88 | }, 89 | { 90 | sql.NullFloat64{ 91 | Valid: false, 92 | }, 93 | }, 94 | { 95 | sql.NullFloat64{ 96 | Valid: true, 97 | Float64: 3.4, 98 | }, 99 | }, 100 | }, 101 | maxIndex: 0, 102 | wantVal: 3.4, 103 | }, 104 | { 105 | name: "所有列查出来的都为null", 106 | input: [][]any{ 107 | { 108 | sql.NullFloat64{ 109 | Valid: false, 110 | }, 111 | }, 112 | { 113 | sql.NullFloat64{ 114 | Valid: false, 115 | }, 116 | }, 117 | { 118 | sql.NullFloat64{ 119 | Valid: false, 120 | }, 121 | }, 122 | }, 123 | maxIndex: 0, 124 | wantVal: sql.NullFloat64{ 125 | Valid: false, 126 | }, 127 | }, 128 | { 129 | name: "所有列查出来的都不是null", 130 | input: [][]any{ 131 | { 132 | sql.NullFloat64{ 133 | Float64: 2.2, 134 | Valid: true, 135 | }, 136 | }, 137 | { 138 | sql.NullFloat64{ 139 | Float64: 5.6, 140 | Valid: true, 141 | }, 142 | }, 143 | { 144 | sql.NullFloat64{ 145 | Valid: true, 146 | Float64: 3.4, 147 | }, 148 | }, 149 | }, 150 | maxIndex: 0, 151 | wantVal: 5.6, 152 | }, 153 | { 154 | name: "表示 三者混合情况", 155 | input: [][]any{ 156 | { 157 | sql.NullFloat64{ 158 | Float64: 2.2, 159 | Valid: true, 160 | }, 161 | }, 162 | { 163 | sql.NullFloat64{ 164 | Valid: false, 165 | }, 166 | }, 167 | { 168 | 3.4, 169 | }, 170 | }, 171 | maxIndex: 0, 172 | wantVal: 3.4, 173 | }, 174 | } 175 | for _, tc := range testcases { 176 | t.Run(tc.name, func(t *testing.T) { 177 | info := merger.ColumnInfo{Index: tc.maxIndex, Name: "id", AggregateFunc: "MAX"} 178 | m := NewMax(info) 179 | val, err := m.Aggregate(tc.input) 180 | assert.Equal(t, tc.wantErr, err) 181 | if err != nil { 182 | return 183 | } 184 | assert.Equal(t, tc.wantVal, val) 185 | assert.Equal(t, info, m.maxColumnInfo) 186 | }) 187 | } 188 | 189 | } 190 | -------------------------------------------------------------------------------- /internal/merger/internal/aggregatemerger/aggregator/min.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package aggregator 16 | 17 | import ( 18 | "reflect" 19 | 20 | "github.com/ecodeclub/eorm/internal/merger" 21 | 22 | "github.com/ecodeclub/eorm/internal/merger/internal/errs" 23 | ) 24 | 25 | type Min struct { 26 | name string 27 | minColumnInfo merger.ColumnInfo 28 | } 29 | 30 | func (m *Min) Aggregate(cols [][]any) (any, error) { 31 | minFunc, err := m.findMinFunc(cols[0]) 32 | if err != nil { 33 | return nil, err 34 | } 35 | return minFunc(cols, m.minColumnInfo.Index) 36 | } 37 | 38 | func (m *Min) findMinFunc(col []any) (func([][]any, int) (any, error), error) { 39 | minIndex := m.minColumnInfo.Index 40 | if minIndex < 0 || minIndex >= len(col) { 41 | return nil, errs.ErrMergerInvalidAggregateColumnIndex 42 | } 43 | return m.minNullableAggregator, nil 44 | } 45 | 46 | func (m *Min) ColumnInfo() merger.ColumnInfo { 47 | return m.minColumnInfo 48 | } 49 | 50 | func (m *Min) Name() string { 51 | return m.name 52 | } 53 | 54 | func NewMin(info merger.ColumnInfo) *Min { 55 | return &Min{ 56 | name: "MIN", 57 | minColumnInfo: info, 58 | } 59 | } 60 | 61 | func minAggregator[T AggregateElement](colsData [][]any, minIndex int) (any, error) { 62 | return findExtremeValue[T](colsData, isMinValue[T], minIndex) 63 | } 64 | 65 | func (*Min) minNullableAggregator(colsData [][]any, minIndex int) (any, error) { 66 | notNullCols, kind := nullableAggregator(colsData, minIndex) 67 | // 说明几个数据库里查出来的数据都为null,返回第一个null值即可 68 | if len(notNullCols) == 0 { 69 | return colsData[0][minIndex], nil 70 | } 71 | minFunc, ok := minFuncMapping[kind] 72 | if !ok { 73 | return nil, errs.ErrMergerAggregateFuncNotFound 74 | } 75 | return minFunc(notNullCols, minIndex) 76 | } 77 | 78 | var minFuncMapping = map[reflect.Kind]func([][]any, int) (any, error){ 79 | reflect.Int: minAggregator[int], 80 | reflect.Int8: minAggregator[int8], 81 | reflect.Int16: minAggregator[int16], 82 | reflect.Int32: minAggregator[int32], 83 | reflect.Int64: minAggregator[int64], 84 | reflect.Uint8: minAggregator[uint8], 85 | reflect.Uint16: minAggregator[uint16], 86 | reflect.Uint32: minAggregator[uint32], 87 | reflect.Uint64: minAggregator[uint64], 88 | reflect.Float32: minAggregator[float32], 89 | reflect.Float64: minAggregator[float64], 90 | reflect.Uint: minAggregator[uint], 91 | } 92 | 93 | func isMinValue[T AggregateElement](minData T, data T) bool { 94 | return minData > data 95 | } 96 | -------------------------------------------------------------------------------- /internal/merger/internal/aggregatemerger/aggregator/min_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package aggregator 16 | 17 | import ( 18 | "database/sql" 19 | "testing" 20 | 21 | "github.com/ecodeclub/eorm/internal/merger" 22 | 23 | "github.com/ecodeclub/eorm/internal/merger/internal/errs" 24 | "github.com/stretchr/testify/assert" 25 | ) 26 | 27 | func TestMin_Aggregate(t *testing.T) { 28 | testcases := []struct { 29 | name string 30 | input [][]any 31 | wantVal any 32 | wantErr error 33 | minIndex int 34 | }{ 35 | { 36 | name: "Min正常合并", 37 | input: [][]any{ 38 | { 39 | int64(10), 40 | }, 41 | { 42 | int64(20), 43 | }, 44 | { 45 | int64(30), 46 | }, 47 | }, 48 | wantVal: int64(10), 49 | minIndex: 0, 50 | }, 51 | { 52 | name: "传入的参数非AggregateElement类型", 53 | input: [][]any{ 54 | { 55 | "1", 56 | }, 57 | { 58 | "3", 59 | }, 60 | }, 61 | wantErr: errs.ErrMergerAggregateFuncNotFound, 62 | minIndex: 0, 63 | }, 64 | { 65 | name: "columnInfo的index不合法", 66 | input: [][]any{ 67 | { 68 | int64(10), 69 | }, 70 | { 71 | int64(20), 72 | }, 73 | { 74 | int64(30), 75 | }, 76 | }, 77 | minIndex: 20, 78 | wantErr: errs.ErrMergerInvalidAggregateColumnIndex, 79 | }, 80 | { 81 | name: "columnInfo为nullable类型", 82 | input: [][]any{ 83 | { 84 | sql.NullFloat64{ 85 | Float64: 2.2, 86 | Valid: true, 87 | }, 88 | }, 89 | { 90 | sql.NullFloat64{ 91 | Valid: false, 92 | }, 93 | }, 94 | { 95 | sql.NullFloat64{ 96 | Valid: true, 97 | Float64: 3.4, 98 | }, 99 | }, 100 | }, 101 | minIndex: 0, 102 | wantVal: 2.2, 103 | }, 104 | { 105 | name: "所有列查出来的都为null", 106 | input: [][]any{ 107 | { 108 | sql.NullFloat64{ 109 | Valid: false, 110 | }, 111 | }, 112 | { 113 | sql.NullFloat64{ 114 | Valid: false, 115 | }, 116 | }, 117 | { 118 | sql.NullFloat64{ 119 | Valid: false, 120 | }, 121 | }, 122 | }, 123 | minIndex: 0, 124 | wantVal: sql.NullFloat64{ 125 | Valid: false, 126 | }, 127 | }, 128 | { 129 | name: "所有列查出来的都不是null", 130 | input: [][]any{ 131 | { 132 | sql.NullFloat64{ 133 | Float64: 2.2, 134 | Valid: true, 135 | }, 136 | }, 137 | { 138 | sql.NullFloat64{ 139 | Float64: 5.6, 140 | Valid: true, 141 | }, 142 | }, 143 | { 144 | sql.NullFloat64{ 145 | Valid: true, 146 | Float64: 3.4, 147 | }, 148 | }, 149 | }, 150 | minIndex: 0, 151 | wantVal: 2.2, 152 | }, 153 | { 154 | name: "三者混合情况", 155 | input: [][]any{ 156 | { 157 | sql.NullFloat64{ 158 | Float64: 2.2, 159 | Valid: true, 160 | }, 161 | }, 162 | { 163 | sql.NullFloat64{ 164 | Valid: false, 165 | }, 166 | }, 167 | { 168 | 3.4, 169 | }, 170 | }, 171 | minIndex: 0, 172 | wantVal: 2.2, 173 | }, 174 | } 175 | for _, tc := range testcases { 176 | t.Run(tc.name, func(t *testing.T) { 177 | info := merger.ColumnInfo{Index: tc.minIndex, Name: "id", AggregateFunc: "MIN"} 178 | m := NewMin(info) 179 | val, err := m.Aggregate(tc.input) 180 | assert.Equal(t, tc.wantErr, err) 181 | if err != nil { 182 | return 183 | } 184 | assert.Equal(t, tc.wantVal, val) 185 | assert.Equal(t, info, m.ColumnInfo()) 186 | }) 187 | } 188 | 189 | } 190 | -------------------------------------------------------------------------------- /internal/merger/internal/aggregatemerger/aggregator/sum.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package aggregator 16 | 17 | import ( 18 | "reflect" 19 | 20 | "github.com/ecodeclub/eorm/internal/merger" 21 | 22 | "github.com/ecodeclub/eorm/internal/merger/internal/errs" 23 | ) 24 | 25 | type Sum struct { 26 | name string 27 | sumColumnInfo merger.ColumnInfo 28 | } 29 | 30 | func (s *Sum) Aggregate(cols [][]any) (any, error) { 31 | sumFunc, err := s.findSumFunc(cols[0]) 32 | if err != nil { 33 | return nil, err 34 | } 35 | return sumFunc(cols, s.sumColumnInfo.Index) 36 | } 37 | 38 | func (s *Sum) findSumFunc(col []any) (func([][]any, int) (any, error), error) { 39 | sumIndex := s.sumColumnInfo.Index 40 | if sumIndex < 0 || sumIndex >= len(col) { 41 | return nil, errs.ErrMergerInvalidAggregateColumnIndex 42 | } 43 | return s.sumNullableAggregator, nil 44 | } 45 | 46 | func (s *Sum) ColumnInfo() merger.ColumnInfo { 47 | return s.sumColumnInfo 48 | } 49 | 50 | func (s *Sum) Name() string { 51 | return s.name 52 | } 53 | 54 | func NewSum(info merger.ColumnInfo) *Sum { 55 | return &Sum{ 56 | name: "SUM", 57 | sumColumnInfo: info, 58 | } 59 | } 60 | 61 | func sumAggregate[T AggregateElement](cols [][]any, sumIndex int) (any, error) { 62 | var sum T 63 | for _, col := range cols { 64 | sum += col[sumIndex].(T) 65 | } 66 | return sum, nil 67 | } 68 | 69 | func (*Sum) sumNullableAggregator(colsData [][]any, sumIndex int) (any, error) { 70 | notNullCols, kind := nullableAggregator(colsData, sumIndex) 71 | // 说明几个数据库里查出来的数据都为null,返回第一个null值即可 72 | if len(notNullCols) == 0 { 73 | return colsData[0][sumIndex], nil 74 | } 75 | sumFunc, ok := sumAggregateFuncMapping[kind] 76 | if !ok { 77 | return nil, errs.ErrMergerAggregateFuncNotFound 78 | } 79 | return sumFunc(notNullCols, sumIndex) 80 | } 81 | 82 | var sumAggregateFuncMapping = map[reflect.Kind]func([][]any, int) (any, error){ 83 | reflect.Int: sumAggregate[int], 84 | reflect.Int8: sumAggregate[int8], 85 | reflect.Int16: sumAggregate[int16], 86 | reflect.Int32: sumAggregate[int32], 87 | reflect.Int64: sumAggregate[int64], 88 | reflect.Uint8: sumAggregate[uint8], 89 | reflect.Uint16: sumAggregate[uint16], 90 | reflect.Uint32: sumAggregate[uint32], 91 | reflect.Uint64: sumAggregate[uint64], 92 | reflect.Float32: sumAggregate[float32], 93 | reflect.Float64: sumAggregate[float64], 94 | reflect.Uint: sumAggregate[uint], 95 | } 96 | -------------------------------------------------------------------------------- /internal/merger/internal/aggregatemerger/aggregator/sum_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package aggregator 16 | 17 | import ( 18 | "database/sql" 19 | "testing" 20 | 21 | "github.com/ecodeclub/eorm/internal/merger" 22 | 23 | "github.com/ecodeclub/eorm/internal/merger/internal/errs" 24 | 25 | "github.com/stretchr/testify/assert" 26 | ) 27 | 28 | func TestSum_Aggregate(t *testing.T) { 29 | testcases := []struct { 30 | name string 31 | input [][]any 32 | wantVal any 33 | wantErr error 34 | sumIndex int 35 | }{ 36 | { 37 | name: "sum正常合并", 38 | input: [][]any{ 39 | { 40 | int64(10), 41 | }, 42 | { 43 | int64(20), 44 | }, 45 | { 46 | int64(30), 47 | }, 48 | }, 49 | wantVal: int64(60), 50 | sumIndex: 0, 51 | }, 52 | 53 | { 54 | name: "传入的参数非AggregateElement类型", 55 | input: [][]any{ 56 | { 57 | "1", 58 | }, 59 | { 60 | "3", 61 | }, 62 | }, 63 | wantErr: errs.ErrMergerAggregateFuncNotFound, 64 | sumIndex: 0, 65 | }, 66 | { 67 | name: "columnInfo的index不合法", 68 | input: [][]any{ 69 | { 70 | int64(10), 71 | }, 72 | { 73 | int64(20), 74 | }, 75 | { 76 | int64(30), 77 | }, 78 | }, 79 | sumIndex: 20, 80 | wantErr: errs.ErrMergerInvalidAggregateColumnIndex, 81 | }, 82 | { 83 | name: "columnInfo为nullable类型", 84 | input: [][]any{ 85 | { 86 | sql.NullFloat64{ 87 | Float64: 2.2, 88 | Valid: true, 89 | }, 90 | }, 91 | { 92 | sql.NullFloat64{ 93 | Valid: false, 94 | }, 95 | }, 96 | { 97 | sql.NullFloat64{ 98 | Valid: true, 99 | Float64: 3.4, 100 | }, 101 | }, 102 | }, 103 | sumIndex: 0, 104 | wantVal: 5.6, 105 | }, 106 | { 107 | name: "所有列查出来的都为null", 108 | input: [][]any{ 109 | { 110 | sql.NullFloat64{ 111 | Valid: false, 112 | }, 113 | }, 114 | { 115 | sql.NullFloat64{ 116 | Valid: false, 117 | }, 118 | }, 119 | { 120 | sql.NullFloat64{ 121 | Valid: false, 122 | }, 123 | }, 124 | }, 125 | sumIndex: 0, 126 | wantVal: sql.NullFloat64{ 127 | Valid: false, 128 | }, 129 | }, 130 | { 131 | name: "所有列查出来的都不是null", 132 | input: [][]any{ 133 | { 134 | sql.NullFloat64{ 135 | Float64: 2.2, 136 | Valid: true, 137 | }, 138 | }, 139 | { 140 | sql.NullFloat64{ 141 | Float64: 5.6, 142 | Valid: true, 143 | }, 144 | }, 145 | { 146 | sql.NullFloat64{ 147 | Valid: true, 148 | Float64: 3.4, 149 | }, 150 | }, 151 | }, 152 | sumIndex: 0, 153 | wantVal: 11.2, 154 | }, 155 | { 156 | name: "表示三者混合的情况", 157 | input: [][]any{ 158 | { 159 | sql.NullFloat64{ 160 | Float64: 2.2, 161 | Valid: true, 162 | }, 163 | }, 164 | { 165 | 5.6, 166 | }, 167 | { 168 | sql.NullFloat64{ 169 | Valid: false, 170 | }, 171 | }, 172 | }, 173 | sumIndex: 0, 174 | wantVal: 7.8, 175 | }, 176 | } 177 | for _, tc := range testcases { 178 | t.Run(tc.name, func(t *testing.T) { 179 | info := merger.ColumnInfo{Index: tc.sumIndex, Name: "id", AggregateFunc: "SUM"} 180 | sum := NewSum(info) 181 | val, err := sum.Aggregate(tc.input) 182 | assert.Equal(t, tc.wantErr, err) 183 | if err != nil { 184 | return 185 | } 186 | assert.Equal(t, tc.wantVal, val) 187 | assert.Equal(t, info, sum.ColumnInfo()) 188 | }) 189 | } 190 | 191 | } 192 | -------------------------------------------------------------------------------- /internal/merger/internal/aggregatemerger/aggregator/type.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package aggregator 16 | 17 | import ( 18 | "database/sql/driver" 19 | "reflect" 20 | 21 | "github.com/ecodeclub/eorm/internal/merger" 22 | ) 23 | 24 | type AggregateElement interface { 25 | ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 26 | } 27 | 28 | type Aggregator interface { 29 | // Aggregate 将多个列聚合 cols表示sqlRows列表里的数据,聚合函数通过下标拿到需要的列 30 | Aggregate(cols [][]any) (any, error) 31 | // ColumnInfo 聚合列的信息 32 | ColumnInfo() merger.ColumnInfo 33 | // Name 聚合函数本身的名称, MIN/MAX/SUM/COUNT/AVG 34 | Name() string 35 | } 36 | 37 | // nullableAggregator 处理查询到的nullable类型的数据,第一个返回值为 非null的数据 如果是sql.nullfloat64{value: 1.1,valid: true},返回的就是1.1,第二个返回值为value的kind 38 | func nullableAggregator(colsData [][]any, index int) ([][]any, reflect.Kind) { 39 | notNullCols := make([][]any, 0, len(colsData)) 40 | var kind reflect.Kind 41 | for _, colData := range colsData { 42 | col := colData[index] 43 | if reflect.TypeOf(col).Kind() == reflect.Struct { 44 | maxVal, _ := col.(driver.Valuer).Value() 45 | if maxVal != nil { 46 | kind = reflect.TypeOf(maxVal).Kind() 47 | colData[index] = maxVal 48 | notNullCols = append(notNullCols, colData) 49 | } 50 | } else { 51 | kind = reflect.TypeOf(col).Kind() 52 | notNullCols = append(notNullCols, colData) 53 | } 54 | } 55 | return notNullCols, kind 56 | } 57 | -------------------------------------------------------------------------------- /internal/merger/internal/errs/error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package errs 16 | 17 | import ( 18 | "errors" 19 | "fmt" 20 | ) 21 | 22 | var ( 23 | ErrEmptySortColumns = errors.New("merger: 排序列为空") 24 | ErrMergerEmptyRows = errors.New("merger: sql.Rows列表为空") 25 | ErrMergerRowsIsNull = errors.New("merger: sql.Rows列表中有元素为nil") 26 | ErrMergerScanNotNext = errors.New("merger: Scan之前没有调用Next方法") 27 | ErrMergerRowsClosed = errors.New("merger: Rows已经关闭") 28 | ErrMergerRowsDiff = errors.New("merger: sql.Rows列表中的字段不同") 29 | ErrMergerInvalidLimitOrOffset = errors.New("merger: offset或limit小于0") 30 | ErrMergerAggregateHasEmptyRows = errors.New("merger: 聚合函数计算时rowsList有一个或多个为空") 31 | ErrMergerInvalidAggregateColumnIndex = errors.New("merger: ColumnInfo的index不合法") 32 | ErrMergerAggregateFuncNotFound = errors.New("merger: 聚合函数方法未找到") 33 | 34 | ErrDistinctColsRepeated = errors.New("merger: 去重列重复") 35 | ErrSortColListNotContainDistinctCol = errors.New("merger: 排序列里包含不在去重列表中的列") 36 | ErrDistinctColsNotInCols = errors.New("merger:去重列不在数据库字段集合里面") 37 | ErrDistinctColsIsNull = errors.New("merger:去重列为空") 38 | ) 39 | 40 | func NewRepeatSortColumn(column string) error { 41 | return fmt.Errorf("merger: 排序列重复%s", column) 42 | } 43 | 44 | func NewInvalidSortColumn(column string) error { 45 | return fmt.Errorf("merger: 数据库字段中没有这个排序列:%s", column) 46 | } 47 | -------------------------------------------------------------------------------- /internal/merger/internal/pagedmerger/merger.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package pagedmerger 16 | 17 | import ( 18 | "context" 19 | "database/sql" 20 | "fmt" 21 | "sync" 22 | 23 | "github.com/ecodeclub/eorm/internal/rows" 24 | 25 | "github.com/ecodeclub/eorm/internal/merger" 26 | "github.com/ecodeclub/eorm/internal/merger/internal/errs" 27 | ) 28 | 29 | type Merger struct { 30 | m merger.Merger 31 | limit int 32 | offset int 33 | } 34 | 35 | func NewMerger(m merger.Merger, offset int, limit int) (*Merger, error) { 36 | if offset < 0 || limit <= 0 { 37 | return nil, errs.ErrMergerInvalidLimitOrOffset 38 | } 39 | 40 | return &Merger{ 41 | m: m, 42 | limit: limit, 43 | offset: offset, 44 | }, nil 45 | } 46 | 47 | func (m *Merger) Merge(ctx context.Context, results []rows.Rows) (rows.Rows, error) { 48 | rs, err := m.m.Merge(ctx, results) 49 | if err != nil { 50 | return nil, err 51 | } 52 | err = m.nextOffset(ctx, rs) 53 | if err != nil { 54 | return nil, err 55 | } 56 | return &Rows{ 57 | rows: rs, 58 | mu: &sync.RWMutex{}, 59 | limit: m.limit, 60 | }, nil 61 | } 62 | 63 | // nextOffset 会把游标挪到 offset 所指定的位置。 64 | func (m *Merger) nextOffset(ctx context.Context, rows rows.Rows) error { 65 | offset := m.offset 66 | for i := 0; i < offset; i++ { 67 | if ctx.Err() != nil { 68 | return ctx.Err() 69 | } 70 | // 如果偏移量超过rows结果集返回的行数,不会报错。用户最终查到0行 71 | if !rows.Next() { 72 | return rows.Err() 73 | } 74 | } 75 | return nil 76 | } 77 | 78 | type Rows struct { 79 | rows rows.Rows 80 | limit int 81 | cnt int 82 | lastErr error 83 | closed bool 84 | mu *sync.RWMutex 85 | } 86 | 87 | func (*Rows) NextResultSet() bool { 88 | return false 89 | } 90 | 91 | func (r *Rows) Next() bool { 92 | r.mu.Lock() 93 | if r.closed { 94 | r.mu.Unlock() 95 | return false 96 | } 97 | if r.cnt >= r.limit || r.lastErr != nil { 98 | r.mu.Unlock() 99 | _ = r.Close() 100 | return false 101 | } 102 | canNext, err := r.nextVal() 103 | if err != nil { 104 | r.lastErr = err 105 | r.mu.Unlock() 106 | _ = r.Close() 107 | return false 108 | } 109 | if !canNext { 110 | r.mu.Unlock() 111 | _ = r.Close() 112 | return canNext 113 | } 114 | r.cnt++ 115 | r.mu.Unlock() 116 | return canNext 117 | } 118 | func (r *Rows) nextVal() (bool, error) { 119 | if r.rows.Next() { 120 | return true, nil 121 | } 122 | if r.rows.Err() != nil { 123 | return false, r.rows.Err() 124 | } 125 | return false, nil 126 | } 127 | 128 | func (r *Rows) Scan(dest ...any) error { 129 | r.mu.RLock() 130 | defer r.mu.RUnlock() 131 | if r.lastErr != nil { 132 | return r.lastErr 133 | } 134 | if r.closed { 135 | return errs.ErrMergerRowsClosed 136 | } 137 | return r.rows.Scan(dest...) 138 | } 139 | 140 | func (r *Rows) Close() error { 141 | r.mu.Lock() 142 | defer r.mu.Unlock() 143 | r.closed = true 144 | return r.rows.Close() 145 | } 146 | 147 | func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) { 148 | r.mu.Lock() 149 | defer r.mu.Unlock() 150 | if r.closed { 151 | return nil, fmt.Errorf("%w", errs.ErrMergerRowsClosed) 152 | } 153 | return r.rows.ColumnTypes() 154 | } 155 | func (r *Rows) Columns() ([]string, error) { 156 | return r.rows.Columns() 157 | } 158 | 159 | func (r *Rows) Err() error { 160 | r.mu.RLock() 161 | defer r.mu.RUnlock() 162 | return r.lastErr 163 | } 164 | -------------------------------------------------------------------------------- /internal/merger/internal/sortmerger/heap/heap.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package heap 16 | 17 | import ( 18 | "container/heap" 19 | "database/sql/driver" 20 | "reflect" 21 | 22 | "github.com/ecodeclub/eorm/internal/merger" 23 | ) 24 | 25 | type Heap struct { 26 | nodes []*Node 27 | sortColumns merger.SortColumns 28 | } 29 | 30 | func NewHeap(h []*Node, sortColumns merger.SortColumns) *Heap { 31 | hp := &Heap{nodes: h, sortColumns: sortColumns} 32 | heap.Init(hp) 33 | return hp 34 | } 35 | 36 | func (h *Heap) Len() int { 37 | return len(h.nodes) 38 | } 39 | 40 | func (h *Heap) Less(i, j int) bool { 41 | for k := 0; k < h.sortColumns.Len(); k++ { 42 | valueI := h.nodes[i].SortColumnValues[k] 43 | valueJ := h.nodes[j].SortColumnValues[k] 44 | _, ok := valueJ.(driver.Valuer) 45 | var cmp func(any, any, merger.Order) int 46 | if ok { 47 | cmp = merger.CompareNullable 48 | } else { 49 | kind := reflect.TypeOf(valueI).Kind() 50 | cmp = merger.CompareFuncMapping[kind] 51 | } 52 | res := cmp(valueI, valueJ, h.sortColumns.Get(k).Order) 53 | if res == 0 { 54 | continue 55 | } 56 | if res == -1 { 57 | return true 58 | } 59 | return false 60 | } 61 | return false 62 | } 63 | 64 | func (h *Heap) Swap(i, j int) { 65 | h.nodes[i], h.nodes[j] = h.nodes[j], h.nodes[i] 66 | } 67 | 68 | func (h *Heap) Push(x any) { 69 | h.nodes = append(h.nodes, x.(*Node)) 70 | } 71 | 72 | func (h *Heap) Pop() any { 73 | v := h.nodes[len(h.nodes)-1] 74 | h.nodes = h.nodes[:len(h.nodes)-1] 75 | return v 76 | } 77 | 78 | type Node struct { 79 | RowsListIndex int 80 | // 用于排序列 81 | SortColumnValues []any 82 | // 完整的行中的所有列,不用于排序仅用于缓存行数据 83 | ColumnValues []any 84 | } 85 | -------------------------------------------------------------------------------- /internal/operator/operator.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package Operator 16 | 17 | import "github.com/ecodeclub/eorm/internal/errs" 18 | 19 | type Op struct { 20 | Symbol string 21 | Text string 22 | } 23 | 24 | var emptyOp = Op{} 25 | 26 | var ( 27 | OpLT = Op{Symbol: "<", Text: "<"} 28 | OpLTEQ = Op{Symbol: "<=", Text: "<="} 29 | OpGT = Op{Symbol: ">", Text: ">"} 30 | OpGTEQ = Op{Symbol: ">=", Text: ">="} 31 | OpEQ = Op{Symbol: "=", Text: "="} 32 | OpNEQ = Op{Symbol: "!=", Text: "!="} 33 | OpAdd = Op{Symbol: "+", Text: "+"} 34 | // OpIn = Op{Symbol: "IN", Text: " IN "} 35 | // OpMinus = Op{Symbol:"-", Text: "-"} 36 | OpMulti = Op{Symbol: "*", Text: "*"} 37 | // OpDiv = Op{Symbol:"/", Text: "/"} 38 | OpAnd = Op{Symbol: "AND", Text: " AND "} 39 | OpOr = Op{Symbol: "OR", Text: " OR "} 40 | OpNot = Op{Symbol: "NOT", Text: "NOT "} 41 | OpIn = Op{Symbol: "IN", Text: " IN "} 42 | OpNotIN = Op{Symbol: "NOT IN", Text: " NOT IN "} 43 | OpFalse = Op{Symbol: "FALSE", Text: "FALSE"} 44 | OpLike = Op{Symbol: "LIKE", Text: " LIKE "} 45 | OpNotLike = Op{Symbol: "NOT LIKE", Text: " NOT LIKE "} 46 | OpExist = Op{Symbol: "EXIST", Text: "EXIST "} 47 | ) 48 | 49 | func NegateOp(op Op) (Op, error) { 50 | switch op { 51 | case OpNEQ: 52 | return OpEQ, nil 53 | case OpEQ: 54 | return OpNEQ, nil 55 | case OpIn: 56 | return OpNotIN, nil 57 | case OpNotIN: 58 | return OpIn, nil 59 | case OpGT: 60 | return OpLTEQ, nil 61 | case OpLT: 62 | return OpGTEQ, nil 63 | case OpGTEQ: 64 | return OpLT, nil 65 | case OpLTEQ: 66 | return OpGT, nil 67 | default: 68 | return emptyOp, errs.NewUnsupportedOperatorError(op.Text) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /internal/query/query.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package query 16 | 17 | import "fmt" 18 | 19 | type Query struct { 20 | SQL string 21 | Args []any 22 | DB string 23 | Datasource string 24 | } 25 | 26 | func (q Query) String() string { 27 | return fmt.Sprintf("SQL: %s\nArgs: %#v\n", q.SQL, q.Args) 28 | } 29 | 30 | type Feature int 31 | 32 | const ( 33 | AggregateFunc Feature = 1 << iota 34 | GroupBy 35 | Distinct 36 | OrderBy 37 | Limit 38 | ) 39 | -------------------------------------------------------------------------------- /internal/rows/convert_assign.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package rows 16 | 17 | import ( 18 | "database/sql" 19 | "database/sql/driver" 20 | _ "unsafe" 21 | ) 22 | 23 | //go:linkname sqlConvertAssign database/sql.convertAssign 24 | func sqlConvertAssign(dest, src any) error 25 | 26 | func ConvertAssign(dest, src any) error { 27 | srcVal, ok := src.(driver.Valuer) 28 | if ok { 29 | var err error 30 | src, err = srcVal.Value() 31 | if err != nil { 32 | return err 33 | } 34 | } 35 | // 预处理一下 sqlConvertAssign 不支持的转换,遇到一个加一个 36 | switch sv := src.(type) { 37 | case sql.RawBytes: 38 | switch dv := dest.(type) { 39 | case *string: 40 | *dv = string(sv) 41 | return nil 42 | } 43 | } 44 | return sqlConvertAssign(dest, src) 45 | } 46 | -------------------------------------------------------------------------------- /internal/rows/data_rows.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package rows 16 | 17 | import ( 18 | "database/sql" 19 | 20 | "github.com/ecodeclub/eorm/internal/errs" 21 | ) 22 | 23 | var _ Rows = (*DataRows)(nil) 24 | 25 | // DataRows 直接传入数据,伪装成了一个 Rows 26 | // 非线程安全实现 27 | type DataRows struct { 28 | data [][]any 29 | len int 30 | columns []string 31 | columnTypes []*sql.ColumnType 32 | // 第几行 33 | idx int 34 | } 35 | 36 | func (*DataRows) NextResultSet() bool { 37 | return false 38 | } 39 | 40 | func (d *DataRows) ColumnTypes() ([]*sql.ColumnType, error) { 41 | return d.columnTypes, nil 42 | } 43 | 44 | func NewDataRows(data [][]any, columns []string, columnTypes []*sql.ColumnType) *DataRows { 45 | // 这里并没有什么必要检查 data 和 columns 的输入 46 | // 因为只有在很故意的情况下,data 和 columns 才可能会有问题 47 | return &DataRows{ 48 | data: data, 49 | len: len(data), 50 | columns: columns, 51 | idx: -1, 52 | columnTypes: columnTypes, 53 | } 54 | } 55 | 56 | func (d *DataRows) Next() bool { 57 | if d.idx >= d.len-1 { 58 | return false 59 | } 60 | d.idx++ 61 | return true 62 | } 63 | 64 | func (d *DataRows) Scan(dest ...any) error { 65 | // 不需要检测,作为内部代码我们可以预期用户会主动控制 66 | data := d.data[d.idx] 67 | if len(data) != len(dest) { 68 | return errs.NewErrScanWrongDestinationArguments(len(data), len(dest)) 69 | } 70 | for idx, dst := range dest { 71 | if err := ConvertAssign(dst, data[idx]); err != nil { 72 | return err 73 | } 74 | } 75 | return nil 76 | } 77 | 78 | func (*DataRows) Close() error { 79 | return nil 80 | } 81 | 82 | func (d *DataRows) Columns() ([]string, error) { 83 | return d.columns, nil 84 | } 85 | 86 | func (*DataRows) Err() error { 87 | return nil 88 | } 89 | -------------------------------------------------------------------------------- /internal/rows/types.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package rows 16 | 17 | import ( 18 | "github.com/ecodeclub/ekit/sqlx" 19 | ) 20 | 21 | // Rows 各方法用法及语义尽可能与sql.Rows相同 22 | type Rows = sqlx.Rows 23 | -------------------------------------------------------------------------------- /internal/sharding/compare.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package sharding 16 | 17 | import "strings" 18 | 19 | func CompareDSDBTab(i, j Dst) int { 20 | strI := strings.Join([]string{i.Name, i.DB, i.Table}, "") 21 | strJ := strings.Join([]string{j.Name, j.DB, j.Table}, "") 22 | if strI < strJ { 23 | return -1 24 | } else if strI == strJ { 25 | return 0 26 | } 27 | return 1 28 | 29 | } 30 | -------------------------------------------------------------------------------- /internal/sharding/hash/shadow_hash.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package hash 16 | 17 | import ( 18 | "context" 19 | "fmt" 20 | "strings" 21 | 22 | "github.com/ecodeclub/eorm/internal/errs" 23 | "github.com/ecodeclub/eorm/internal/sharding" 24 | ) 25 | 26 | // ShadowHash TODO experiemntal 27 | type ShadowHash struct { 28 | *Hash 29 | Prefix string 30 | } 31 | 32 | func (h *ShadowHash) Broadcast(ctx context.Context) []sharding.Dst { 33 | res := make([]sharding.Dst, 0, 8) 34 | for i := 0; i < h.DBPattern.Base; i++ { 35 | dbName := fmt.Sprintf(h.Prefix+h.DBPattern.Name, i) 36 | for j := 0; j < h.TablePattern.Base; j++ { 37 | res = append(res, sharding.Dst{ 38 | Name: h.DsPattern.Name, 39 | DB: dbName, 40 | Table: fmt.Sprintf(h.Prefix+h.TablePattern.Name, j), 41 | }) 42 | } 43 | } 44 | return res 45 | } 46 | 47 | func (h *ShadowHash) Sharding(ctx context.Context, req sharding.Request) (sharding.Response, error) { 48 | if h.ShardingKey == "" { 49 | return sharding.EmptyResp, errs.ErrMissingShardingKey 50 | } 51 | skVal, ok := req.SkValues[h.ShardingKey] 52 | if !ok { 53 | return sharding.Response{Dsts: h.Broadcast(ctx)}, nil 54 | } 55 | dbName := h.DBPattern.Name 56 | if !h.DBPattern.NotSharding && strings.Contains(dbName, "%d") { 57 | dbName = fmt.Sprintf(dbName, skVal.(int)%h.DBPattern.Base) 58 | } 59 | tbName := h.TablePattern.Name 60 | if !h.TablePattern.NotSharding && strings.Contains(tbName, "%d") { 61 | tbName = fmt.Sprintf(tbName, skVal.(int)%h.TablePattern.Base) 62 | } 63 | dsName := h.DsPattern.Name 64 | if !h.DsPattern.NotSharding && strings.Contains(dsName, "%d") { 65 | dsName = fmt.Sprintf(dsName, skVal.(int)%h.DsPattern.Base) 66 | } 67 | if isSourceKey(ctx) { 68 | dsName = h.Prefix + dsName 69 | } 70 | if isDBKey(ctx) { 71 | dbName = h.Prefix + dbName 72 | } 73 | 74 | if isTableKey(ctx) { 75 | tbName = h.Prefix + tbName 76 | } 77 | 78 | return sharding.Response{ 79 | Dsts: []sharding.Dst{{Name: dsName, DB: dbName, Table: tbName}}, 80 | }, nil 81 | } 82 | 83 | type sourceKey struct{} 84 | 85 | type dbKey struct{} 86 | 87 | type tableKey struct{} 88 | 89 | func CtxWithTableKey(ctx context.Context) context.Context { 90 | return context.WithValue(ctx, tableKey{}, true) 91 | } 92 | 93 | func CtxWithDBKey(ctx context.Context) context.Context { 94 | return context.WithValue(ctx, dbKey{}, true) 95 | } 96 | 97 | func CtxWithSourceKey(ctx context.Context) context.Context { 98 | return context.WithValue(ctx, sourceKey{}, true) 99 | } 100 | 101 | func isSourceKey(ctx context.Context) bool { 102 | return ctx.Value(sourceKey{}) != nil 103 | } 104 | 105 | func isDBKey(ctx context.Context) bool { 106 | return ctx.Value(dbKey{}) != nil 107 | } 108 | 109 | func isTableKey(ctx context.Context) bool { 110 | return ctx.Value(tableKey{}) != nil 111 | } 112 | -------------------------------------------------------------------------------- /internal/sharding/result.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package sharding 16 | 17 | import "database/sql" 18 | 19 | type Result struct { 20 | err error 21 | res []sql.Result 22 | } 23 | 24 | func (r Result) Err() error { 25 | return r.err 26 | } 27 | 28 | func (r Result) LastInsertId() (int64, error) { 29 | if r.err != nil { 30 | return 0, r.err 31 | } 32 | return r.res[len(r.res)-1].LastInsertId() 33 | } 34 | func (r Result) RowsAffected() (int64, error) { 35 | if r.err != nil { 36 | return 0, r.err 37 | } 38 | var sum int64 39 | for _, i := range r.res { 40 | n, err := i.RowsAffected() 41 | if err != nil { 42 | return 0, err 43 | } 44 | sum += n 45 | } 46 | return sum, nil 47 | } 48 | 49 | func NewResult(res []sql.Result, err error) Result { 50 | return Result{res: res, err: err} 51 | } 52 | -------------------------------------------------------------------------------- /internal/sharding/types.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package sharding 16 | 17 | import ( 18 | "context" 19 | 20 | operator "github.com/ecodeclub/eorm/internal/operator" 21 | "github.com/ecodeclub/eorm/internal/query" 22 | ) 23 | 24 | var EmptyResp = Response{} 25 | var EmptyQuery = Query{} 26 | 27 | type Algorithm interface { 28 | // Sharding 返回分库分表之后目标库和目标表信息 29 | Sharding(ctx context.Context, req Request) (Response, error) 30 | // Broadcast 返回所有的目标库、目标表 31 | Broadcast(ctx context.Context) []Dst 32 | // ShardingKeys 返回所有的 sharding key 33 | // 这部分不包含任何放在 context.Context 中的部分,例如 shadow 标记位等 34 | // 或者说,它只是指数据库中用于分库分表的列 35 | ShardingKeys() []string 36 | } 37 | 38 | // Executor sql 语句执行器 39 | type Executor interface { 40 | Exec(ctx context.Context) Result 41 | } 42 | 43 | // QueryBuilder sharding sql 构造抽象 44 | type QueryBuilder interface { 45 | Build(ctx context.Context) ([]Query, error) 46 | } 47 | 48 | type Query = query.Query 49 | 50 | type Dst struct { 51 | // Name 数据源的逻辑名字 52 | Name string 53 | DB string 54 | Table string 55 | } 56 | 57 | func (r Dst) Equals(l Dst) bool { 58 | return r.Name == l.Name && r.DB == l.DB && r.Table == l.Table 59 | } 60 | 61 | func (r Dst) NotEquals(l Dst) bool { 62 | return r.Name != l.Name || r.DB != l.DB || r.Table != l.Table 63 | } 64 | 65 | type Request struct { 66 | Op operator.Op 67 | SkValues map[string]any 68 | } 69 | 70 | type Response struct { 71 | Dsts []Dst 72 | } 73 | -------------------------------------------------------------------------------- /internal/test/types_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package test 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/ecodeclub/ekit" 21 | "github.com/stretchr/testify/assert" 22 | ) 23 | 24 | func TestJsonColumn_Scan(t *testing.T) { 25 | type User struct { 26 | Name string 27 | } 28 | testCases := []struct { 29 | name string 30 | input any 31 | wantVal User 32 | wantErr string 33 | }{ 34 | { 35 | name: "empty string", 36 | input: ``, 37 | }, 38 | { 39 | name: "no fields", 40 | input: `{}`, 41 | wantVal: User{}, 42 | }, 43 | { 44 | name: "string", 45 | input: `{"name":"Tom"}`, 46 | wantVal: User{Name: "Tom"}, 47 | }, 48 | { 49 | name: "nil bytes", 50 | input: []byte(nil), 51 | }, 52 | { 53 | name: "empty bytes", 54 | input: []byte(""), 55 | }, 56 | { 57 | name: "bytes", 58 | input: []byte(`{"name":"Tom"}`), 59 | wantVal: User{Name: "Tom"}, 60 | }, 61 | { 62 | name: "nil", 63 | }, 64 | { 65 | name: "empty bytes ptr", 66 | input: ekit.ToPtr[[]byte]([]byte("")), 67 | }, 68 | { 69 | name: "bytes ptr", 70 | input: ekit.ToPtr[[]byte]([]byte(`{"name":"Tom"}`)), 71 | wantVal: User{Name: "Tom"}, 72 | }, 73 | } 74 | 75 | for _, tc := range testCases { 76 | t.Run(tc.name, func(t *testing.T) { 77 | js := &JsonColumn{} 78 | err := js.Scan(tc.input) 79 | if tc.wantErr != "" { 80 | assert.EqualError(t, err, tc.wantErr) 81 | return 82 | } else { 83 | assert.Nil(t, err) 84 | } 85 | _, err = js.Value() 86 | assert.Nil(t, err) 87 | assert.EqualValues(t, tc.wantVal, js.Val) 88 | }) 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /internal/valuer/primitive.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package valuer 16 | 17 | import ( 18 | "database/sql" 19 | "reflect" 20 | 21 | "github.com/ecodeclub/eorm/internal/rows" 22 | 23 | "github.com/ecodeclub/eorm/internal/model" 24 | ) 25 | 26 | // primitiveValue 支持基本类型 Value 27 | type primitiveValue struct { 28 | Value 29 | val any 30 | valType reflect.Type 31 | } 32 | 33 | // Field 返回字段值 34 | func (s primitiveValue) Field(name string) (reflect.Value, error) { 35 | return s.Value.Field(name) 36 | } 37 | 38 | // SetColumns 设置列值, 支持基本类型,基于 reflect 与 unsafe Value 封装 39 | func (s primitiveValue) SetColumns(rows rows.Rows) error { 40 | switch s.valType.Elem().Kind() { 41 | case reflect.Struct: 42 | if scanner, ok := s.val.(sql.Scanner); ok { 43 | return rows.Scan(scanner) 44 | } 45 | return s.Value.SetColumns(rows) 46 | default: 47 | return rows.Scan(s.val) 48 | } 49 | } 50 | 51 | // PrimitiveCreator 支持基本类型的 Creator, 基于原生的 Creator 扩展 52 | type PrimitiveCreator struct { 53 | Creator 54 | } 55 | 56 | // NewPrimitiveValue 返回一个封装好的,基于支持基本类型实现的 Value 57 | // 输入 val 必须是一个指向结构体实例的指针,而不能是任何其它类型 58 | func (c PrimitiveCreator) NewPrimitiveValue(val any, meta *model.TableMeta) Value { 59 | return primitiveValue{ 60 | val: val, 61 | Value: c.Creator(val, meta), 62 | valType: reflect.TypeOf(val), 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /internal/valuer/reflect.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package valuer 16 | 17 | import ( 18 | "reflect" 19 | 20 | "github.com/ecodeclub/eorm/internal/rows" 21 | 22 | "github.com/ecodeclub/eorm/internal/errs" 23 | "github.com/ecodeclub/eorm/internal/model" 24 | ) 25 | 26 | var _ Creator = NewReflectValue 27 | 28 | // reflectValue 基于反射的 Value 29 | type reflectValue struct { 30 | val reflect.Value 31 | meta *model.TableMeta 32 | } 33 | 34 | // NewReflectValue 返回一个封装好的,基于反射实现的 Value 35 | // 输入 val 必须是一个指向结构体实例的指针,而不能是任何其它类型 36 | func NewReflectValue(val interface{}, meta *model.TableMeta) Value { 37 | return reflectValue{ 38 | val: reflect.ValueOf(val).Elem(), 39 | meta: meta, 40 | } 41 | } 42 | 43 | // Field 返回字段值 44 | func (r reflectValue) Field(name string) (reflect.Value, error) { 45 | res, ok := r.fieldByIndex(name) 46 | if !ok { 47 | return reflect.Value{}, errs.NewInvalidFieldError(name) 48 | } 49 | return res, nil 50 | } 51 | 52 | func (r reflectValue) fieldByIndex(name string) (reflect.Value, bool) { 53 | cm, ok := r.meta.FieldMap[name] 54 | if !ok { 55 | return reflect.Value{}, false 56 | } 57 | value := r.val 58 | for _, i := range cm.FieldIndexes { 59 | value = value.Field(i) 60 | } 61 | return value, true 62 | } 63 | 64 | func (r reflectValue) SetColumns(rows rows.Rows) error { 65 | cs, err := rows.Columns() 66 | if err != nil { 67 | return err 68 | } 69 | if len(cs) > len(r.meta.Columns) { 70 | return errs.ErrTooManyColumns 71 | } 72 | 73 | // TODO 性能优化 74 | // colValues 和 colEleValues 实质上最终都指向同一个对象 75 | colValues := make([]interface{}, len(cs)) 76 | colEleValues := make([]reflect.Value, len(cs)) 77 | 78 | for i, c := range cs { 79 | cm, ok := r.meta.ColumnMap[c] 80 | if !ok { 81 | return errs.NewInvalidColumnError(c) 82 | } 83 | val := reflect.New(cm.Typ) 84 | colValues[i] = val.Interface() 85 | colEleValues[i] = val.Elem() 86 | } 87 | 88 | if err = rows.Scan(colValues...); err != nil { 89 | return err 90 | } 91 | 92 | for i, c := range cs { 93 | cm := r.meta.ColumnMap[c] 94 | fd, _ := r.fieldByIndex(cm.FieldName) 95 | fd.Set(colEleValues[i]) 96 | } 97 | return nil 98 | } 99 | -------------------------------------------------------------------------------- /internal/valuer/unsafe.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package valuer 16 | 17 | import ( 18 | "reflect" 19 | "unsafe" 20 | 21 | "github.com/ecodeclub/eorm/internal/rows" 22 | 23 | "github.com/ecodeclub/eorm/internal/errs" 24 | "github.com/ecodeclub/eorm/internal/model" 25 | ) 26 | 27 | var _ Creator = NewUnsafeValue 28 | 29 | type unsafeValue struct { 30 | val reflect.Value 31 | addr unsafe.Pointer 32 | meta *model.TableMeta 33 | } 34 | 35 | func NewUnsafeValue(val interface{}, meta *model.TableMeta) Value { 36 | refVal := reflect.ValueOf(val) 37 | return unsafeValue{ 38 | meta: meta, 39 | val: refVal.Elem(), 40 | addr: unsafe.Pointer(refVal.Pointer()), 41 | } 42 | } 43 | 44 | func (u unsafeValue) Field(name string) (reflect.Value, error) { 45 | fd, ok := u.meta.FieldMap[name] 46 | if !ok { 47 | return reflect.Value{}, errs.NewInvalidFieldError(name) 48 | } 49 | ptr := unsafe.Pointer(uintptr(u.addr) + fd.Offset) 50 | val := reflect.NewAt(fd.Typ, ptr).Elem() 51 | return val, nil 52 | } 53 | 54 | func (u unsafeValue) SetColumns(rows rows.Rows) error { 55 | 56 | cs, err := rows.Columns() 57 | if err != nil { 58 | return err 59 | } 60 | if len(cs) > len(u.meta.Columns) { 61 | return errs.ErrTooManyColumns 62 | } 63 | 64 | // TODO 性能优化 65 | colValues := make([]interface{}, len(cs)) 66 | for i, c := range cs { 67 | cm, ok := u.meta.ColumnMap[c] 68 | if !ok { 69 | return errs.NewInvalidColumnError(c) 70 | } 71 | ptr := unsafe.Pointer(uintptr(u.addr) + cm.Offset) 72 | val := reflect.NewAt(cm.Typ, ptr) 73 | colValues[i] = val.Interface() 74 | } 75 | return rows.Scan(colValues...) 76 | } 77 | -------------------------------------------------------------------------------- /internal/valuer/value.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package valuer 16 | 17 | import ( 18 | "reflect" 19 | 20 | "github.com/ecodeclub/eorm/internal/rows" 21 | 22 | "github.com/ecodeclub/eorm/internal/model" 23 | ) 24 | 25 | // Value 是对结构体实例的内部抽象 26 | type Value interface { 27 | // Field 访问结构体字段, name 是字段名 28 | Field(name string) (reflect.Value, error) 29 | // SetColumns 设置新值,column 是列名 30 | // 要注意,val 可能存在被上层复用,从而引起篡改的问题 31 | // SetColumns(rows *sql.Rows) error 32 | 33 | SetColumns(rows rows.Rows) error 34 | } 35 | 36 | type Creator func(val any, meta *model.TableMeta) Value 37 | -------------------------------------------------------------------------------- /middleware.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import ( 18 | "context" 19 | 20 | "github.com/ecodeclub/eorm/internal/model" 21 | ) 22 | 23 | type QueryContext struct { 24 | Type string 25 | meta *model.TableMeta 26 | q Query 27 | } 28 | 29 | func (qc *QueryContext) GetQuery() Query { 30 | return qc.q 31 | } 32 | 33 | type QueryResult struct { 34 | Result any 35 | Err error 36 | } 37 | 38 | type Middleware func(next HandleFunc) HandleFunc 39 | 40 | type HandleFunc func(ctx context.Context, queryContext *QueryContext) *QueryResult 41 | -------------------------------------------------------------------------------- /middleware/querylog/querylog.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package querylog 16 | 17 | import ( 18 | "context" 19 | "log" 20 | 21 | "github.com/ecodeclub/eorm" 22 | ) 23 | 24 | type MiddlewareBuilder struct { 25 | logFunc func(sql string, args ...any) 26 | } 27 | 28 | func NewBuilder() *MiddlewareBuilder { 29 | return &MiddlewareBuilder{ 30 | logFunc: func(sql string, args ...any) { 31 | log.Println(sql, args) 32 | }, 33 | } 34 | 35 | } 36 | 37 | func (b *MiddlewareBuilder) LogFunc(logFunc func(sql string, args ...any)) *MiddlewareBuilder { 38 | b.logFunc = logFunc 39 | return b 40 | } 41 | 42 | func (b *MiddlewareBuilder) Build() eorm.Middleware { 43 | return func(next eorm.HandleFunc) eorm.HandleFunc { 44 | return func(ctx context.Context, queryContext *eorm.QueryContext) *eorm.QueryResult { 45 | query := queryContext.GetQuery() 46 | b.logFunc(query.SQL, query.Args...) 47 | return next(ctx, queryContext) 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /middleware/querylog/querylog_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package querylog 16 | 17 | import ( 18 | "context" 19 | "database/sql" 20 | "errors" 21 | "fmt" 22 | "strings" 23 | "testing" 24 | 25 | "github.com/ecodeclub/eorm" 26 | _ "github.com/mattn/go-sqlite3" 27 | "github.com/stretchr/testify/assert" 28 | ) 29 | 30 | func TestMiddlewareBuilder_Build(t *testing.T) { 31 | testCases := []struct { 32 | name string 33 | mdls []eorm.Middleware 34 | builder *testMiddlewareBuilder 35 | wantVal string 36 | wantErr error 37 | }{ 38 | { 39 | name: "default", 40 | builder: &testMiddlewareBuilder{ 41 | MiddlewareBuilder: NewBuilder(), 42 | printVal: strings.Builder{}, 43 | }, 44 | mdls: []eorm.Middleware{}, 45 | }, 46 | { 47 | name: "output args", 48 | builder: func() *testMiddlewareBuilder { 49 | b := &testMiddlewareBuilder{ 50 | MiddlewareBuilder: NewBuilder(), 51 | printVal: strings.Builder{}, 52 | } 53 | logfunc := func(sql string, args ...any) { 54 | fmt.Println(sql, args) 55 | b.printVal.WriteString(sql) 56 | } 57 | b.LogFunc(logfunc) 58 | return b 59 | }(), 60 | mdls: []eorm.Middleware{}, 61 | wantVal: "SELECT `id`,`first_name`,`age`,`last_name` FROM `test_model` LIMIT ?;", 62 | }, 63 | { 64 | name: "not args", 65 | builder: &testMiddlewareBuilder{ 66 | printVal: strings.Builder{}, 67 | MiddlewareBuilder: NewBuilder().LogFunc(func(sql string, args ...any) { 68 | fmt.Println(sql) 69 | }), 70 | }, 71 | mdls: []eorm.Middleware{}, 72 | wantVal: "SELECT `id`,`first_name`,`age`,`last_name` FROM `test_model` LIMIT ?;", 73 | }, 74 | { 75 | name: "interrupt err", 76 | builder: &testMiddlewareBuilder{ 77 | printVal: strings.Builder{}, 78 | MiddlewareBuilder: NewBuilder().LogFunc(func(sql string, args ...any) { 79 | fmt.Println(sql) 80 | }), 81 | }, 82 | mdls: func() []eorm.Middleware { 83 | var interrupt eorm.Middleware = func(next eorm.HandleFunc) eorm.HandleFunc { 84 | return func(ctx context.Context, qc *eorm.QueryContext) *eorm.QueryResult { 85 | return &eorm.QueryResult{ 86 | Err: errors.New("interrupt execution"), 87 | } 88 | } 89 | } 90 | return []eorm.Middleware{interrupt} 91 | }(), 92 | wantErr: errors.New("interrupt execution"), 93 | }, 94 | } 95 | 96 | for _, tc := range testCases { 97 | t.Run(tc.name, func(t *testing.T) { 98 | mdls := tc.mdls 99 | mdls = append(mdls, tc.builder.Build()) 100 | orm, err := eorm.Open("sqlite3", 101 | "file:test.db?cache=shared&mode=memory", 102 | eorm.DBWithMiddlewares(mdls...)) 103 | if err != nil { 104 | t.Fatal(err) 105 | } 106 | defer func() { 107 | _ = orm.Close() 108 | }() 109 | _, err = eorm.NewSelector[TestModel](orm).Get(context.Background()) 110 | if err.Error() == "no such table: test_model" { 111 | return 112 | } 113 | if err != nil { 114 | assert.Equal(t, tc.wantErr, err) 115 | return 116 | } 117 | assert.Equal(t, tc.wantVal, tc.builder.printVal.String()) 118 | }) 119 | } 120 | 121 | } 122 | 123 | type testMiddlewareBuilder struct { 124 | *MiddlewareBuilder 125 | printVal strings.Builder 126 | } 127 | 128 | type TestModel struct { 129 | Id int64 `eorm:"primary_key"` 130 | FirstName string 131 | Age int8 132 | LastName *sql.NullString 133 | } 134 | -------------------------------------------------------------------------------- /middleware_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import ( 18 | "context" 19 | "errors" 20 | "testing" 21 | 22 | "github.com/stretchr/testify/require" 23 | 24 | "github.com/stretchr/testify/assert" 25 | ) 26 | 27 | func Test_Middleware(t *testing.T) { 28 | testCases := []struct { 29 | name string 30 | wantErr error 31 | mdls []Middleware 32 | }{ 33 | { 34 | name: "one middleware", 35 | mdls: func() []Middleware { 36 | var mdl Middleware = func(next HandleFunc) HandleFunc { 37 | return func(ctx context.Context, queryContext *QueryContext) *QueryResult { 38 | return &QueryResult{} 39 | } 40 | } 41 | return []Middleware{mdl} 42 | }(), 43 | }, 44 | { 45 | name: "many middleware", 46 | mdls: func() []Middleware { 47 | mdl1 := func(next HandleFunc) HandleFunc { 48 | return func(ctx context.Context, queryContext *QueryContext) *QueryResult { 49 | return &QueryResult{Result: "mdl1"} 50 | } 51 | } 52 | mdl2 := func(next HandleFunc) HandleFunc { 53 | return func(ctx context.Context, queryContext *QueryContext) *QueryResult { 54 | return &QueryResult{Result: "mdl2"} 55 | } 56 | } 57 | return []Middleware{mdl1, mdl2} 58 | }(), 59 | }, 60 | } 61 | for _, tc := range testCases { 62 | t.Run(tc.name, func(t *testing.T) { 63 | db, err := Open("sqlite3", "file:test.db?cache=shared&mode=memory", 64 | DBWithMiddlewares(tc.mdls...)) 65 | if err != nil { 66 | t.Error(err) 67 | } 68 | defer func() { 69 | _ = db.Close() 70 | }() 71 | assert.EqualValues(t, tc.mdls, db.ms) 72 | }) 73 | } 74 | } 75 | 76 | func Test_Middleware_order(t *testing.T) { 77 | var res []byte 78 | var mdl1 Middleware = func(next HandleFunc) HandleFunc { 79 | return func(ctx context.Context, qc *QueryContext) *QueryResult { 80 | res = append(res, '1') 81 | return next(ctx, qc) 82 | } 83 | } 84 | var mdl2 Middleware = func(next HandleFunc) HandleFunc { 85 | return func(ctx context.Context, qc *QueryContext) *QueryResult { 86 | res = append(res, '2') 87 | return next(ctx, qc) 88 | } 89 | } 90 | 91 | var mdl3 Middleware = func(next HandleFunc) HandleFunc { 92 | return func(ctx context.Context, qc *QueryContext) *QueryResult { 93 | res = append(res, '3') 94 | return next(ctx, qc) 95 | } 96 | } 97 | var last Middleware = func(next HandleFunc) HandleFunc { 98 | return func(ctx context.Context, qc *QueryContext) *QueryResult { 99 | return &QueryResult{ 100 | Err: errors.New("mock error"), 101 | } 102 | } 103 | } 104 | db, err := Open("sqlite3", "file:test.db?cache=shared&mode=memory", 105 | DBWithMiddlewares(mdl1, mdl2, mdl3, last)) 106 | require.NoError(t, err) 107 | 108 | _, err = NewSelector[TestModel](db).Get(context.Background()) 109 | assert.Equal(t, errors.New("mock error"), err) 110 | assert.Equal(t, "123", string(res)) 111 | 112 | } 113 | 114 | func TestQueryContext(t *testing.T) { 115 | testCases := []struct { 116 | name string 117 | wantErr error 118 | q Query 119 | qc *QueryContext 120 | }{ 121 | { 122 | name: "one middleware", 123 | q: Query{ 124 | SQL: `SELECT * FROM user_tab WHERE id = ?;`, 125 | Args: []any{1}, 126 | }, 127 | qc: &QueryContext{ 128 | q: Query{ 129 | SQL: `SELECT * FROM user_tab WHERE id = ?;`, 130 | Args: []any{1}, 131 | }, 132 | }, 133 | }, 134 | } 135 | for _, tc := range testCases { 136 | t.Run(tc.name, func(t *testing.T) { 137 | assert.EqualValues(t, tc.q, tc.qc.GetQuery()) 138 | }) 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /predicate.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import operator "github.com/ecodeclub/eorm/internal/operator" 18 | 19 | // type op Operator.Op 20 | var ( 21 | opLT = operator.OpLT 22 | opLTEQ = operator.OpLTEQ 23 | opGT = operator.OpGT 24 | opGTEQ = operator.OpGTEQ 25 | opEQ = operator.OpEQ 26 | opNEQ = operator.OpNEQ 27 | opAdd = operator.OpAdd 28 | opMulti = operator.OpMulti 29 | opAnd = operator.OpAnd 30 | opOr = operator.OpOr 31 | opNot = operator.OpNot 32 | opIn = operator.OpIn 33 | opNotIN = operator.OpNotIN 34 | opFalse = operator.OpFalse 35 | opLike = operator.OpLike 36 | opNotLike = operator.OpNotLike 37 | opExist = operator.OpExist 38 | ) 39 | 40 | // Predicate will be used in Where Or Having 41 | type Predicate binaryExpr 42 | 43 | var emptyPredicate = Predicate{} 44 | 45 | func (Predicate) expr() (string, error) { 46 | return "", nil 47 | } 48 | 49 | // Exist indicates "Exist" 50 | func Exist(sub Subquery) Predicate { 51 | return Predicate{ 52 | op: opExist, 53 | right: sub, 54 | } 55 | } 56 | 57 | // Not indicates "NOT" 58 | func Not(p Predicate) Predicate { 59 | return Predicate{ 60 | left: Raw(""), 61 | op: opNot, 62 | right: p, 63 | } 64 | } 65 | 66 | // And indicates "AND" 67 | func (p Predicate) And(pred Predicate) Predicate { 68 | return Predicate{ 69 | left: p, 70 | op: opAnd, 71 | right: pred, 72 | } 73 | } 74 | 75 | // Or indicates "OR" 76 | func (p Predicate) Or(pred Predicate) Predicate { 77 | return Predicate{ 78 | left: p, 79 | op: opOr, 80 | right: pred, 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /result.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import "database/sql" 18 | 19 | type Result struct { 20 | err error 21 | res sql.Result 22 | } 23 | 24 | func (r Result) Err() error { 25 | return r.err 26 | } 27 | 28 | func (r Result) LastInsertId() (int64, error) { 29 | if r.err != nil { 30 | return 0, r.err 31 | } 32 | return r.res.LastInsertId() 33 | } 34 | 35 | func (r Result) RowsAffected() (int64, error) { 36 | if r.err != nil { 37 | return 0, r.err 38 | } 39 | return r.res.RowsAffected() 40 | } 41 | -------------------------------------------------------------------------------- /result_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import ( 18 | "errors" 19 | "testing" 20 | 21 | "github.com/DATA-DOG/go-sqlmock" 22 | "github.com/stretchr/testify/assert" 23 | ) 24 | 25 | func TestResult_RowsAffected(t *testing.T) { 26 | testCases := []struct { 27 | name string 28 | res Result 29 | wantAffected int64 30 | wantErr error 31 | }{ 32 | { 33 | name: "err", 34 | wantErr: errors.New("exec err"), 35 | res: Result{err: errors.New("exec err")}, 36 | }, 37 | { 38 | name: "unknown error", 39 | wantErr: errors.New("unknown error"), 40 | res: Result{res: sqlmock.NewErrorResult(errors.New("unknown error"))}, 41 | }, 42 | { 43 | name: "no err", 44 | wantAffected: int64(234), 45 | res: Result{res: sqlmock.NewResult(123, 234)}, 46 | }, 47 | } 48 | for _, tc := range testCases { 49 | t.Run(tc.name, func(t *testing.T) { 50 | affected, err := tc.res.RowsAffected() 51 | assert.Equal(t, tc.wantErr, err) 52 | if err != nil { 53 | return 54 | } 55 | assert.Equal(t, tc.wantAffected, affected) 56 | }) 57 | } 58 | } 59 | 60 | func TestResult_LastInsertId(t *testing.T) { 61 | testCases := []struct { 62 | name string 63 | res Result 64 | wantLastId int64 65 | wantErr error 66 | }{ 67 | { 68 | name: "err", 69 | wantErr: errors.New("exec err"), 70 | res: Result{err: errors.New("exec err")}, 71 | }, 72 | { 73 | name: "res err", 74 | wantErr: errors.New("exec err"), 75 | res: Result{res: sqlmock.NewErrorResult(errors.New("exec err"))}, 76 | }, 77 | { 78 | name: "no err", 79 | wantLastId: int64(123), 80 | res: Result{res: sqlmock.NewResult(123, 234)}, 81 | }, 82 | } 83 | for _, tc := range testCases { 84 | t.Run(tc.name, func(t *testing.T) { 85 | id, err := tc.res.LastInsertId() 86 | assert.Equal(t, tc.wantErr, err) 87 | if err != nil { 88 | return 89 | } 90 | assert.Equal(t, tc.wantLastId, id) 91 | }) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /script/fmt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # shellcheck disable=SC2044 4 | for item in $(find . -type f -name '*.go' -not -path './.idea/*'); do 5 | goimports -l -w "$item"; 6 | done -------------------------------------------------------------------------------- /script/integrate_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | docker compose -f script/integration_test_compose.yml down 5 | docker compose -f script/integration_test_compose.yml up -d 6 | #sudo echo "127.0.0.1 slave.a.com" >> /etc/hosts 7 | go test -timeout=30m -race ./... -tags=e2e 8 | docker compose -f script/integration_test_compose.yml down 9 | -------------------------------------------------------------------------------- /script/integration_test_compose.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | version: "3.0" 16 | services: 17 | mysql8: 18 | image: mysql:8.0.29 19 | restart: always 20 | command: --default-authentication-plugin=mysql_native_password 21 | environment: 22 | MYSQL_ROOT_PASSWORD: root 23 | volumes: 24 | - ./mysql/init.sql:/docker-entrypoint-initdb.d/init.sql 25 | ports: 26 | - "13306:3306" 27 | master: 28 | image: mysql:8.0.29 29 | ports: 30 | - '13307:3306' 31 | restart: always 32 | hostname: mysql-master 33 | environment: 34 | MYSQL_ROOT_PASSWORD: root 35 | MASTER_SYNC_USER: "sync" 36 | MASTER_SYNC_PASSWORD: "123456" 37 | ADMIN_USER: "root" 38 | ADMIN_PASSWORD: "root" 39 | command: 40 | - "--server-id=1" 41 | - "--character-set-server=utf8mb4" 42 | - "--collation-server=utf8mb4_unicode_ci" 43 | - "--log-bin=mysql-bin" 44 | - "--sync_binlog=1" 45 | - "--binlog-ignore-db=mysql" 46 | - "--binlog-ignore-db=sys" 47 | - "--binlog-ignore-db=performance_schema" 48 | - "--binlog-ignore-db=information_schema" 49 | - "--sql_mode=NO_AUTO_VALUE_ON_ZERO,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION,PIPES_AS_CONCAT,ANSI_QUOTES" 50 | volumes: 51 | - ./mysql/master:/docker-entrypoint-initdb.d/ 52 | - ./mysql/init.sql:/docker-entrypoint-initdb.d/init.sql 53 | slave: 54 | image: mysql:8.0.29 55 | container_name: mysql-slave 56 | ports: 57 | - '13308:3306' 58 | restart: always 59 | hostname: mysql-slave 60 | environment: 61 | MYSQL_ROOT_PASSWORD: "root" 62 | SLAVE_SYNC_USER: "sync" 63 | SLAVE_SYNC_PASSWORD: "123456" 64 | ADMIN_USER: "root" 65 | ADMIN_PASSWORD: "root" 66 | MASTER_HOST: "mysql-master" 67 | command: 68 | - "--server-id=2" 69 | - "--character-set-server=utf8mb4" 70 | - "--collation-server=utf8mb4_unicode_ci" 71 | - "--sql_mode=NO_AUTO_VALUE_ON_ZERO,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION,PIPES_AS_CONCAT,ANSI_QUOTES" 72 | volumes: 73 | - ./mysql/slave:/docker-entrypoint-initdb.d/ 74 | - ./mysql/init.sql:/docker-entrypoint-initdb.d/init.sql 75 | # mysql5: 76 | # image: mysql:5.7.38 77 | # restart: always 78 | # environment: 79 | # MYSQL_ROOT_PASSWORD: root 80 | # volumes: 81 | # - ./init.sql:/script/sql/mysql.sh 82 | # ports: 83 | # - "13307:3306" 84 | 85 | -------------------------------------------------------------------------------- /script/mysql/master/init.sql: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecodeclub/eorm/f8fa86c70e9e1e3941653032663d07c1798085d8/script/mysql/master/init.sql -------------------------------------------------------------------------------- /script/mysql/master/master.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #定义用于同步的用户名 3 | MASTER_SYNC_USER=${MASTER_SYNC_USER:-sync_admin} 4 | #定义用于同步的用户密码 5 | MASTER_SYNC_PASSWORD=${MASTER_SYNC_PASSWORD:-123456} 6 | #定义用于登录mysql的用户名 7 | ADMIN_USER=${ADMIN_USER:-root} 8 | #定义用于登录mysql的用户密码 9 | ADMIN_PASSWORD=${ADMIN_PASSWORD:-root} 10 | #定义运行登录的host地址 11 | ALLOW_HOST=${ALLOW_HOST:-%} 12 | #定义创建账号的sql语句 13 | CREATE_USER_SQL="CREATE USER '$MASTER_SYNC_USER'@'$ALLOW_HOST' IDENTIFIED BY '$MASTER_SYNC_PASSWORD';" 14 | #定义赋予同步账号权限的sql,这里设置两个权限,REPLICATION SLAVE,属于从节点副本的权限,REPLICATION CLIENT是副本客户端的权限,可以执行show master status语句 15 | GRANT_PRIVILEGES_SQL="GRANT SELECT,REPLICATION SLAVE,REPLICATION CLIENT ON *.* TO '$MASTER_SYNC_USER'@'$ALLOW_HOST';" 16 | #定义刷新权限的sql 17 | FLUSH_PRIVILEGES_SQL="FLUSH PRIVILEGES;" 18 | #执行sql 19 | mysql -u"$ADMIN_USER" -p"$ADMIN_PASSWORD" -e "$CREATE_USER_SQL $GRANT_PRIVILEGES_SQL $FLUSH_PRIVILEGES_SQL" 20 | -------------------------------------------------------------------------------- /script/mysql/slave/init.sql: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecodeclub/eorm/f8fa86c70e9e1e3941653032663d07c1798085d8/script/mysql/slave/init.sql -------------------------------------------------------------------------------- /script/mysql/slave/slave.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #定义连接master进行同步的账号 3 | SLAVE_SYNC_USER="${SLAVE_SYNC_USER:-sync_admin}" 4 | #定义连接master进行同步的账号密码 5 | SLAVE_SYNC_PASSWORD="${SLAVE_SYNC_PASSWORD:-123456}" 6 | #定义slave数据库账号 7 | ADMIN_USER="${ADMIN_USER:-root}" 8 | #定义slave数据库密码 9 | ADMIN_PASSWORD="${ADMIN_PASSWORD:-root}" 10 | #定义连接master数据库host地址 11 | MASTER_HOST="${MASTER_HOST:-%}" 12 | #连接master数据库,查询二进制数据,并解析出logfile和pos,这里同步用户要开启 REPLICATION CLIENT权限,才能使用SHOW MASTER STATUS; 13 | RESULT=`mysql -u"$SLAVE_SYNC_USER" -h$MASTER_HOST -p"$SLAVE_SYNC_PASSWORD" -e "SHOW MASTER STATUS;" | grep -v grep |tail -n +2| awk '{print $1,$2}'` 14 | #解析出logfile 15 | LOG_FILE_NAME=`echo $RESULT | grep -v grep | awk '{print $1}'` 16 | #解析出pos 17 | LOG_FILE_POS=`echo $RESULT | grep -v grep | awk '{print $2}'` 18 | #设置连接master的同步相关信息 19 | SYNC_SQL="change master to master_host='$MASTER_HOST',master_user='$SLAVE_SYNC_USER',master_password='$SLAVE_SYNC_PASSWORD',master_log_file='$LOG_FILE_NAME',master_log_pos=$LOG_FILE_POS,get_master_public_key=1;" 20 | #开启同步 21 | START_SYNC_SQL="start slave;" 22 | #查看同步状态 23 | STATUS_SQL="show slave status\G;" 24 | mysql -u"$ADMIN_USER" -p"$ADMIN_PASSWORD" -e "$SYNC_SQL $START_SYNC_SQL $STATUS_SQL" 25 | -------------------------------------------------------------------------------- /script/setup.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2021 ecodeclub 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | SOURCE_COMMIT=.github/pre-commit 16 | TARGET_COMMIT=.git/hooks/pre-commit 17 | SOURCE_PUSH=.github/pre-push 18 | TARGET_PUSH=.git/hooks/pre-push 19 | 20 | # copy pre-commit file if not exist. 21 | if [ ! -f $TARGET_COMMIT ]; then 22 | echo "设置 git pre-commit hooks..." 23 | cp $SOURCE_COMMIT $TARGET_COMMIT 24 | fi 25 | 26 | # copy pre-push file if not exist. 27 | if [ ! -f $TARGET_PUSH ]; then 28 | echo "设置 git pre-push hooks..." 29 | cp $SOURCE_PUSH $TARGET_PUSH 30 | fi 31 | 32 | # add permission to TARGET_PUSH and TARGET_COMMIT file. 33 | test -x $TARGET_PUSH || chmod +x $TARGET_PUSH 34 | test -x $TARGET_COMMIT || chmod +x $TARGET_COMMIT 35 | 36 | echo "安装 golangci-lint..." 37 | go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.52.2 38 | 39 | echo "安装 goimports..." 40 | go install golang.org/x/tools/cmd/goimports@latest -------------------------------------------------------------------------------- /select_builder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | type selectorBuilderAttribute struct { 18 | columns []Selectable 19 | where []Predicate 20 | having []Predicate 21 | groupBy []string 22 | orderBy []OrderBy 23 | 24 | distinct bool 25 | offset int 26 | limit int 27 | } 28 | 29 | type selectorBuilder struct { 30 | builder 31 | selectorBuilderAttribute 32 | } 33 | 34 | type shardingSelectorBuilder struct { 35 | shardingBuilder 36 | selectorBuilderAttribute 37 | } 38 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import ( 18 | "context" 19 | "database/sql" 20 | 21 | "github.com/ecodeclub/ekit/list" 22 | "github.com/ecodeclub/eorm/internal/datasource" 23 | "github.com/ecodeclub/eorm/internal/rows" 24 | "golang.org/x/sync/errgroup" 25 | ) 26 | 27 | var _ Session = (*baseSession)(nil) 28 | 29 | type baseSession struct { 30 | core 31 | executor datasource.Executor 32 | } 33 | 34 | func (sess *baseSession) queryContext(ctx context.Context, q Query) (rows.Rows, error) { 35 | return sess.executor.Query(ctx, q) 36 | } 37 | 38 | func (sess *baseSession) queryMulti(ctx context.Context, qs []Query) (list.List[rows.Rows], error) { 39 | res := &list.ConcurrentList[rows.Rows]{ 40 | List: list.NewArrayList[rows.Rows](len(qs)), 41 | } 42 | var eg errgroup.Group 43 | for _, query := range qs { 44 | q := query 45 | eg.Go(func() error { 46 | rs, err := sess.queryContext(ctx, q) 47 | if err == nil { 48 | return res.Append(rs) 49 | } 50 | return err 51 | }) 52 | } 53 | return res, eg.Wait() 54 | } 55 | 56 | func (sess *baseSession) execContext(ctx context.Context, q Query) (sql.Result, error) { 57 | return sess.executor.Exec(ctx, q) 58 | } 59 | 60 | func (sess *baseSession) getCore() core { 61 | return sess.core 62 | } 63 | -------------------------------------------------------------------------------- /transaction.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import ( 18 | "context" 19 | "database/sql" 20 | 21 | "github.com/ecodeclub/ekit/list" 22 | "github.com/ecodeclub/ekit/mapx" 23 | "github.com/ecodeclub/ekit/sqlx" 24 | "github.com/ecodeclub/eorm/internal/rows" 25 | "github.com/valyala/bytebufferpool" 26 | "golang.org/x/sync/errgroup" 27 | 28 | "github.com/ecodeclub/eorm/internal/datasource" 29 | ) 30 | 31 | type Tx struct { 32 | baseSession 33 | tx datasource.Tx 34 | } 35 | 36 | func (t *Tx) queryMulti(ctx context.Context, qs []Query) (list.List[rows.Rows], error) { 37 | // 事务在查询的时候,需要将同一个 DB 上的语句合并在一起 38 | // 参考 https://github.com/ecodeclub/eorm/discussions/213 39 | mp := mapx.NewMultiBuiltinMap[string, Query](len(qs)) 40 | for _, q := range qs { 41 | if err := mp.Put(q.DB+"_"+q.Datasource, q); err != nil { 42 | return nil, err 43 | } 44 | } 45 | keys := mp.Keys() 46 | rowsList := &list.ConcurrentList[rows.Rows]{ 47 | List: list.NewArrayList[rows.Rows](len(keys)), 48 | } 49 | var eg errgroup.Group 50 | for _, key := range keys { 51 | dbQs, _ := mp.Get(key) 52 | eg.Go(func() error { 53 | return t.execDBQueries(ctx, dbQs, rowsList) 54 | }) 55 | } 56 | return rowsList, eg.Wait() 57 | } 58 | 59 | // execDBQueries 执行某个 DB 上的全部查询。 60 | // 执行结果会被加入进去 rowsList 里面。虽然这种修改传入参数的做法不是很好,但是作为一个内部方法还是可以接受的。 61 | func (t *Tx) execDBQueries(ctx context.Context, dbQs []Query, rowsList *list.ConcurrentList[rows.Rows]) error { 62 | qsCnt := len(dbQs) 63 | // 考虑到大部分都只有一个查询,我们做一个快路径的优化。 64 | if qsCnt == 1 { 65 | rs, err := t.tx.Query(ctx, dbQs[0]) 66 | if err != nil { 67 | return err 68 | } 69 | return rowsList.Append(rs) 70 | } 71 | // 慢路径,也就是必须要把同一个库的查询合并在一起 72 | q := t.mergeDBQueries(dbQs) 73 | rs, err := t.tx.Query(ctx, q) 74 | if err != nil { 75 | return err 76 | } 77 | // 查询之后,事务必须再次按照结果集分割开。 78 | // 这样是为了让结果集的数量和查询数量保持一致。 79 | return t.splitTxResultSet(rowsList, rs) 80 | } 81 | 82 | func (t *Tx) splitTxResultSet(list list.List[rows.Rows], rs *sql.Rows) error { 83 | cs, err := rs.Columns() 84 | if err != nil { 85 | return err 86 | } 87 | ct, err := rs.ColumnTypes() 88 | if err != nil { 89 | return err 90 | } 91 | scanner, err := sqlx.NewSQLRowsScanner(rs) 92 | if err != nil { 93 | return err 94 | } 95 | // 虽然这里我们可以尝试不读取最后一个 ResultSet 96 | // 但是这个优化目前来说不准备做, 97 | // 防止用户出现因为类型转换遇到一些潜在的问题 98 | // 数据库类型到 GO 类型再到用户希望的类型,是一个漫长的过程。 99 | hasNext := true 100 | for hasNext { 101 | var data [][]any 102 | data, err = scanner.ScanAll() 103 | if err != nil { 104 | return err 105 | } 106 | err = list.Append(rows.NewDataRows(data, cs, ct)) 107 | if err != nil { 108 | return err 109 | } 110 | hasNext = scanner.NextResultSet() 111 | } 112 | return nil 113 | } 114 | 115 | func (t *Tx) mergeDBQueries(dbQs []Query) Query { 116 | buffer := bytebufferpool.Get() 117 | defer bytebufferpool.Put(buffer) 118 | first := dbQs[0] 119 | // 预估有多少查询参数,一个查询的参数个数 * 查询个数 120 | args := make([]any, 0, len(first.Args)*len(dbQs)) 121 | for _, dbQ := range dbQs { 122 | _, _ = buffer.WriteString(dbQ.SQL) 123 | args = append(args, dbQ.Args...) 124 | } 125 | return Query{ 126 | SQL: buffer.String(), 127 | Args: args, 128 | DB: first.DB, 129 | Datasource: first.Datasource, 130 | } 131 | } 132 | 133 | func (t *Tx) Commit() error { 134 | return t.tx.Commit() 135 | } 136 | 137 | func (t *Tx) Rollback() error { 138 | return t.tx.Rollback() 139 | } 140 | -------------------------------------------------------------------------------- /types.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import ( 18 | "context" 19 | "database/sql" 20 | 21 | "github.com/ecodeclub/ekit/list" 22 | "github.com/ecodeclub/eorm/internal/rows" 23 | ) 24 | 25 | // Executor sql 语句执行器 26 | type Executor interface { 27 | Exec(ctx context.Context) Result 28 | } 29 | 30 | // QueryBuilder 普通 sql 构造抽象 31 | type QueryBuilder interface { 32 | Build() (Query, error) 33 | } 34 | 35 | // Session 代表一个抽象的概念,即会话 36 | type Session interface { 37 | getCore() core 38 | queryMulti(ctx context.Context, qs []Query) (list.List[rows.Rows], error) 39 | queryContext(ctx context.Context, query Query) (rows.Rows, error) 40 | execContext(ctx context.Context, query Query) (sql.Result, error) 41 | } 42 | -------------------------------------------------------------------------------- /update_builder.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 ecodeclub 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package eorm 16 | 17 | import "github.com/ecodeclub/eorm/internal/valuer" 18 | 19 | type updaterBuilderAttribute struct { 20 | val valuer.Value 21 | where []Predicate 22 | assigns []Assignable 23 | ignoreNilVal bool 24 | ignoreZeroVal bool 25 | } 26 | 27 | type updaterBuilder struct { 28 | builder 29 | updaterBuilderAttribute 30 | } 31 | 32 | type shardingUpdaterBuilder struct { 33 | shardingBuilder 34 | updaterBuilderAttribute 35 | } 36 | --------------------------------------------------------------------------------