├── .github ├── FUNDING.yml ├── dependabot.yml ├── labels.json ├── release-drafter.yml └── workflows │ ├── create-release.yml │ ├── golangci-lint.yml │ ├── invalid_question.yml │ ├── labeler.yml │ ├── missing_playground.yml │ ├── stale.yml │ └── tests.yml ├── .gitignore ├── .golangci.yml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── association.go ├── callbacks.go ├── callbacks ├── associations.go ├── callbacks.go ├── callmethod.go ├── create.go ├── create_test.go ├── delete.go ├── helper.go ├── helper_test.go ├── interfaces.go ├── preload.go ├── query.go ├── raw.go ├── row.go ├── transaction.go └── update.go ├── chainable_api.go ├── clause ├── benchmarks_test.go ├── clause.go ├── clause_test.go ├── delete.go ├── delete_test.go ├── expression.go ├── expression_test.go ├── from.go ├── from_test.go ├── group_by.go ├── group_by_test.go ├── insert.go ├── insert_test.go ├── joins.go ├── joins_test.go ├── limit.go ├── limit_test.go ├── locking.go ├── locking_test.go ├── on_conflict.go ├── order_by.go ├── order_by_test.go ├── returning.go ├── returning_test.go ├── select.go ├── select_test.go ├── set.go ├── set_test.go ├── update.go ├── update_test.go ├── values.go ├── values_test.go ├── where.go ├── where_test.go └── with.go ├── errors.go ├── finisher_api.go ├── generics.go ├── go.mod ├── go.sum ├── gorm.go ├── interfaces.go ├── internal ├── lru │ └── lru.go └── stmt_store │ └── stmt_store.go ├── logger ├── logger.go ├── sql.go └── sql_test.go ├── migrator.go ├── migrator ├── column_type.go ├── index.go ├── migrator.go └── table_type.go ├── model.go ├── prepare_stmt.go ├── scan.go ├── schema ├── callbacks_test.go ├── constraint.go ├── constraint_test.go ├── field.go ├── field_test.go ├── index.go ├── index_test.go ├── interfaces.go ├── model_test.go ├── naming.go ├── naming_test.go ├── pool.go ├── relationship.go ├── relationship_test.go ├── schema.go ├── schema_helper_test.go ├── schema_test.go ├── serializer.go ├── utils.go └── utils_test.go ├── soft_delete.go ├── statement.go ├── statement_test.go ├── tests ├── .gitignore ├── README.md ├── associations_belongs_to_test.go ├── associations_has_many_test.go ├── associations_has_one_test.go ├── associations_many2many_test.go ├── associations_test.go ├── benchmark_test.go ├── callbacks_test.go ├── compose.yml ├── connection_test.go ├── connpool_test.go ├── count_test.go ├── create_test.go ├── customize_field_test.go ├── default_value_test.go ├── delete_test.go ├── distinct_test.go ├── embedded_struct_test.go ├── error_translator_test.go ├── generics_test.go ├── go.mod ├── gorm_test.go ├── group_by_test.go ├── helper_test.go ├── hooks_test.go ├── joins_table_test.go ├── joins_test.go ├── lru_test.go ├── main_test.go ├── migrate_test.go ├── multi_primary_keys_test.go ├── named_argument_test.go ├── named_polymorphic_test.go ├── non_std_test.go ├── postgres_test.go ├── preload_suits_test.go ├── preload_test.go ├── prepared_stmt_test.go ├── query_test.go ├── scan_test.go ├── scanner_valuer_test.go ├── scopes_test.go ├── serializer_test.go ├── soft_delete_test.go ├── sql_builder_test.go ├── table_test.go ├── tests_all.sh ├── tests_test.go ├── tracer_test.go ├── transaction_test.go ├── update_belongs_to_test.go ├── update_has_many_test.go ├── update_has_one_test.go ├── update_many2many_test.go ├── update_test.go └── upsert_test.go └── utils ├── tests ├── dummy_dialecter.go ├── models.go └── utils.go ├── utils.go ├── utils_test.go ├── utils_unix_test.go └── utils_windows_test.go /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [jinzhu] 4 | patreon: jinzhu 5 | open_collective: gorm 6 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 2 3 | updates: 4 | - package-ecosystem: gomod 5 | directory: / 6 | schedule: 7 | interval: weekly 8 | - package-ecosystem: github-actions 9 | directory: / 10 | schedule: 11 | interval: weekly 12 | - package-ecosystem: gomod 13 | directory: /tests 14 | schedule: 15 | interval: weekly 16 | -------------------------------------------------------------------------------- /.github/labels.json: -------------------------------------------------------------------------------- 1 | { 2 | "labels": { 3 | "critical": { 4 | "name": "type:critical", 5 | "colour": "#E84137", 6 | "description": "critical questions" 7 | }, 8 | "question": { 9 | "name": "type:question", 10 | "colour": "#EDEDED", 11 | "description": "general questions" 12 | }, 13 | "feature": { 14 | "name": "type:feature_request", 15 | "colour": "#43952A", 16 | "description": "feature request" 17 | }, 18 | "invalid_question": { 19 | "name": "type:invalid question", 20 | "colour": "#CF2E1F", 21 | "description": "invalid question (not related to GORM or described in document or not enough information provided)" 22 | }, 23 | "with_playground": { 24 | "name": "type:with reproduction steps", 25 | "colour": "#00ff00", 26 | "description": "with reproduction steps" 27 | }, 28 | "without_playground": { 29 | "name": "type:missing reproduction steps", 30 | "colour": "#CF2E1F", 31 | "description": "missing reproduction steps" 32 | }, 33 | "has_pr": { 34 | "name": "type:has pull request", 35 | "colour": "#43952A", 36 | "description": "has pull request" 37 | }, 38 | "not_tested": { 39 | "name": "type:not tested", 40 | "colour": "#CF2E1F", 41 | "description": "not tested" 42 | }, 43 | "tested": { 44 | "name": "type:tested", 45 | "colour": "#00ff00", 46 | "description": "tested" 47 | }, 48 | "breaking_change": { 49 | "name": "type:breaking change", 50 | "colour": "#CF2E1F", 51 | "description": "breaking change" 52 | } 53 | }, 54 | "issue": { 55 | "with_playground": { 56 | "requires": 1, 57 | "conditions": [ 58 | { 59 | "type": "descriptionMatches", 60 | "pattern": "/github.com\/go-gorm\/playground\/pull\/\\d\\d+/s" 61 | } 62 | ] 63 | }, 64 | "critical": { 65 | "requires": 1, 66 | "conditions": [ 67 | { 68 | "type": "descriptionMatches", 69 | "pattern": "/(critical|urgent)/i" 70 | }, 71 | { 72 | "type": "titleMatches", 73 | "pattern": "/(critical|urgent)/i" 74 | } 75 | ] 76 | }, 77 | "question": { 78 | "requires": 1, 79 | "conditions": [ 80 | { 81 | "type": "titleMatches", 82 | "pattern": "/question/i" 83 | }, 84 | { 85 | "type": "descriptionMatches", 86 | "pattern": "/question/i" 87 | } 88 | ] 89 | }, 90 | "feature": { 91 | "requires": 1, 92 | "conditions": [ 93 | { 94 | "type": "titleMatches", 95 | "pattern": "/feature/i" 96 | }, 97 | { 98 | "type": "descriptionMatches", 99 | "pattern": "/Describe the feature/i" 100 | } 101 | ] 102 | }, 103 | "without_playground": { 104 | "requires": 6, 105 | "conditions": [ 106 | { 107 | "type": "descriptionMatches", 108 | "pattern": "/^((?!github.com\/go-gorm\/playground\/pull\/\\d\\d+).)*$/s" 109 | }, 110 | { 111 | "type": "titleMatches", 112 | "pattern": "/^((?!question).)*$/s" 113 | }, 114 | { 115 | "type": "descriptionMatches", 116 | "pattern": "/^((?!question).)*$/is" 117 | }, 118 | { 119 | "type": "descriptionMatches", 120 | "pattern": "/^((?!Describe the feature).)*$/is" 121 | }, 122 | { 123 | "type": "titleMatches", 124 | "pattern": "/^((?!critical|urgent).)*$/s" 125 | }, 126 | { 127 | "type": "descriptionMatches", 128 | "pattern": "/^((?!critical|urgent).)*$/s" 129 | } 130 | ] 131 | } 132 | }, 133 | "pr": { 134 | "critical": { 135 | "requires": 1, 136 | "conditions": [ 137 | { 138 | "type": "descriptionMatches", 139 | "pattern": "/(critical|urgent)/i" 140 | }, 141 | { 142 | "type": "titleMatches", 143 | "pattern": "/(critical|urgent)/i" 144 | } 145 | ] 146 | }, 147 | "not_tested": { 148 | "requires": 1, 149 | "conditions": [ 150 | { 151 | "type": "descriptionMatches", 152 | "pattern": "/\\[\\] Tested/" 153 | } 154 | ] 155 | }, 156 | "breaking_change": { 157 | "requires": 1, 158 | "conditions": [ 159 | { 160 | "type": "descriptionMatches", 161 | "pattern": "/\\[\\] Non breaking API changes/" 162 | } 163 | ] 164 | } 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: 'v Release $NEXT_PATCH_VERSION 🌈' 2 | tag-template: 'v$NEXT_PATCH_VERSION' 3 | categories: 4 | - title: '🚀 Features' 5 | labels: 6 | - 'feature' 7 | - 'enhancement' 8 | - title: '🐛 Bug Fixes' 9 | labels: 10 | - 'fix' 11 | - 'bugfix' 12 | - 'bug' 13 | - title: '🧰 Maintenance' 14 | label: 'chore' 15 | change-template: '- $TITLE @$AUTHOR (#$NUMBER)' 16 | change-title-escapes: '\<*_&' 17 | template: | 18 | ## Changes 19 | 20 | $CHANGES -------------------------------------------------------------------------------- /.github/workflows/create-release.yml: -------------------------------------------------------------------------------- 1 | name: Create Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*.*.*' 7 | 8 | permissions: 9 | contents: write 10 | pull-requests: read 11 | 12 | jobs: 13 | create_release: 14 | name: Create Release 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v4 20 | 21 | - name: Generate Release Notes and Publish 22 | id: generate_release_notes 23 | uses: release-drafter/release-drafter@v6 24 | with: 25 | config-name: 'release-drafter.yml' 26 | name: "Release ${{ github.ref_name }}" 27 | tag: ${{ github.ref_name }} 28 | publish: true 29 | prerelease: false 30 | env: 31 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 32 | -------------------------------------------------------------------------------- /.github/workflows/golangci-lint.yml: -------------------------------------------------------------------------------- 1 | name: golangci-lint 2 | on: 3 | push: 4 | branches: 5 | - main 6 | - master 7 | pull_request: 8 | 9 | permissions: 10 | contents: read 11 | pull-requests: read 12 | 13 | jobs: 14 | golangci: 15 | name: lint 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v4 19 | - uses: actions/setup-go@v5 20 | with: 21 | go-version: stable 22 | - name: golangci-lint 23 | uses: golangci/golangci-lint-action@v7 24 | with: 25 | version: v2.0 26 | only-new-issues: true 27 | -------------------------------------------------------------------------------- /.github/workflows/invalid_question.yml: -------------------------------------------------------------------------------- 1 | name: "Close invalid questions issues" 2 | on: 3 | schedule: 4 | - cron: "*/10 * * * *" 5 | 6 | permissions: 7 | contents: read 8 | 9 | jobs: 10 | stale: 11 | permissions: 12 | issues: write # for actions/stale to close stale issues 13 | pull-requests: write # for actions/stale to close stale PRs 14 | runs-on: ubuntu-latest 15 | env: 16 | ACTIONS_STEP_DEBUG: true 17 | steps: 18 | - name: Close Stale Issues 19 | uses: actions/stale@v8 20 | with: 21 | repo-token: ${{ secrets.GITHUB_TOKEN }} 22 | stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" 23 | stale-issue-label: "status:stale" 24 | days-before-stale: 0 25 | days-before-close: 30 26 | remove-stale-when-updated: true 27 | only-labels: "type:invalid question" 28 | 29 | -------------------------------------------------------------------------------- /.github/workflows/labeler.yml: -------------------------------------------------------------------------------- 1 | name: "Issue Labeler" 2 | on: 3 | issues: 4 | types: [opened, edited, reopened] 5 | pull_request: 6 | types: [opened, edited, reopened] 7 | 8 | jobs: 9 | triage: 10 | runs-on: ubuntu-latest 11 | name: Label issues and pull requests 12 | steps: 13 | - name: check out 14 | uses: actions/checkout@v4 15 | 16 | - name: labeler 17 | uses: jinzhu/super-labeler-action@develop 18 | with: 19 | GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" 20 | -------------------------------------------------------------------------------- /.github/workflows/missing_playground.yml: -------------------------------------------------------------------------------- 1 | name: "Close Missing Playground issues" 2 | on: 3 | schedule: 4 | - cron: "*/10 * * * *" 5 | 6 | permissions: 7 | contents: read 8 | 9 | jobs: 10 | stale: 11 | permissions: 12 | issues: write # for actions/stale to close stale issues 13 | pull-requests: write # for actions/stale to close stale PRs 14 | runs-on: ubuntu-latest 15 | env: 16 | ACTIONS_STEP_DEBUG: true 17 | steps: 18 | - name: Close Stale Issues 19 | uses: actions/stale@v8 20 | with: 21 | repo-token: ${{ secrets.GITHUB_TOKEN }} 22 | stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" 23 | stale-issue-label: "status:stale" 24 | days-before-stale: 0 25 | days-before-close: 30 26 | remove-stale-when-updated: true 27 | only-labels: "type:missing reproduction steps" 28 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: "Stale" 2 | on: 3 | schedule: 4 | - cron: "0 2 * * *" 5 | 6 | permissions: 7 | contents: read 8 | 9 | jobs: 10 | stale: 11 | permissions: 12 | issues: write # for actions/stale to close stale issues 13 | pull-requests: write # for actions/stale to close stale PRs 14 | runs-on: ubuntu-latest 15 | env: 16 | ACTIONS_STEP_DEBUG: true 17 | steps: 18 | - name: Close Stale Issues 19 | uses: actions/stale@v8 20 | with: 21 | repo-token: ${{ secrets.GITHUB_TOKEN }} 22 | stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" 23 | days-before-stale: 360 24 | days-before-close: 180 25 | stale-issue-label: "status:stale" 26 | exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request' 27 | stale-pr-label: 'status:stale' 28 | exempt-pr-labels: 'type:feature,type:with reproduction steps,type:has pull request' 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | TODO* 2 | documents 3 | coverage.txt 4 | _book 5 | .idea 6 | vendor 7 | .vscode 8 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | linters: 4 | default: standard 5 | enable: 6 | - cyclop 7 | - gocritic 8 | - gosec 9 | - ineffassign 10 | - misspell 11 | - prealloc 12 | - unconvert 13 | - unparam 14 | - whitespace 15 | 16 | formatters: 17 | enable: 18 | - gofumpt 19 | - goimports 20 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to participate in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community includes: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | . 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period. This 91 | includes avoiding interactions in community spaces and external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any interaction or public 101 | communication with the community for a specified period. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2013-present Jinzhu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GORM 2 | 3 | The fantastic ORM library for Golang, aims to be developer friendly. 4 | 5 | [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) 6 | [![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) 7 | [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) 8 | [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) 9 | 10 | ## Overview 11 | 12 | * Full-Featured ORM 13 | * Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance) 14 | * Hooks (Before/After Create/Save/Update/Delete/Find) 15 | * Eager loading with `Preload`, `Joins` 16 | * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point 17 | * Context, Prepared Statement Mode, DryRun Mode 18 | * Batch Insert, FindInBatches, Find To Map 19 | * SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr 20 | * Composite Primary Key 21 | * Auto Migrations 22 | * Logger 23 | * Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus… 24 | * Every feature comes with tests 25 | * Developer Friendly 26 | 27 | ## Getting Started 28 | 29 | * GORM Guides [https://gorm.io](https://gorm.io) 30 | * Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html) 31 | 32 | ## Contributing 33 | 34 | [You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) 35 | 36 | ## Contributors 37 | 38 | [Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework! 39 | 40 | ## License 41 | 42 | © Jinzhu, 2013~time.Now 43 | 44 | Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE) 45 | -------------------------------------------------------------------------------- /callbacks/callbacks.go: -------------------------------------------------------------------------------- 1 | package callbacks 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | ) 6 | 7 | var ( 8 | createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"} 9 | queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"} 10 | updateClauses = []string{"UPDATE", "SET", "WHERE"} 11 | deleteClauses = []string{"DELETE", "FROM", "WHERE"} 12 | ) 13 | 14 | type Config struct { 15 | LastInsertIDReversed bool 16 | CreateClauses []string 17 | QueryClauses []string 18 | UpdateClauses []string 19 | DeleteClauses []string 20 | } 21 | 22 | func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { 23 | enableTransaction := func(db *gorm.DB) bool { 24 | return !db.SkipDefaultTransaction 25 | } 26 | 27 | if len(config.CreateClauses) == 0 { 28 | config.CreateClauses = createClauses 29 | } 30 | if len(config.QueryClauses) == 0 { 31 | config.QueryClauses = queryClauses 32 | } 33 | if len(config.DeleteClauses) == 0 { 34 | config.DeleteClauses = deleteClauses 35 | } 36 | if len(config.UpdateClauses) == 0 { 37 | config.UpdateClauses = updateClauses 38 | } 39 | 40 | createCallback := db.Callback().Create() 41 | createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) 42 | createCallback.Register("gorm:before_create", BeforeCreate) 43 | createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true)) 44 | createCallback.Register("gorm:create", Create(config)) 45 | createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) 46 | createCallback.Register("gorm:after_create", AfterCreate) 47 | createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) 48 | createCallback.Clauses = config.CreateClauses 49 | 50 | queryCallback := db.Callback().Query() 51 | queryCallback.Register("gorm:query", Query) 52 | queryCallback.Register("gorm:preload", Preload) 53 | queryCallback.Register("gorm:after_query", AfterQuery) 54 | queryCallback.Clauses = config.QueryClauses 55 | 56 | deleteCallback := db.Callback().Delete() 57 | deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) 58 | deleteCallback.Register("gorm:before_delete", BeforeDelete) 59 | deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) 60 | deleteCallback.Register("gorm:delete", Delete(config)) 61 | deleteCallback.Register("gorm:after_delete", AfterDelete) 62 | deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) 63 | deleteCallback.Clauses = config.DeleteClauses 64 | 65 | updateCallback := db.Callback().Update() 66 | updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) 67 | updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) 68 | updateCallback.Register("gorm:before_update", BeforeUpdate) 69 | updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) 70 | updateCallback.Register("gorm:update", Update(config)) 71 | updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) 72 | updateCallback.Register("gorm:after_update", AfterUpdate) 73 | updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) 74 | updateCallback.Clauses = config.UpdateClauses 75 | 76 | rowCallback := db.Callback().Row() 77 | rowCallback.Register("gorm:row", RowQuery) 78 | rowCallback.Clauses = config.QueryClauses 79 | 80 | rawCallback := db.Callback().Raw() 81 | rawCallback.Register("gorm:raw", RawExec) 82 | rawCallback.Clauses = config.QueryClauses 83 | } 84 | -------------------------------------------------------------------------------- /callbacks/callmethod.go: -------------------------------------------------------------------------------- 1 | package callbacks 2 | 3 | import ( 4 | "reflect" 5 | 6 | "gorm.io/gorm" 7 | ) 8 | 9 | func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { 10 | tx := db.Session(&gorm.Session{NewDB: true}) 11 | if called := fc(db.Statement.ReflectValue.Interface(), tx); !called { 12 | switch db.Statement.ReflectValue.Kind() { 13 | case reflect.Slice, reflect.Array: 14 | db.Statement.CurDestIndex = 0 15 | for i := 0; i < db.Statement.ReflectValue.Len(); i++ { 16 | if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() { 17 | fc(value.Addr().Interface(), tx) 18 | } else { 19 | db.AddError(gorm.ErrInvalidValue) 20 | return 21 | } 22 | db.Statement.CurDestIndex++ 23 | } 24 | case reflect.Struct: 25 | if db.Statement.ReflectValue.CanAddr() { 26 | fc(db.Statement.ReflectValue.Addr().Interface(), tx) 27 | } else { 28 | db.AddError(gorm.ErrInvalidValue) 29 | } 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /callbacks/create_test.go: -------------------------------------------------------------------------------- 1 | package callbacks 2 | 3 | import ( 4 | "reflect" 5 | "sync" 6 | "testing" 7 | "time" 8 | 9 | "gorm.io/gorm" 10 | "gorm.io/gorm/clause" 11 | "gorm.io/gorm/schema" 12 | ) 13 | 14 | var schemaCache = &sync.Map{} 15 | 16 | func TestConvertToCreateValues_DestType_Slice(t *testing.T) { 17 | type user struct { 18 | ID int `gorm:"primaryKey"` 19 | Name string 20 | Email string `gorm:"default:(-)"` 21 | Age int `gorm:"default:(-)"` 22 | } 23 | 24 | s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{}) 25 | if err != nil { 26 | t.Errorf("parse schema error: %v, is not expected", err) 27 | return 28 | } 29 | dest := []*user{ 30 | { 31 | ID: 1, 32 | Name: "alice", 33 | Email: "email", 34 | Age: 18, 35 | }, 36 | { 37 | ID: 2, 38 | Name: "bob", 39 | Email: "email", 40 | Age: 19, 41 | }, 42 | } 43 | stmt := &gorm.Statement{ 44 | DB: &gorm.DB{ 45 | Config: &gorm.Config{ 46 | NowFunc: func() time.Time { return time.Time{} }, 47 | }, 48 | Statement: &gorm.Statement{ 49 | Settings: sync.Map{}, 50 | Schema: s, 51 | }, 52 | }, 53 | ReflectValue: reflect.ValueOf(dest), 54 | Dest: dest, 55 | } 56 | 57 | stmt.Schema = s 58 | 59 | values := ConvertToCreateValues(stmt) 60 | expected := clause.Values{ 61 | // column has value + defaultValue column has value (which should have a stable order) 62 | Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}}, 63 | Values: [][]interface{}{ 64 | {"alice", "email", 18, 1}, 65 | {"bob", "email", 19, 2}, 66 | }, 67 | } 68 | if !reflect.DeepEqual(expected, values) { 69 | t.Errorf("expected: %v got %v", expected, values) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /callbacks/helper.go: -------------------------------------------------------------------------------- 1 | package callbacks 2 | 3 | import ( 4 | "reflect" 5 | "sort" 6 | 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/clause" 9 | ) 10 | 11 | // ConvertMapToValuesForCreate convert map to values 12 | func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { 13 | values.Columns = make([]clause.Column, 0, len(mapValue)) 14 | selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) 15 | 16 | keys := make([]string, 0, len(mapValue)) 17 | for k := range mapValue { 18 | keys = append(keys, k) 19 | } 20 | sort.Strings(keys) 21 | 22 | for _, k := range keys { 23 | value := mapValue[k] 24 | if stmt.Schema != nil { 25 | if field := stmt.Schema.LookUpField(k); field != nil { 26 | k = field.DBName 27 | } 28 | } 29 | 30 | if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { 31 | values.Columns = append(values.Columns, clause.Column{Name: k}) 32 | if len(values.Values) == 0 { 33 | values.Values = [][]interface{}{{}} 34 | } 35 | 36 | values.Values[0] = append(values.Values[0], value) 37 | } 38 | } 39 | return 40 | } 41 | 42 | // ConvertSliceOfMapToValuesForCreate convert slice of map to values 43 | func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { 44 | columns := make([]string, 0, len(mapValues)) 45 | 46 | // when the length of mapValues is zero,return directly here 47 | // no need to call stmt.SelectAndOmitColumns method 48 | if len(mapValues) == 0 { 49 | stmt.AddError(gorm.ErrEmptySlice) 50 | return 51 | } 52 | 53 | var ( 54 | result = make(map[string][]interface{}, len(mapValues)) 55 | selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) 56 | ) 57 | 58 | for idx, mapValue := range mapValues { 59 | for k, v := range mapValue { 60 | if stmt.Schema != nil { 61 | if field := stmt.Schema.LookUpField(k); field != nil { 62 | k = field.DBName 63 | } 64 | } 65 | 66 | if _, ok := result[k]; !ok { 67 | if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { 68 | result[k] = make([]interface{}, len(mapValues)) 69 | columns = append(columns, k) 70 | } else { 71 | continue 72 | } 73 | } 74 | 75 | result[k][idx] = v 76 | } 77 | } 78 | 79 | sort.Strings(columns) 80 | values.Values = make([][]interface{}, len(mapValues)) 81 | values.Columns = make([]clause.Column, len(columns)) 82 | for idx, column := range columns { 83 | values.Columns[idx] = clause.Column{Name: column} 84 | 85 | for i, v := range result[column] { 86 | if len(values.Values[i]) == 0 { 87 | values.Values[i] = make([]interface{}, len(columns)) 88 | } 89 | 90 | values.Values[i][idx] = v 91 | } 92 | } 93 | return 94 | } 95 | 96 | func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { 97 | if supportReturning { 98 | if c, ok := tx.Statement.Clauses["RETURNING"]; ok { 99 | returning, _ := c.Expression.(clause.Returning) 100 | if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") { 101 | return true, 0 102 | } 103 | return true, gorm.ScanUpdate 104 | } 105 | } 106 | return false, 0 107 | } 108 | 109 | func checkMissingWhereConditions(db *gorm.DB) { 110 | if !db.AllowGlobalUpdate && db.Error == nil { 111 | where, withCondition := db.Statement.Clauses["WHERE"] 112 | if withCondition { 113 | if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete { 114 | whereClause, _ := where.Expression.(clause.Where) 115 | withCondition = len(whereClause.Exprs) > 1 116 | } 117 | } 118 | if !withCondition { 119 | db.AddError(gorm.ErrMissingWhereClause) 120 | } 121 | return 122 | } 123 | } 124 | 125 | type visitMap = map[reflect.Value]bool 126 | 127 | // Check if circular values, return true if loaded 128 | func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) { 129 | if v.Kind() == reflect.Ptr { 130 | v = v.Elem() 131 | } 132 | 133 | switch v.Kind() { 134 | case reflect.Slice, reflect.Array: 135 | loaded = true 136 | for i := 0; i < v.Len(); i++ { 137 | if !loadOrStoreVisitMap(visitMap, v.Index(i)) { 138 | loaded = false 139 | } 140 | } 141 | case reflect.Struct, reflect.Interface: 142 | if v.CanAddr() { 143 | p := v.Addr() 144 | if _, ok := (*visitMap)[p]; ok { 145 | return true 146 | } 147 | (*visitMap)[p] = true 148 | } 149 | } 150 | 151 | return 152 | } 153 | -------------------------------------------------------------------------------- /callbacks/helper_test.go: -------------------------------------------------------------------------------- 1 | package callbacks 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/clause" 9 | ) 10 | 11 | func TestLoadOrStoreVisitMap(t *testing.T) { 12 | var vm visitMap 13 | var loaded bool 14 | type testM struct { 15 | Name string 16 | } 17 | 18 | t1 := testM{Name: "t1"} 19 | t2 := testM{Name: "t2"} 20 | t3 := testM{Name: "t3"} 21 | 22 | vm = make(visitMap) 23 | if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { 24 | t.Fatalf("loaded should be false") 25 | } 26 | 27 | if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { 28 | t.Fatalf("loaded should be true") 29 | } 30 | 31 | // t1 already exist but t2 not 32 | if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { 33 | t.Fatalf("loaded should be false") 34 | } 35 | 36 | if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { 37 | t.Fatalf("loaded should be true") 38 | } 39 | } 40 | 41 | func TestConvertMapToValuesForCreate(t *testing.T) { 42 | testCase := []struct { 43 | name string 44 | input map[string]interface{} 45 | expect clause.Values 46 | }{ 47 | { 48 | name: "Test convert string value", 49 | input: map[string]interface{}{ 50 | "name": "my name", 51 | }, 52 | expect: clause.Values{ 53 | Columns: []clause.Column{{Name: "name"}}, 54 | Values: [][]interface{}{{"my name"}}, 55 | }, 56 | }, 57 | { 58 | name: "Test convert int value", 59 | input: map[string]interface{}{ 60 | "age": 18, 61 | }, 62 | expect: clause.Values{ 63 | Columns: []clause.Column{{Name: "age"}}, 64 | Values: [][]interface{}{{18}}, 65 | }, 66 | }, 67 | { 68 | name: "Test convert float value", 69 | input: map[string]interface{}{ 70 | "score": 99.5, 71 | }, 72 | expect: clause.Values{ 73 | Columns: []clause.Column{{Name: "score"}}, 74 | Values: [][]interface{}{{99.5}}, 75 | }, 76 | }, 77 | { 78 | name: "Test convert bool value", 79 | input: map[string]interface{}{ 80 | "active": true, 81 | }, 82 | expect: clause.Values{ 83 | Columns: []clause.Column{{Name: "active"}}, 84 | Values: [][]interface{}{{true}}, 85 | }, 86 | }, 87 | } 88 | 89 | for _, tc := range testCase { 90 | t.Run(tc.name, func(t *testing.T) { 91 | actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input) 92 | if !reflect.DeepEqual(actual, tc.expect) { 93 | t.Errorf("expect %v got %v", tc.expect, actual) 94 | } 95 | }) 96 | } 97 | } 98 | 99 | func TestConvertSliceOfMapToValuesForCreate(t *testing.T) { 100 | testCase := []struct { 101 | name string 102 | input []map[string]interface{} 103 | expect clause.Values 104 | }{ 105 | { 106 | name: "Test convert slice of string value", 107 | input: []map[string]interface{}{ 108 | {"name": "my name"}, 109 | }, 110 | expect: clause.Values{ 111 | Columns: []clause.Column{{Name: "name"}}, 112 | Values: [][]interface{}{{"my name"}}, 113 | }, 114 | }, 115 | { 116 | name: "Test convert slice of int value", 117 | input: []map[string]interface{}{ 118 | {"age": 18}, 119 | }, 120 | expect: clause.Values{ 121 | Columns: []clause.Column{{Name: "age"}}, 122 | Values: [][]interface{}{{18}}, 123 | }, 124 | }, 125 | { 126 | name: "Test convert slice of float value", 127 | input: []map[string]interface{}{ 128 | {"score": 99.5}, 129 | }, 130 | expect: clause.Values{ 131 | Columns: []clause.Column{{Name: "score"}}, 132 | Values: [][]interface{}{{99.5}}, 133 | }, 134 | }, 135 | { 136 | name: "Test convert slice of bool value", 137 | input: []map[string]interface{}{ 138 | {"active": true}, 139 | }, 140 | expect: clause.Values{ 141 | Columns: []clause.Column{{Name: "active"}}, 142 | Values: [][]interface{}{{true}}, 143 | }, 144 | }, 145 | } 146 | 147 | for _, tc := range testCase { 148 | t.Run(tc.name, func(t *testing.T) { 149 | actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input) 150 | 151 | if !reflect.DeepEqual(actual, tc.expect) { 152 | t.Errorf("expected %v but got %v", tc.expect, actual) 153 | } 154 | }) 155 | } 156 | 157 | } 158 | -------------------------------------------------------------------------------- /callbacks/interfaces.go: -------------------------------------------------------------------------------- 1 | package callbacks 2 | 3 | import "gorm.io/gorm" 4 | 5 | type BeforeCreateInterface interface { 6 | BeforeCreate(*gorm.DB) error 7 | } 8 | 9 | type AfterCreateInterface interface { 10 | AfterCreate(*gorm.DB) error 11 | } 12 | 13 | type BeforeUpdateInterface interface { 14 | BeforeUpdate(*gorm.DB) error 15 | } 16 | 17 | type AfterUpdateInterface interface { 18 | AfterUpdate(*gorm.DB) error 19 | } 20 | 21 | type BeforeSaveInterface interface { 22 | BeforeSave(*gorm.DB) error 23 | } 24 | 25 | type AfterSaveInterface interface { 26 | AfterSave(*gorm.DB) error 27 | } 28 | 29 | type BeforeDeleteInterface interface { 30 | BeforeDelete(*gorm.DB) error 31 | } 32 | 33 | type AfterDeleteInterface interface { 34 | AfterDelete(*gorm.DB) error 35 | } 36 | 37 | type AfterFindInterface interface { 38 | AfterFind(*gorm.DB) error 39 | } 40 | -------------------------------------------------------------------------------- /callbacks/raw.go: -------------------------------------------------------------------------------- 1 | package callbacks 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | ) 6 | 7 | func RawExec(db *gorm.DB) { 8 | if db.Error == nil && !db.DryRun { 9 | result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) 10 | if err != nil { 11 | db.AddError(err) 12 | return 13 | } 14 | 15 | db.RowsAffected, _ = result.RowsAffected() 16 | 17 | if db.Statement.Result != nil { 18 | db.Statement.Result.Result = result 19 | db.Statement.Result.RowsAffected = db.RowsAffected 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /callbacks/row.go: -------------------------------------------------------------------------------- 1 | package callbacks 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | ) 6 | 7 | func RowQuery(db *gorm.DB) { 8 | if db.Error == nil { 9 | BuildQuerySQL(db) 10 | if db.DryRun || db.Error != nil { 11 | return 12 | } 13 | 14 | if isRows, ok := db.Get("rows"); ok && isRows.(bool) { 15 | db.Statement.Settings.Delete("rows") 16 | db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) 17 | } else { 18 | db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) 19 | } 20 | 21 | db.RowsAffected = -1 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /callbacks/transaction.go: -------------------------------------------------------------------------------- 1 | package callbacks 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | ) 6 | 7 | func BeginTransaction(db *gorm.DB) { 8 | if !db.Config.SkipDefaultTransaction && db.Error == nil { 9 | if tx := db.Begin(); tx.Error == nil { 10 | db.Statement.ConnPool = tx.Statement.ConnPool 11 | db.InstanceSet("gorm:started_transaction", true) 12 | } else if tx.Error == gorm.ErrInvalidTransaction { 13 | tx.Error = nil 14 | } else { 15 | db.Error = tx.Error 16 | } 17 | } 18 | } 19 | 20 | func CommitOrRollbackTransaction(db *gorm.DB) { 21 | if !db.Config.SkipDefaultTransaction { 22 | if _, ok := db.InstanceGet("gorm:started_transaction"); ok { 23 | if db.Error != nil { 24 | db.Rollback() 25 | } else { 26 | db.Commit() 27 | } 28 | 29 | db.Statement.ConnPool = db.ConnPool 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /clause/benchmarks_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/clause" 9 | "gorm.io/gorm/schema" 10 | "gorm.io/gorm/utils/tests" 11 | ) 12 | 13 | func BenchmarkSelect(b *testing.B) { 14 | user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) 15 | 16 | for i := 0; i < b.N; i++ { 17 | stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} 18 | clauses := []clause.Interface{clause.Select{}, clause.From{}, clause.Where{Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}}} 19 | 20 | for _, clause := range clauses { 21 | stmt.AddClause(clause) 22 | } 23 | 24 | stmt.Build("SELECT", "FROM", "WHERE") 25 | _ = stmt.SQL.String() 26 | } 27 | } 28 | 29 | func BenchmarkComplexSelect(b *testing.B) { 30 | user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) 31 | 32 | limit10 := 10 33 | for i := 0; i < b.N; i++ { 34 | stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} 35 | clauses := []clause.Interface{ 36 | clause.Select{}, 37 | clause.From{}, 38 | clause.Where{Exprs: []clause.Expression{ 39 | clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, 40 | clause.Gt{Column: "age", Value: 18}, 41 | clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), 42 | }}, 43 | clause.Where{Exprs: []clause.Expression{ 44 | clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), 45 | }}, 46 | clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, 47 | clause.Limit{Limit: &limit10, Offset: 20}, 48 | clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, 49 | } 50 | 51 | for _, clause := range clauses { 52 | stmt.AddClause(clause) 53 | } 54 | 55 | stmt.Build("SELECT", "FROM", "WHERE", "GROUP BY", "LIMIT", "ORDER BY") 56 | _ = stmt.SQL.String() 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /clause/clause.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | // Interface clause interface 4 | type Interface interface { 5 | Name() string 6 | Build(Builder) 7 | MergeClause(*Clause) 8 | } 9 | 10 | // ClauseBuilder clause builder, allows to customize how to build clause 11 | type ClauseBuilder func(Clause, Builder) 12 | 13 | type Writer interface { 14 | WriteByte(byte) error 15 | WriteString(string) (int, error) 16 | } 17 | 18 | // Builder builder interface 19 | type Builder interface { 20 | Writer 21 | WriteQuoted(field interface{}) 22 | AddVar(Writer, ...interface{}) 23 | AddError(error) error 24 | } 25 | 26 | // Clause 27 | type Clause struct { 28 | Name string // WHERE 29 | BeforeExpression Expression 30 | AfterNameExpression Expression 31 | AfterExpression Expression 32 | Expression Expression 33 | Builder ClauseBuilder 34 | } 35 | 36 | // Build build clause 37 | func (c Clause) Build(builder Builder) { 38 | if c.Builder != nil { 39 | c.Builder(c, builder) 40 | } else if c.Expression != nil { 41 | if c.BeforeExpression != nil { 42 | c.BeforeExpression.Build(builder) 43 | builder.WriteByte(' ') 44 | } 45 | 46 | if c.Name != "" { 47 | builder.WriteString(c.Name) 48 | builder.WriteByte(' ') 49 | } 50 | 51 | if c.AfterNameExpression != nil { 52 | c.AfterNameExpression.Build(builder) 53 | builder.WriteByte(' ') 54 | } 55 | 56 | c.Expression.Build(builder) 57 | 58 | if c.AfterExpression != nil { 59 | builder.WriteByte(' ') 60 | c.AfterExpression.Build(builder) 61 | } 62 | } 63 | } 64 | 65 | const ( 66 | PrimaryKey string = "~~~py~~~" // primary key 67 | CurrentTable string = "~~~ct~~~" // current table 68 | Associations string = "~~~as~~~" // associations 69 | ) 70 | 71 | var ( 72 | currentTable = Table{Name: CurrentTable} 73 | PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey} 74 | ) 75 | 76 | // Column quote with name 77 | type Column struct { 78 | Table string 79 | Name string 80 | Alias string 81 | Raw bool 82 | } 83 | 84 | // Table quote with name 85 | type Table struct { 86 | Name string 87 | Alias string 88 | Raw bool 89 | } 90 | -------------------------------------------------------------------------------- /clause/clause_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | "sync" 7 | "testing" 8 | 9 | "gorm.io/gorm" 10 | "gorm.io/gorm/clause" 11 | "gorm.io/gorm/schema" 12 | "gorm.io/gorm/utils/tests" 13 | ) 14 | 15 | var db, _ = gorm.Open(tests.DummyDialector{}, nil) 16 | 17 | func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) { 18 | var ( 19 | buildNames []string 20 | buildNamesMap = map[string]bool{} 21 | user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) 22 | stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} 23 | ) 24 | 25 | for _, c := range clauses { 26 | if _, ok := buildNamesMap[c.Name()]; !ok { 27 | buildNames = append(buildNames, c.Name()) 28 | buildNamesMap[c.Name()] = true 29 | } 30 | 31 | stmt.AddClause(c) 32 | } 33 | 34 | stmt.Build(buildNames...) 35 | 36 | if strings.TrimSpace(stmt.SQL.String()) != result { 37 | t.Errorf("SQL expects %v got %v", result, stmt.SQL.String()) 38 | } 39 | 40 | if !reflect.DeepEqual(stmt.Vars, vars) { 41 | t.Errorf("Vars expects %+v got %v", stmt.Vars, vars) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /clause/delete.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | type Delete struct { 4 | Modifier string 5 | } 6 | 7 | func (d Delete) Name() string { 8 | return "DELETE" 9 | } 10 | 11 | func (d Delete) Build(builder Builder) { 12 | builder.WriteString("DELETE") 13 | 14 | if d.Modifier != "" { 15 | builder.WriteByte(' ') 16 | builder.WriteString(d.Modifier) 17 | } 18 | } 19 | 20 | func (d Delete) MergeClause(clause *Clause) { 21 | clause.Name = "" 22 | clause.Expression = d 23 | } 24 | -------------------------------------------------------------------------------- /clause/delete_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | func TestDelete(t *testing.T) { 11 | results := []struct { 12 | Clauses []clause.Interface 13 | Result string 14 | Vars []interface{} 15 | }{ 16 | { 17 | []clause.Interface{clause.Delete{}, clause.From{}}, 18 | "DELETE FROM `users`", nil, 19 | }, 20 | { 21 | []clause.Interface{clause.Delete{Modifier: "LOW_PRIORITY"}, clause.From{}}, 22 | "DELETE LOW_PRIORITY FROM `users`", nil, 23 | }, 24 | } 25 | 26 | for idx, result := range results { 27 | t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { 28 | checkBuildClauses(t, result.Clauses, result.Result, result.Vars) 29 | }) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /clause/from.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | // From from clause 4 | type From struct { 5 | Tables []Table 6 | Joins []Join 7 | } 8 | 9 | // Name from clause name 10 | func (from From) Name() string { 11 | return "FROM" 12 | } 13 | 14 | // Build build from clause 15 | func (from From) Build(builder Builder) { 16 | if len(from.Tables) > 0 { 17 | for idx, table := range from.Tables { 18 | if idx > 0 { 19 | builder.WriteByte(',') 20 | } 21 | 22 | builder.WriteQuoted(table) 23 | } 24 | } else { 25 | builder.WriteQuoted(currentTable) 26 | } 27 | 28 | for _, join := range from.Joins { 29 | builder.WriteByte(' ') 30 | join.Build(builder) 31 | } 32 | } 33 | 34 | // MergeClause merge from clause 35 | func (from From) MergeClause(clause *Clause) { 36 | clause.Expression = from 37 | } 38 | -------------------------------------------------------------------------------- /clause/from_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | func TestFrom(t *testing.T) { 11 | results := []struct { 12 | Clauses []clause.Interface 13 | Result string 14 | Vars []interface{} 15 | }{ 16 | { 17 | []clause.Interface{clause.Select{}, clause.From{}}, 18 | "SELECT * FROM `users`", nil, 19 | }, 20 | { 21 | []clause.Interface{ 22 | clause.Select{}, clause.From{ 23 | Tables: []clause.Table{{Name: "users"}}, 24 | Joins: []clause.Join{ 25 | { 26 | Type: clause.InnerJoin, 27 | Table: clause.Table{Name: "articles"}, 28 | ON: clause.Where{ 29 | []clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}}, 30 | }, 31 | }, 32 | }, 33 | }, 34 | }, 35 | "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id`", nil, 36 | }, 37 | { 38 | []clause.Interface{ 39 | clause.Select{}, clause.From{ 40 | Tables: []clause.Table{{Name: "users"}}, 41 | Joins: []clause.Join{ 42 | { 43 | Type: clause.RightJoin, 44 | Table: clause.Table{Name: "profiles"}, 45 | ON: clause.Where{ 46 | []clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}}, 47 | }, 48 | }, 49 | }, 50 | }, clause.From{ 51 | Joins: []clause.Join{ 52 | { 53 | Type: clause.InnerJoin, 54 | Table: clause.Table{Name: "articles"}, 55 | ON: clause.Where{ 56 | []clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}}, 57 | }, 58 | }, { 59 | Type: clause.LeftJoin, 60 | Table: clause.Table{Name: "companies"}, 61 | Using: []string{"company_name"}, 62 | }, 63 | }, 64 | }, 65 | }, 66 | "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`)", nil, 67 | }, 68 | } 69 | 70 | for idx, result := range results { 71 | t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { 72 | checkBuildClauses(t, result.Clauses, result.Result, result.Vars) 73 | }) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /clause/group_by.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | // GroupBy group by clause 4 | type GroupBy struct { 5 | Columns []Column 6 | Having []Expression 7 | } 8 | 9 | // Name from clause name 10 | func (groupBy GroupBy) Name() string { 11 | return "GROUP BY" 12 | } 13 | 14 | // Build build group by clause 15 | func (groupBy GroupBy) Build(builder Builder) { 16 | for idx, column := range groupBy.Columns { 17 | if idx > 0 { 18 | builder.WriteByte(',') 19 | } 20 | 21 | builder.WriteQuoted(column) 22 | } 23 | 24 | if len(groupBy.Having) > 0 { 25 | builder.WriteString(" HAVING ") 26 | Where{Exprs: groupBy.Having}.Build(builder) 27 | } 28 | } 29 | 30 | // MergeClause merge group by clause 31 | func (groupBy GroupBy) MergeClause(clause *Clause) { 32 | if v, ok := clause.Expression.(GroupBy); ok { 33 | copiedColumns := make([]Column, len(v.Columns)) 34 | copy(copiedColumns, v.Columns) 35 | groupBy.Columns = append(copiedColumns, groupBy.Columns...) 36 | 37 | copiedHaving := make([]Expression, len(v.Having)) 38 | copy(copiedHaving, v.Having) 39 | groupBy.Having = append(copiedHaving, groupBy.Having...) 40 | } 41 | clause.Expression = groupBy 42 | 43 | if len(groupBy.Columns) == 0 { 44 | clause.Name = "" 45 | } else { 46 | clause.Name = groupBy.Name() 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /clause/group_by_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | func TestGroupBy(t *testing.T) { 11 | results := []struct { 12 | Clauses []clause.Interface 13 | Result string 14 | Vars []interface{} 15 | }{ 16 | { 17 | []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ 18 | Columns: []clause.Column{{Name: "role"}}, 19 | Having: []clause.Expression{clause.Eq{"role", "admin"}}, 20 | }}, 21 | "SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", 22 | []interface{}{"admin"}, 23 | }, 24 | { 25 | []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ 26 | Columns: []clause.Column{{Name: "role"}}, 27 | Having: []clause.Expression{clause.Eq{"role", "admin"}}, 28 | }, clause.GroupBy{ 29 | Columns: []clause.Column{{Name: "gender"}}, 30 | Having: []clause.Expression{clause.Neq{"gender", "U"}}, 31 | }}, 32 | "SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", 33 | []interface{}{"admin", "U"}, 34 | }, 35 | } 36 | 37 | for idx, result := range results { 38 | t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { 39 | checkBuildClauses(t, result.Clauses, result.Result, result.Vars) 40 | }) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /clause/insert.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | type Insert struct { 4 | Table Table 5 | Modifier string 6 | } 7 | 8 | // Name insert clause name 9 | func (insert Insert) Name() string { 10 | return "INSERT" 11 | } 12 | 13 | // Build build insert clause 14 | func (insert Insert) Build(builder Builder) { 15 | if insert.Modifier != "" { 16 | builder.WriteString(insert.Modifier) 17 | builder.WriteByte(' ') 18 | } 19 | 20 | builder.WriteString("INTO ") 21 | if insert.Table.Name == "" { 22 | builder.WriteQuoted(currentTable) 23 | } else { 24 | builder.WriteQuoted(insert.Table) 25 | } 26 | } 27 | 28 | // MergeClause merge insert clause 29 | func (insert Insert) MergeClause(clause *Clause) { 30 | if v, ok := clause.Expression.(Insert); ok { 31 | if insert.Modifier == "" { 32 | insert.Modifier = v.Modifier 33 | } 34 | if insert.Table.Name == "" { 35 | insert.Table = v.Table 36 | } 37 | } 38 | clause.Expression = insert 39 | } 40 | -------------------------------------------------------------------------------- /clause/insert_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | func TestInsert(t *testing.T) { 11 | results := []struct { 12 | Clauses []clause.Interface 13 | Result string 14 | Vars []interface{} 15 | }{ 16 | { 17 | []clause.Interface{clause.Insert{}}, 18 | "INSERT INTO `users`", nil, 19 | }, 20 | { 21 | []clause.Interface{clause.Insert{Modifier: "LOW_PRIORITY"}}, 22 | "INSERT LOW_PRIORITY INTO `users`", nil, 23 | }, 24 | { 25 | []clause.Interface{clause.Insert{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}}, 26 | "INSERT LOW_PRIORITY INTO `products`", nil, 27 | }, 28 | } 29 | 30 | for idx, result := range results { 31 | t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { 32 | checkBuildClauses(t, result.Clauses, result.Result, result.Vars) 33 | }) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /clause/joins.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | import "gorm.io/gorm/utils" 4 | 5 | type JoinType string 6 | 7 | const ( 8 | CrossJoin JoinType = "CROSS" 9 | InnerJoin JoinType = "INNER" 10 | LeftJoin JoinType = "LEFT" 11 | RightJoin JoinType = "RIGHT" 12 | ) 13 | 14 | type JoinTarget struct { 15 | Type JoinType 16 | Association string 17 | Subquery Expression 18 | Table string 19 | } 20 | 21 | func Has(name string) JoinTarget { 22 | return JoinTarget{Type: InnerJoin, Association: name} 23 | } 24 | 25 | func (jt JoinType) Association(name string) JoinTarget { 26 | return JoinTarget{Type: jt, Association: name} 27 | } 28 | 29 | func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget { 30 | return JoinTarget{Type: jt, Association: name, Subquery: subquery} 31 | } 32 | 33 | func (jt JoinTarget) As(name string) JoinTarget { 34 | jt.Table = name 35 | return jt 36 | } 37 | 38 | // Join clause for from 39 | type Join struct { 40 | Type JoinType 41 | Table Table 42 | ON Where 43 | Using []string 44 | Expression Expression 45 | } 46 | 47 | func JoinTable(names ...string) Table { 48 | return Table{ 49 | Name: utils.JoinNestedRelationNames(names), 50 | } 51 | } 52 | 53 | func (join Join) Build(builder Builder) { 54 | if join.Expression != nil { 55 | join.Expression.Build(builder) 56 | } else { 57 | if join.Type != "" { 58 | builder.WriteString(string(join.Type)) 59 | builder.WriteByte(' ') 60 | } 61 | 62 | builder.WriteString("JOIN ") 63 | builder.WriteQuoted(join.Table) 64 | 65 | if len(join.ON.Exprs) > 0 { 66 | builder.WriteString(" ON ") 67 | join.ON.Build(builder) 68 | } else if len(join.Using) > 0 { 69 | builder.WriteString(" USING (") 70 | for idx, c := range join.Using { 71 | if idx > 0 { 72 | builder.WriteByte(',') 73 | } 74 | builder.WriteQuoted(c) 75 | } 76 | builder.WriteByte(')') 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /clause/joins_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/clause" 9 | "gorm.io/gorm/schema" 10 | "gorm.io/gorm/utils/tests" 11 | ) 12 | 13 | func TestJoin(t *testing.T) { 14 | results := []struct { 15 | name string 16 | join clause.Join 17 | sql string 18 | }{ 19 | { 20 | name: "LEFT JOIN", 21 | join: clause.Join{ 22 | Type: clause.LeftJoin, 23 | Table: clause.Table{Name: "user"}, 24 | ON: clause.Where{ 25 | Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, 26 | }, 27 | }, 28 | sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", 29 | }, 30 | { 31 | name: "RIGHT JOIN", 32 | join: clause.Join{ 33 | Type: clause.RightJoin, 34 | Table: clause.Table{Name: "user"}, 35 | ON: clause.Where{ 36 | Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, 37 | }, 38 | }, 39 | sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", 40 | }, 41 | { 42 | name: "INNER JOIN", 43 | join: clause.Join{ 44 | Type: clause.InnerJoin, 45 | Table: clause.Table{Name: "user"}, 46 | ON: clause.Where{ 47 | Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, 48 | }, 49 | }, 50 | sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`", 51 | }, 52 | { 53 | name: "CROSS JOIN", 54 | join: clause.Join{ 55 | Type: clause.CrossJoin, 56 | Table: clause.Table{Name: "user"}, 57 | ON: clause.Where{ 58 | Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, 59 | }, 60 | }, 61 | sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`", 62 | }, 63 | { 64 | name: "USING", 65 | join: clause.Join{ 66 | Type: clause.InnerJoin, 67 | Table: clause.Table{Name: "user"}, 68 | Using: []string{"id"}, 69 | }, 70 | sql: "INNER JOIN `user` USING (`id`)", 71 | }, 72 | { 73 | name: "Expression", 74 | join: clause.Join{ 75 | // Invalid 76 | Type: clause.LeftJoin, 77 | Table: clause.Table{Name: "user"}, 78 | ON: clause.Where{ 79 | Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, 80 | }, 81 | // Valid 82 | Expression: clause.Join{ 83 | Type: clause.InnerJoin, 84 | Table: clause.Table{Name: "user"}, 85 | Using: []string{"id"}, 86 | }, 87 | }, 88 | sql: "INNER JOIN `user` USING (`id`)", 89 | }, 90 | } 91 | for _, result := range results { 92 | t.Run(result.name, func(t *testing.T) { 93 | user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) 94 | stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} 95 | result.join.Build(stmt) 96 | if result.sql != stmt.SQL.String() { 97 | t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String()) 98 | } 99 | }) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /clause/limit.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | // Limit limit clause 4 | type Limit struct { 5 | Limit *int 6 | Offset int 7 | } 8 | 9 | // Name where clause name 10 | func (limit Limit) Name() string { 11 | return "LIMIT" 12 | } 13 | 14 | // Build build where clause 15 | func (limit Limit) Build(builder Builder) { 16 | if limit.Limit != nil && *limit.Limit >= 0 { 17 | builder.WriteString("LIMIT ") 18 | builder.AddVar(builder, *limit.Limit) 19 | } 20 | if limit.Offset > 0 { 21 | if limit.Limit != nil && *limit.Limit >= 0 { 22 | builder.WriteByte(' ') 23 | } 24 | builder.WriteString("OFFSET ") 25 | builder.AddVar(builder, limit.Offset) 26 | } 27 | } 28 | 29 | // MergeClause merge order by clauses 30 | func (limit Limit) MergeClause(clause *Clause) { 31 | clause.Name = "" 32 | 33 | if v, ok := clause.Expression.(Limit); ok { 34 | if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil { 35 | limit.Limit = v.Limit 36 | } 37 | 38 | if limit.Offset == 0 && v.Offset > 0 { 39 | limit.Offset = v.Offset 40 | } else if limit.Offset < 0 { 41 | limit.Offset = 0 42 | } 43 | } 44 | 45 | clause.Expression = limit 46 | } 47 | -------------------------------------------------------------------------------- /clause/limit_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | func TestLimit(t *testing.T) { 11 | limit0 := 0 12 | limit10 := 10 13 | limit50 := 50 14 | limitNeg10 := -10 15 | results := []struct { 16 | Clauses []clause.Interface 17 | Result string 18 | Vars []interface{} 19 | }{ 20 | { 21 | []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ 22 | Limit: &limit10, 23 | Offset: 20, 24 | }}, 25 | "SELECT * FROM `users` LIMIT ? OFFSET ?", 26 | []interface{}{limit10, 20}, 27 | }, 28 | { 29 | []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}}, 30 | "SELECT * FROM `users` LIMIT ?", 31 | []interface{}{limit0}, 32 | }, 33 | { 34 | []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}}, 35 | "SELECT * FROM `users` LIMIT ?", 36 | []interface{}{limit0}, 37 | }, 38 | { 39 | []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, 40 | "SELECT * FROM `users` OFFSET ?", 41 | []interface{}{20}, 42 | }, 43 | { 44 | []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}}, 45 | "SELECT * FROM `users` OFFSET ?", 46 | []interface{}{30}, 47 | }, 48 | { 49 | []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}}, 50 | "SELECT * FROM `users` LIMIT ? OFFSET ?", 51 | []interface{}{limit10, 20}, 52 | }, 53 | { 54 | []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}}, 55 | "SELECT * FROM `users` LIMIT ? OFFSET ?", 56 | []interface{}{limit10, 30}, 57 | }, 58 | { 59 | []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, 60 | "SELECT * FROM `users` LIMIT ?", 61 | []interface{}{limit10}, 62 | }, 63 | { 64 | []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}}, 65 | "SELECT * FROM `users` OFFSET ?", 66 | []interface{}{30}, 67 | }, 68 | { 69 | []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}}, 70 | "SELECT * FROM `users` LIMIT ? OFFSET ?", 71 | []interface{}{limit50, 30}, 72 | }, 73 | } 74 | 75 | for idx, result := range results { 76 | t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { 77 | checkBuildClauses(t, result.Clauses, result.Result, result.Vars) 78 | }) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /clause/locking.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | const ( 4 | LockingStrengthUpdate = "UPDATE" 5 | LockingStrengthShare = "SHARE" 6 | LockingOptionsSkipLocked = "SKIP LOCKED" 7 | LockingOptionsNoWait = "NOWAIT" 8 | ) 9 | 10 | type Locking struct { 11 | Strength string 12 | Table Table 13 | Options string 14 | } 15 | 16 | // Name where clause name 17 | func (locking Locking) Name() string { 18 | return "FOR" 19 | } 20 | 21 | // Build build where clause 22 | func (locking Locking) Build(builder Builder) { 23 | builder.WriteString(locking.Strength) 24 | if locking.Table.Name != "" { 25 | builder.WriteString(" OF ") 26 | builder.WriteQuoted(locking.Table) 27 | } 28 | 29 | if locking.Options != "" { 30 | builder.WriteByte(' ') 31 | builder.WriteString(locking.Options) 32 | } 33 | } 34 | 35 | // MergeClause merge order by clauses 36 | func (locking Locking) MergeClause(clause *Clause) { 37 | clause.Expression = locking 38 | } 39 | -------------------------------------------------------------------------------- /clause/locking_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | func TestLocking(t *testing.T) { 11 | results := []struct { 12 | Clauses []clause.Interface 13 | Result string 14 | Vars []interface{} 15 | }{ 16 | { 17 | []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}}, 18 | "SELECT * FROM `users` FOR UPDATE", nil, 19 | }, 20 | { 21 | []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}}, 22 | "SELECT * FROM `users` FOR SHARE OF `users`", nil, 23 | }, 24 | { 25 | []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}}, 26 | "SELECT * FROM `users` FOR UPDATE NOWAIT", nil, 27 | }, 28 | { 29 | []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}}, 30 | "SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil, 31 | }, 32 | } 33 | 34 | for idx, result := range results { 35 | t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { 36 | checkBuildClauses(t, result.Clauses, result.Result, result.Vars) 37 | }) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /clause/on_conflict.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | type OnConflict struct { 4 | Columns []Column 5 | Where Where 6 | TargetWhere Where 7 | OnConstraint string 8 | DoNothing bool 9 | DoUpdates Set 10 | UpdateAll bool 11 | } 12 | 13 | func (OnConflict) Name() string { 14 | return "ON CONFLICT" 15 | } 16 | 17 | // Build build onConflict clause 18 | func (onConflict OnConflict) Build(builder Builder) { 19 | if onConflict.OnConstraint != "" { 20 | builder.WriteString("ON CONSTRAINT ") 21 | builder.WriteString(onConflict.OnConstraint) 22 | builder.WriteByte(' ') 23 | } else { 24 | if len(onConflict.Columns) > 0 { 25 | builder.WriteByte('(') 26 | for idx, column := range onConflict.Columns { 27 | if idx > 0 { 28 | builder.WriteByte(',') 29 | } 30 | builder.WriteQuoted(column) 31 | } 32 | builder.WriteString(`) `) 33 | } 34 | 35 | if len(onConflict.TargetWhere.Exprs) > 0 { 36 | builder.WriteString(" WHERE ") 37 | onConflict.TargetWhere.Build(builder) 38 | builder.WriteByte(' ') 39 | } 40 | } 41 | 42 | if onConflict.DoNothing { 43 | builder.WriteString("DO NOTHING") 44 | } else { 45 | builder.WriteString("DO UPDATE SET ") 46 | onConflict.DoUpdates.Build(builder) 47 | } 48 | 49 | if len(onConflict.Where.Exprs) > 0 { 50 | builder.WriteString(" WHERE ") 51 | onConflict.Where.Build(builder) 52 | builder.WriteByte(' ') 53 | } 54 | } 55 | 56 | // MergeClause merge onConflict clauses 57 | func (onConflict OnConflict) MergeClause(clause *Clause) { 58 | clause.Expression = onConflict 59 | } 60 | -------------------------------------------------------------------------------- /clause/order_by.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | type OrderByColumn struct { 4 | Column Column 5 | Desc bool 6 | Reorder bool 7 | } 8 | 9 | type OrderBy struct { 10 | Columns []OrderByColumn 11 | Expression Expression 12 | } 13 | 14 | // Name where clause name 15 | func (orderBy OrderBy) Name() string { 16 | return "ORDER BY" 17 | } 18 | 19 | // Build build where clause 20 | func (orderBy OrderBy) Build(builder Builder) { 21 | if orderBy.Expression != nil { 22 | orderBy.Expression.Build(builder) 23 | } else { 24 | for idx, column := range orderBy.Columns { 25 | if idx > 0 { 26 | builder.WriteByte(',') 27 | } 28 | 29 | builder.WriteQuoted(column.Column) 30 | if column.Desc { 31 | builder.WriteString(" DESC") 32 | } 33 | } 34 | } 35 | } 36 | 37 | // MergeClause merge order by clauses 38 | func (orderBy OrderBy) MergeClause(clause *Clause) { 39 | if v, ok := clause.Expression.(OrderBy); ok { 40 | for i := len(orderBy.Columns) - 1; i >= 0; i-- { 41 | if orderBy.Columns[i].Reorder { 42 | orderBy.Columns = orderBy.Columns[i:] 43 | clause.Expression = orderBy 44 | return 45 | } 46 | } 47 | 48 | copiedColumns := make([]OrderByColumn, len(v.Columns)) 49 | copy(copiedColumns, v.Columns) 50 | orderBy.Columns = append(copiedColumns, orderBy.Columns...) 51 | } 52 | 53 | clause.Expression = orderBy 54 | } 55 | -------------------------------------------------------------------------------- /clause/order_by_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | func TestOrderBy(t *testing.T) { 11 | results := []struct { 12 | Clauses []clause.Interface 13 | Result string 14 | Vars []interface{} 15 | }{ 16 | { 17 | []clause.Interface{clause.Select{}, clause.From{}, clause.OrderBy{ 18 | Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, 19 | }}, 20 | "SELECT * FROM `users` ORDER BY `users`.`id` DESC", nil, 21 | }, 22 | { 23 | []clause.Interface{ 24 | clause.Select{}, clause.From{}, clause.OrderBy{ 25 | Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, 26 | }, clause.OrderBy{ 27 | Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}}}, 28 | }, 29 | }, 30 | "SELECT * FROM `users` ORDER BY `users`.`id` DESC,`name`", nil, 31 | }, 32 | { 33 | []clause.Interface{ 34 | clause.Select{}, clause.From{}, clause.OrderBy{ 35 | Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, 36 | }, clause.OrderBy{ 37 | Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}, Reorder: true}}, 38 | }, 39 | }, 40 | "SELECT * FROM `users` ORDER BY `name`", nil, 41 | }, 42 | { 43 | []clause.Interface{ 44 | clause.Select{}, clause.From{}, clause.OrderBy{ 45 | Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, 46 | }, 47 | }, 48 | "SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", 49 | []interface{}{1, 2, 3}, 50 | }, 51 | } 52 | 53 | for idx, result := range results { 54 | t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { 55 | checkBuildClauses(t, result.Clauses, result.Result, result.Vars) 56 | }) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /clause/returning.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | type Returning struct { 4 | Columns []Column 5 | } 6 | 7 | // Name where clause name 8 | func (returning Returning) Name() string { 9 | return "RETURNING" 10 | } 11 | 12 | // Build build where clause 13 | func (returning Returning) Build(builder Builder) { 14 | if len(returning.Columns) > 0 { 15 | for idx, column := range returning.Columns { 16 | if idx > 0 { 17 | builder.WriteByte(',') 18 | } 19 | 20 | builder.WriteQuoted(column) 21 | } 22 | } else { 23 | builder.WriteByte('*') 24 | } 25 | } 26 | 27 | // MergeClause merge order by clauses 28 | func (returning Returning) MergeClause(clause *Clause) { 29 | if v, ok := clause.Expression.(Returning); ok && len(returning.Columns) > 0 { 30 | if v.Columns != nil { 31 | returning.Columns = append(v.Columns, returning.Columns...) 32 | } else { 33 | returning.Columns = nil 34 | } 35 | } 36 | clause.Expression = returning 37 | } 38 | -------------------------------------------------------------------------------- /clause/returning_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | func TestReturning(t *testing.T) { 11 | results := []struct { 12 | Clauses []clause.Interface 13 | Result string 14 | Vars []interface{} 15 | }{ 16 | { 17 | []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ 18 | []clause.Column{clause.PrimaryColumn}, 19 | }}, 20 | "SELECT * FROM `users` RETURNING `users`.`id`", nil, 21 | }, { 22 | []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ 23 | []clause.Column{clause.PrimaryColumn}, 24 | }, clause.Returning{ 25 | []clause.Column{{Name: "name"}, {Name: "age"}}, 26 | }}, 27 | "SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil, 28 | }, 29 | { 30 | []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ 31 | []clause.Column{clause.PrimaryColumn}, 32 | }, clause.Returning{}, clause.Returning{ 33 | []clause.Column{{Name: "name"}, {Name: "age"}}, 34 | }}, 35 | "SELECT * FROM `users` RETURNING *", nil, 36 | }, 37 | { 38 | []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ 39 | []clause.Column{clause.PrimaryColumn}, 40 | }, clause.Returning{ 41 | []clause.Column{{Name: "name"}, {Name: "age"}}, 42 | }, clause.Returning{}}, 43 | "SELECT * FROM `users` RETURNING *", nil, 44 | }, 45 | } 46 | 47 | for idx, result := range results { 48 | t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { 49 | checkBuildClauses(t, result.Clauses, result.Result, result.Vars) 50 | }) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /clause/select.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | // Select select attrs when querying, updating, creating 4 | type Select struct { 5 | Distinct bool 6 | Columns []Column 7 | Expression Expression 8 | } 9 | 10 | func (s Select) Name() string { 11 | return "SELECT" 12 | } 13 | 14 | func (s Select) Build(builder Builder) { 15 | if len(s.Columns) > 0 { 16 | if s.Distinct { 17 | builder.WriteString("DISTINCT ") 18 | } 19 | 20 | for idx, column := range s.Columns { 21 | if idx > 0 { 22 | builder.WriteByte(',') 23 | } 24 | builder.WriteQuoted(column) 25 | } 26 | } else { 27 | builder.WriteByte('*') 28 | } 29 | } 30 | 31 | func (s Select) MergeClause(clause *Clause) { 32 | if s.Expression != nil { 33 | if s.Distinct { 34 | if expr, ok := s.Expression.(Expr); ok { 35 | expr.SQL = "DISTINCT " + expr.SQL 36 | clause.Expression = expr 37 | return 38 | } 39 | } 40 | 41 | clause.Expression = s.Expression 42 | } else { 43 | clause.Expression = s 44 | } 45 | } 46 | 47 | // CommaExpression represents a group of expressions separated by commas. 48 | type CommaExpression struct { 49 | Exprs []Expression 50 | } 51 | 52 | func (comma CommaExpression) Build(builder Builder) { 53 | for idx, expr := range comma.Exprs { 54 | if idx > 0 { 55 | _, _ = builder.WriteString(", ") 56 | } 57 | expr.Build(builder) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /clause/select_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | func TestSelect(t *testing.T) { 11 | results := []struct { 12 | Clauses []clause.Interface 13 | Result string 14 | Vars []interface{} 15 | }{ 16 | { 17 | []clause.Interface{clause.Select{}, clause.From{}}, 18 | "SELECT * FROM `users`", nil, 19 | }, 20 | { 21 | []clause.Interface{clause.Select{ 22 | Columns: []clause.Column{clause.PrimaryColumn}, 23 | }, clause.From{}}, 24 | "SELECT `users`.`id` FROM `users`", nil, 25 | }, 26 | { 27 | []clause.Interface{clause.Select{ 28 | Columns: []clause.Column{clause.PrimaryColumn}, 29 | }, clause.Select{ 30 | Columns: []clause.Column{{Name: "name"}}, 31 | }, clause.From{}}, 32 | "SELECT `name` FROM `users`", nil, 33 | }, 34 | { 35 | []clause.Interface{clause.Select{ 36 | Expression: clause.CommaExpression{ 37 | Exprs: []clause.Expression{ 38 | clause.NamedExpr{"?", []interface{}{clause.Column{Name: "id"}}}, 39 | clause.NamedExpr{"?", []interface{}{clause.Column{Name: "name"}}}, 40 | clause.NamedExpr{"LENGTH(?)", []interface{}{clause.Column{Name: "mobile"}}}, 41 | }, 42 | }, 43 | }, clause.From{}}, 44 | "SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil, 45 | }, 46 | { 47 | []clause.Interface{clause.Select{ 48 | Expression: clause.CommaExpression{ 49 | Exprs: []clause.Expression{ 50 | clause.Expr{ 51 | SQL: "? as name", 52 | Vars: []interface{}{ 53 | clause.Eq{ 54 | Column: clause.Column{Name: "age"}, 55 | Value: 18, 56 | }, 57 | }, 58 | }, 59 | }, 60 | }, 61 | }, clause.From{}}, 62 | "SELECT `age` = ? as name FROM `users`", 63 | []interface{}{18}, 64 | }, 65 | } 66 | 67 | for idx, result := range results { 68 | t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { 69 | checkBuildClauses(t, result.Clauses, result.Result, result.Vars) 70 | }) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /clause/set.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | import "sort" 4 | 5 | type Set []Assignment 6 | 7 | type Assignment struct { 8 | Column Column 9 | Value interface{} 10 | } 11 | 12 | func (set Set) Name() string { 13 | return "SET" 14 | } 15 | 16 | func (set Set) Build(builder Builder) { 17 | if len(set) > 0 { 18 | for idx, assignment := range set { 19 | if idx > 0 { 20 | builder.WriteByte(',') 21 | } 22 | builder.WriteQuoted(assignment.Column) 23 | builder.WriteByte('=') 24 | builder.AddVar(builder, assignment.Value) 25 | } 26 | } else { 27 | builder.WriteQuoted(Column{Name: PrimaryKey}) 28 | builder.WriteByte('=') 29 | builder.WriteQuoted(Column{Name: PrimaryKey}) 30 | } 31 | } 32 | 33 | // MergeClause merge assignments clauses 34 | func (set Set) MergeClause(clause *Clause) { 35 | copiedAssignments := make([]Assignment, len(set)) 36 | copy(copiedAssignments, set) 37 | clause.Expression = Set(copiedAssignments) 38 | } 39 | 40 | func Assignments(values map[string]interface{}) Set { 41 | keys := make([]string, 0, len(values)) 42 | for key := range values { 43 | keys = append(keys, key) 44 | } 45 | sort.Strings(keys) 46 | 47 | assignments := make([]Assignment, len(keys)) 48 | for idx, key := range keys { 49 | assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]} 50 | } 51 | return assignments 52 | } 53 | 54 | func AssignmentColumns(values []string) Set { 55 | assignments := make([]Assignment, len(values)) 56 | for idx, value := range values { 57 | assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}} 58 | } 59 | return assignments 60 | } 61 | -------------------------------------------------------------------------------- /clause/set_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "strings" 7 | "testing" 8 | 9 | "gorm.io/gorm/clause" 10 | ) 11 | 12 | func TestSet(t *testing.T) { 13 | results := []struct { 14 | Clauses []clause.Interface 15 | Result string 16 | Vars []interface{} 17 | }{ 18 | { 19 | []clause.Interface{ 20 | clause.Update{}, 21 | clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), 22 | }, 23 | "UPDATE `users` SET `users`.`id`=?", 24 | []interface{}{1}, 25 | }, 26 | { 27 | []clause.Interface{ 28 | clause.Update{}, 29 | clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), 30 | clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}), 31 | }, 32 | "UPDATE `users` SET `name`=?", 33 | []interface{}{"jinzhu"}, 34 | }, 35 | } 36 | 37 | for idx, result := range results { 38 | t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { 39 | checkBuildClauses(t, result.Clauses, result.Result, result.Vars) 40 | }) 41 | } 42 | } 43 | 44 | func TestAssignments(t *testing.T) { 45 | set := clause.Assignments(map[string]interface{}{ 46 | "name": "jinzhu", 47 | "age": 18, 48 | }) 49 | 50 | assignments := []clause.Assignment(set) 51 | 52 | sort.Slice(assignments, func(i, j int) bool { 53 | return strings.Compare(assignments[i].Column.Name, assignments[j].Column.Name) > 0 54 | }) 55 | 56 | if len(assignments) != 2 || assignments[0].Column.Name != "name" || assignments[0].Value.(string) != "jinzhu" || assignments[1].Column.Name != "age" || assignments[1].Value.(int) != 18 { 57 | t.Errorf("invalid assignments, got %v", assignments) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /clause/update.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | type Update struct { 4 | Modifier string 5 | Table Table 6 | } 7 | 8 | // Name update clause name 9 | func (update Update) Name() string { 10 | return "UPDATE" 11 | } 12 | 13 | // Build build update clause 14 | func (update Update) Build(builder Builder) { 15 | if update.Modifier != "" { 16 | builder.WriteString(update.Modifier) 17 | builder.WriteByte(' ') 18 | } 19 | 20 | if update.Table.Name == "" { 21 | builder.WriteQuoted(currentTable) 22 | } else { 23 | builder.WriteQuoted(update.Table) 24 | } 25 | } 26 | 27 | // MergeClause merge update clause 28 | func (update Update) MergeClause(clause *Clause) { 29 | if v, ok := clause.Expression.(Update); ok { 30 | if update.Modifier == "" { 31 | update.Modifier = v.Modifier 32 | } 33 | if update.Table.Name == "" { 34 | update.Table = v.Table 35 | } 36 | } 37 | clause.Expression = update 38 | } 39 | -------------------------------------------------------------------------------- /clause/update_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | func TestUpdate(t *testing.T) { 11 | results := []struct { 12 | Clauses []clause.Interface 13 | Result string 14 | Vars []interface{} 15 | }{ 16 | { 17 | []clause.Interface{clause.Update{}}, 18 | "UPDATE `users`", nil, 19 | }, 20 | { 21 | []clause.Interface{clause.Update{Modifier: "LOW_PRIORITY"}}, 22 | "UPDATE LOW_PRIORITY `users`", nil, 23 | }, 24 | { 25 | []clause.Interface{clause.Update{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}}, 26 | "UPDATE LOW_PRIORITY `products`", nil, 27 | }, 28 | } 29 | 30 | for idx, result := range results { 31 | t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { 32 | checkBuildClauses(t, result.Clauses, result.Result, result.Vars) 33 | }) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /clause/values.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | type Values struct { 4 | Columns []Column 5 | Values [][]interface{} 6 | } 7 | 8 | // Name from clause name 9 | func (Values) Name() string { 10 | return "VALUES" 11 | } 12 | 13 | // Build build from clause 14 | func (values Values) Build(builder Builder) { 15 | if len(values.Columns) > 0 { 16 | builder.WriteByte('(') 17 | for idx, column := range values.Columns { 18 | if idx > 0 { 19 | builder.WriteByte(',') 20 | } 21 | builder.WriteQuoted(column) 22 | } 23 | builder.WriteByte(')') 24 | 25 | builder.WriteString(" VALUES ") 26 | 27 | for idx, value := range values.Values { 28 | if idx > 0 { 29 | builder.WriteByte(',') 30 | } 31 | 32 | builder.WriteByte('(') 33 | builder.AddVar(builder, value...) 34 | builder.WriteByte(')') 35 | } 36 | } else { 37 | builder.WriteString("DEFAULT VALUES") 38 | } 39 | } 40 | 41 | // MergeClause merge values clauses 42 | func (values Values) MergeClause(clause *Clause) { 43 | clause.Name = "" 44 | clause.Expression = values 45 | } 46 | -------------------------------------------------------------------------------- /clause/values_test.go: -------------------------------------------------------------------------------- 1 | package clause_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | func TestValues(t *testing.T) { 11 | results := []struct { 12 | Clauses []clause.Interface 13 | Result string 14 | Vars []interface{} 15 | }{ 16 | { 17 | []clause.Interface{ 18 | clause.Insert{}, 19 | clause.Values{ 20 | Columns: []clause.Column{{Name: "name"}, {Name: "age"}}, 21 | Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}}, 22 | }, 23 | }, 24 | "INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", 25 | []interface{}{"jinzhu", 18, "josh", 1}, 26 | }, 27 | } 28 | 29 | for idx, result := range results { 30 | t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { 31 | checkBuildClauses(t, result.Clauses, result.Result, result.Vars) 32 | }) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /clause/with.go: -------------------------------------------------------------------------------- 1 | package clause 2 | 3 | type With struct{} 4 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | "errors" 5 | 6 | "gorm.io/gorm/logger" 7 | ) 8 | 9 | var ( 10 | // ErrRecordNotFound record not found error 11 | ErrRecordNotFound = logger.ErrRecordNotFound 12 | // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` 13 | ErrInvalidTransaction = errors.New("invalid transaction") 14 | // ErrNotImplemented not implemented 15 | ErrNotImplemented = errors.New("not implemented") 16 | // ErrMissingWhereClause missing where clause 17 | ErrMissingWhereClause = errors.New("WHERE conditions required") 18 | // ErrUnsupportedRelation unsupported relations 19 | ErrUnsupportedRelation = errors.New("unsupported relations") 20 | // ErrPrimaryKeyRequired primary keys required 21 | ErrPrimaryKeyRequired = errors.New("primary key required") 22 | // ErrModelValueRequired model value required 23 | ErrModelValueRequired = errors.New("model value required") 24 | // ErrModelAccessibleFieldsRequired model accessible fields required 25 | ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required") 26 | // ErrSubQueryRequired sub query required 27 | ErrSubQueryRequired = errors.New("sub query required") 28 | // ErrInvalidData unsupported data 29 | ErrInvalidData = errors.New("unsupported data") 30 | // ErrUnsupportedDriver unsupported driver 31 | ErrUnsupportedDriver = errors.New("unsupported driver") 32 | // ErrRegistered registered 33 | ErrRegistered = errors.New("registered") 34 | // ErrInvalidField invalid field 35 | ErrInvalidField = errors.New("invalid field") 36 | // ErrEmptySlice empty slice found 37 | ErrEmptySlice = errors.New("empty slice found") 38 | // ErrDryRunModeUnsupported dry run mode unsupported 39 | ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") 40 | // ErrInvalidDB invalid db 41 | ErrInvalidDB = errors.New("invalid db") 42 | // ErrInvalidValue invalid value 43 | ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice") 44 | // ErrInvalidValueOfLength invalid values do not match length 45 | ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") 46 | // ErrPreloadNotAllowed preload is not allowed when count is used 47 | ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") 48 | // ErrDuplicatedKey occurs when there is a unique key constraint violation 49 | ErrDuplicatedKey = errors.New("duplicated key not allowed") 50 | // ErrForeignKeyViolated occurs when there is a foreign key constraint violation 51 | ErrForeignKeyViolated = errors.New("violates foreign key constraint") 52 | // ErrCheckConstraintViolated occurs when there is a check constraint violation 53 | ErrCheckConstraintViolated = errors.New("violates check constraint") 54 | ) 55 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module gorm.io/gorm 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/jinzhu/inflection v1.0.0 7 | github.com/jinzhu/now v1.1.5 8 | golang.org/x/text v0.20.0 9 | ) 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= 2 | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= 3 | github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= 4 | github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= 5 | golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= 6 | golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= 7 | -------------------------------------------------------------------------------- /interfaces.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "gorm.io/gorm/clause" 8 | "gorm.io/gorm/schema" 9 | ) 10 | 11 | // Dialector GORM database dialector 12 | type Dialector interface { 13 | Name() string 14 | Initialize(*DB) error 15 | Migrator(db *DB) Migrator 16 | DataTypeOf(*schema.Field) string 17 | DefaultValueOf(*schema.Field) clause.Expression 18 | BindVarTo(writer clause.Writer, stmt *Statement, v interface{}) 19 | QuoteTo(clause.Writer, string) 20 | Explain(sql string, vars ...interface{}) string 21 | } 22 | 23 | // Plugin GORM plugin interface 24 | type Plugin interface { 25 | Name() string 26 | Initialize(*DB) error 27 | } 28 | 29 | type ParamsFilter interface { 30 | ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) 31 | } 32 | 33 | // ConnPool db conns pool interface 34 | type ConnPool interface { 35 | PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) 36 | ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 37 | QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 38 | QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 39 | } 40 | 41 | // SavePointerDialectorInterface save pointer interface 42 | type SavePointerDialectorInterface interface { 43 | SavePoint(tx *DB, name string) error 44 | RollbackTo(tx *DB, name string) error 45 | } 46 | 47 | // TxBeginner tx beginner 48 | type TxBeginner interface { 49 | BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) 50 | } 51 | 52 | // ConnPoolBeginner conn pool beginner 53 | type ConnPoolBeginner interface { 54 | BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) 55 | } 56 | 57 | // TxCommitter tx committer 58 | type TxCommitter interface { 59 | Commit() error 60 | Rollback() error 61 | } 62 | 63 | // Tx sql.Tx interface 64 | type Tx interface { 65 | ConnPool 66 | TxCommitter 67 | StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt 68 | } 69 | 70 | // Valuer gorm valuer interface 71 | type Valuer interface { 72 | GormValue(context.Context, *DB) clause.Expr 73 | } 74 | 75 | // GetDBConnector SQL db connector 76 | type GetDBConnector interface { 77 | GetDBConn() (*sql.DB, error) 78 | } 79 | 80 | // Rows rows interface 81 | type Rows interface { 82 | Columns() ([]string, error) 83 | ColumnTypes() ([]*sql.ColumnType, error) 84 | Next() bool 85 | Scan(dest ...interface{}) error 86 | Err() error 87 | Close() error 88 | } 89 | 90 | type ErrorTranslator interface { 91 | Translate(err error) error 92 | } 93 | -------------------------------------------------------------------------------- /logger/sql.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "reflect" 7 | "regexp" 8 | "strconv" 9 | "strings" 10 | "time" 11 | "unicode" 12 | 13 | "gorm.io/gorm/utils" 14 | ) 15 | 16 | const ( 17 | tmFmtWithMS = "2006-01-02 15:04:05.999" 18 | tmFmtZero = "0000-00-00 00:00:00" 19 | nullStr = "NULL" 20 | ) 21 | 22 | func isPrintable(s string) bool { 23 | for _, r := range s { 24 | if !unicode.IsPrint(r) { 25 | return false 26 | } 27 | } 28 | return true 29 | } 30 | 31 | // A list of Go types that should be converted to SQL primitives 32 | var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} 33 | 34 | // RegEx matches only numeric values 35 | var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) 36 | 37 | func isNumeric(k reflect.Kind) bool { 38 | switch k { 39 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 40 | return true 41 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 42 | return true 43 | case reflect.Float32, reflect.Float64: 44 | return true 45 | default: 46 | return false 47 | } 48 | } 49 | 50 | // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability 51 | func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { 52 | var ( 53 | convertParams func(interface{}, int) 54 | vars = make([]string, len(avars)) 55 | ) 56 | 57 | convertParams = func(v interface{}, idx int) { 58 | switch v := v.(type) { 59 | case bool: 60 | vars[idx] = strconv.FormatBool(v) 61 | case time.Time: 62 | if v.IsZero() { 63 | vars[idx] = escaper + tmFmtZero + escaper 64 | } else { 65 | vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper 66 | } 67 | case *time.Time: 68 | if v != nil { 69 | if v.IsZero() { 70 | vars[idx] = escaper + tmFmtZero + escaper 71 | } else { 72 | vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper 73 | } 74 | } else { 75 | vars[idx] = nullStr 76 | } 77 | case driver.Valuer: 78 | reflectValue := reflect.ValueOf(v) 79 | if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { 80 | r, _ := v.Value() 81 | convertParams(r, idx) 82 | } else { 83 | vars[idx] = nullStr 84 | } 85 | case fmt.Stringer: 86 | reflectValue := reflect.ValueOf(v) 87 | switch reflectValue.Kind() { 88 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 89 | vars[idx] = fmt.Sprintf("%d", reflectValue.Interface()) 90 | case reflect.Float32, reflect.Float64: 91 | vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface()) 92 | case reflect.Bool: 93 | vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) 94 | case reflect.String: 95 | vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper 96 | default: 97 | if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { 98 | vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper 99 | } else { 100 | vars[idx] = nullStr 101 | } 102 | } 103 | case []byte: 104 | if s := string(v); isPrintable(s) { 105 | vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper 106 | } else { 107 | vars[idx] = escaper + "" + escaper 108 | } 109 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: 110 | vars[idx] = utils.ToString(v) 111 | case float32: 112 | vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32) 113 | case float64: 114 | vars[idx] = strconv.FormatFloat(v, 'f', -1, 64) 115 | case string: 116 | vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper 117 | default: 118 | rv := reflect.ValueOf(v) 119 | if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { 120 | vars[idx] = nullStr 121 | } else if valuer, ok := v.(driver.Valuer); ok { 122 | v, _ = valuer.Value() 123 | convertParams(v, idx) 124 | } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { 125 | convertParams(reflect.Indirect(rv).Interface(), idx) 126 | } else if isNumeric(rv.Kind()) { 127 | if rv.CanInt() || rv.CanUint() { 128 | vars[idx] = fmt.Sprintf("%d", rv.Interface()) 129 | } else { 130 | vars[idx] = fmt.Sprintf("%.6f", rv.Interface()) 131 | } 132 | } else { 133 | for _, t := range convertibleTypes { 134 | if rv.Type().ConvertibleTo(t) { 135 | convertParams(rv.Convert(t).Interface(), idx) 136 | return 137 | } 138 | } 139 | vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper 140 | } 141 | } 142 | } 143 | 144 | for idx, v := range avars { 145 | convertParams(v, idx) 146 | } 147 | 148 | if numericPlaceholder == nil { 149 | var idx int 150 | var newSQL strings.Builder 151 | 152 | for _, v := range []byte(sql) { 153 | if v == '?' { 154 | if len(vars) > idx { 155 | newSQL.WriteString(vars[idx]) 156 | idx++ 157 | continue 158 | } 159 | } 160 | newSQL.WriteByte(v) 161 | } 162 | 163 | sql = newSQL.String() 164 | } else { 165 | sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") 166 | 167 | sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string { 168 | num := v[1 : len(v)-1] 169 | n, _ := strconv.Atoi(num) 170 | 171 | // position var start from 1 ($1, $2) 172 | n -= 1 173 | if n >= 0 && n <= len(vars)-1 { 174 | return vars[n] 175 | } 176 | return v 177 | }) 178 | } 179 | 180 | return sql 181 | } 182 | -------------------------------------------------------------------------------- /migrator.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | "reflect" 5 | 6 | "gorm.io/gorm/clause" 7 | "gorm.io/gorm/schema" 8 | ) 9 | 10 | // Migrator returns migrator 11 | func (db *DB) Migrator() Migrator { 12 | tx := db.getInstance() 13 | 14 | // apply scopes to migrator 15 | for len(tx.Statement.scopes) > 0 { 16 | tx = tx.executeScopes() 17 | } 18 | 19 | return tx.Dialector.Migrator(tx.Session(&Session{})) 20 | } 21 | 22 | // AutoMigrate run auto migration for given models 23 | func (db *DB) AutoMigrate(dst ...interface{}) error { 24 | return db.Migrator().AutoMigrate(dst...) 25 | } 26 | 27 | // ViewOption view option 28 | type ViewOption struct { 29 | Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE` 30 | CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION` 31 | Query *DB // required subquery. 32 | } 33 | 34 | // ColumnType column type interface 35 | type ColumnType interface { 36 | Name() string 37 | DatabaseTypeName() string // varchar 38 | ColumnType() (columnType string, ok bool) // varchar(64) 39 | PrimaryKey() (isPrimaryKey bool, ok bool) 40 | AutoIncrement() (isAutoIncrement bool, ok bool) 41 | Length() (length int64, ok bool) 42 | DecimalSize() (precision int64, scale int64, ok bool) 43 | Nullable() (nullable bool, ok bool) 44 | Unique() (unique bool, ok bool) 45 | ScanType() reflect.Type 46 | Comment() (value string, ok bool) 47 | DefaultValue() (value string, ok bool) 48 | } 49 | 50 | type Index interface { 51 | Table() string 52 | Name() string 53 | Columns() []string 54 | PrimaryKey() (isPrimaryKey bool, ok bool) 55 | Unique() (unique bool, ok bool) 56 | Option() string 57 | } 58 | 59 | // TableType table type interface 60 | type TableType interface { 61 | Schema() string 62 | Name() string 63 | Type() string 64 | Comment() (comment string, ok bool) 65 | } 66 | 67 | // Migrator migrator interface 68 | type Migrator interface { 69 | // AutoMigrate 70 | AutoMigrate(dst ...interface{}) error 71 | 72 | // Database 73 | CurrentDatabase() string 74 | FullDataTypeOf(*schema.Field) clause.Expr 75 | GetTypeAliases(databaseTypeName string) []string 76 | 77 | // Tables 78 | CreateTable(dst ...interface{}) error 79 | DropTable(dst ...interface{}) error 80 | HasTable(dst interface{}) bool 81 | RenameTable(oldName, newName interface{}) error 82 | GetTables() (tableList []string, err error) 83 | TableType(dst interface{}) (TableType, error) 84 | 85 | // Columns 86 | AddColumn(dst interface{}, field string) error 87 | DropColumn(dst interface{}, field string) error 88 | AlterColumn(dst interface{}, field string) error 89 | MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error 90 | // MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn. 91 | MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error 92 | HasColumn(dst interface{}, field string) bool 93 | RenameColumn(dst interface{}, oldName, field string) error 94 | ColumnTypes(dst interface{}) ([]ColumnType, error) 95 | 96 | // Views 97 | CreateView(name string, option ViewOption) error 98 | DropView(name string) error 99 | 100 | // Constraints 101 | CreateConstraint(dst interface{}, name string) error 102 | DropConstraint(dst interface{}, name string) error 103 | HasConstraint(dst interface{}, name string) bool 104 | 105 | // Indexes 106 | CreateIndex(dst interface{}, name string) error 107 | DropIndex(dst interface{}, name string) error 108 | HasIndex(dst interface{}, name string) bool 109 | RenameIndex(dst interface{}, oldName, newName string) error 110 | GetIndexes(dst interface{}) ([]Index, error) 111 | } 112 | -------------------------------------------------------------------------------- /migrator/column_type.go: -------------------------------------------------------------------------------- 1 | package migrator 2 | 3 | import ( 4 | "database/sql" 5 | "reflect" 6 | ) 7 | 8 | // ColumnType column type implements ColumnType interface 9 | type ColumnType struct { 10 | SQLColumnType *sql.ColumnType 11 | NameValue sql.NullString 12 | DataTypeValue sql.NullString 13 | ColumnTypeValue sql.NullString 14 | PrimaryKeyValue sql.NullBool 15 | UniqueValue sql.NullBool 16 | AutoIncrementValue sql.NullBool 17 | LengthValue sql.NullInt64 18 | DecimalSizeValue sql.NullInt64 19 | ScaleValue sql.NullInt64 20 | NullableValue sql.NullBool 21 | ScanTypeValue reflect.Type 22 | CommentValue sql.NullString 23 | DefaultValueValue sql.NullString 24 | } 25 | 26 | // Name returns the name or alias of the column. 27 | func (ct ColumnType) Name() string { 28 | if ct.NameValue.Valid { 29 | return ct.NameValue.String 30 | } 31 | return ct.SQLColumnType.Name() 32 | } 33 | 34 | // DatabaseTypeName returns the database system name of the column type. If an empty 35 | // string is returned, then the driver type name is not supported. 36 | // Consult your driver documentation for a list of driver data types. Length specifiers 37 | // are not included. 38 | // Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", 39 | // "INT", and "BIGINT". 40 | func (ct ColumnType) DatabaseTypeName() string { 41 | if ct.DataTypeValue.Valid { 42 | return ct.DataTypeValue.String 43 | } 44 | return ct.SQLColumnType.DatabaseTypeName() 45 | } 46 | 47 | // ColumnType returns the database type of the column. like `varchar(16)` 48 | func (ct ColumnType) ColumnType() (columnType string, ok bool) { 49 | return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid 50 | } 51 | 52 | // PrimaryKey returns the column is primary key or not. 53 | func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) { 54 | return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid 55 | } 56 | 57 | // AutoIncrement returns the column is auto increment or not. 58 | func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) { 59 | return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid 60 | } 61 | 62 | // Length returns the column type length for variable length column types 63 | func (ct ColumnType) Length() (length int64, ok bool) { 64 | if ct.LengthValue.Valid { 65 | return ct.LengthValue.Int64, true 66 | } 67 | return ct.SQLColumnType.Length() 68 | } 69 | 70 | // DecimalSize returns the scale and precision of a decimal type. 71 | func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) { 72 | if ct.DecimalSizeValue.Valid { 73 | return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true 74 | } 75 | return ct.SQLColumnType.DecimalSize() 76 | } 77 | 78 | // Nullable reports whether the column may be null. 79 | func (ct ColumnType) Nullable() (nullable bool, ok bool) { 80 | if ct.NullableValue.Valid { 81 | return ct.NullableValue.Bool, true 82 | } 83 | return ct.SQLColumnType.Nullable() 84 | } 85 | 86 | // Unique reports whether the column may be unique. 87 | func (ct ColumnType) Unique() (unique bool, ok bool) { 88 | return ct.UniqueValue.Bool, ct.UniqueValue.Valid 89 | } 90 | 91 | // ScanType returns a Go type suitable for scanning into using Rows.Scan. 92 | func (ct ColumnType) ScanType() reflect.Type { 93 | if ct.ScanTypeValue != nil { 94 | return ct.ScanTypeValue 95 | } 96 | return ct.SQLColumnType.ScanType() 97 | } 98 | 99 | // Comment returns the comment of current column. 100 | func (ct ColumnType) Comment() (value string, ok bool) { 101 | return ct.CommentValue.String, ct.CommentValue.Valid 102 | } 103 | 104 | // DefaultValue returns the default value of current column. 105 | func (ct ColumnType) DefaultValue() (value string, ok bool) { 106 | return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid 107 | } 108 | -------------------------------------------------------------------------------- /migrator/index.go: -------------------------------------------------------------------------------- 1 | package migrator 2 | 3 | import "database/sql" 4 | 5 | // Index implements gorm.Index interface 6 | type Index struct { 7 | TableName string 8 | NameValue string 9 | ColumnList []string 10 | PrimaryKeyValue sql.NullBool 11 | UniqueValue sql.NullBool 12 | OptionValue string 13 | } 14 | 15 | // Table return the table name of the index. 16 | func (idx Index) Table() string { 17 | return idx.TableName 18 | } 19 | 20 | // Name return the name of the index. 21 | func (idx Index) Name() string { 22 | return idx.NameValue 23 | } 24 | 25 | // Columns return the columns of the index 26 | func (idx Index) Columns() []string { 27 | return idx.ColumnList 28 | } 29 | 30 | // PrimaryKey returns the index is primary key or not. 31 | func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) { 32 | return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid 33 | } 34 | 35 | // Unique returns whether the index is unique or not. 36 | func (idx Index) Unique() (unique bool, ok bool) { 37 | return idx.UniqueValue.Bool, idx.UniqueValue.Valid 38 | } 39 | 40 | // Option return the optional attribute of the index 41 | func (idx Index) Option() string { 42 | return idx.OptionValue 43 | } 44 | -------------------------------------------------------------------------------- /migrator/table_type.go: -------------------------------------------------------------------------------- 1 | package migrator 2 | 3 | import ( 4 | "database/sql" 5 | ) 6 | 7 | // TableType table type implements TableType interface 8 | type TableType struct { 9 | SchemaValue string 10 | NameValue string 11 | TypeValue string 12 | CommentValue sql.NullString 13 | } 14 | 15 | // Schema returns the schema of the table. 16 | func (ct TableType) Schema() string { 17 | return ct.SchemaValue 18 | } 19 | 20 | // Name returns the name of the table. 21 | func (ct TableType) Name() string { 22 | return ct.NameValue 23 | } 24 | 25 | // Type returns the type of the table. 26 | func (ct TableType) Type() string { 27 | return ct.TypeValue 28 | } 29 | 30 | // Comment returns the comment of current table. 31 | func (ct TableType) Comment() (comment string, ok bool) { 32 | return ct.CommentValue.String, ct.CommentValue.Valid 33 | } 34 | -------------------------------------------------------------------------------- /model.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import "time" 4 | 5 | // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt 6 | // It may be embedded into your model or you may build your own model without it 7 | // 8 | // type User struct { 9 | // gorm.Model 10 | // } 11 | type Model struct { 12 | ID uint `gorm:"primarykey"` 13 | CreatedAt time.Time 14 | UpdatedAt time.Time 15 | DeletedAt DeletedAt `gorm:"index"` 16 | } 17 | -------------------------------------------------------------------------------- /schema/callbacks_test.go: -------------------------------------------------------------------------------- 1 | package schema_test 2 | 3 | import ( 4 | "reflect" 5 | "sync" 6 | "testing" 7 | 8 | "gorm.io/gorm" 9 | "gorm.io/gorm/schema" 10 | ) 11 | 12 | type UserWithCallback struct{} 13 | 14 | func (UserWithCallback) BeforeSave(*gorm.DB) error { 15 | return nil 16 | } 17 | 18 | func (UserWithCallback) AfterCreate(*gorm.DB) error { 19 | return nil 20 | } 21 | 22 | func TestCallback(t *testing.T) { 23 | user, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{}) 24 | if err != nil { 25 | t.Fatalf("failed to parse user with callback, got error %v", err) 26 | } 27 | 28 | for _, str := range []string{"BeforeSave", "AfterCreate"} { 29 | if !reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { 30 | t.Errorf("%v should be true", str) 31 | } 32 | } 33 | 34 | for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} { 35 | if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { 36 | t.Errorf("%v should be false", str) 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /schema/constraint.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "regexp" 5 | "strings" 6 | 7 | "gorm.io/gorm/clause" 8 | ) 9 | 10 | // reg match english letters and midline 11 | var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`) 12 | 13 | type CheckConstraint struct { 14 | Name string 15 | Constraint string // length(phone) >= 10 16 | *Field 17 | } 18 | 19 | func (chk *CheckConstraint) GetName() string { return chk.Name } 20 | 21 | func (chk *CheckConstraint) Build() (sql string, vars []interface{}) { 22 | return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}} 23 | } 24 | 25 | // ParseCheckConstraints parse schema check constraints 26 | func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint { 27 | checks := map[string]CheckConstraint{} 28 | for _, field := range schema.FieldsByDBName { 29 | if chk := field.TagSettings["CHECK"]; chk != "" { 30 | names := strings.Split(chk, ",") 31 | if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { 32 | checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} 33 | } else { 34 | if names[0] == "" { 35 | chk = strings.Join(names[1:], ",") 36 | } 37 | name := schema.namer.CheckerName(schema.Table, field.DBName) 38 | checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field} 39 | } 40 | } 41 | } 42 | return checks 43 | } 44 | 45 | type UniqueConstraint struct { 46 | Name string 47 | Field *Field 48 | } 49 | 50 | func (uni *UniqueConstraint) GetName() string { return uni.Name } 51 | 52 | func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) { 53 | return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}} 54 | } 55 | 56 | // ParseUniqueConstraints parse schema unique constraints 57 | func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint { 58 | uniques := make(map[string]UniqueConstraint) 59 | for _, field := range schema.Fields { 60 | if field.Unique { 61 | name := schema.namer.UniqueName(schema.Table, field.DBName) 62 | uniques[name] = UniqueConstraint{Name: name, Field: field} 63 | } 64 | } 65 | return uniques 66 | } 67 | -------------------------------------------------------------------------------- /schema/constraint_test.go: -------------------------------------------------------------------------------- 1 | package schema_test 2 | 3 | import ( 4 | "reflect" 5 | "sync" 6 | "testing" 7 | 8 | "gorm.io/gorm/schema" 9 | "gorm.io/gorm/utils/tests" 10 | ) 11 | 12 | type UserCheck struct { 13 | Name string `gorm:"check:name_checker,name <> 'jinzhu'"` 14 | Name2 string `gorm:"check:name <> 'jinzhu'"` 15 | Name3 string `gorm:"check:,name <> 'jinzhu'"` 16 | } 17 | 18 | func TestParseCheck(t *testing.T) { 19 | user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) 20 | if err != nil { 21 | t.Fatalf("failed to parse user check, got error %v", err) 22 | } 23 | 24 | results := map[string]schema.CheckConstraint{ 25 | "name_checker": { 26 | Name: "name_checker", 27 | Constraint: "name <> 'jinzhu'", 28 | }, 29 | "chk_user_checks_name2": { 30 | Name: "chk_user_checks_name2", 31 | Constraint: "name <> 'jinzhu'", 32 | }, 33 | "chk_user_checks_name3": { 34 | Name: "chk_user_checks_name3", 35 | Constraint: "name <> 'jinzhu'", 36 | }, 37 | } 38 | 39 | checks := user.ParseCheckConstraints() 40 | 41 | for k, result := range results { 42 | v, ok := checks[k] 43 | if !ok { 44 | t.Errorf("Failed to found check %v from parsed checks %+v", k, checks) 45 | } 46 | 47 | for _, name := range []string{"Name", "Constraint"} { 48 | if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { 49 | t.Errorf( 50 | "check %v %v should equal, expects %v, got %v", 51 | k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), 52 | ) 53 | } 54 | } 55 | } 56 | } 57 | 58 | func TestParseUniqueConstraints(t *testing.T) { 59 | type UserUnique struct { 60 | Name1 string `gorm:"unique"` 61 | Name2 string `gorm:"uniqueIndex"` 62 | } 63 | 64 | user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{}) 65 | if err != nil { 66 | t.Fatalf("failed to parse user unique, got error %v", err) 67 | } 68 | constraints := user.ParseUniqueConstraints() 69 | 70 | results := map[string]schema.UniqueConstraint{ 71 | "uni_user_uniques_name1": { 72 | Name: "uni_user_uniques_name1", 73 | Field: &schema.Field{Name: "Name1", Unique: true}, 74 | }, 75 | } 76 | for k, result := range results { 77 | v, ok := constraints[k] 78 | if !ok { 79 | t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints) 80 | } 81 | tests.AssertObjEqual(t, result, v, "Name") 82 | tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex") 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /schema/index.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | type Index struct { 11 | Name string 12 | Class string // UNIQUE | FULLTEXT | SPATIAL 13 | Type string // btree, hash, gist, spgist, gin, and brin 14 | Where string 15 | Comment string 16 | Option string // WITH PARSER parser_name 17 | Fields []IndexOption // Note: IndexOption's Field maybe the same 18 | } 19 | 20 | type IndexOption struct { 21 | *Field 22 | Expression string 23 | Sort string // DESC, ASC 24 | Collate string 25 | Length int 26 | Priority int 27 | } 28 | 29 | // ParseIndexes parse schema indexes 30 | func (schema *Schema) ParseIndexes() []*Index { 31 | indexesByName := map[string]*Index{} 32 | indexes := []*Index{} 33 | 34 | for _, field := range schema.Fields { 35 | if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { 36 | fieldIndexes, err := parseFieldIndexes(field) 37 | if err != nil { 38 | schema.err = err 39 | break 40 | } 41 | for _, index := range fieldIndexes { 42 | idx := indexesByName[index.Name] 43 | if idx == nil { 44 | idx = &Index{Name: index.Name} 45 | indexesByName[index.Name] = idx 46 | indexes = append(indexes, idx) 47 | } 48 | idx.Name = index.Name 49 | if idx.Class == "" { 50 | idx.Class = index.Class 51 | } 52 | if idx.Type == "" { 53 | idx.Type = index.Type 54 | } 55 | if idx.Where == "" { 56 | idx.Where = index.Where 57 | } 58 | if idx.Comment == "" { 59 | idx.Comment = index.Comment 60 | } 61 | if idx.Option == "" { 62 | idx.Option = index.Option 63 | } 64 | 65 | idx.Fields = append(idx.Fields, index.Fields...) 66 | sort.Slice(idx.Fields, func(i, j int) bool { 67 | return idx.Fields[i].Priority < idx.Fields[j].Priority 68 | }) 69 | } 70 | } 71 | } 72 | for _, index := range indexes { 73 | if index.Class == "UNIQUE" && len(index.Fields) == 1 { 74 | index.Fields[0].Field.UniqueIndex = index.Name 75 | } 76 | } 77 | return indexes 78 | } 79 | 80 | func (schema *Schema) LookIndex(name string) *Index { 81 | if schema != nil { 82 | indexes := schema.ParseIndexes() 83 | for _, index := range indexes { 84 | if index.Name == name { 85 | return index 86 | } 87 | 88 | for _, field := range index.Fields { 89 | if field.Name == name { 90 | return index 91 | } 92 | } 93 | } 94 | } 95 | 96 | return nil 97 | } 98 | 99 | func parseFieldIndexes(field *Field) (indexes []Index, err error) { 100 | for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { 101 | if value != "" { 102 | v := strings.Split(value, ":") 103 | k := strings.TrimSpace(strings.ToUpper(v[0])) 104 | if k == "INDEX" || k == "UNIQUEINDEX" { 105 | var ( 106 | name string 107 | tag = strings.Join(v[1:], ":") 108 | idx = strings.IndexByte(tag, ',') 109 | tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",") 110 | settings = ParseTagSetting(tagSetting, ",") 111 | length, _ = strconv.Atoi(settings["LENGTH"]) 112 | ) 113 | 114 | if idx == -1 { 115 | idx = len(tag) 116 | } 117 | 118 | name = tag[0:idx] 119 | if name == "" { 120 | subName := field.Name 121 | const key = "COMPOSITE" 122 | if composite, found := settings[key]; found { 123 | if len(composite) == 0 || composite == key { 124 | err = fmt.Errorf( 125 | "the composite tag of %s.%s cannot be empty", 126 | field.Schema.Name, 127 | field.Name) 128 | return 129 | } 130 | subName = composite 131 | } 132 | name = field.Schema.namer.IndexName( 133 | field.Schema.Table, subName) 134 | } 135 | 136 | if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { 137 | settings["CLASS"] = "UNIQUE" 138 | } 139 | 140 | priority, err := strconv.Atoi(settings["PRIORITY"]) 141 | if err != nil { 142 | priority = 10 143 | } 144 | 145 | indexes = append(indexes, Index{ 146 | Name: name, 147 | Class: settings["CLASS"], 148 | Type: settings["TYPE"], 149 | Where: settings["WHERE"], 150 | Comment: settings["COMMENT"], 151 | Option: settings["OPTION"], 152 | Fields: []IndexOption{{ 153 | Field: field, 154 | Expression: settings["EXPRESSION"], 155 | Sort: settings["SORT"], 156 | Collate: settings["COLLATE"], 157 | Length: length, 158 | Priority: priority, 159 | }}, 160 | }) 161 | } 162 | } 163 | } 164 | 165 | err = nil 166 | return 167 | } 168 | -------------------------------------------------------------------------------- /schema/interfaces.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "gorm.io/gorm/clause" 5 | ) 6 | 7 | // ConstraintInterface database constraint interface 8 | type ConstraintInterface interface { 9 | GetName() string 10 | Build() (sql string, vars []interface{}) 11 | } 12 | 13 | // GormDataTypeInterface gorm data type interface 14 | type GormDataTypeInterface interface { 15 | GormDataType() string 16 | } 17 | 18 | // FieldNewValuePool field new scan value pool 19 | type FieldNewValuePool interface { 20 | Get() interface{} 21 | Put(interface{}) 22 | } 23 | 24 | // CreateClausesInterface create clauses interface 25 | type CreateClausesInterface interface { 26 | CreateClauses(*Field) []clause.Interface 27 | } 28 | 29 | // QueryClausesInterface query clauses interface 30 | type QueryClausesInterface interface { 31 | QueryClauses(*Field) []clause.Interface 32 | } 33 | 34 | // UpdateClausesInterface update clauses interface 35 | type UpdateClausesInterface interface { 36 | UpdateClauses(*Field) []clause.Interface 37 | } 38 | 39 | // DeleteClausesInterface delete clauses interface 40 | type DeleteClausesInterface interface { 41 | DeleteClauses(*Field) []clause.Interface 42 | } 43 | -------------------------------------------------------------------------------- /schema/model_test.go: -------------------------------------------------------------------------------- 1 | package schema_test 2 | 3 | import ( 4 | "database/sql" 5 | "time" 6 | 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/utils/tests" 9 | ) 10 | 11 | type User struct { 12 | *gorm.Model 13 | Name *string 14 | Age *uint 15 | Birthday *time.Time 16 | Account *tests.Account 17 | Pets []*tests.Pet 18 | Toys []*tests.Toy `gorm:"polymorphic:Owner"` 19 | CompanyID *int 20 | Company *tests.Company 21 | ManagerID *uint 22 | Manager *User 23 | Team []*User `gorm:"foreignkey:ManagerID"` 24 | Languages []*tests.Language `gorm:"many2many:UserSpeak"` 25 | Friends []*User `gorm:"many2many:user_friends"` 26 | Active *bool 27 | } 28 | 29 | type ( 30 | mytime time.Time 31 | myint int 32 | mybool = bool 33 | ) 34 | 35 | type AdvancedDataTypeUser struct { 36 | ID sql.NullInt64 37 | Name *sql.NullString 38 | Birthday sql.NullTime 39 | RegisteredAt mytime 40 | DeletedAt *mytime 41 | Active mybool 42 | Admin *mybool 43 | } 44 | 45 | type BaseModel struct { 46 | ID uint 47 | CreatedAt time.Time 48 | CreatedBy *int 49 | Created *VersionUser `gorm:"foreignKey:CreatedBy"` 50 | UpdatedAt time.Time 51 | DeletedAt gorm.DeletedAt `gorm:"index"` 52 | } 53 | 54 | type VersionModel struct { 55 | BaseModel 56 | Version int 57 | } 58 | 59 | type VersionUser struct { 60 | VersionModel 61 | Name string 62 | Age uint 63 | Birthday *time.Time 64 | } 65 | -------------------------------------------------------------------------------- /schema/naming.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "crypto/sha1" 5 | "encoding/hex" 6 | "regexp" 7 | "strings" 8 | "unicode/utf8" 9 | 10 | "github.com/jinzhu/inflection" 11 | "golang.org/x/text/cases" 12 | "golang.org/x/text/language" 13 | ) 14 | 15 | // Namer namer interface 16 | type Namer interface { 17 | TableName(table string) string 18 | SchemaName(table string) string 19 | ColumnName(table, column string) string 20 | JoinTableName(joinTable string) string 21 | RelationshipFKName(Relationship) string 22 | CheckerName(table, column string) string 23 | IndexName(table, column string) string 24 | UniqueName(table, column string) string 25 | } 26 | 27 | // Replacer replacer interface like strings.Replacer 28 | type Replacer interface { 29 | Replace(name string) string 30 | } 31 | 32 | var _ Namer = (*NamingStrategy)(nil) 33 | 34 | // NamingStrategy tables, columns naming strategy 35 | type NamingStrategy struct { 36 | TablePrefix string 37 | SingularTable bool 38 | NameReplacer Replacer 39 | NoLowerCase bool 40 | IdentifierMaxLength int 41 | } 42 | 43 | // TableName convert string to table name 44 | func (ns NamingStrategy) TableName(str string) string { 45 | if ns.SingularTable { 46 | return ns.TablePrefix + ns.toDBName(str) 47 | } 48 | return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) 49 | } 50 | 51 | // SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName 52 | func (ns NamingStrategy) SchemaName(table string) string { 53 | table = strings.TrimPrefix(table, ns.TablePrefix) 54 | 55 | if ns.SingularTable { 56 | return ns.toSchemaName(table) 57 | } 58 | return ns.toSchemaName(inflection.Singular(table)) 59 | } 60 | 61 | // ColumnName convert string to column name 62 | func (ns NamingStrategy) ColumnName(table, column string) string { 63 | return ns.toDBName(column) 64 | } 65 | 66 | // JoinTableName convert string to join table name 67 | func (ns NamingStrategy) JoinTableName(str string) string { 68 | if !ns.NoLowerCase && strings.ToLower(str) == str { 69 | return ns.TablePrefix + str 70 | } 71 | 72 | if ns.SingularTable { 73 | return ns.TablePrefix + ns.toDBName(str) 74 | } 75 | return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) 76 | } 77 | 78 | // RelationshipFKName generate fk name for relation 79 | func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { 80 | return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name)) 81 | } 82 | 83 | // CheckerName generate checker name 84 | func (ns NamingStrategy) CheckerName(table, column string) string { 85 | return ns.formatName("chk", table, column) 86 | } 87 | 88 | // IndexName generate index name 89 | func (ns NamingStrategy) IndexName(table, column string) string { 90 | return ns.formatName("idx", table, ns.toDBName(column)) 91 | } 92 | 93 | // UniqueName generate unique constraint name 94 | func (ns NamingStrategy) UniqueName(table, column string) string { 95 | return ns.formatName("uni", table, ns.toDBName(column)) 96 | } 97 | 98 | func (ns NamingStrategy) formatName(prefix, table, name string) string { 99 | formattedName := strings.ReplaceAll(strings.Join([]string{ 100 | prefix, table, name, 101 | }, "_"), ".", "_") 102 | 103 | if ns.IdentifierMaxLength == 0 { 104 | ns.IdentifierMaxLength = 64 105 | } 106 | 107 | if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength { 108 | h := sha1.New() 109 | h.Write([]byte(formattedName)) 110 | bs := h.Sum(nil) 111 | 112 | formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8] 113 | } 114 | return formattedName 115 | } 116 | 117 | var ( 118 | // https://github.com/golang/lint/blob/master/lint.go#L770 119 | commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} 120 | commonInitialismsReplacer *strings.Replacer 121 | ) 122 | 123 | func init() { 124 | commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms)) 125 | for _, initialism := range commonInitialisms { 126 | commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism)) 127 | } 128 | commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) 129 | } 130 | 131 | func (ns NamingStrategy) toDBName(name string) string { 132 | if name == "" { 133 | return "" 134 | } 135 | 136 | if ns.NameReplacer != nil { 137 | tmpName := ns.NameReplacer.Replace(name) 138 | 139 | if tmpName == "" { 140 | return name 141 | } 142 | 143 | name = tmpName 144 | } 145 | 146 | if ns.NoLowerCase { 147 | return name 148 | } 149 | 150 | var ( 151 | value = commonInitialismsReplacer.Replace(name) 152 | buf strings.Builder 153 | lastCase, nextCase, nextNumber bool // upper case == true 154 | curCase = value[0] <= 'Z' && value[0] >= 'A' 155 | ) 156 | 157 | for i, v := range value[:len(value)-1] { 158 | nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A' 159 | nextNumber = value[i+1] >= '0' && value[i+1] <= '9' 160 | 161 | if curCase { 162 | if lastCase && (nextCase || nextNumber) { 163 | buf.WriteRune(v + 32) 164 | } else { 165 | if i > 0 && value[i-1] != '_' && value[i+1] != '_' { 166 | buf.WriteByte('_') 167 | } 168 | buf.WriteRune(v + 32) 169 | } 170 | } else { 171 | buf.WriteRune(v) 172 | } 173 | 174 | lastCase = curCase 175 | curCase = nextCase 176 | } 177 | 178 | if curCase { 179 | if !lastCase && len(value) > 1 { 180 | buf.WriteByte('_') 181 | } 182 | buf.WriteByte(value[len(value)-1] + 32) 183 | } else { 184 | buf.WriteByte(value[len(value)-1]) 185 | } 186 | ret := buf.String() 187 | return ret 188 | } 189 | 190 | func (ns NamingStrategy) toSchemaName(name string) string { 191 | result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "") 192 | for _, initialism := range commonInitialisms { 193 | result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") 194 | } 195 | return result 196 | } 197 | -------------------------------------------------------------------------------- /schema/pool.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "reflect" 5 | "sync" 6 | ) 7 | 8 | // sync pools 9 | var ( 10 | normalPool sync.Map 11 | poolInitializer = func(reflectType reflect.Type) FieldNewValuePool { 12 | v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{ 13 | New: func() interface{} { 14 | return reflect.New(reflectType).Interface() 15 | }, 16 | }) 17 | return v.(FieldNewValuePool) 18 | } 19 | ) 20 | -------------------------------------------------------------------------------- /schema/serializer.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "database/sql" 7 | "database/sql/driver" 8 | "encoding/gob" 9 | "encoding/json" 10 | "fmt" 11 | "reflect" 12 | "strings" 13 | "sync" 14 | "time" 15 | ) 16 | 17 | var serializerMap = sync.Map{} 18 | 19 | // RegisterSerializer register serializer 20 | func RegisterSerializer(name string, serializer SerializerInterface) { 21 | serializerMap.Store(strings.ToLower(name), serializer) 22 | } 23 | 24 | // GetSerializer get serializer 25 | func GetSerializer(name string) (serializer SerializerInterface, ok bool) { 26 | v, ok := serializerMap.Load(strings.ToLower(name)) 27 | if ok { 28 | serializer, ok = v.(SerializerInterface) 29 | } 30 | return serializer, ok 31 | } 32 | 33 | func init() { 34 | RegisterSerializer("json", JSONSerializer{}) 35 | RegisterSerializer("unixtime", UnixSecondSerializer{}) 36 | RegisterSerializer("gob", GobSerializer{}) 37 | } 38 | 39 | // Serializer field value serializer 40 | type serializer struct { 41 | Field *Field 42 | Serializer SerializerInterface 43 | SerializeValuer SerializerValuerInterface 44 | Destination reflect.Value 45 | Context context.Context 46 | value interface{} 47 | fieldValue interface{} 48 | } 49 | 50 | // Scan implements sql.Scanner interface 51 | func (s *serializer) Scan(value interface{}) error { 52 | s.value = value 53 | return nil 54 | } 55 | 56 | // Value implements driver.Valuer interface 57 | func (s serializer) Value() (driver.Value, error) { 58 | return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue) 59 | } 60 | 61 | // SerializerInterface serializer interface 62 | type SerializerInterface interface { 63 | Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error 64 | SerializerValuerInterface 65 | } 66 | 67 | // SerializerValuerInterface serializer valuer interface 68 | type SerializerValuerInterface interface { 69 | Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) 70 | } 71 | 72 | // JSONSerializer json serializer 73 | type JSONSerializer struct{} 74 | 75 | // Scan implements serializer interface 76 | func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { 77 | fieldValue := reflect.New(field.FieldType) 78 | 79 | if dbValue != nil { 80 | var bytes []byte 81 | switch v := dbValue.(type) { 82 | case []byte: 83 | bytes = v 84 | case string: 85 | bytes = []byte(v) 86 | default: 87 | bytes, err = json.Marshal(v) 88 | if err != nil { 89 | return err 90 | } 91 | } 92 | 93 | if len(bytes) > 0 { 94 | err = json.Unmarshal(bytes, fieldValue.Interface()) 95 | } 96 | } 97 | 98 | field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) 99 | return 100 | } 101 | 102 | // Value implements serializer interface 103 | func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { 104 | result, err := json.Marshal(fieldValue) 105 | if string(result) == "null" { 106 | if field.TagSettings["NOT NULL"] != "" { 107 | return "", nil 108 | } 109 | return nil, err 110 | } 111 | return string(result), err 112 | } 113 | 114 | // UnixSecondSerializer json serializer 115 | type UnixSecondSerializer struct{} 116 | 117 | // Scan implements serializer interface 118 | func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { 119 | t := sql.NullTime{} 120 | if err = t.Scan(dbValue); err == nil && t.Valid { 121 | err = field.Set(ctx, dst, t.Time.Unix()) 122 | } 123 | 124 | return 125 | } 126 | 127 | // Value implements serializer interface 128 | func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { 129 | rv := reflect.ValueOf(fieldValue) 130 | switch v := fieldValue.(type) { 131 | case int64, int, uint, uint64, int32, uint32, int16, uint16: 132 | result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() 133 | case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: 134 | if rv.IsZero() { 135 | return nil, nil 136 | } 137 | result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() 138 | default: 139 | err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) 140 | } 141 | return 142 | } 143 | 144 | // GobSerializer gob serializer 145 | type GobSerializer struct{} 146 | 147 | // Scan implements serializer interface 148 | func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { 149 | fieldValue := reflect.New(field.FieldType) 150 | 151 | if dbValue != nil { 152 | var bytesValue []byte 153 | switch v := dbValue.(type) { 154 | case []byte: 155 | bytesValue = v 156 | default: 157 | return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue) 158 | } 159 | if len(bytesValue) > 0 { 160 | decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) 161 | err = decoder.Decode(fieldValue.Interface()) 162 | } 163 | } 164 | field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) 165 | return 166 | } 167 | 168 | // Value implements serializer interface 169 | func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { 170 | buf := new(bytes.Buffer) 171 | err := gob.NewEncoder(buf).Encode(fieldValue) 172 | return buf.Bytes(), err 173 | } 174 | -------------------------------------------------------------------------------- /schema/utils_test.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestRemoveSettingFromTag(t *testing.T) { 9 | tags := map[string]string{ 10 | `gorm:"before:value;column:db;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, 11 | `gorm:"before:value;column:db;" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`, 12 | `gorm:"before:value;column:db" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`, 13 | `gorm:"column:db" other:"before:value;column:db;after:value"`: `gorm:"" other:"before:value;column:db;after:value"`, 14 | `gorm:"before:value;column:db ;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, 15 | `gorm:"before:value;column:db; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, 16 | `gorm:"before:value;column; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, 17 | } 18 | 19 | for k, v := range tags { 20 | if string(removeSettingFromTag(reflect.StructTag(k), "column")) != v { 21 | t.Errorf("%v after removeSettingFromTag should equal %v, but got %v", k, v, removeSettingFromTag(reflect.StructTag(k), "column")) 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /soft_delete.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "encoding/json" 7 | "reflect" 8 | 9 | "github.com/jinzhu/now" 10 | "gorm.io/gorm/clause" 11 | "gorm.io/gorm/schema" 12 | ) 13 | 14 | type DeletedAt sql.NullTime 15 | 16 | // Scan implements the Scanner interface. 17 | func (n *DeletedAt) Scan(value interface{}) error { 18 | return (*sql.NullTime)(n).Scan(value) 19 | } 20 | 21 | // Value implements the driver Valuer interface. 22 | func (n DeletedAt) Value() (driver.Value, error) { 23 | if !n.Valid { 24 | return nil, nil 25 | } 26 | return n.Time, nil 27 | } 28 | 29 | func (n DeletedAt) MarshalJSON() ([]byte, error) { 30 | if n.Valid { 31 | return json.Marshal(n.Time) 32 | } 33 | return json.Marshal(nil) 34 | } 35 | 36 | func (n *DeletedAt) UnmarshalJSON(b []byte) error { 37 | if string(b) == "null" { 38 | n.Valid = false 39 | return nil 40 | } 41 | err := json.Unmarshal(b, &n.Time) 42 | if err == nil { 43 | n.Valid = true 44 | } 45 | return err 46 | } 47 | 48 | func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { 49 | return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}} 50 | } 51 | 52 | func parseZeroValueTag(f *schema.Field) sql.NullString { 53 | if v, ok := f.TagSettings["ZEROVALUE"]; ok { 54 | if _, err := now.Parse(v); err == nil { 55 | return sql.NullString{String: v, Valid: true} 56 | } 57 | } 58 | return sql.NullString{Valid: false} 59 | } 60 | 61 | type SoftDeleteQueryClause struct { 62 | ZeroValue sql.NullString 63 | Field *schema.Field 64 | } 65 | 66 | func (sd SoftDeleteQueryClause) Name() string { 67 | return "" 68 | } 69 | 70 | func (sd SoftDeleteQueryClause) Build(clause.Builder) { 71 | } 72 | 73 | func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { 74 | } 75 | 76 | func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { 77 | if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok && !stmt.Statement.Unscoped { 78 | if c, ok := stmt.Clauses["WHERE"]; ok { 79 | if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) >= 1 { 80 | for _, expr := range where.Exprs { 81 | if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 { 82 | where.Exprs = []clause.Expression{clause.And(where.Exprs...)} 83 | c.Expression = where 84 | stmt.Clauses["WHERE"] = c 85 | break 86 | } 87 | } 88 | } 89 | } 90 | 91 | stmt.AddClause(clause.Where{Exprs: []clause.Expression{ 92 | clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue}, 93 | }}) 94 | stmt.Clauses["soft_delete_enabled"] = clause.Clause{} 95 | } 96 | } 97 | 98 | func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface { 99 | return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}} 100 | } 101 | 102 | type SoftDeleteUpdateClause struct { 103 | ZeroValue sql.NullString 104 | Field *schema.Field 105 | } 106 | 107 | func (sd SoftDeleteUpdateClause) Name() string { 108 | return "" 109 | } 110 | 111 | func (sd SoftDeleteUpdateClause) Build(clause.Builder) { 112 | } 113 | 114 | func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { 115 | } 116 | 117 | func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { 118 | if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { 119 | SoftDeleteQueryClause(sd).ModifyStatement(stmt) 120 | } 121 | } 122 | 123 | func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { 124 | return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}} 125 | } 126 | 127 | type SoftDeleteDeleteClause struct { 128 | ZeroValue sql.NullString 129 | Field *schema.Field 130 | } 131 | 132 | func (sd SoftDeleteDeleteClause) Name() string { 133 | return "" 134 | } 135 | 136 | func (sd SoftDeleteDeleteClause) Build(clause.Builder) { 137 | } 138 | 139 | func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { 140 | } 141 | 142 | func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { 143 | if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { 144 | curTime := stmt.DB.NowFunc() 145 | stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) 146 | stmt.SetColumn(sd.Field.DBName, curTime, true) 147 | 148 | if stmt.Schema != nil { 149 | _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) 150 | column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) 151 | 152 | if len(values) > 0 { 153 | stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) 154 | } 155 | 156 | if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { 157 | _, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) 158 | column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) 159 | 160 | if len(values) > 0 { 161 | stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) 162 | } 163 | } 164 | } 165 | 166 | SoftDeleteQueryClause(sd).ModifyStatement(stmt) 167 | stmt.AddClauseIfNotExists(clause.Update{}) 168 | stmt.Build(stmt.DB.Callback().Update().Clauses...) 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /statement_test.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | 8 | "gorm.io/gorm/clause" 9 | ) 10 | 11 | func TestWhereCloneCorruption(t *testing.T) { 12 | for whereCount := 1; whereCount <= 8; whereCount++ { 13 | t.Run(fmt.Sprintf("w=%d", whereCount), func(t *testing.T) { 14 | s := new(Statement) 15 | for w := 0; w < whereCount; w++ { 16 | s = s.clone() 17 | s.AddClause(clause.Where{ 18 | Exprs: s.BuildCondition(fmt.Sprintf("where%d", w)), 19 | }) 20 | } 21 | 22 | s1 := s.clone() 23 | s1.AddClause(clause.Where{ 24 | Exprs: s.BuildCondition("FINAL1"), 25 | }) 26 | s2 := s.clone() 27 | s2.AddClause(clause.Where{ 28 | Exprs: s.BuildCondition("FINAL2"), 29 | }) 30 | 31 | if reflect.DeepEqual(s1.Clauses["WHERE"], s2.Clauses["WHERE"]) { 32 | t.Errorf("Where conditions should be different") 33 | } 34 | }) 35 | } 36 | } 37 | 38 | func TestNilCondition(t *testing.T) { 39 | s := new(Statement) 40 | if len(s.BuildCondition(nil)) != 0 { 41 | t.Errorf("Nil condition should be empty") 42 | } 43 | } 44 | 45 | func TestNameMatcher(t *testing.T) { 46 | for k, v := range map[string][]string{ 47 | "table.name": {"table", "name"}, 48 | "`table`.`name`": {"table", "name"}, 49 | "'table'.'name'": {"table", "name"}, 50 | "'table'.name": {"table", "name"}, 51 | "table1.name_23": {"table1", "name_23"}, 52 | "`table_1`.`name23`": {"table_1", "name23"}, 53 | "'table23'.'name_1'": {"table23", "name_1"}, 54 | "'table23'.name1": {"table23", "name1"}, 55 | "'name1'": {"", "name1"}, 56 | "`name_1`": {"", "name_1"}, 57 | "`Name_1`": {"", "Name_1"}, 58 | "`Table`.`nAme`": {"Table", "nAme"}, 59 | "my_table.*": {"my_table", "*"}, 60 | "`my_table`.*": {"my_table", "*"}, 61 | "User__Company.*": {"User__Company", "*"}, 62 | "`User__Company`.*": {"User__Company", "*"}, 63 | `"User__Company".*`: {"User__Company", "*"}, 64 | `"table"."*"`: {"", ""}, 65 | } { 66 | if table, column := matchName(k); table != v[0] || column != v[1] { 67 | t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v) 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | go.sum 2 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Test Guide 2 | 3 | ```bash 4 | cd tests 5 | # prepare test databases 6 | docker-compose up 7 | 8 | # run all tests 9 | ./tests_all.sh 10 | ``` 11 | -------------------------------------------------------------------------------- /tests/benchmark_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | . "gorm.io/gorm/utils/tests" 8 | ) 9 | 10 | func BenchmarkCreate(b *testing.B) { 11 | user := *GetUser("bench", Config{}) 12 | 13 | for x := 0; x < b.N; x++ { 14 | user.ID = 0 15 | DB.Create(&user) 16 | } 17 | } 18 | 19 | func BenchmarkFind(b *testing.B) { 20 | user := *GetUser("find", Config{}) 21 | DB.Create(&user) 22 | 23 | for x := 0; x < b.N; x++ { 24 | DB.Find(&User{}, "id = ?", user.ID) 25 | } 26 | } 27 | 28 | func BenchmarkScan(b *testing.B) { 29 | user := *GetUser("scan", Config{}) 30 | DB.Create(&user) 31 | 32 | var u User 33 | b.ResetTimer() 34 | for x := 0; x < b.N; x++ { 35 | DB.Raw("select * from users where id = ?", user.ID).Scan(&u) 36 | } 37 | } 38 | 39 | func BenchmarkScanSlice(b *testing.B) { 40 | DB.Exec("delete from users") 41 | for i := 0; i < 10_000; i++ { 42 | user := *GetUser(fmt.Sprintf("scan-%d", i), Config{}) 43 | DB.Create(&user) 44 | } 45 | 46 | var u []User 47 | b.ResetTimer() 48 | for x := 0; x < b.N; x++ { 49 | DB.Raw("select * from users").Scan(&u) 50 | } 51 | } 52 | 53 | func BenchmarkScanSlicePointer(b *testing.B) { 54 | DB.Exec("delete from users") 55 | for i := 0; i < 10_000; i++ { 56 | user := *GetUser(fmt.Sprintf("scan-%d", i), Config{}) 57 | DB.Create(&user) 58 | } 59 | 60 | var u []*User 61 | b.ResetTimer() 62 | for x := 0; x < b.N; x++ { 63 | DB.Raw("select * from users").Scan(&u) 64 | } 65 | } 66 | 67 | func BenchmarkUpdate(b *testing.B) { 68 | user := *GetUser("find", Config{}) 69 | DB.Create(&user) 70 | 71 | for x := 0; x < b.N; x++ { 72 | DB.Model(&user).Updates(map[string]interface{}{"Age": x}) 73 | } 74 | } 75 | 76 | func BenchmarkDelete(b *testing.B) { 77 | user := *GetUser("find", Config{}) 78 | 79 | for x := 0; x < b.N; x++ { 80 | user.ID = 0 81 | DB.Create(&user) 82 | DB.Delete(&user) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /tests/compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | mysql: 3 | image: 'mysql/mysql-server:latest' 4 | ports: 5 | - "127.0.0.1:9910:3306" 6 | environment: 7 | - MYSQL_DATABASE=gorm 8 | - MYSQL_USER=gorm 9 | - MYSQL_PASSWORD=gorm 10 | - MYSQL_RANDOM_ROOT_PASSWORD="yes" 11 | postgres: 12 | image: 'postgres:latest' 13 | ports: 14 | - "127.0.0.1:9920:5432" 15 | environment: 16 | - TZ=Asia/Shanghai 17 | - POSTGRES_DB=gorm 18 | - POSTGRES_USER=gorm 19 | - POSTGRES_PASSWORD=gorm 20 | mssql: 21 | image: '${MSSQL_IMAGE}:latest' 22 | ports: 23 | - "127.0.0.1:9930:1433" 24 | environment: 25 | - TZ=Asia/Shanghai 26 | - ACCEPT_EULA=Y 27 | - MSSQL_SA_PASSWORD=LoremIpsum86 28 | tidb: 29 | image: 'pingcap/tidb:v6.5.0' 30 | ports: 31 | - "127.0.0.1:9940:4000" 32 | command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 & 33 | -------------------------------------------------------------------------------- /tests/connection_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gorm.io/driver/mysql" 8 | "gorm.io/gorm" 9 | ) 10 | 11 | func TestWithSingleConnection(t *testing.T) { 12 | expectedName := "test" 13 | var actualName string 14 | 15 | setSQL, getSQL := getSetSQL(DB.Dialector.Name()) 16 | if len(setSQL) == 0 || len(getSQL) == 0 { 17 | return 18 | } 19 | 20 | err := DB.Connection(func(tx *gorm.DB) error { 21 | if err := tx.Exec(setSQL, expectedName).Error; err != nil { 22 | return err 23 | } 24 | 25 | if err := tx.Raw(getSQL).Scan(&actualName).Error; err != nil { 26 | return err 27 | } 28 | return nil 29 | }) 30 | if err != nil { 31 | t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err)) 32 | } 33 | 34 | if actualName != expectedName { 35 | t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName) 36 | } 37 | } 38 | 39 | func getSetSQL(driverName string) (string, string) { 40 | switch driverName { 41 | case mysql.Dialector{}.Name(): 42 | return "SET @testName := ?", "SELECT @testName" 43 | default: 44 | return "", "" 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /tests/default_value_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "gorm.io/gorm" 8 | ) 9 | 10 | func TestDefaultValue(t *testing.T) { 11 | type Harumph struct { 12 | gorm.Model 13 | Email string `gorm:"not null;index:,unique"` 14 | Name string `gorm:"notNull;default:foo"` 15 | Name2 string `gorm:"size:233;not null;default:'foo'"` 16 | Name3 string `gorm:"size:233;notNull;default:''"` 17 | Age int `gorm:"default:18"` 18 | Created time.Time `gorm:"default:2000-01-02"` 19 | Enabled bool `gorm:"default:true"` 20 | } 21 | 22 | DB.Migrator().DropTable(&Harumph{}) 23 | 24 | if err := DB.AutoMigrate(&Harumph{}); err != nil { 25 | t.Fatalf("Failed to migrate with default value, got error: %v", err) 26 | } 27 | 28 | harumph := Harumph{Email: "hello@gorm.io"} 29 | if err := DB.Create(&harumph).Error; err != nil { 30 | t.Fatalf("Failed to create data with default value, got error: %v", err) 31 | } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled || harumph.Created.Format("20060102") != "20000102" { 32 | t.Fatalf("Failed to create data with default value, got: %+v", harumph) 33 | } 34 | 35 | var result Harumph 36 | if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { 37 | t.Fatalf("Failed to find created data, got error: %v", err) 38 | } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" { 39 | t.Fatalf("Failed to find created data with default data, got %+v", result) 40 | } 41 | 42 | type Harumph2 struct { 43 | ID int `gorm:"default:0"` 44 | Email string `gorm:"not null;index:,unique"` 45 | Name string `gorm:"notNull;default:foo"` 46 | Name2 string `gorm:"size:233;not null;default:'foo'"` 47 | Name3 string `gorm:"size:233;notNull;default:''"` 48 | Age int `gorm:"default:18"` 49 | Created time.Time `gorm:"default:2000-01-02"` 50 | Enabled bool `gorm:"default:true"` 51 | } 52 | 53 | harumph2 := Harumph2{ID: 2, Email: "hello2@gorm.io"} 54 | if err := DB.Table("harumphs").Create(&harumph2).Error; err != nil { 55 | t.Fatalf("Failed to create data with default value, got error: %v", err) 56 | } else if harumph2.ID != 2 || harumph2.Name != "foo" || harumph2.Name2 != "foo" || harumph2.Name3 != "" || harumph2.Age != 18 || !harumph2.Enabled || harumph2.Created.Format("20060102") != "20000102" { 57 | t.Fatalf("Failed to create data with default value, got: %+v", harumph2) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /tests/distinct_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "regexp" 5 | "testing" 6 | 7 | "gorm.io/gorm" 8 | . "gorm.io/gorm/utils/tests" 9 | ) 10 | 11 | func TestDistinct(t *testing.T) { 12 | users := []User{ 13 | *GetUser("distinct", Config{}), 14 | *GetUser("distinct", Config{}), 15 | *GetUser("distinct", Config{}), 16 | *GetUser("distinct-2", Config{}), 17 | *GetUser("distinct-3", Config{}), 18 | } 19 | users[0].Age = 20 20 | 21 | if err := DB.Create(&users).Error; err != nil { 22 | t.Fatalf("errors happened when create users: %v", err) 23 | } 24 | 25 | var names []string 26 | DB.Table("users").Where("name like ?", "distinct%").Order("name").Pluck("name", &names) 27 | AssertEqual(t, names, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"}) 28 | 29 | var names1 []string 30 | DB.Model(&User{}).Where("name like ?", "distinct%").Distinct().Order("name").Pluck("Name", &names1) 31 | 32 | AssertEqual(t, names1, []string{"distinct", "distinct-2", "distinct-3"}) 33 | 34 | var names2 []string 35 | DB.Scopes(func(db *gorm.DB) *gorm.DB { 36 | return db.Table("users") 37 | }).Where("name like ?", "distinct%").Order("name").Pluck("name", &names2) 38 | AssertEqual(t, names2, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"}) 39 | 40 | var results []User 41 | if err := DB.Distinct("name", "age").Where("name like ?", "distinct%").Order("name, age desc").Find(&results).Error; err != nil { 42 | t.Errorf("failed to query users, got error: %v", err) 43 | } 44 | 45 | expects := []User{ 46 | {Name: "distinct", Age: 20}, 47 | {Name: "distinct", Age: 18}, 48 | {Name: "distinct-2", Age: 18}, 49 | {Name: "distinct-3", Age: 18}, 50 | } 51 | 52 | if len(results) != 4 { 53 | t.Fatalf("invalid results length found, expects: %v, got %v", len(expects), len(results)) 54 | } 55 | 56 | for idx, expect := range expects { 57 | AssertObjEqual(t, results[idx], expect, "Name", "Age") 58 | } 59 | 60 | var count int64 61 | if err := DB.Model(&User{}).Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 5 { 62 | t.Errorf("failed to query users count, got error: %v, count: %v", err, count) 63 | } 64 | 65 | if err := DB.Model(&User{}).Distinct("name").Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 3 { 66 | t.Errorf("failed to query users count, got error: %v, count %v", err, count) 67 | } 68 | 69 | dryDB := DB.Session(&gorm.Session{DryRun: true}) 70 | r := dryDB.Distinct("u.id, u.*").Table("user_speaks as s").Joins("inner join users as u on u.id = s.user_id").Where("s.language_code ='US' or s.language_code ='ES'").Find(&User{}) 71 | if !regexp.MustCompile(`SELECT DISTINCT u\.id, u\.\* FROM user_speaks as s inner join users as u`).MatchString(r.Statement.SQL.String()) { 72 | t.Fatalf("Build Distinct with u.*, but got %v", r.Statement.SQL.String()) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /tests/error_translator_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/utils/tests" 9 | ) 10 | 11 | func TestDialectorWithErrorTranslatorSupport(t *testing.T) { 12 | // it shouldn't translate error when the TranslateError flag is false 13 | translatedErr := errors.New("translated error") 14 | untranslatedErr := errors.New("some random error") 15 | db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}) 16 | 17 | err := db.AddError(untranslatedErr) 18 | if !errors.Is(err, untranslatedErr) { 19 | t.Fatalf("expected err: %v got err: %v", untranslatedErr, err) 20 | } 21 | 22 | // it should translate error when the TranslateError flag is true 23 | db, _ = gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}, &gorm.Config{TranslateError: true}) 24 | 25 | err = db.AddError(untranslatedErr) 26 | if !errors.Is(err, translatedErr) { 27 | t.Fatalf("expected err: %v got err: %v", translatedErr, err) 28 | } 29 | } 30 | 31 | func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) { 32 | type City struct { 33 | gorm.Model 34 | Name string `gorm:"unique"` 35 | } 36 | 37 | db, err := OpenTestConnection(&gorm.Config{TranslateError: true}) 38 | if err != nil { 39 | t.Fatalf("failed to connect database, got error %v", err) 40 | } 41 | 42 | dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true} 43 | if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) { 44 | return 45 | } 46 | 47 | DB.Migrator().DropTable(&City{}) 48 | 49 | if err = db.AutoMigrate(&City{}); err != nil { 50 | t.Fatalf("failed to migrate cities table, got error: %v", err) 51 | } 52 | 53 | err = db.Create(&City{Name: "Kabul"}).Error 54 | if err != nil { 55 | t.Fatalf("failed to create record: %v", err) 56 | } 57 | 58 | err = db.Create(&City{Name: "Kabul"}).Error 59 | if !errors.Is(err, gorm.ErrDuplicatedKey) { 60 | t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) 61 | } 62 | } 63 | 64 | func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) { 65 | tidbSkip(t, "not support the foreign key feature") 66 | 67 | type City struct { 68 | gorm.Model 69 | Name string `gorm:"unique"` 70 | } 71 | 72 | type Museum struct { 73 | gorm.Model 74 | Name string `gorm:"unique"` 75 | CityID uint 76 | City City `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:CityID;References:ID"` 77 | } 78 | 79 | db, err := OpenTestConnection(&gorm.Config{TranslateError: true}) 80 | if err != nil { 81 | t.Fatalf("failed to connect database, got error %v", err) 82 | } 83 | 84 | dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true} 85 | if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) { 86 | return 87 | } 88 | 89 | DB.Migrator().DropTable(&City{}, &Museum{}) 90 | 91 | if err = db.AutoMigrate(&City{}, &Museum{}); err != nil { 92 | t.Fatalf("failed to migrate countries & cities tables, got error: %v", err) 93 | } 94 | 95 | city := City{Name: "Amsterdam"} 96 | 97 | err = db.Create(&city).Error 98 | if err != nil { 99 | t.Fatalf("failed to create city: %v", err) 100 | } 101 | 102 | err = db.Create(&Museum{Name: "Eye Filmmuseum", CityID: city.ID}).Error 103 | if err != nil { 104 | t.Fatalf("failed to create museum: %v", err) 105 | } 106 | 107 | err = db.Create(&Museum{Name: "Dungeon", CityID: 123}).Error 108 | if !errors.Is(err, gorm.ErrForeignKeyViolated) { 109 | t.Fatalf("expected err: %v got err: %v", gorm.ErrForeignKeyViolated, err) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /tests/go.mod: -------------------------------------------------------------------------------- 1 | module gorm.io/gorm/tests 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/google/uuid v1.6.0 7 | github.com/jinzhu/now v1.1.5 8 | github.com/lib/pq v1.10.9 9 | github.com/stretchr/testify v1.10.0 10 | gorm.io/driver/mysql v1.5.7 11 | gorm.io/driver/postgres v1.6.0 12 | gorm.io/driver/sqlite v1.5.7 13 | gorm.io/driver/sqlserver v1.6.0 14 | gorm.io/gorm v1.30.0 15 | ) 16 | 17 | require ( 18 | filippo.io/edwards25519 v1.1.0 // indirect 19 | github.com/davecgh/go-spew v1.1.1 // indirect 20 | github.com/go-sql-driver/mysql v1.9.2 // indirect 21 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect 22 | github.com/golang-sql/sqlexp v0.1.0 // indirect 23 | github.com/jackc/pgpassfile v1.0.0 // indirect 24 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect 25 | github.com/jackc/pgx/v5 v5.7.5 // indirect 26 | github.com/jackc/puddle/v2 v2.2.2 // indirect 27 | github.com/jinzhu/inflection v1.0.0 // indirect 28 | github.com/kr/text v0.2.0 // indirect 29 | github.com/mattn/go-sqlite3 v1.14.28 // indirect 30 | github.com/microsoft/go-mssqldb v1.8.1 // indirect 31 | github.com/pmezard/go-difflib v1.0.0 // indirect 32 | github.com/rogpeppe/go-internal v1.12.0 // indirect 33 | golang.org/x/crypto v0.38.0 // indirect 34 | golang.org/x/sync v0.14.0 // indirect 35 | golang.org/x/text v0.25.0 // indirect 36 | gopkg.in/yaml.v3 v3.0.1 // indirect 37 | ) 38 | 39 | replace gorm.io/gorm => ../ 40 | -------------------------------------------------------------------------------- /tests/gorm_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "gorm.io/driver/mysql" 7 | 8 | "gorm.io/gorm" 9 | ) 10 | 11 | func TestOpen(t *testing.T) { 12 | dsn := "gorm:gorm@tcp(localhost:9910)/gorm?loc=Asia%2FHongKong" // invalid loc 13 | _, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) 14 | if err == nil { 15 | t.Fatalf("should returns error but got nil") 16 | } 17 | } 18 | 19 | func TestReturningWithNullToZeroValues(t *testing.T) { 20 | dialect := DB.Dialector.Name() 21 | switch dialect { 22 | case "mysql", "sqlserver": 23 | // these dialects do not support the "returning" clause 24 | return 25 | default: 26 | // This user struct will leverage the existing users table, but override 27 | // the Name field to default to null. 28 | type user struct { 29 | gorm.Model 30 | Name string `gorm:"default:null"` 31 | } 32 | u1 := user{} 33 | 34 | if results := DB.Create(&u1); results.Error != nil { 35 | t.Fatalf("errors happened on create: %v", results.Error) 36 | } else if results.RowsAffected != 1 { 37 | t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) 38 | } else if u1.ID == 0 { 39 | t.Fatalf("ID expects : not equal 0, got %v", u1.ID) 40 | } 41 | 42 | got := user{} 43 | results := DB.First(&got, "id = ?", u1.ID) 44 | if results.Error != nil { 45 | t.Fatalf("errors happened on first: %v", results.Error) 46 | } else if results.RowsAffected != 1 { 47 | t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) 48 | } else if got.ID != u1.ID { 49 | t.Fatalf("first expects: %v, got %v", u1, got) 50 | } 51 | 52 | results = DB.Select("id, name").Find(&got) 53 | if results.Error != nil { 54 | t.Fatalf("errors happened on first: %v", results.Error) 55 | } else if results.RowsAffected != 1 { 56 | t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) 57 | } else if got.ID != u1.ID { 58 | t.Fatalf("select expects: %v, got %v", u1, got) 59 | } 60 | 61 | u1.Name = "jinzhu" 62 | if results := DB.Save(&u1); results.Error != nil { 63 | t.Fatalf("errors happened on update: %v", results.Error) 64 | } else if results.RowsAffected != 1 { 65 | t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) 66 | } 67 | 68 | u1 = user{} // important to reinitialize this before creating it again 69 | u2 := user{} 70 | db := DB.Session(&gorm.Session{CreateBatchSize: 10}) 71 | 72 | if results := db.Create([]*user{&u1, &u2}); results.Error != nil { 73 | t.Fatalf("errors happened on create: %v", results.Error) 74 | } else if results.RowsAffected != 2 { 75 | t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) 76 | } else if u1.ID == 0 { 77 | t.Fatalf("ID expects : not equal 0, got %v", u1.ID) 78 | } else if u2.ID == 0 { 79 | t.Fatalf("ID expects : not equal 0, got %v", u2.ID) 80 | } 81 | 82 | var gotUsers []user 83 | results = DB.Where("id in (?, ?)", u1.ID, u2.ID).Order("id asc").Select("id, name").Find(&gotUsers) 84 | if results.Error != nil { 85 | t.Fatalf("errors happened on first: %v", results.Error) 86 | } else if results.RowsAffected != 2 { 87 | t.Fatalf("rows affected expects: %v, got %v", 2, results.RowsAffected) 88 | } else if gotUsers[0].ID != u1.ID { 89 | t.Fatalf("select expects: %v, got %v", u1.ID, gotUsers[0].ID) 90 | } else if gotUsers[1].ID != u2.ID { 91 | t.Fatalf("select expects: %v, got %v", u2.ID, gotUsers[1].ID) 92 | } 93 | 94 | u1.Name = "Jinzhu" 95 | u2.Name = "Zhang" 96 | if results := DB.Save([]*user{&u1, &u2}); results.Error != nil { 97 | t.Fatalf("errors happened on update: %v", results.Error) 98 | } else if results.RowsAffected != 2 { 99 | t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) 100 | } 101 | 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /tests/group_by_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "gorm.io/gorm/utils/tests" 7 | ) 8 | 9 | func TestGroupBy(t *testing.T) { 10 | users := []User{{ 11 | Name: "groupby", 12 | Age: 10, 13 | Birthday: Now(), 14 | Active: true, 15 | }, { 16 | Name: "groupby", 17 | Age: 20, 18 | Birthday: Now(), 19 | }, { 20 | Name: "groupby", 21 | Age: 30, 22 | Birthday: Now(), 23 | Active: true, 24 | }, { 25 | Name: "groupby1", 26 | Age: 110, 27 | Birthday: Now(), 28 | }, { 29 | Name: "groupby1", 30 | Age: 220, 31 | Birthday: Now(), 32 | Active: true, 33 | }, { 34 | Name: "groupby1", 35 | Age: 330, 36 | Birthday: Now(), 37 | Active: true, 38 | }} 39 | 40 | if err := DB.Create(&users).Error; err != nil { 41 | t.Errorf("errors happened when create: %v", err) 42 | } 43 | 44 | var name string 45 | var total int 46 | if err := DB.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("name").Row().Scan(&name, &total); err != nil { 47 | t.Errorf("no error should happen, but got %v", err) 48 | } 49 | 50 | if name != "groupby" || total != 60 { 51 | t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) 52 | } 53 | 54 | if err := DB.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("users.name").Row().Scan(&name, &total); err != nil { 55 | t.Errorf("no error should happen, but got %v", err) 56 | } 57 | 58 | if name != "groupby" || total != 60 { 59 | t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) 60 | } 61 | 62 | if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Row().Scan(&name, &total); err != nil { 63 | t.Errorf("no error should happen, but got %v", err) 64 | } 65 | 66 | if name != "groupby1" || total != 660 { 67 | t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) 68 | } 69 | 70 | result := struct { 71 | Name string 72 | Total int64 73 | }{} 74 | 75 | if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Find(&result).Error; err != nil { 76 | t.Errorf("no error should happen, but got %v", err) 77 | } 78 | 79 | if result.Name != "groupby1" || result.Total != 660 { 80 | t.Errorf("name should be groupby, total should be 660, but got %+v", result) 81 | } 82 | 83 | if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Scan(&result).Error; err != nil { 84 | t.Errorf("no error should happen, but got %v", err) 85 | } 86 | 87 | if result.Name != "groupby1" || result.Total != 660 { 88 | t.Errorf("name should be groupby, total should be 660, but got %+v", result) 89 | } 90 | 91 | var active bool 92 | if err := DB.Model(&User{}).Select("name, active, sum(age)").Where("name = ? and active = ?", "groupby", true).Group("name").Group("active").Row().Scan(&name, &active, &total); err != nil { 93 | t.Errorf("no error should happen, but got %v", err) 94 | } 95 | 96 | if name != "groupby" || active != true || total != 40 { 97 | t.Errorf("group by two columns, name %v, age %v, active: %v", name, total, active) 98 | } 99 | 100 | if DB.Dialector.Name() == "mysql" { 101 | if err := DB.Model(&User{}).Select("name, age as total").Where("name LIKE ?", "groupby%").Having("total > ?", 300).Scan(&result).Error; err != nil { 102 | t.Errorf("no error should happen, but got %v", err) 103 | } 104 | 105 | if result.Name != "groupby1" || result.Total != 330 { 106 | t.Errorf("name should be groupby, total should be 660, but got %+v", result) 107 | } 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /tests/joins_table_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/clause" 9 | ) 10 | 11 | type Person struct { 12 | ID int 13 | Name string 14 | Addresses []Address `gorm:"many2many:person_addresses;"` 15 | DeletedAt gorm.DeletedAt 16 | } 17 | 18 | type Address struct { 19 | ID uint 20 | Name string 21 | } 22 | 23 | type PersonAddress struct { 24 | PersonID int 25 | AddressID int 26 | CreatedAt time.Time 27 | DeletedAt gorm.DeletedAt 28 | } 29 | 30 | func TestOverrideJoinTable(t *testing.T) { 31 | DB.Migrator().DropTable(&Person{}, &Address{}, &PersonAddress{}) 32 | 33 | if err := DB.SetupJoinTable(&Person{}, "Addresses", &PersonAddress{}); err != nil { 34 | t.Fatalf("Failed to setup join table for person, got error %v", err) 35 | } 36 | 37 | if err := DB.AutoMigrate(&Person{}, &Address{}); err != nil { 38 | t.Fatalf("Failed to migrate, got %v", err) 39 | } 40 | 41 | address1 := Address{Name: "address 1"} 42 | address2 := Address{Name: "address 2"} 43 | person := Person{Name: "person", Addresses: []Address{address1, address2}} 44 | DB.Create(&person) 45 | 46 | var addresses1 []Address 47 | if err := DB.Model(&person).Association("Addresses").Find(&addresses1); err != nil || len(addresses1) != 2 { 48 | t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses1)) 49 | } 50 | 51 | if err := DB.Model(&person).Association("Addresses").Delete(&person.Addresses[0]); err != nil { 52 | t.Fatalf("Failed to delete address, got error %v", err) 53 | } 54 | 55 | if len(person.Addresses) != 1 { 56 | t.Fatalf("Should have one address left") 57 | } 58 | 59 | if DB.Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 1 { 60 | t.Fatalf("Should found one address") 61 | } 62 | 63 | var addresses2 []Address 64 | if err := DB.Model(&person).Association("Addresses").Find(&addresses2); err != nil || len(addresses2) != 1 { 65 | t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses2)) 66 | } 67 | 68 | if DB.Model(&person).Association("Addresses").Count() != 1 { 69 | t.Fatalf("Should found one address") 70 | } 71 | 72 | var addresses3 []Address 73 | if err := DB.Unscoped().Model(&person).Association("Addresses").Find(&addresses3); err != nil || len(addresses3) != 2 { 74 | t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses3)) 75 | } 76 | 77 | if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 2 { 78 | t.Fatalf("Should found soft deleted addresses with unscoped") 79 | } 80 | 81 | if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { 82 | t.Fatalf("Should found soft deleted addresses with unscoped") 83 | } 84 | 85 | DB.Model(&person).Association("Addresses").Clear() 86 | 87 | if DB.Model(&person).Association("Addresses").Count() != 0 { 88 | t.Fatalf("Should deleted all addresses") 89 | } 90 | 91 | if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { 92 | t.Fatalf("Should found soft deleted addresses with unscoped") 93 | } 94 | 95 | DB.Unscoped().Model(&person).Association("Addresses").Clear() 96 | 97 | if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { 98 | t.Fatalf("address should be deleted when clear with unscoped") 99 | } 100 | 101 | address2_1 := Address{Name: "address 2-1"} 102 | address2_2 := Address{Name: "address 2-2"} 103 | person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}} 104 | DB.Create(&person2) 105 | if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil { 106 | t.Fatalf("failed to delete person, got error: %v", err) 107 | } 108 | 109 | if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 { 110 | t.Errorf("person's addresses expects 2, got %v", count) 111 | } 112 | 113 | if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 { 114 | t.Errorf("person's addresses expects 2, got %v", count) 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /tests/main_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "gorm.io/gorm/utils/tests" 7 | ) 8 | 9 | func TestExceptionsWithInvalidSql(t *testing.T) { 10 | if name := DB.Dialector.Name(); name == "sqlserver" { 11 | t.Skip("skip sqlserver due to it will raise data race for invalid sql") 12 | } 13 | 14 | var columns []string 15 | if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { 16 | t.Errorf("Should got error with invalid SQL") 17 | } 18 | 19 | if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { 20 | t.Errorf("Should got error with invalid SQL") 21 | } 22 | 23 | if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { 24 | t.Errorf("Should got error with invalid SQL") 25 | } 26 | 27 | var count1, count2 int64 28 | DB.Model(&User{}).Count(&count1) 29 | if count1 <= 0 { 30 | t.Errorf("Should find some users") 31 | } 32 | 33 | if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { 34 | t.Errorf("Should got error with invalid SQL") 35 | } 36 | 37 | DB.Model(&User{}).Count(&count2) 38 | if count1 != count2 { 39 | t.Errorf("No user should not be deleted by invalid SQL") 40 | } 41 | } 42 | 43 | func TestSetAndGet(t *testing.T) { 44 | if value, ok := DB.Set("hello", "world").Get("hello"); !ok { 45 | t.Errorf("Should be able to get setting after set") 46 | } else if value.(string) != "world" { 47 | t.Errorf("Set value should not be changed") 48 | } 49 | 50 | if _, ok := DB.Get("non_existing"); ok { 51 | t.Errorf("Get non existing key should return error") 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /tests/named_argument_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "testing" 7 | 8 | "gorm.io/gorm" 9 | . "gorm.io/gorm/utils/tests" 10 | ) 11 | 12 | func TestNamedArg(t *testing.T) { 13 | type NamedUser struct { 14 | gorm.Model 15 | Name1 string 16 | Name2 string 17 | Name3 string 18 | } 19 | 20 | DB.Migrator().DropTable(&NamedUser{}) 21 | DB.AutoMigrate(&NamedUser{}) 22 | 23 | namedUser := NamedUser{Name1: "jinzhu1", Name2: "jinzhu2", Name3: "jinzhu3"} 24 | DB.Create(&namedUser) 25 | 26 | var result NamedUser 27 | DB.First(&result, "name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")) 28 | 29 | AssertEqual(t, result, namedUser) 30 | 31 | var result2 NamedUser 32 | DB.Where("name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")).First(&result2) 33 | 34 | AssertEqual(t, result2, namedUser) 35 | 36 | var result3 NamedUser 37 | DB.Where("name1 = @name OR name2 = @name OR name3 = @name", map[string]interface{}{"name": "jinzhu2"}).First(&result3) 38 | 39 | AssertEqual(t, result3, namedUser) 40 | 41 | var result4 NamedUser 42 | if err := DB.Raw("SELECT * FROM named_users WHERE name1 = @name OR name2 = @name2 OR name3 = @name", sql.Named("name", "jinzhu-none"), sql.Named("name2", "jinzhu2")).Find(&result4).Error; err != nil { 43 | t.Errorf("failed to update with named arg") 44 | } 45 | 46 | AssertEqual(t, result4, namedUser) 47 | 48 | if err := DB.Exec("UPDATE named_users SET name1 = @name, name2 = @name2, name3 = @name", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Error; err != nil { 49 | t.Errorf("failed to update with named arg") 50 | } 51 | 52 | namedUser.Name1 = "jinzhu-new" 53 | namedUser.Name2 = "jinzhu-new2" 54 | namedUser.Name3 = "jinzhu-new" 55 | 56 | var result5 NamedUser 57 | if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil { 58 | t.Errorf("failed to update with named arg") 59 | } 60 | 61 | AssertEqual(t, result5, namedUser) 62 | 63 | var result6 NamedUser 64 | if err := DB.Raw(`SELECT * FROM named_users WHERE (name1 = @name 65 | AND name3 = @name) AND name2 = @name2`, map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result6).Error; err != nil { 66 | t.Errorf("failed to update with named arg") 67 | } 68 | 69 | AssertEqual(t, result6, namedUser) 70 | 71 | var result7 NamedUser 72 | if err := DB.Where("name1 = @name OR name2 = @name", sql.Named("name", "jinzhu-new")).Where("name3 = 'jinzhu-new3'").First(&result7).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { 73 | t.Errorf("should return record not found error, but got %v", err) 74 | } 75 | 76 | DB.Delete(&namedUser) 77 | 78 | var result8 NamedUser 79 | if err := DB.Where("name1 = @name OR name2 = @name", map[string]interface{}{"name": "jinzhu-new"}).First(&result8).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { 80 | t.Errorf("should return record not found error, but got %v", err) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /tests/named_polymorphic_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "gorm.io/gorm/utils/tests" 7 | ) 8 | 9 | type Hamster struct { 10 | Id int 11 | Name string 12 | PreferredToy Toy `gorm:"polymorphic:Owner;polymorphicValue:hamster_preferred"` 13 | OtherToy Toy `gorm:"polymorphic:Owner;polymorphicValue:hamster_other"` 14 | } 15 | 16 | func TestNamedPolymorphic(t *testing.T) { 17 | DB.Migrator().DropTable(&Hamster{}) 18 | DB.AutoMigrate(&Hamster{}) 19 | 20 | hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}} 21 | DB.Save(&hamster) 22 | 23 | hamster2 := Hamster{} 24 | DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) 25 | 26 | if hamster2.PreferredToy.ID != hamster.PreferredToy.ID || hamster2.PreferredToy.Name != hamster.PreferredToy.Name { 27 | t.Errorf("Hamster's preferred toy failed to preload") 28 | } 29 | 30 | if hamster2.OtherToy.ID != hamster.OtherToy.ID || hamster2.OtherToy.Name != hamster.OtherToy.Name { 31 | t.Errorf("Hamster's other toy failed to preload") 32 | } 33 | 34 | // clear to omit Toy.ID in count 35 | hamster2.PreferredToy = Toy{} 36 | hamster2.OtherToy = Toy{} 37 | 38 | if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { 39 | t.Errorf("Hamster's preferred toy count should be 1") 40 | } 41 | 42 | if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { 43 | t.Errorf("Hamster's other toy count should be 1") 44 | } 45 | 46 | // Query 47 | hamsterToy := Toy{} 48 | DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) 49 | if hamsterToy.Name != hamster.PreferredToy.Name { 50 | t.Errorf("Should find has one polymorphic association") 51 | } 52 | 53 | hamsterToy = Toy{} 54 | DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) 55 | if hamsterToy.Name != hamster.OtherToy.Name { 56 | t.Errorf("Should find has one polymorphic association") 57 | } 58 | 59 | // Append 60 | DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ 61 | Name: "bike 2", 62 | }) 63 | 64 | DB.Model(&hamster).Association("OtherToy").Append(&Toy{ 65 | Name: "treadmill 2", 66 | }) 67 | 68 | hamsterToy = Toy{} 69 | DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) 70 | if hamsterToy.Name != "bike 2" { 71 | t.Errorf("Should update has one polymorphic association with Append") 72 | } 73 | 74 | hamsterToy = Toy{} 75 | DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) 76 | if hamsterToy.Name != "treadmill 2" { 77 | t.Errorf("Should update has one polymorphic association with Append") 78 | } 79 | 80 | if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { 81 | t.Errorf("Hamster's toys count should be 1 after Append") 82 | } 83 | 84 | if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { 85 | t.Errorf("Hamster's toys count should be 1 after Append") 86 | } 87 | 88 | // Replace 89 | DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{ 90 | Name: "bike 3", 91 | }) 92 | 93 | DB.Model(&hamster).Association("OtherToy").Replace(&Toy{ 94 | Name: "treadmill 3", 95 | }) 96 | 97 | hamsterToy = Toy{} 98 | DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) 99 | if hamsterToy.Name != "bike 3" { 100 | t.Errorf("Should update has one polymorphic association with Replace") 101 | } 102 | 103 | hamsterToy = Toy{} 104 | DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) 105 | if hamsterToy.Name != "treadmill 3" { 106 | t.Errorf("Should update has one polymorphic association with Replace") 107 | } 108 | 109 | if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { 110 | t.Errorf("hamster's toys count should be 1 after Replace") 111 | } 112 | 113 | if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { 114 | t.Errorf("hamster's toys count should be 1 after Replace") 115 | } 116 | 117 | // Clear 118 | DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ 119 | Name: "bike 2", 120 | }) 121 | DB.Model(&hamster).Association("OtherToy").Append(&Toy{ 122 | Name: "treadmill 2", 123 | }) 124 | 125 | if DB.Model(&hamster).Association("PreferredToy").Count() != 1 { 126 | t.Errorf("Hamster's toys should be added with Append") 127 | } 128 | 129 | if DB.Model(&hamster).Association("OtherToy").Count() != 1 { 130 | t.Errorf("Hamster's toys should be added with Append") 131 | } 132 | 133 | DB.Model(&hamster).Association("PreferredToy").Clear() 134 | 135 | if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 { 136 | t.Errorf("Hamster's preferred toy should be cleared with Clear") 137 | } 138 | 139 | if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { 140 | t.Errorf("Hamster's other toy should be still available") 141 | } 142 | 143 | DB.Model(&hamster).Association("OtherToy").Clear() 144 | if DB.Model(&hamster).Association("OtherToy").Count() != 0 { 145 | t.Errorf("Hamster's other toy should be cleared with Clear") 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /tests/non_std_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | type Animal struct { 9 | Counter uint64 `gorm:"primary_key:yes"` 10 | Name string `gorm:"DEFAULT:'galeone'"` 11 | From string // test reserved sql keyword as field name 12 | Age *time.Time 13 | unexported string // unexported value 14 | CreatedAt time.Time 15 | UpdatedAt time.Time 16 | } 17 | 18 | func TestNonStdPrimaryKeyAndDefaultValues(t *testing.T) { 19 | DB.Migrator().DropTable(&Animal{}) 20 | if err := DB.AutoMigrate(&Animal{}); err != nil { 21 | t.Fatalf("no error should happen when migrate but got %v", err) 22 | } 23 | 24 | animal := Animal{Name: "Ferdinand"} 25 | DB.Save(&animal) 26 | updatedAt1 := animal.UpdatedAt 27 | 28 | DB.Save(&animal).Update("name", "Francis") 29 | if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { 30 | t.Errorf("UpdatedAt should be updated") 31 | } 32 | 33 | var animals []Animal 34 | DB.Find(&animals) 35 | if count := DB.Model(Animal{}).Where("1=1").Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { 36 | t.Error("RowsAffected should be correct when do batch update") 37 | } 38 | 39 | animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) 40 | DB.Save(&animal).Update("From", "a nice place") // The name field should be untouched 41 | DB.First(&animal, animal.Counter) 42 | if animal.Name != "galeone" { 43 | t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name) 44 | } 45 | 46 | // When changing a field with a default value, the change must occur 47 | animal.Name = "amazing horse" 48 | DB.Save(&animal) 49 | DB.First(&animal, animal.Counter) 50 | if animal.Name != "amazing horse" { 51 | t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) 52 | } 53 | 54 | // When changing a field with a default value with blank value 55 | animal.Name = "" 56 | DB.Save(&animal) 57 | DB.First(&animal, animal.Counter) 58 | if animal.Name != "" { 59 | t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /tests/scopes_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "gorm.io/gorm" 8 | . "gorm.io/gorm/utils/tests" 9 | ) 10 | 11 | func NameIn1And2(d *gorm.DB) *gorm.DB { 12 | return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) 13 | } 14 | 15 | func NameIn2And3(d *gorm.DB) *gorm.DB { 16 | return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) 17 | } 18 | 19 | func NameIn(names []string) func(d *gorm.DB) *gorm.DB { 20 | return func(d *gorm.DB) *gorm.DB { 21 | return d.Where("name in (?)", names) 22 | } 23 | } 24 | 25 | func TestScopes(t *testing.T) { 26 | users := []*User{ 27 | GetUser("ScopeUser1", Config{}), 28 | GetUser("ScopeUser2", Config{}), 29 | GetUser("ScopeUser3", Config{}), 30 | } 31 | 32 | DB.Create(&users) 33 | 34 | var users1, users2, users3 []User 35 | DB.Scopes(NameIn1And2).Find(&users1) 36 | if len(users1) != 2 { 37 | t.Errorf("Should found two users's name in 1, 2, but got %v", len(users1)) 38 | } 39 | 40 | DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) 41 | if len(users2) != 1 { 42 | t.Errorf("Should found one user's name is 2, but got %v", len(users2)) 43 | } 44 | 45 | DB.Scopes(NameIn([]string{users[0].Name, users[2].Name})).Find(&users3) 46 | if len(users3) != 2 { 47 | t.Errorf("Should found two users's name in 1, 3, but got %v", len(users3)) 48 | } 49 | 50 | db := DB.Scopes(func(tx *gorm.DB) *gorm.DB { 51 | return tx.Table("custom_table") 52 | }).Session(&gorm.Session{}) 53 | 54 | db.AutoMigrate(&User{}) 55 | if db.Find(&User{}).Statement.Table != "custom_table" { 56 | t.Errorf("failed to call Scopes") 57 | } 58 | 59 | result := DB.Scopes(NameIn1And2, func(tx *gorm.DB) *gorm.DB { 60 | return tx.Session(&gorm.Session{}) 61 | }).Find(&users1) 62 | 63 | if result.RowsAffected != 2 { 64 | t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected) 65 | } 66 | 67 | var maxId int64 68 | userTable := func(db *gorm.DB) *gorm.DB { 69 | return db.WithContext(context.Background()).Table("users") 70 | } 71 | if err := DB.Scopes(userTable).Select("max(id)").Scan(&maxId).Error; err != nil { 72 | t.Errorf("select max(id)") 73 | } 74 | } 75 | 76 | func TestComplexScopes(t *testing.T) { 77 | tests := []struct { 78 | name string 79 | queryFn func(tx *gorm.DB) *gorm.DB 80 | expected string 81 | }{ 82 | { 83 | name: "depth_1", 84 | queryFn: func(tx *gorm.DB) *gorm.DB { 85 | return tx.Scopes( 86 | func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, 87 | func(d *gorm.DB) *gorm.DB { 88 | return d.Where(DB.Or("b = 2").Or("c = 3")) 89 | }, 90 | ).Find(&Language{}) 91 | }, 92 | expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`, 93 | }, { 94 | name: "depth_1_pre_cond", 95 | queryFn: func(tx *gorm.DB) *gorm.DB { 96 | return tx.Where("z = 0").Scopes( 97 | func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, 98 | func(d *gorm.DB) *gorm.DB { 99 | return d.Or(DB.Where("b = 2").Or("c = 3")) 100 | }, 101 | ).Find(&Language{}) 102 | }, 103 | expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`, 104 | }, { 105 | name: "depth_2", 106 | queryFn: func(tx *gorm.DB) *gorm.DB { 107 | return tx.Scopes( 108 | func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) }, 109 | func(d *gorm.DB) *gorm.DB { 110 | return d. 111 | Or(DB.Scopes( 112 | func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, 113 | func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") }, 114 | )). 115 | Or("c = 3") 116 | }, 117 | func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") }, 118 | ).Find(&Language{}) 119 | }, 120 | expected: `SELECT * FROM "languages" WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)`, 121 | }, 122 | } 123 | 124 | for _, test := range tests { 125 | t.Run(test.name, func(t *testing.T) { 126 | assertEqualSQL(t, test.expected, DB.ToSQL(test.queryFn)) 127 | }) 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /tests/tests_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | 3 | dialects=("sqlite" "mysql" "postgres" "sqlserver" "tidb") 4 | 5 | if [[ $(pwd) == *"gorm/tests"* ]]; then 6 | cd .. 7 | fi 8 | 9 | if [ -d tests ] 10 | then 11 | cd tests 12 | go get -u -t ./... 13 | go mod download 14 | go mod tidy 15 | cd .. 16 | fi 17 | 18 | # SqlServer for Mac M1 19 | if [[ -z $GITHUB_ACTION && -d tests ]]; then 20 | cd tests 21 | if [[ $(uname -a) == *" arm64" ]]; then 22 | MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker compose up -d --wait 23 | go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true 24 | for query in \ 25 | "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" \ 26 | "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" \ 27 | "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" 28 | do 29 | SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "$query" > /dev/null || true 30 | done 31 | else 32 | MSSQL_IMAGE=mcr.microsoft.com/mssql/server docker compose up -d --wait 33 | fi 34 | cd .. 35 | fi 36 | 37 | 38 | for dialect in "${dialects[@]}" ; do 39 | if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] 40 | then 41 | echo "testing ${dialect}..." 42 | 43 | if [ "$GORM_VERBOSE" = "" ] 44 | then 45 | GORM_DIALECT=${dialect} go test -race -count=1 ./... 46 | if [ -d tests ] 47 | then 48 | cd tests 49 | GORM_DIALECT=${dialect} go test -race -count=1 ./... 50 | cd .. 51 | fi 52 | else 53 | GORM_DIALECT=${dialect} go test -race -count=1 -v ./... 54 | if [ -d tests ] 55 | then 56 | cd tests 57 | GORM_DIALECT=${dialect} go test -race -count=1 -v ./... 58 | cd .. 59 | fi 60 | fi 61 | fi 62 | done 63 | -------------------------------------------------------------------------------- /tests/tests_test.go: -------------------------------------------------------------------------------- 1 | //go:debug x509negativeserial=1 2 | package tests_test 3 | 4 | import ( 5 | "log" 6 | "math/rand" 7 | "os" 8 | "path/filepath" 9 | "time" 10 | 11 | "gorm.io/driver/mysql" 12 | "gorm.io/driver/postgres" 13 | "gorm.io/driver/sqlite" 14 | "gorm.io/driver/sqlserver" 15 | "gorm.io/gorm" 16 | "gorm.io/gorm/logger" 17 | . "gorm.io/gorm/utils/tests" 18 | ) 19 | 20 | var DB *gorm.DB 21 | var ( 22 | mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" 23 | postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" 24 | sqlserverDSN = "sqlserver://sa:LoremIpsum86@localhost:9930?database=master" 25 | tidbDSN = "root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" 26 | ) 27 | 28 | func init() { 29 | var err error 30 | if DB, err = OpenTestConnection(&gorm.Config{}); err != nil { 31 | log.Printf("failed to connect database, got error %v", err) 32 | os.Exit(1) 33 | } else { 34 | sqlDB, err := DB.DB() 35 | if err != nil { 36 | log.Printf("failed to connect database, got error %v", err) 37 | os.Exit(1) 38 | } 39 | 40 | err = sqlDB.Ping() 41 | if err != nil { 42 | log.Printf("failed to ping sqlDB, got error %v", err) 43 | os.Exit(1) 44 | } 45 | 46 | RunMigrations() 47 | } 48 | } 49 | 50 | func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { 51 | dbDSN := os.Getenv("GORM_DSN") 52 | switch os.Getenv("GORM_DIALECT") { 53 | case "mysql": 54 | log.Println("testing mysql...") 55 | if dbDSN == "" { 56 | dbDSN = mysqlDSN 57 | } 58 | db, err = gorm.Open(mysql.Open(dbDSN), cfg) 59 | case "postgres": 60 | log.Println("testing postgres...") 61 | if dbDSN == "" { 62 | dbDSN = postgresDSN 63 | } 64 | db, err = gorm.Open(postgres.New(postgres.Config{ 65 | DSN: dbDSN, 66 | PreferSimpleProtocol: true, 67 | }), cfg) 68 | case "sqlserver": 69 | // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest 70 | // SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 71 | // CREATE DATABASE gorm; 72 | // GO 73 | // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; 74 | // CREATE USER gorm FROM LOGIN gorm; 75 | // ALTER SERVER ROLE sysadmin ADD MEMBER [gorm]; 76 | // GO 77 | log.Println("testing sqlserver...") 78 | if dbDSN == "" { 79 | dbDSN = sqlserverDSN 80 | } 81 | db, err = gorm.Open(sqlserver.Open(dbDSN), cfg) 82 | case "tidb": 83 | log.Println("testing tidb...") 84 | if dbDSN == "" { 85 | dbDSN = tidbDSN 86 | } 87 | db, err = gorm.Open(mysql.Open(dbDSN), cfg) 88 | default: 89 | log.Println("testing sqlite3...") 90 | db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg) 91 | if err == nil { 92 | db.Exec("PRAGMA foreign_keys = ON") 93 | } 94 | } 95 | 96 | if err != nil { 97 | return 98 | } 99 | 100 | if debug := os.Getenv("DEBUG"); debug == "true" { 101 | db.Logger = db.Logger.LogMode(logger.Info) 102 | } else if debug == "false" { 103 | db.Logger = db.Logger.LogMode(logger.Silent) 104 | } 105 | 106 | return 107 | } 108 | 109 | func RunMigrations() { 110 | var err error 111 | allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}} 112 | rand.Seed(time.Now().UnixNano()) 113 | rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) 114 | 115 | DB.Migrator().DropTable("user_friends", "user_speaks") 116 | 117 | if err = DB.Migrator().DropTable(allModels...); err != nil { 118 | log.Printf("Failed to drop table, got error %v\n", err) 119 | os.Exit(1) 120 | } 121 | 122 | if err = DB.AutoMigrate(allModels...); err != nil { 123 | log.Printf("Failed to auto migrate, but got error %v\n", err) 124 | os.Exit(1) 125 | } 126 | 127 | for _, m := range allModels { 128 | if !DB.Migrator().HasTable(m) { 129 | log.Printf("Failed to create table for %#v\n", m) 130 | os.Exit(1) 131 | } 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /tests/tracer_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "gorm.io/gorm/logger" 8 | ) 9 | 10 | type Tracer struct { 11 | Logger logger.Interface 12 | Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) 13 | } 14 | 15 | func (S Tracer) LogMode(level logger.LogLevel) logger.Interface { 16 | return S.Logger.LogMode(level) 17 | } 18 | 19 | func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) { 20 | S.Logger.Info(ctx, s, i...) 21 | } 22 | 23 | func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) { 24 | S.Logger.Warn(ctx, s, i...) 25 | } 26 | 27 | func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) { 28 | S.Logger.Error(ctx, s, i...) 29 | } 30 | 31 | func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { 32 | S.Logger.Trace(ctx, begin, fc, err) 33 | S.Test(ctx, begin, fc, err) 34 | } 35 | -------------------------------------------------------------------------------- /tests/update_belongs_to_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "gorm.io/gorm" 7 | . "gorm.io/gorm/utils/tests" 8 | ) 9 | 10 | func TestUpdateBelongsTo(t *testing.T) { 11 | user := *GetUser("update-belongs-to", Config{}) 12 | 13 | if err := DB.Create(&user).Error; err != nil { 14 | t.Fatalf("errors happened when create: %v", err) 15 | } 16 | 17 | user.Company = Company{Name: "company-belongs-to-association"} 18 | user.Manager = &User{Name: "manager-belongs-to-association"} 19 | if err := DB.Save(&user).Error; err != nil { 20 | t.Fatalf("errors happened when update: %v", err) 21 | } 22 | 23 | var user2 User 24 | DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) 25 | CheckUser(t, user2, user) 26 | 27 | user.Company.Name += "new" 28 | user.Manager.Name += "new" 29 | if err := DB.Save(&user).Error; err != nil { 30 | t.Fatalf("errors happened when update: %v", err) 31 | } 32 | 33 | var user3 User 34 | DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID) 35 | CheckUser(t, user2, user3) 36 | 37 | if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { 38 | t.Fatalf("errors happened when update: %v", err) 39 | } 40 | 41 | var user4 User 42 | DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID) 43 | CheckUser(t, user4, user) 44 | 45 | user.Company.Name += "new2" 46 | user.Manager.Name += "new2" 47 | if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Select("`Company`").Save(&user).Error; err != nil { 48 | t.Fatalf("errors happened when update: %v", err) 49 | } 50 | 51 | var user5 User 52 | DB.Preload("Company").Preload("Manager").Find(&user5, "id = ?", user.ID) 53 | if user5.Manager.Name != user4.Manager.Name { 54 | t.Errorf("should not update user's manager") 55 | } else { 56 | user.Manager.Name = user4.Manager.Name 57 | } 58 | CheckUser(t, user, user5) 59 | } 60 | -------------------------------------------------------------------------------- /tests/update_has_many_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "gorm.io/gorm" 7 | . "gorm.io/gorm/utils/tests" 8 | ) 9 | 10 | func TestUpdateHasManyAssociations(t *testing.T) { 11 | user := *GetUser("update-has-many", Config{}) 12 | 13 | if err := DB.Create(&user).Error; err != nil { 14 | t.Fatalf("errors happened when create: %v", err) 15 | } 16 | 17 | user.Pets = []*Pet{{Name: "pet1"}, {Name: "pet2"}} 18 | if err := DB.Save(&user).Error; err != nil { 19 | t.Fatalf("errors happened when update: %v", err) 20 | } 21 | 22 | var user2 User 23 | DB.Preload("Pets").Find(&user2, "id = ?", user.ID) 24 | CheckUser(t, user2, user) 25 | 26 | for _, pet := range user.Pets { 27 | pet.Name += "new" 28 | } 29 | 30 | if err := DB.Save(&user).Error; err != nil { 31 | t.Fatalf("errors happened when update: %v", err) 32 | } 33 | 34 | var user3 User 35 | DB.Preload("Pets").Find(&user3, "id = ?", user.ID) 36 | CheckUser(t, user2, user3) 37 | 38 | if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { 39 | t.Fatalf("errors happened when update: %v", err) 40 | } 41 | 42 | var user4 User 43 | DB.Preload("Pets").Find(&user4, "id = ?", user.ID) 44 | CheckUser(t, user4, user) 45 | 46 | t.Run("Polymorphic", func(t *testing.T) { 47 | user := *GetUser("update-has-many", Config{}) 48 | 49 | if err := DB.Create(&user).Error; err != nil { 50 | t.Fatalf("errors happened when create: %v", err) 51 | } 52 | 53 | user.Toys = []Toy{{Name: "toy1"}, {Name: "toy2"}} 54 | if err := DB.Save(&user).Error; err != nil { 55 | t.Fatalf("errors happened when update: %v", err) 56 | } 57 | 58 | var user2 User 59 | DB.Preload("Toys").Find(&user2, "id = ?", user.ID) 60 | CheckUser(t, user2, user) 61 | 62 | for idx := range user.Toys { 63 | user.Toys[idx].Name += "new" 64 | } 65 | 66 | if err := DB.Save(&user).Error; err != nil { 67 | t.Fatalf("errors happened when update: %v", err) 68 | } 69 | 70 | var user3 User 71 | DB.Preload("Toys").Find(&user3, "id = ?", user.ID) 72 | CheckUser(t, user2, user3) 73 | 74 | if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { 75 | t.Fatalf("errors happened when update: %v", err) 76 | } 77 | 78 | var user4 User 79 | DB.Preload("Toys").Find(&user4, "id = ?", user.ID) 80 | CheckUser(t, user4, user) 81 | }) 82 | } 83 | -------------------------------------------------------------------------------- /tests/update_has_one_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | "time" 7 | 8 | "gorm.io/gorm" 9 | . "gorm.io/gorm/utils/tests" 10 | ) 11 | 12 | func TestUpdateHasOne(t *testing.T) { 13 | user := *GetUser("update-has-one", Config{}) 14 | 15 | if err := DB.Create(&user).Error; err != nil { 16 | t.Fatalf("errors happened when create: %v", err) 17 | } 18 | 19 | user.Account = Account{Number: "account-has-one-association"} 20 | 21 | if err := DB.Save(&user).Error; err != nil { 22 | t.Fatalf("errors happened when update: %v", err) 23 | } 24 | 25 | var user2 User 26 | DB.Preload("Account").Find(&user2, "id = ?", user.ID) 27 | CheckUser(t, user2, user) 28 | 29 | user.Account.Number += "new" 30 | if err := DB.Save(&user).Error; err != nil { 31 | t.Fatalf("errors happened when update: %v", err) 32 | } 33 | 34 | var user3 User 35 | DB.Preload("Account").Find(&user3, "id = ?", user.ID) 36 | 37 | CheckUser(t, user2, user3) 38 | lastUpdatedAt := user2.Account.UpdatedAt 39 | time.Sleep(time.Second) 40 | 41 | if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { 42 | t.Fatalf("errors happened when update: %v", err) 43 | } 44 | 45 | var user4 User 46 | DB.Preload("Account").Find(&user4, "id = ?", user.ID) 47 | 48 | if lastUpdatedAt.Format(time.RFC3339) == user4.Account.UpdatedAt.Format(time.RFC3339) { 49 | t.Fatalf("updated at should be updated, but not, old: %v, new %v", lastUpdatedAt.Format(time.RFC3339), user3.Account.UpdatedAt.Format(time.RFC3339)) 50 | } else { 51 | user.Account.UpdatedAt = user4.Account.UpdatedAt 52 | CheckUser(t, user4, user) 53 | } 54 | 55 | t.Run("Polymorphic", func(t *testing.T) { 56 | pet := Pet{Name: "create"} 57 | 58 | if err := DB.Create(&pet).Error; err != nil { 59 | t.Fatalf("errors happened when create: %v", err) 60 | } 61 | 62 | pet.Toy = Toy{Name: "Update-HasOneAssociation-Polymorphic"} 63 | 64 | if err := DB.Save(&pet).Error; err != nil { 65 | t.Fatalf("errors happened when create: %v", err) 66 | } 67 | 68 | var pet2 Pet 69 | DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) 70 | CheckPet(t, pet2, pet) 71 | 72 | pet.Toy.Name += "new" 73 | if err := DB.Save(&pet).Error; err != nil { 74 | t.Fatalf("errors happened when update: %v", err) 75 | } 76 | 77 | var pet3 Pet 78 | DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID) 79 | CheckPet(t, pet2, pet3) 80 | 81 | if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil { 82 | t.Fatalf("errors happened when update: %v", err) 83 | } 84 | 85 | var pet4 Pet 86 | DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID) 87 | CheckPet(t, pet4, pet) 88 | }) 89 | 90 | t.Run("Restriction", func(t *testing.T) { 91 | type CustomizeAccount struct { 92 | gorm.Model 93 | UserID sql.NullInt64 94 | Number string `gorm:"<-:create"` 95 | Number2 string 96 | } 97 | 98 | type CustomizeUser struct { 99 | gorm.Model 100 | Name string 101 | Account CustomizeAccount `gorm:"foreignkey:UserID"` 102 | } 103 | 104 | DB.Migrator().DropTable(&CustomizeUser{}) 105 | DB.Migrator().DropTable(&CustomizeAccount{}) 106 | 107 | if err := DB.AutoMigrate(&CustomizeUser{}); err != nil { 108 | t.Fatalf("failed to migrate, got error: %v", err) 109 | } 110 | if err := DB.AutoMigrate(&CustomizeAccount{}); err != nil { 111 | t.Fatalf("failed to migrate, got error: %v", err) 112 | } 113 | 114 | number := "number-has-one-associations" 115 | cusUser := CustomizeUser{ 116 | Name: "update-has-one-associations", 117 | Account: CustomizeAccount{ 118 | Number: number, 119 | Number2: number, 120 | }, 121 | } 122 | 123 | if err := DB.Create(&cusUser).Error; err != nil { 124 | t.Fatalf("errors happened when create: %v", err) 125 | } 126 | cusUser.Account.Number += "-update" 127 | cusUser.Account.Number2 += "-update" 128 | if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&cusUser).Error; err != nil { 129 | t.Fatalf("errors happened when create: %v", err) 130 | } 131 | 132 | var account2 CustomizeAccount 133 | DB.Find(&account2, "user_id = ?", cusUser.ID) 134 | AssertEqual(t, account2.Number, number) 135 | AssertEqual(t, account2.Number2, cusUser.Account.Number2) 136 | }) 137 | } 138 | -------------------------------------------------------------------------------- /tests/update_many2many_test.go: -------------------------------------------------------------------------------- 1 | package tests_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "gorm.io/gorm" 7 | . "gorm.io/gorm/utils/tests" 8 | ) 9 | 10 | func TestUpdateMany2ManyAssociations(t *testing.T) { 11 | user := *GetUser("update-many2many", Config{}) 12 | 13 | if err := DB.Create(&user).Error; err != nil { 14 | t.Fatalf("errors happened when create: %v", err) 15 | } 16 | 17 | user.Languages = []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}} 18 | for _, lang := range user.Languages { 19 | DB.Create(&lang) 20 | } 21 | user.Friends = []*User{{Name: "friend-1"}, {Name: "friend-2"}} 22 | 23 | if err := DB.Save(&user).Error; err != nil { 24 | t.Fatalf("errors happened when update: %v", err) 25 | } 26 | 27 | var user2 User 28 | DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) 29 | CheckUser(t, user2, user) 30 | 31 | for idx := range user.Friends { 32 | user.Friends[idx].Name += "new" 33 | } 34 | 35 | for idx := range user.Languages { 36 | user.Languages[idx].Name += "new" 37 | } 38 | 39 | if err := DB.Save(&user).Error; err != nil { 40 | t.Fatalf("errors happened when update: %v", err) 41 | } 42 | 43 | var user3 User 44 | DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID) 45 | CheckUser(t, user2, user3) 46 | 47 | if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { 48 | t.Fatalf("errors happened when update: %v", err) 49 | } 50 | 51 | var user4 User 52 | DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID) 53 | CheckUser(t, user4, user) 54 | } 55 | -------------------------------------------------------------------------------- /utils/tests/dummy_dialecter.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | "gorm.io/gorm/callbacks" 6 | "gorm.io/gorm/clause" 7 | "gorm.io/gorm/logger" 8 | "gorm.io/gorm/schema" 9 | ) 10 | 11 | type DummyDialector struct { 12 | TranslatedErr error 13 | } 14 | 15 | func (DummyDialector) Name() string { 16 | return "dummy" 17 | } 18 | 19 | func (DummyDialector) Initialize(db *gorm.DB) error { 20 | callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ 21 | CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, 22 | UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, 23 | DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, 24 | LastInsertIDReversed: true, 25 | }) 26 | 27 | return nil 28 | } 29 | 30 | func (DummyDialector) DefaultValueOf(field *schema.Field) clause.Expression { 31 | return clause.Expr{SQL: "DEFAULT"} 32 | } 33 | 34 | func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator { 35 | return nil 36 | } 37 | 38 | func (DummyDialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { 39 | writer.WriteByte('?') 40 | } 41 | 42 | func (DummyDialector) QuoteTo(writer clause.Writer, str string) { 43 | var ( 44 | underQuoted, selfQuoted bool 45 | continuousBacktick int8 46 | shiftDelimiter int8 47 | ) 48 | 49 | for _, v := range []byte(str) { 50 | switch v { 51 | case '`': 52 | continuousBacktick++ 53 | if continuousBacktick == 2 { 54 | writer.WriteString("``") 55 | continuousBacktick = 0 56 | } 57 | case '.': 58 | if continuousBacktick > 0 || !selfQuoted { 59 | shiftDelimiter = 0 60 | underQuoted = false 61 | continuousBacktick = 0 62 | writer.WriteByte('`') 63 | } 64 | writer.WriteByte(v) 65 | continue 66 | default: 67 | if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { 68 | writer.WriteByte('`') 69 | underQuoted = true 70 | if selfQuoted = continuousBacktick > 0; selfQuoted { 71 | continuousBacktick -= 1 72 | } 73 | } 74 | 75 | for ; continuousBacktick > 0; continuousBacktick -= 1 { 76 | writer.WriteString("``") 77 | } 78 | 79 | writer.WriteByte(v) 80 | } 81 | shiftDelimiter++ 82 | } 83 | 84 | if continuousBacktick > 0 && !selfQuoted { 85 | writer.WriteString("``") 86 | } 87 | writer.WriteByte('`') 88 | } 89 | 90 | func (DummyDialector) Explain(sql string, vars ...interface{}) string { 91 | return logger.ExplainSQL(sql, nil, `"`, vars...) 92 | } 93 | 94 | func (DummyDialector) DataTypeOf(*schema.Field) string { 95 | return "" 96 | } 97 | 98 | func (d DummyDialector) Translate(err error) error { 99 | return d.TranslatedErr 100 | } 101 | -------------------------------------------------------------------------------- /utils/tests/models.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "database/sql" 5 | "time" 6 | 7 | "gorm.io/gorm" 8 | ) 9 | 10 | // User has one `Account` (has one), many `Pets` (has many) and `Toys` (has many - polymorphic) 11 | // He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) 12 | // He speaks many languages (many to many) and has many friends (many to many - single-table) 13 | // His pet also has one Toy (has one - polymorphic) 14 | // NamedPet is a reference to a named `Pet` (has one) 15 | type User struct { 16 | gorm.Model 17 | Name string 18 | Age uint 19 | Birthday *time.Time 20 | Account Account 21 | Pets []*Pet 22 | NamedPet *Pet 23 | Toys []Toy `gorm:"polymorphic:Owner"` 24 | Tools []Tools `gorm:"polymorphicType:Type;polymorphicId:CustomID"` 25 | CompanyID *int 26 | Company Company 27 | ManagerID *uint 28 | Manager *User 29 | Team []User `gorm:"foreignkey:ManagerID"` 30 | Languages []Language `gorm:"many2many:UserSpeak;"` 31 | Friends []*User `gorm:"many2many:user_friends;"` 32 | Active bool 33 | } 34 | 35 | type Account struct { 36 | gorm.Model 37 | UserID sql.NullInt64 38 | Number string 39 | } 40 | 41 | type Pet struct { 42 | gorm.Model 43 | UserID *uint 44 | Name string 45 | Toy Toy `gorm:"polymorphic:Owner;"` 46 | } 47 | 48 | type Toy struct { 49 | gorm.Model 50 | Name string 51 | OwnerID string 52 | OwnerType string 53 | } 54 | 55 | type Tools struct { 56 | gorm.Model 57 | Name string 58 | CustomID string 59 | Type string 60 | } 61 | 62 | type Company struct { 63 | ID int 64 | Name string 65 | } 66 | 67 | type Language struct { 68 | Code string `gorm:"primarykey"` 69 | Name string 70 | } 71 | 72 | type Coupon struct { 73 | ID int `gorm:"primarykey; size:255"` 74 | AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"` 75 | AmountOff uint32 `gorm:"column:amount_off"` 76 | PercentOff float32 `gorm:"column:percent_off"` 77 | } 78 | 79 | type CouponProduct struct { 80 | CouponId int `gorm:"primarykey;size:255"` 81 | ProductId string `gorm:"primarykey;size:255"` 82 | Desc string 83 | } 84 | 85 | type Order struct { 86 | gorm.Model 87 | Num string 88 | Coupon *Coupon 89 | CouponID string 90 | } 91 | 92 | type Parent struct { 93 | gorm.Model 94 | FavChildID uint 95 | FavChild *Child 96 | Children []*Child 97 | } 98 | 99 | type Child struct { 100 | gorm.Model 101 | Name string 102 | ParentID *uint 103 | Parent *Parent 104 | } 105 | -------------------------------------------------------------------------------- /utils/tests/utils.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "go/ast" 7 | "reflect" 8 | "testing" 9 | "time" 10 | 11 | "gorm.io/gorm/utils" 12 | ) 13 | 14 | func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { 15 | for _, name := range names { 16 | rv := reflect.Indirect(reflect.ValueOf(r)) 17 | ev := reflect.Indirect(reflect.ValueOf(e)) 18 | if rv.IsValid() != ev.IsValid() { 19 | t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), r, e) 20 | return 21 | } 22 | got := rv.FieldByName(name).Interface() 23 | expect := ev.FieldByName(name).Interface() 24 | t.Run(name, func(t *testing.T) { 25 | AssertEqual(t, got, expect) 26 | }) 27 | } 28 | } 29 | 30 | func AssertEqual(t *testing.T, got, expect interface{}) { 31 | if !reflect.DeepEqual(got, expect) { 32 | isEqual := func() { 33 | if curTime, ok := got.(time.Time); ok { 34 | format := "2006-01-02T15:04:05Z07:00" 35 | 36 | if curTime.Round(time.Second).UTC().Format(format) != expect.(time.Time).Round(time.Second).UTC().Format(format) && curTime.Truncate(time.Second).UTC().Format(format) != expect.(time.Time).Truncate(time.Second).UTC().Format(format) { 37 | t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) 38 | } 39 | } else if fmt.Sprint(got) != fmt.Sprint(expect) { 40 | t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) 41 | } 42 | } 43 | 44 | if fmt.Sprint(got) == fmt.Sprint(expect) { 45 | return 46 | } 47 | 48 | if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { 49 | t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) 50 | return 51 | } 52 | 53 | if valuer, ok := got.(driver.Valuer); ok { 54 | got, _ = valuer.Value() 55 | } 56 | 57 | if valuer, ok := expect.(driver.Valuer); ok { 58 | expect, _ = valuer.Value() 59 | } 60 | 61 | if got != nil { 62 | got = reflect.Indirect(reflect.ValueOf(got)).Interface() 63 | } 64 | 65 | if expect != nil { 66 | expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() 67 | } 68 | 69 | if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { 70 | t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) 71 | return 72 | } 73 | 74 | if reflect.ValueOf(got).Kind() == reflect.Slice { 75 | if reflect.ValueOf(expect).Kind() == reflect.Slice { 76 | if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { 77 | for i := 0; i < reflect.ValueOf(got).Len(); i++ { 78 | name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) 79 | t.Run(name, func(t *testing.T) { 80 | AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) 81 | }) 82 | } 83 | } else { 84 | name := reflect.ValueOf(got).Type().Elem().Name() 85 | t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got) 86 | } 87 | return 88 | } 89 | } 90 | 91 | if reflect.ValueOf(got).Kind() == reflect.Struct { 92 | if reflect.ValueOf(expect).Kind() == reflect.Struct { 93 | if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { 94 | exported := false 95 | for i := 0; i < reflect.ValueOf(got).NumField(); i++ { 96 | if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { 97 | exported = true 98 | field := reflect.ValueOf(got).Field(i) 99 | t.Run(fieldStruct.Name, func(t *testing.T) { 100 | AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) 101 | }) 102 | } 103 | } 104 | 105 | if exported { 106 | return 107 | } 108 | } 109 | } 110 | } 111 | 112 | if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { 113 | got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() 114 | isEqual() 115 | } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { 116 | expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() 117 | isEqual() 118 | } else { 119 | t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) 120 | return 121 | } 122 | } 123 | } 124 | 125 | func Now() *time.Time { 126 | now := time.Now() 127 | return &now 128 | } 129 | -------------------------------------------------------------------------------- /utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "path/filepath" 7 | "reflect" 8 | "runtime" 9 | "strconv" 10 | "strings" 11 | "unicode" 12 | ) 13 | 14 | var gormSourceDir string 15 | 16 | func init() { 17 | _, file, _, _ := runtime.Caller(0) 18 | // compatible solution to get gorm source directory with various operating systems 19 | gormSourceDir = sourceDir(file) 20 | } 21 | 22 | func sourceDir(file string) string { 23 | dir := filepath.Dir(file) 24 | dir = filepath.Dir(dir) 25 | 26 | s := filepath.Dir(dir) 27 | if filepath.Base(s) != "gorm.io" { 28 | s = dir 29 | } 30 | return filepath.ToSlash(s) + "/" 31 | } 32 | 33 | // FileWithLineNum return the file name and line number of the current file 34 | func FileWithLineNum() string { 35 | pcs := [13]uintptr{} 36 | // the third caller usually from gorm internal 37 | len := runtime.Callers(3, pcs[:]) 38 | frames := runtime.CallersFrames(pcs[:len]) 39 | for i := 0; i < len; i++ { 40 | // second return value is "more", not "ok" 41 | frame, _ := frames.Next() 42 | if (!strings.HasPrefix(frame.File, gormSourceDir) || 43 | strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { 44 | return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) 45 | } 46 | } 47 | 48 | return "" 49 | } 50 | 51 | func IsValidDBNameChar(c rune) bool { 52 | return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' 53 | } 54 | 55 | // CheckTruth check string true or not 56 | func CheckTruth(vals ...string) bool { 57 | for _, val := range vals { 58 | if val != "" && !strings.EqualFold(val, "false") { 59 | return true 60 | } 61 | } 62 | return false 63 | } 64 | 65 | func ToStringKey(values ...interface{}) string { 66 | results := make([]string, len(values)) 67 | 68 | for idx, value := range values { 69 | if valuer, ok := value.(driver.Valuer); ok { 70 | value, _ = valuer.Value() 71 | } 72 | 73 | switch v := value.(type) { 74 | case string: 75 | results[idx] = v 76 | case []byte: 77 | results[idx] = string(v) 78 | case uint: 79 | results[idx] = strconv.FormatUint(uint64(v), 10) 80 | default: 81 | results[idx] = "nil" 82 | vv := reflect.ValueOf(v) 83 | if vv.IsValid() && !vv.IsZero() { 84 | results[idx] = fmt.Sprint(reflect.Indirect(vv).Interface()) 85 | } 86 | } 87 | } 88 | 89 | return strings.Join(results, "_") 90 | } 91 | 92 | func Contains(elems []string, elem string) bool { 93 | for _, e := range elems { 94 | if elem == e { 95 | return true 96 | } 97 | } 98 | return false 99 | } 100 | 101 | func AssertEqual(x, y interface{}) bool { 102 | if reflect.DeepEqual(x, y) { 103 | return true 104 | } 105 | if x == nil || y == nil { 106 | return false 107 | } 108 | 109 | xval := reflect.ValueOf(x) 110 | yval := reflect.ValueOf(y) 111 | if xval.Kind() == reflect.Ptr && xval.IsNil() || 112 | yval.Kind() == reflect.Ptr && yval.IsNil() { 113 | return false 114 | } 115 | 116 | if valuer, ok := x.(driver.Valuer); ok { 117 | x, _ = valuer.Value() 118 | } 119 | if valuer, ok := y.(driver.Valuer); ok { 120 | y, _ = valuer.Value() 121 | } 122 | return reflect.DeepEqual(x, y) 123 | } 124 | 125 | func ToString(value interface{}) string { 126 | switch v := value.(type) { 127 | case string: 128 | return v 129 | case int: 130 | return strconv.FormatInt(int64(v), 10) 131 | case int8: 132 | return strconv.FormatInt(int64(v), 10) 133 | case int16: 134 | return strconv.FormatInt(int64(v), 10) 135 | case int32: 136 | return strconv.FormatInt(int64(v), 10) 137 | case int64: 138 | return strconv.FormatInt(v, 10) 139 | case uint: 140 | return strconv.FormatUint(uint64(v), 10) 141 | case uint8: 142 | return strconv.FormatUint(uint64(v), 10) 143 | case uint16: 144 | return strconv.FormatUint(uint64(v), 10) 145 | case uint32: 146 | return strconv.FormatUint(uint64(v), 10) 147 | case uint64: 148 | return strconv.FormatUint(v, 10) 149 | } 150 | return "" 151 | } 152 | 153 | const nestedRelationSplit = "__" 154 | 155 | // NestedRelationName nested relationships like `Manager__Company` 156 | func NestedRelationName(prefix, name string) string { 157 | return prefix + nestedRelationSplit + name 158 | } 159 | 160 | // SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}` 161 | func SplitNestedRelationName(name string) []string { 162 | return strings.Split(name, nestedRelationSplit) 163 | } 164 | 165 | // JoinNestedRelationNames nested relationships like `Manager__Company` 166 | func JoinNestedRelationNames(relationNames []string) string { 167 | return strings.Join(relationNames, nestedRelationSplit) 168 | } 169 | 170 | // RTrimSlice Right trims the given slice by given length 171 | func RTrimSlice[T any](v []T, trimLen int) []T { 172 | if trimLen >= len(v) { // trimLen greater than slice len means fully sliced 173 | return v[:0] 174 | } 175 | if trimLen < 0 { // negative trimLen is ignored 176 | return v[:] 177 | } 178 | return v[:len(v)-trimLen] 179 | } 180 | -------------------------------------------------------------------------------- /utils/utils_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "errors" 7 | "math" 8 | "strings" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func TestIsValidDBNameChar(t *testing.T) { 14 | for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} { 15 | if fields := strings.FieldsFunc(db, IsValidDBNameChar); len(fields) != 1 { 16 | t.Fatalf("failed to parse db name %v", db) 17 | } 18 | } 19 | } 20 | 21 | func TestCheckTruth(t *testing.T) { 22 | checkTruthTests := []struct { 23 | v string 24 | out bool 25 | }{ 26 | {"123", true}, 27 | {"true", true}, 28 | {"", false}, 29 | {"false", false}, 30 | {"False", false}, 31 | {"FALSE", false}, 32 | {"\u0046alse", false}, 33 | } 34 | 35 | for _, test := range checkTruthTests { 36 | t.Run(test.v, func(t *testing.T) { 37 | if out := CheckTruth(test.v); out != test.out { 38 | t.Errorf("CheckTruth(%s) want: %t, got: %t", test.v, test.out, out) 39 | } 40 | }) 41 | } 42 | } 43 | 44 | func TestToStringKey(t *testing.T) { 45 | cases := []struct { 46 | values []interface{} 47 | key string 48 | }{ 49 | {[]interface{}{"a"}, "a"}, 50 | {[]interface{}{1, 2, 3}, "1_2_3"}, 51 | {[]interface{}{1, nil, 3}, "1_nil_3"}, 52 | {[]interface{}{[]interface{}{1, 2, 3}}, "[1 2 3]"}, 53 | {[]interface{}{[]interface{}{"1", "2", "3"}}, "[1 2 3]"}, 54 | {[]interface{}{[]interface{}{"1", nil, "3"}}, "[1 3]"}, 55 | } 56 | for _, c := range cases { 57 | if key := ToStringKey(c.values...); key != c.key { 58 | t.Errorf("%v: expected %v, got %v", c.values, c.key, key) 59 | } 60 | } 61 | } 62 | 63 | func TestContains(t *testing.T) { 64 | containsTests := []struct { 65 | name string 66 | elems []string 67 | elem string 68 | out bool 69 | }{ 70 | {"exists", []string{"1", "2", "3"}, "1", true}, 71 | {"not exists", []string{"1", "2", "3"}, "4", false}, 72 | } 73 | for _, test := range containsTests { 74 | t.Run(test.name, func(t *testing.T) { 75 | if out := Contains(test.elems, test.elem); test.out != out { 76 | t.Errorf("Contains(%v, %s) want: %t, got: %t", test.elems, test.elem, test.out, out) 77 | } 78 | }) 79 | } 80 | } 81 | 82 | type ModifyAt sql.NullTime 83 | 84 | // Value return a Unix time. 85 | func (n ModifyAt) Value() (driver.Value, error) { 86 | if !n.Valid { 87 | return nil, nil 88 | } 89 | return n.Time.Unix(), nil 90 | } 91 | 92 | func TestAssertEqual(t *testing.T) { 93 | now := time.Now() 94 | assertEqualTests := []struct { 95 | name string 96 | src, dst interface{} 97 | out bool 98 | }{ 99 | {"error equal", errors.New("1"), errors.New("1"), true}, 100 | {"error not equal", errors.New("1"), errors.New("2"), false}, 101 | {"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true}, 102 | {"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false}, 103 | {"driver.Valuer equal (ptr to nil ptr)", (*ModifyAt)(nil), &ModifyAt{}, false}, 104 | } 105 | for _, test := range assertEqualTests { 106 | t.Run(test.name, func(t *testing.T) { 107 | if out := AssertEqual(test.src, test.dst); test.out != out { 108 | t.Errorf("AssertEqual(%v, %v) want: %t, got: %t", test.src, test.dst, test.out, out) 109 | } 110 | }) 111 | } 112 | } 113 | 114 | func TestToString(t *testing.T) { 115 | tests := []struct { 116 | name string 117 | in interface{} 118 | out string 119 | }{ 120 | {"int", math.MaxInt64, "9223372036854775807"}, 121 | {"int8", int8(math.MaxInt8), "127"}, 122 | {"int16", int16(math.MaxInt16), "32767"}, 123 | {"int32", int32(math.MaxInt32), "2147483647"}, 124 | {"int64", int64(math.MaxInt64), "9223372036854775807"}, 125 | {"uint", uint(math.MaxUint64), "18446744073709551615"}, 126 | {"uint8", uint8(math.MaxUint8), "255"}, 127 | {"uint16", uint16(math.MaxUint16), "65535"}, 128 | {"uint32", uint32(math.MaxUint32), "4294967295"}, 129 | {"uint64", uint64(math.MaxUint64), "18446744073709551615"}, 130 | {"string", "abc", "abc"}, 131 | {"other", true, ""}, 132 | } 133 | for _, test := range tests { 134 | t.Run(test.name, func(t *testing.T) { 135 | if out := ToString(test.in); test.out != out { 136 | t.Fatalf("ToString(%v) want: %s, got: %s", test.in, test.out, out) 137 | } 138 | }) 139 | } 140 | } 141 | 142 | func TestRTrimSlice(t *testing.T) { 143 | tests := []struct { 144 | name string 145 | input []int 146 | trimLen int 147 | expected []int 148 | }{ 149 | { 150 | name: "Trim two elements from end", 151 | input: []int{1, 2, 3, 4, 5}, 152 | trimLen: 2, 153 | expected: []int{1, 2, 3}, 154 | }, 155 | { 156 | name: "Trim entire slice", 157 | input: []int{1, 2, 3}, 158 | trimLen: 3, 159 | expected: []int{}, 160 | }, 161 | { 162 | name: "Trim length greater than slice length", 163 | input: []int{1, 2, 3}, 164 | trimLen: 5, 165 | expected: []int{}, 166 | }, 167 | { 168 | name: "Zero trim length", 169 | input: []int{1, 2, 3}, 170 | trimLen: 0, 171 | expected: []int{1, 2, 3}, 172 | }, 173 | { 174 | name: "Trim one element from end", 175 | input: []int{1, 2, 3}, 176 | trimLen: 1, 177 | expected: []int{1, 2}, 178 | }, 179 | { 180 | name: "Empty slice", 181 | input: []int{}, 182 | trimLen: 2, 183 | expected: []int{}, 184 | }, 185 | { 186 | name: "Negative trim length (should be treated as zero)", 187 | input: []int{1, 2, 3}, 188 | trimLen: -1, 189 | expected: []int{1, 2, 3}, 190 | }, 191 | } 192 | 193 | for _, testcase := range tests { 194 | t.Run(testcase.name, func(t *testing.T) { 195 | result := RTrimSlice(testcase.input, testcase.trimLen) 196 | if !AssertEqual(result, testcase.expected) { 197 | t.Errorf("RTrimSlice(%v, %d) = %v; want %v", testcase.input, testcase.trimLen, result, testcase.expected) 198 | } 199 | }) 200 | } 201 | } 202 | -------------------------------------------------------------------------------- /utils/utils_unix_test.go: -------------------------------------------------------------------------------- 1 | //go:build unix 2 | // +build unix 3 | 4 | package utils 5 | 6 | import ( 7 | "testing" 8 | ) 9 | 10 | func TestSourceDir(t *testing.T) { 11 | cases := []struct { 12 | file string 13 | want string 14 | }{ 15 | { 16 | file: "/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go", 17 | want: "/Users/name/go/pkg/mod/gorm.io/", 18 | }, 19 | { 20 | file: "/go/work/proj/gorm/utils/utils.go", 21 | want: "/go/work/proj/gorm/", 22 | }, 23 | { 24 | file: "/go/work/proj/gorm_alias/utils/utils.go", 25 | want: "/go/work/proj/gorm_alias/", 26 | }, 27 | { 28 | file: "/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go", 29 | want: "/go/work/proj/my.gorm.io/gorm@v1.2.3/", 30 | }, 31 | } 32 | for _, c := range cases { 33 | s := sourceDir(c.file) 34 | if s != c.want { 35 | t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /utils/utils_windows_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestSourceDir(t *testing.T) { 8 | cases := []struct { 9 | file string 10 | want string 11 | }{ 12 | { 13 | file: `C:/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go`, 14 | want: `C:/Users/name/go/pkg/mod/gorm.io/`, 15 | }, 16 | { 17 | file: `C:/go/work/proj/gorm/utils/utils.go`, 18 | want: `C:/go/work/proj/gorm/`, 19 | }, 20 | { 21 | file: `C:/go/work/proj/gorm_alias/utils/utils.go`, 22 | want: `C:/go/work/proj/gorm_alias/`, 23 | }, 24 | { 25 | file: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go`, 26 | want: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/`, 27 | }, 28 | } 29 | for _, c := range cases { 30 | s := sourceDir(c.file) 31 | if s != c.want { 32 | t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) 33 | } 34 | } 35 | } 36 | --------------------------------------------------------------------------------