├── .github ├── dependabot.yml └── workflows │ ├── close_inactive.yml │ ├── contrib-readme.yml │ ├── release.yml │ └── tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── _dev ├── config │ ├── CHANGELOG.tpl.md │ └── chglog.config.yml ├── goland │ ├── go-mod.run.xml │ ├── go-staticcheck.run.xml │ └── go-vulncheck.run.xml └── script │ ├── done-time-pause.bat │ ├── go-mod.bat │ ├── go-staticcheck.bat │ └── go-vulncheck.bat ├── create.go ├── create_test.go ├── datatypes_json_map_test.go ├── go.mod ├── migrator.go ├── migrator_test.go ├── namer.go ├── oracle.go ├── oracle_ora.go ├── oracle_ora_test.go ├── oracle_test.go ├── reserved.go └── update.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "github-actions" 9 | directory: "/" 10 | schedule: 11 | interval: "daily" 12 | commit-message: 13 | prefix: "⬆️ " 14 | - package-ecosystem: "gomod" 15 | directory: "/" 16 | schedule: 17 | interval: "daily" 18 | commit-message: 19 | prefix: "⬆️ " 20 | -------------------------------------------------------------------------------- /.github/workflows/close_inactive.yml: -------------------------------------------------------------------------------- 1 | name: Close inactive issues and PRs 2 | on: 3 | schedule: 4 | - cron: "30 1 * * *" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v9 14 | with: 15 | days-before-issue-stale: 90 16 | days-before-issue-close: 30 17 | stale-issue-label: "stale" 18 | stale-issue-message: "This issue is stale because it has been open for 90 days with no activity. Remove stale label or comment or this will be closed in 30 days." 19 | close-issue-message: "This issue was closed because it has been inactive for 30 days since being marked as stale. Please leave a comment tagging an assigned team member if you need this issue to be reopened, and we will be happy to investigate. Thank you!" 20 | days-before-pr-stale: 90 21 | days-before-pr-close: 30 22 | stale-pr-message: "This pull request is stale because it has been open for 90 days with no activity. Remove stale label or comment or this will be closed in 30 days." 23 | close-pr-message: "This pull request was closed because it has been inactive for 30 days since being marked as stale." 24 | exempt-assignees: iTanken 25 | operations-per-run: 500 26 | repo-token: ${{ secrets.GITHUB_TOKEN }} 27 | -------------------------------------------------------------------------------- /.github/workflows/contrib-readme.yml: -------------------------------------------------------------------------------- 1 | name: A job to automate contrib in readme 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | contrib-readme-job: 10 | runs-on: ubuntu-latest 11 | name: A job to automate contribute list in readme.md 12 | steps: 13 | - name: Contribute List 14 | uses: akhilmhdh/contributors-readme-action@v2.3.10 15 | with: 16 | # Size of square images in the stack 17 | image_size: 100 18 | # Path of the readme file you want to update 19 | readme_path: "README.md " 20 | # To use github-id instead of profile name 21 | use_username: true 22 | # Number of columns in a row 23 | columns_per_row: 6 24 | # Type of collaborators options: all/direct/outside 25 | collaborators: all 26 | # Commit message of the GitHub action 27 | commit_message: "👥 自动更新 README 中的贡献者信息" 28 | # Username on commit 29 | committer_username: iTanken 30 | # email id of committer 31 | committer_email: zixizixi@vip.qq.com 32 | # check if branch is protected 33 | auto_detect_branch_protection: true 34 | env: 35 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 36 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*" 7 | 8 | permissions: 9 | contents: write 10 | 11 | jobs: 12 | releaser: 13 | name: release 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v4 18 | with: 19 | fetch-depth: 0 20 | 21 | - name: Set up Go 22 | uses: actions/setup-go@v5 23 | with: 24 | go-version: 'stable' 25 | cache: false 26 | 27 | - name: Get Tag Name 28 | id: tag 29 | run: echo "tagName=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT 30 | 31 | - name: Generate Changelog 32 | id: gen_changelog 33 | run: | 34 | config=./_dev/config/chglog.config.yml 35 | tagName=v${{ steps.tag.outputs.tagName }} 36 | go run github.com/git-chglog/git-chglog/cmd/git-chglog@latest --config $config $tagName 37 | go run github.com/git-chglog/git-chglog/cmd/git-chglog@latest --config $config $tagName > changelog.md 38 | 39 | - name: Create Release 40 | uses: ncipollo/release-action@v1 41 | with: 42 | bodyFile: changelog.md 43 | token: ${{ secrets.GITHUB_TOKEN }} 44 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: tests 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | timeout-minutes: 30 17 | 18 | services: 19 | oracle: 20 | image: truevoly/oracle-12c 21 | env: 22 | TZ: Asia/Shanghai 23 | WEB_CONSOLE: false 24 | ports: 25 | - 30256:1521 26 | # volumes: 27 | # - /data/oracle_test:/u01/app/oracle 28 | options: >- 29 | --restart=always 30 | --privileged=true 31 | 32 | steps: 33 | - uses: actions/checkout@v4 34 | 35 | - name: Set up Go 36 | uses: actions/setup-go@v5 37 | with: 38 | go-version: 'stable' 39 | cache: false 40 | 41 | - name: Tidy 42 | run: go mod tidy 43 | 44 | - name: Build 45 | run: go build -v ./... 46 | 47 | - name: Check oracle port 48 | run: | 49 | if ss -tln | grep -q ":30256 "; then 50 | echo "oracle 服务端口号正常!" 51 | else 52 | echo "oracle 服务端口号异常!" 53 | fi 54 | go run github.com/cloverstd/tcping@latest 127.0.0.1:30256 55 | 56 | - name: Test 57 | env: 58 | GORM_ORA_DSN: "oracle://system:oracle@localhost:30256/xe?LANGUAGE=SIMPLIFIED+CHINESE&TERRITORY=CHINA" 59 | GORM_ORA_WAIT_MIN: 5 # wait for the database initialization to complete 60 | run: go test -timeout 20m -v ./... 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | vendor/ 3 | go.sum 4 | CHANGELOG.md 5 | /test_local/ 6 | /go.work* 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 iTanken 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GORM Oracle Driver 2 | 3 | ## Description 4 | 5 | GORM Oracle driver for connect Oracle DB and Manage Oracle DB, Based on [CengSin/oracle](https://github.com/CengSin/oracle) 6 | and [sijms/go-ora](https://github.com/sijms/go-ora) (pure go oracle client),*not recommended for use in a production environment*. 7 | 8 | ## Required dependency Install 9 | 10 | - Oracle `11g` + (*`v1.6.3` and earlier versions support only `12c` +*) 11 | - Golang 12 | - `v1.6.1`: `go1.16` + 13 | - `v1.6.2`: `go1.18` + 14 | - gorm `1.24.0` + 15 | 16 | ## Quick Start 17 | 18 | ### How to install 19 | 20 | ```bash 21 | go get -d github.com/godoes/gorm-oracle 22 | ``` 23 | 24 | ### Usage 25 | 26 | ```go 27 | package main 28 | 29 | import ( 30 | oracle "github.com/godoes/gorm-oracle" 31 | "gorm.io/gorm" 32 | ) 33 | 34 | func main() { 35 | options := map[string]string{ 36 | "CONNECTION TIMEOUT": "90", 37 | "LANGUAGE": "SIMPLIFIED CHINESE", 38 | "TERRITORY": "CHINA", 39 | "SSL": "false", 40 | } 41 | // oracle://user:password@127.0.0.1:1521/service 42 | url := oracle.BuildUrl("127.0.0.1", "1521", "service", "user", "password", options) 43 | dialector := oracle.New(oracle.Config{ 44 | DSN: url, 45 | IgnoreCase: false, // query conditions are not case-sensitive 46 | NamingCaseSensitive: true, // whether naming is case-sensitive 47 | VarcharSizeIsCharLength: true, // whether VARCHAR type size is character length, defaulting to byte length 48 | 49 | // RowNumberAliasForOracle11 is the alias for ROW_NUMBER() in Oracle 11g, defaulting to ROW_NUM 50 | RowNumberAliasForOracle11: "ROW_NUM", 51 | }) 52 | db, err := gorm.Open(dialector, &gorm.Config{ 53 | SkipDefaultTransaction: true, // 是否禁用默认在事务中执行单次创建、更新、删除操作 54 | DisableForeignKeyConstraintWhenMigrating: true, // 是否禁止在自动迁移或创建表时自动创建外键约束 55 | // 自定义命名策略 56 | NamingStrategy: schema.NamingStrategy{ 57 | NoLowerCase: true, // 是否不自动转换小写表名 58 | IdentifierMaxLength: 30, // Oracle: 30, PostgreSQL:63, MySQL: 64, SQL Server、SQLite、DM: 128 59 | }, 60 | PrepareStmt: false, // 创建并缓存预编译语句,启用后可能会报 ORA-01002 错误 61 | CreateBatchSize: 50, // 插入数据默认批处理大小 62 | }) 63 | if err != nil { 64 | // panic error or log error info 65 | } 66 | 67 | // set session parameters 68 | if sqlDB, err := db.DB(); err == nil { 69 | _, _ = oracle.AddSessionParams(sqlDB, map[string]string{ 70 | "TIME_ZONE": "+08:00", // ALTER SESSION SET TIME_ZONE = '+08:00'; 71 | "NLS_DATE_FORMAT": "YYYY-MM-DD", // ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD'; 72 | "NLS_TIME_FORMAT": "HH24:MI:SSXFF", // ALTER SESSION SET NLS_TIME_FORMAT = 'HH24:MI:SS.FF3'; 73 | "NLS_TIMESTAMP_FORMAT": "YYYY-MM-DD HH24:MI:SSXFF", // ALTER SESSION SET NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3'; 74 | "NLS_TIME_TZ_FORMAT": "HH24:MI:SS.FF TZR", // ALTER SESSION SET NLS_TIME_TZ_FORMAT = 'HH24:MI:SS.FF3 TZR'; 75 | "NLS_TIMESTAMP_TZ_FORMAT": "YYYY-MM-DD HH24:MI:SSXFF TZR", // ALTER SESSION SET NLS_TIMESTAMP_TZ_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3 TZR'; 76 | }) 77 | } 78 | 79 | // do somethings 80 | } 81 | 82 | ``` 83 | 84 | ## Questions 85 | 86 | 87 |
88 | ORA-01000: 超出打开游标的最大数 89 | 90 | > ORA-00604: 递归 SQL 级别 1 出现错误 91 | > 92 | > ORA-01000: 超出打开游标的最大数 93 | 94 | ```shell 95 | show parameter OPEN_CURSORS; 96 | ``` 97 | 98 | ```sql 99 | alter system set OPEN_CURSORS = 1000; -- or bigger 100 | commit; 101 | ``` 102 | 103 |
104 | 105 |
106 | ORA-01002: 提取违反顺序 107 | 108 | > 如果重复执行同一查询,第一次查询成功,第二次报 `ORA-01002` 错误,可能是因为启用了 `PrepareStmt`,关闭此配置即可。 109 | 110 | 推荐配置: 111 | 112 | ```go 113 | &gorm.Config{ 114 | SkipDefaultTransaction: true, // 是否禁用默认在事务中执行单次创建、更新、删除操作 115 | DisableForeignKeyConstraintWhenMigrating: true, // 是否禁止在自动迁移或创建表时自动创建外键约束 116 | // 自定义命名策略 117 | NamingStrategy: schema.NamingStrategy{ 118 | NoLowerCase: true, // 是否不自动转换小写表名 119 | IdentifierMaxLength: 30, // Oracle: 30, PostgreSQL:63, MySQL: 64, SQL Server、SQLite、DM: 128 120 | }, 121 | PrepareStmt: false, // 创建并缓存预编译语句,启用后可能会报 ORA-01002 错误 122 | CreateBatchSize: 50, // 插入数据默认批处理大小 123 | } 124 | ``` 125 | 126 |
127 | 128 | ## Contributors 129 | 130 | 131 | 132 | 133 | 134 | 141 | 148 | 155 | 162 | 169 | 176 | 177 | 178 | 185 | 192 | 193 | 194 |
135 | 136 | iTanken 137 |
138 | iTanken 139 |
140 |
142 | 143 | stevefan1999-personal 144 |
145 | stevefan1999-personal 146 |
147 |
149 | 150 | CengSin 151 |
152 | CengSin 153 |
154 |
156 | 157 | jinzhu 158 |
159 | jinzhu 160 |
161 |
163 | 164 | dzwvip 165 |
166 | dzwvip 167 |
168 |
170 | 171 | miclle 172 |
173 | miclle 174 |
175 |
179 | 180 | dk333 181 |
182 | dk333 183 |
184 |
186 | 187 | cloorc 188 |
189 | cloorc 190 |
191 |
195 | 196 | 197 | ## LICENSE 198 | 199 | [MIT license](./LICENSE) 200 | 201 | - Copyright (c) 2020 [Jinzhu](https://github.com/jinzhu) 202 | - Copyright (c) 2020 [Steve Fan](https://github.com/stevefan1999-personal) 203 | - Copyright (c) 2020 [CengSin](https://github.com/CengSin) 204 | - Copyright (c) 2022 [dzwvip](https://github.com/dzwvip) 205 | - Copyright (c) 2022-present [iTanken](https://github.com/iTanken) 206 | -------------------------------------------------------------------------------- /_dev/config/CHANGELOG.tpl.md: -------------------------------------------------------------------------------- 1 | {{ if .Versions -}} 2 | {{ if .Unreleased.CommitGroups -}} 3 | 4 | ## ⭐ [最新变更]({{ .Info.RepositoryURL }}/compare/{{ $latest := index .Versions 0 }}{{ $latest.Tag.Name }}...main) 5 | 6 | {{ range .Unreleased.CommitGroups -}} 7 | ### {{ .RawTitle }} {{ .Title }} 8 | 9 | {{ range .Commits -}} 10 | {{/* SKIPPING RULES - START */ -}} 11 | {{- if not (contains .Subject " CHANGELOG") -}} 12 | {{- if not (contains .Subject "[ci skip]") -}} 13 | {{- if not (contains .Subject "[skip ci]") -}} 14 | {{- if not (hasPrefix .Subject "Merge pull request ") -}} 15 | {{- if not (hasPrefix .Subject "Merge remote-tracking ") -}} 16 | {{- /* SKIPPING RULES - END */ -}} 17 | - [{{ if .Type }}`{{ .Type }}`{{ end }}{{ .Subject }}]({{ $.Info.RepositoryURL }}/commit/{{ .Hash.Short }}) - `{{ datetime "2006-01-02 15:04" .Committer.Date }}` 18 | {{- if .TrimmedBody }} 19 |
20 | 21 | {{ indent .TrimmedBody 2 }} 22 |
23 | 24 | {{ end -}} 25 | {{/* SKIPPING RULES - START */ -}} 26 | {{ end -}} 27 | {{ end -}} 28 | {{ end -}} 29 | {{ end -}} 30 | {{ end -}} 31 | {{/* SKIPPING RULES - END */ -}} 32 | {{ end -}} 33 | {{ end -}} 34 | {{ else }} 35 | {{- range .Unreleased.Commits -}} 36 | {{/* SKIPPING RULES - START */ -}} 37 | {{- if not (contains .Subject " CHANGELOG") -}} 38 | {{- if not (contains .Subject "[ci skip]") -}} 39 | {{- if not (contains .Subject "[skip ci]") -}} 40 | {{- if not (hasPrefix .Subject "Merge pull request ") -}} 41 | {{- if not (hasPrefix .Subject "Merge remote-tracking ") -}} 42 | {{- /* SKIPPING RULES - END */ -}} 43 | - [{{ if .Type }}`{{ .Type }}`{{ end }}{{ .Subject }}]({{ $.Info.RepositoryURL }}/commit/{{ .Hash.Short }}) 44 | {{- if .TrimmedBody }} 45 |
46 | 47 | {{ indent .TrimmedBody 2 }} 48 |
49 | 50 | {{ end -}} 51 | {{/* SKIPPING RULES - START */ -}} 52 | {{ end -}} 53 | {{ end -}} 54 | {{ end -}} 55 | {{ end -}} 56 | {{ end -}} 57 | {{/* SKIPPING RULES - END */ -}} 58 | {{ end -}} 59 | {{ end -}} 60 | {{ end -}} 61 | 62 | {{ range .Versions -}} 63 | ## 🔖 {{ if .Tag.Previous -}} 64 | [`{{ .Tag.Name }}`]({{ $.Info.RepositoryURL }}/compare/{{ .Tag.Previous.Name }}...{{ .Tag.Name }}) 65 | {{- else }}`{{ .Tag.Name }}`{{ end }} - `{{ datetime "2006-01-02" .Tag.Date }}` 66 | {{ if .CommitGroups -}} 67 | {{ range .CommitGroups }} 68 | ### {{ .RawTitle }} {{ .Title }} 69 | 70 | {{ range .Commits -}} 71 | {{/* SKIPPING RULES - START */ -}} 72 | {{- if not (contains .Subject " CHANGELOG") -}} 73 | {{- if not (contains .Subject "[ci skip]") -}} 74 | {{- if not (contains .Subject "[skip ci]") -}} 75 | {{- if not (hasPrefix .Subject "Merge pull request ") -}} 76 | {{- if not (hasPrefix .Subject "Merge remote-tracking ") -}} 77 | {{- /* SKIPPING RULES - END */ -}} 78 | - [{{ if .Type }}`{{ .Type }}`{{ end }}{{ .Subject }}]({{ $.Info.RepositoryURL }}/commit/{{ .Hash.Short }}) 79 | {{- if .TrimmedBody }} 80 |
81 | 82 | {{ indent .TrimmedBody 2 }} 83 |
84 | {{- end }} 85 | {{/* SKIPPING RULES - START */ -}} 86 | {{ end -}} 87 | {{ end -}} 88 | {{ end -}} 89 | {{ end -}} 90 | {{ end -}} 91 | {{/* SKIPPING RULES - END */ -}} 92 | {{ end -}} 93 | {{ end -}} 94 | {{ else }}{{ range .Commits -}} 95 | {{/* SKIPPING RULES - START */ -}} 96 | {{- if not (contains .Subject " CHANGELOG") -}} 97 | {{- if not (contains .Subject "[ci skip]") -}} 98 | {{- if not (contains .Subject "[skip ci]") -}} 99 | {{- if not (hasPrefix .Subject "Merge pull request ") -}} 100 | {{- if not (hasPrefix .Subject "Merge remote-tracking ") }} 101 | {{/* SKIPPING RULES - END */ -}} 102 | - [{{ if .Type }}`{{ .Type }}`{{ end }}{{ .Subject }}]({{ $.Info.RepositoryURL }}/commit/{{ .Hash.Short }}) 103 | {{- if .TrimmedBody }} 104 |
105 | 106 | {{ indent .TrimmedBody 2 }} 107 |
108 | 109 | {{ end -}} 110 | {{/* SKIPPING RULES - START */ -}} 111 | {{ end -}} 112 | {{ end -}} 113 | {{ end -}} 114 | {{ end -}} 115 | {{ end -}} 116 | {{/* SKIPPING RULES - END */ -}} 117 | {{ end -}} 118 | {{ end -}} 119 | {{- if .NoteGroups -}} 120 | {{ range .NoteGroups -}} 121 | 122 | ### {{ .Title }} 123 | 124 | {{ range .Notes }} 125 | {{ .Body }} 126 | {{ end }} 127 | {{ end -}} 128 | {{ end -}} 129 | {{ end -}} 130 | -------------------------------------------------------------------------------- /_dev/config/chglog.config.yml: -------------------------------------------------------------------------------- 1 | style: gitlab 2 | template: CHANGELOG.tpl.md 3 | info: 4 | title: 🕔 变更记录 5 | repository_url: https://github.com/godoes/gorm-oracle 6 | options: 7 | tag_filter_pattern: '^v' 8 | sort: "date" 9 | commits: 10 | filters: 11 | Type: 12 | - 🎉 13 | - ✨ 14 | - 🚩 15 | - 🔖 16 | - 🩹 17 | - 🐛 18 | - 🚑️ 19 | - 🔒️ 20 | - 🔐 21 | - 🚨 22 | - ♻️ 23 | - 👔 24 | - 🍱 25 | - 🔨 26 | - 🔧 27 | - 🙈 28 | - 💡 29 | - 📝 30 | - ⬆️ 31 | - ⬇️ 32 | - 🚚 33 | - 🏗️ 34 | - 👽️ 35 | - 💥 36 | - ⚡️ 37 | - 🚸 38 | - 🔊 39 | - 🔇 40 | - 💄 41 | - 🎨 42 | - 🧱 43 | - 👷 44 | - 💚 45 | - 🚀 46 | - 🧵 47 | - 🛂 48 | - 🦺 49 | - 🔥 50 | - 🏁 51 | - 🐧 52 | - 🗃️ 53 | - 👥 54 | - 📄 55 | - 🌐 56 | - ✅ 57 | - 🧪 58 | - "Merge" 59 | commit_groups: 60 | sort_by: RawTitle 61 | title_maps: 62 | 🎉: 初始化 63 | ✨: 新特性 64 | 🚩: 功能标识 65 | 🔖: 版本发布 66 | 🩹: 修改优化 67 | 🐛: 问题修复 68 | 🚑️: 紧急修复 69 | 🔒️: 安全修复 70 | 🔐: 加密相关 71 | 🚨: 修复警告 72 | ♻️: 代码重构 73 | 👔: 业务逻辑 74 | 🍱: 资源文件 75 | 🔨: 开发配置 76 | 🔧: 程序配置 77 | 🙈: 忽略配置 78 | 💡: 注释文档 79 | 📝: 说明文档 80 | ⬆️: 依赖升级 81 | ⬇️: 依赖降级 82 | 🚚: 移动文件 83 | 🏗️: 架构变更 84 | 👽️: 外部变更 85 | 💥: 突破变更 86 | ⚡️: 性能优化 87 | 🚸: 用户体验 88 | 🔊: 更新日志 89 | 🔇: 移除日志 90 | 💄: 界面样式 91 | 🎨: 代码格式 92 | 🧱: 基础设施 93 | 👷: 持续集成 94 | 💚: 修复构建 95 | 🚀: 部署 96 | 🧵: 并发 97 | 🛂: 权限 98 | 🦺: 验证 99 | 🔥: 删除 100 | 🏁: Windows 101 | 🐧: Unix 102 | 🗃️: 数据库 103 | 👥: 贡献者 104 | 📄: 许可证 105 | 🌐: 国际化 106 | ✅: 通过的测试 107 | 🧪: 失败的测试 108 | "Merge": 合并 109 | header: 110 | # pattern: "^(.*)$" 111 | # https://regex101.com/r/wEipOM/1 112 | # https://emojipedia.org/emoji/%E2%9C%A8/ 113 | pattern: "^(Merge|[\\x{1F300}-\\x{1F5FF}\\x{1F600}-\\x{1F64F}\\x{1F680}-\\x{1F6FF}\\x{1F700}-\\x{1F77F}\\x{1F780}-\\x{1F7FF}\\x{1F800}-\\x{1F8FF}\\x{1F900}-\\x{1F9FF}\\x{1FA00}-\\x{1FA6F}\\x{1FA70}-\\x{1FAFF}\\x{2600}-\\x{26FF}\\x{2700}-\\x{27BF}]|[\\x{2B05}\\x{2B06}\\x{2B07}\\x{23EA}\\x{23EB}])(.*)$" 114 | pattern_maps: 115 | - Type 116 | - Subject 117 | notes: 118 | keywords: 119 | - 💥 120 | -------------------------------------------------------------------------------- /_dev/goland/go-mod.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 14 | -------------------------------------------------------------------------------- /_dev/goland/go-staticcheck.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 14 | -------------------------------------------------------------------------------- /_dev/goland/go-vulncheck.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 14 | -------------------------------------------------------------------------------- /_dev/script/done-time-pause.bat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godoes/gorm-oracle/bb5f5c6dc7f5d2c17c42608fc33182ef20a53e38/_dev/script/done-time-pause.bat -------------------------------------------------------------------------------- /_dev/script/go-mod.bat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godoes/gorm-oracle/bb5f5c6dc7f5d2c17c42608fc33182ef20a53e38/_dev/script/go-mod.bat -------------------------------------------------------------------------------- /_dev/script/go-staticcheck.bat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godoes/gorm-oracle/bb5f5c6dc7f5d2c17c42608fc33182ef20a53e38/_dev/script/go-staticcheck.bat -------------------------------------------------------------------------------- /_dev/script/go-vulncheck.bat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/godoes/gorm-oracle/bb5f5c6dc7f5d2c17c42608fc33182ef20a53e38/_dev/script/go-vulncheck.bat -------------------------------------------------------------------------------- /create.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "reflect" 5 | 6 | "github.com/sijms/go-ora/v2" 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/callbacks" 9 | "gorm.io/gorm/clause" 10 | ) 11 | 12 | func Create(db *gorm.DB) { 13 | if db.Error != nil || db.Statement == nil { 14 | return 15 | } 16 | 17 | stmt := db.Statement 18 | stmtSchema := stmt.Schema 19 | if stmtSchema != nil && !stmt.Unscoped { 20 | for _, c := range stmtSchema.CreateClauses { 21 | stmt.AddClause(c) 22 | } 23 | } 24 | 25 | if stmt.SQL.Len() == 0 { 26 | var ( 27 | createValues = callbacks.ConvertToCreateValues(stmt) 28 | onConflict, hasConflict = stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict) 29 | ) 30 | 31 | if hasConflict { 32 | if stmtSchema != nil && len(stmtSchema.PrimaryFields) > 0 { 33 | columnsMap := map[string]bool{} 34 | for _, column := range createValues.Columns { 35 | columnsMap[column.Name] = true 36 | } 37 | 38 | for _, field := range stmtSchema.PrimaryFields { 39 | if _, ok := columnsMap[field.DBName]; !ok { 40 | hasConflict = false 41 | } 42 | } 43 | } else { 44 | hasConflict = false 45 | } 46 | } 47 | 48 | if hasConflict { 49 | MergeCreate(db, onConflict, createValues) 50 | } else { 51 | stmt.AddClauseIfNotExists(clause.Insert{}) 52 | stmt.AddClause(clause.Values{Columns: createValues.Columns, Values: [][]interface{}{createValues.Values[0]}}) 53 | 54 | stmt.Build("INSERT", "VALUES") 55 | _ = outputInserted(db) 56 | } 57 | 58 | if !db.DryRun && db.Error == nil { 59 | if hasConflict { 60 | for i, val := range stmt.Vars { 61 | // HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned 62 | stmt.Vars[i] = convertValue(val) 63 | } 64 | 65 | result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...) 66 | if db.AddError(err) == nil { 67 | db.RowsAffected, _ = result.RowsAffected() 68 | // TODO: get merged returning 69 | } 70 | } else { 71 | for idx, values := range createValues.Values { 72 | for i, val := range values { 73 | // HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned 74 | stmt.Vars[i] = convertValue(val) 75 | } 76 | 77 | result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...) 78 | if db.AddError(err) == nil { 79 | rowsAffected, _ := result.RowsAffected() 80 | db.RowsAffected += rowsAffected 81 | 82 | if stmtSchema != nil && len(stmtSchema.FieldsWithDefaultDBValue) > 0 { 83 | getDefaultValues(db, idx) 84 | } 85 | } 86 | } 87 | } 88 | } 89 | } 90 | } 91 | 92 | func outputInserted(db *gorm.DB) (lenDefaultValue int) { 93 | stmtSchema := db.Statement.Schema 94 | if stmtSchema == nil { 95 | return 96 | } 97 | lenDefaultValue = len(stmtSchema.FieldsWithDefaultDBValue) 98 | if lenDefaultValue > 0 { 99 | columns := make([]clause.Column, lenDefaultValue) 100 | for idx, field := range stmtSchema.FieldsWithDefaultDBValue { 101 | columns[idx] = clause.Column{Name: field.DBName} 102 | } 103 | db.Statement.AddClauseIfNotExists(clause.Returning{Columns: columns}) 104 | } 105 | db.Statement.Build("RETURNING") 106 | 107 | _, _ = db.Statement.WriteString(" INTO ") 108 | for idx, field := range stmtSchema.FieldsWithDefaultDBValue { 109 | if idx > 0 { 110 | _ = db.Statement.WriteByte(',') 111 | } 112 | 113 | outVar := go_ora.Out{Dest: reflect.New(field.FieldType).Interface()} 114 | if field.Size > 0 { 115 | outVar.Size = field.Size 116 | } 117 | db.Statement.AddVar(db.Statement, outVar) 118 | } 119 | _, _ = db.Statement.WriteString(" /*-go_ora.Out{}-*/") 120 | return 121 | } 122 | 123 | func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) { 124 | dummyTable := getDummyTable(db) 125 | 126 | _, _ = db.Statement.WriteString("MERGE INTO ") 127 | db.Statement.WriteQuoted(db.Statement.Table) 128 | _, _ = db.Statement.WriteString(" USING (") 129 | 130 | for idx, value := range values.Values { 131 | if idx > 0 { 132 | _, _ = db.Statement.WriteString(" UNION ALL ") 133 | } 134 | 135 | _, _ = db.Statement.WriteString("SELECT ") 136 | for i, v := range value { 137 | if i > 0 { 138 | _ = db.Statement.WriteByte(',') 139 | } 140 | column := values.Columns[i] 141 | db.Statement.AddVar(db.Statement, v) 142 | _, _ = db.Statement.WriteString(" AS ") 143 | db.Statement.WriteQuoted(column.Name) 144 | } 145 | _, _ = db.Statement.WriteString(" FROM ") 146 | _, _ = db.Statement.WriteString(dummyTable) 147 | } 148 | 149 | _, _ = db.Statement.WriteString(`) `) 150 | db.Statement.WriteQuoted("excluded") 151 | _, _ = db.Statement.WriteString(" ON (") 152 | 153 | var where clause.Where 154 | for _, field := range db.Statement.Schema.PrimaryFields { 155 | where.Exprs = append(where.Exprs, clause.Eq{ 156 | Column: clause.Column{Table: db.Statement.Table, Name: field.DBName}, 157 | Value: clause.Column{Table: "excluded", Name: field.DBName}, 158 | }) 159 | } 160 | where.Build(db.Statement) 161 | _ = db.Statement.WriteByte(')') 162 | 163 | if len(onConflict.DoUpdates) > 0 { 164 | _, _ = db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ") 165 | onConflict.DoUpdates.Build(db.Statement) 166 | } 167 | 168 | _, _ = db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (") 169 | 170 | written := false 171 | for _, column := range values.Columns { 172 | if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name { 173 | if written { 174 | _ = db.Statement.WriteByte(',') 175 | } 176 | written = true 177 | db.Statement.WriteQuoted(column.Name) 178 | } 179 | } 180 | 181 | _, _ = db.Statement.WriteString(") VALUES (") 182 | 183 | written = false 184 | for _, column := range values.Columns { 185 | if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name { 186 | if written { 187 | _ = db.Statement.WriteByte(',') 188 | } 189 | written = true 190 | db.Statement.WriteQuoted(clause.Column{ 191 | Table: "excluded", 192 | Name: column.Name, 193 | }) 194 | } 195 | } 196 | _, _ = db.Statement.WriteString(")") 197 | } 198 | 199 | func convertValue(val interface{}) interface{} { 200 | val = ptrDereference(val) 201 | switch v := val.(type) { 202 | case bool: 203 | if v { 204 | val = 1 205 | } else { 206 | val = 0 207 | } 208 | case string: 209 | if len(v) > 2000 { 210 | val = go_ora.Clob{String: v, Valid: true} 211 | } 212 | default: 213 | val = convertCustomType(val) 214 | } 215 | return val 216 | } 217 | 218 | func getDummyTable(db *gorm.DB) (dummyTable string) { 219 | switch d := ptrDereference(db.Dialector).(type) { 220 | case Dialector: 221 | dummyTable = d.DummyTableName() 222 | default: 223 | dummyTable = "DUAL" 224 | } 225 | return 226 | } 227 | 228 | func getDefaultValues(db *gorm.DB, idx int) { 229 | if db.Statement.Schema == nil || len(db.Statement.Schema.FieldsWithDefaultDBValue) == 0 { 230 | return 231 | } 232 | insertTo := db.Statement.ReflectValue 233 | switch insertTo.Kind() { 234 | case reflect.Slice, reflect.Array: 235 | insertTo = insertTo.Index(idx) 236 | default: 237 | } 238 | if insertTo.Kind() == reflect.Pointer { 239 | insertTo = insertTo.Elem() 240 | } 241 | 242 | for _, val := range db.Statement.Vars { 243 | switch v := val.(type) { 244 | case go_ora.Out: 245 | switch insertTo.Kind() { 246 | case reflect.Slice, reflect.Array: 247 | for i := insertTo.Len() - 1; i >= 0; i-- { 248 | rv := insertTo.Index(i) 249 | switch reflect.Indirect(rv).Kind() { 250 | case reflect.Struct: 251 | setStructFieldValue(db, rv, v) 252 | default: 253 | } 254 | } 255 | case reflect.Struct: 256 | setStructFieldValue(db, insertTo, v) 257 | default: 258 | } 259 | default: 260 | } 261 | } 262 | } 263 | 264 | func setStructFieldValue(db *gorm.DB, insertTo reflect.Value, out go_ora.Out) { 265 | if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo); !isZero { 266 | return 267 | } 268 | _ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, out.Dest)) 269 | } 270 | -------------------------------------------------------------------------------- /create_test.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "encoding/json" 5 | "strings" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestMergeCreate(t *testing.T) { 11 | db, err := dbNamingCase, dbErrors[0] 12 | if err != nil { 13 | t.Fatal(err) 14 | } 15 | if db == nil { 16 | t.Log("db is nil!") 17 | return 18 | } 19 | 20 | model := TestTableUser{} 21 | migrator := db.Set("gorm:table_comments", "用户信息表").Migrator() 22 | if migrator.HasTable(model) { 23 | if err = migrator.DropTable(model); err != nil { 24 | t.Fatalf("DropTable() error = %v", err) 25 | } 26 | } 27 | if err = migrator.AutoMigrate(model); err != nil { 28 | t.Fatalf("AutoMigrate() error = %v", err) 29 | } else { 30 | t.Log("AutoMigrate() success!") 31 | } 32 | 33 | data := []TestTableUser{ 34 | { 35 | UID: "U1", 36 | Name: "Lisa", 37 | Account: "lisa", 38 | Password: "H6aLDNr", 39 | PhoneNumber: "+8616666666666", 40 | Sex: "0", 41 | UserType: 1, 42 | Enabled: true, 43 | }, 44 | { 45 | UID: "U1", 46 | Name: "Lisa", 47 | Account: "lisa", 48 | Password: "H6aLDNr", 49 | PhoneNumber: "+8616666666666", 50 | Sex: "0", 51 | UserType: 1, 52 | Enabled: true, 53 | }, 54 | { 55 | UID: "U2", 56 | Name: "Daniela", 57 | Account: "daniela", 58 | Password: "Si7l1sRIC79", 59 | PhoneNumber: "+8619999999999", 60 | Sex: "1", 61 | UserType: 1, 62 | Enabled: true, 63 | }, 64 | } 65 | t.Run("MergeCreate", func(t *testing.T) { 66 | tx := db.Create(&data) 67 | if err = tx.Error; err != nil { 68 | t.Fatal(err) 69 | } 70 | dataJsonBytes, _ := json.MarshalIndent(data, "", " ") 71 | t.Logf("result: %s", dataJsonBytes) 72 | }) 73 | } 74 | 75 | type TestTableUserUnique struct { 76 | ID uint64 `gorm:"column:id;size:64;not null;autoIncrement:true;autoIncrementIncrement:1;primaryKey;comment:自增 ID" json:"id"` 77 | UID string `gorm:"column:uid;type:varchar(50);comment:用户身份标识;unique" json:"uid"` 78 | Name string `gorm:"column:name;size:50;comment:用户姓名" json:"name"` 79 | Account string `gorm:"column:account;type:varchar(50);comment:登录账号" json:"account"` 80 | Password string `gorm:"column:password;type:varchar(512);comment:登录密码(密文)" json:"password"` 81 | Email string `gorm:"column:email;type:varchar(128);comment:邮箱地址" json:"email"` 82 | PhoneNumber string `gorm:"column:phone_number;type:varchar(15);comment:E.164" json:"phoneNumber"` 83 | Sex string `gorm:"column:sex;type:char(1);comment:性别" json:"sex"` 84 | Birthday *time.Time `gorm:"column:birthday;->:false;<-:create;comment:生日" json:"birthday,omitempty"` 85 | UserType int `gorm:"column:user_type;size:8;comment:用户类型" json:"userType"` 86 | Enabled bool `gorm:"column:enabled;comment:是否可用" json:"enabled"` 87 | Remark string `gorm:"column:remark;size:1024;comment:备注信息" json:"remark"` 88 | } 89 | 90 | func (TestTableUserUnique) TableName() string { 91 | return "test_user_unique" 92 | } 93 | 94 | func TestMergeCreateUnique(t *testing.T) { 95 | db, err := dbNamingCase, dbErrors[0] 96 | if err != nil { 97 | t.Fatal(err) 98 | } 99 | if db == nil { 100 | t.Log("db is nil!") 101 | return 102 | } 103 | 104 | model := TestTableUserUnique{} 105 | migrator := db.Set("gorm:table_comments", "用户信息表").Migrator() 106 | if migrator.HasTable(model) { 107 | if err = migrator.DropTable(model); err != nil { 108 | t.Fatalf("DropTable() error = %v", err) 109 | } 110 | } 111 | if err = migrator.AutoMigrate(model); err != nil { 112 | t.Fatalf("AutoMigrate() error = %v", err) 113 | } else { 114 | t.Log("AutoMigrate() success!") 115 | } 116 | 117 | data := []TestTableUserUnique{ 118 | { 119 | UID: "U1", 120 | Name: "Lisa", 121 | Account: "lisa", 122 | Password: "H6aLDNr", 123 | PhoneNumber: "+8616666666666", 124 | Sex: "0", 125 | UserType: 1, 126 | Enabled: true, 127 | }, 128 | { 129 | UID: "U2", 130 | Name: "Daniela", 131 | Account: "daniela", 132 | Password: "Si7l1sRIC79", 133 | PhoneNumber: "+8619999999999", 134 | Sex: "1", 135 | UserType: 1, 136 | Enabled: true, 137 | }, 138 | { 139 | UID: "U2", 140 | Name: "Daniela", 141 | Account: "daniela", 142 | Password: "Si7l1sRIC79", 143 | PhoneNumber: "+8619999999999", 144 | Sex: "1", 145 | UserType: 1, 146 | Enabled: true, 147 | }, 148 | } 149 | t.Run("MergeCreateUnique", func(t *testing.T) { 150 | tx := db.Create(&data) 151 | if err = tx.Error; err != nil { 152 | if strings.Contains(err.Error(), "ORA-00001") { 153 | t.Log(err) // ORA-00001: 违反唯一约束条件 154 | var gotData []TestTableUserUnique 155 | tx = db.Where(map[string]interface{}{"uid": []string{"U1", "U2"}}).Find(&gotData) 156 | if err = tx.Error; err != nil { 157 | t.Fatal(err) 158 | } else { 159 | if len(gotData) > 0 { 160 | t.Error("Unique constraint violation, but some data was inserted!") 161 | } else { 162 | t.Log("Unique constraint violation, rolled back!") 163 | } 164 | } 165 | } else { 166 | t.Fatal(err) 167 | } 168 | return 169 | } 170 | dataJsonBytes, _ := json.MarshalIndent(data, "", " ") 171 | t.Logf("result: %s", dataJsonBytes) 172 | }) 173 | } 174 | 175 | type testModelOra03146TTC struct { 176 | Id int64 `gorm:"primaryKey;autoIncrement:false;column:SL_ID;type:uint;size:20;default:0;comment:id" json:"SL_ID"` 177 | ApiName string `gorm:"column:SL_API_NAME;type:VARCHAR2;size:100;default:null;comment:接口名称" json:"SL_API_NAME"` 178 | RawReceive string `gorm:"column:SL_RAW_RECEIVE_JSON;type:VARCHAR2;size:4000;default:null;comment:原始请求参数" json:"SL_RAW_RECEIVE_JSON"` 179 | RawSend string `gorm:"column:SL_RAW_SEND_JSON;type:VARCHAR2;size:4000;default:null;comment:原始响应参数" json:"SL_RAW_SEND_JSON"` 180 | DealReceive string `gorm:"column:SL_DEAL_RECEIVE_JSON;type:VARCHAR2;size:4000;default:null;comment:处理请求参数" json:"SL_DEAL_RECEIVE_JSON"` 181 | DealSend string `gorm:"column:SL_DEAL_SEND_JSON;type:VARCHAR2;size:4000;default:null;comment:处理响应参数" json:"SL_DEAL_SEND_JSON"` 182 | Code string `gorm:"column:SL_CODE;type:VARCHAR2;size:16;default:null;comment:http状态" json:"SL_CODE"` 183 | CreatedTime time.Time `gorm:"column:SL_CREATED_TIME;type:date;default:null;comment:创建时间" json:"SL_CREATED_TIME"` 184 | } 185 | 186 | func TestOra03146TTC(t *testing.T) { 187 | db, err := dbNamingCase, dbErrors[0] 188 | if err != nil { 189 | t.Fatal(err) 190 | } 191 | if db == nil { 192 | t.Log("db is nil!") 193 | return 194 | } 195 | 196 | model := testModelOra03146TTC{} 197 | migrator := db.Set("gorm:table_comments", "TTC 字段的缓冲区长度无效问题测试表").Migrator() 198 | if migrator.HasTable(model) { 199 | if err = migrator.DropTable(model); err != nil { 200 | t.Fatalf("DropTable() error = %v", err) 201 | } 202 | } 203 | if err = migrator.AutoMigrate(model); err != nil { 204 | t.Fatalf("AutoMigrate() error = %v", err) 205 | } else { 206 | t.Log("AutoMigrate() success!") 207 | } 208 | 209 | // INSERT INTO "T100_SCPTOAPI_LOG" ("SL_ID","SL_API_NAME","SL_RAW_RECEIVE_JSON","SL_RAW_SEND_JSON","SL_DEAL_RECEIVE_JSON","SL_DEAL_SEND_JSON","SL_CODE","SL_CREATED_TIME") 210 | // VALUES (9578529926701056,'/v1/t100/packingNum','11111','11111','11111','11111','111','2024-08-27 18:21:39.495') 211 | data := testModelOra03146TTC{ 212 | Id: 9578529926701056, 213 | ApiName: "/v1/t100/packingNum", 214 | RawReceive: "11111", 215 | RawSend: "11111", 216 | DealReceive: "11111", 217 | DealSend: "11111", 218 | Code: "111", 219 | CreatedTime: time.Now(), 220 | } 221 | result := db.Create(&data) 222 | if err = result.Error; err != nil { 223 | t.Fatalf("执行失败:%v", err) 224 | } 225 | t.Log("执行成功,影响行数:", result.RowsAffected) 226 | } 227 | 228 | func TestCreateInBatches(t *testing.T) { 229 | db, err := dbNamingCase, dbErrors[0] 230 | if err != nil { 231 | t.Fatal(err) 232 | } 233 | if db == nil { 234 | t.Log("db is nil!") 235 | return 236 | } 237 | 238 | model := TestTableUser{} 239 | migrator := db.Set("gorm:table_comments", "用户信息表").Migrator() 240 | if migrator.HasTable(model) { 241 | if err = migrator.DropTable(model); err != nil { 242 | t.Fatalf("DropTable() error = %v", err) 243 | } 244 | } 245 | if err = migrator.AutoMigrate(model); err != nil { 246 | t.Fatalf("AutoMigrate() error = %v", err) 247 | } else { 248 | t.Log("AutoMigrate() success!") 249 | } 250 | 251 | data := []TestTableUser{ 252 | {UID: "U1", Name: "Lisa", Account: "lisa", Password: "H6aLDNr", PhoneNumber: "+8616666666666", Sex: "0", UserType: 1, Enabled: true}, 253 | {UID: "U2", Name: "Daniela", Account: "daniela", Password: "Si7l1sRIC79", PhoneNumber: "+8619999999999", Sex: "1", UserType: 1, Enabled: true}, 254 | {UID: "U3", Name: "Tom", Account: "tom", Password: "********", PhoneNumber: "+8618888888888", Sex: "1", UserType: 1, Enabled: true}, 255 | {UID: "U4", Name: "James", Account: "james", Password: "********", PhoneNumber: "+8617777777777", Sex: "1", UserType: 2, Enabled: true}, 256 | {UID: "U5", Name: "John", Account: "john", Password: "********", PhoneNumber: "+8615555555555", Sex: "1", UserType: 1, Enabled: true}, 257 | } 258 | t.Run("CreateInBatches", func(t *testing.T) { 259 | tx := db.CreateInBatches(&data, 2) 260 | if err = tx.Error; err != nil { 261 | t.Fatal(err) 262 | } 263 | dataJsonBytes, _ := json.MarshalIndent(data, "", " ") 264 | t.Logf("result: %s", dataJsonBytes) 265 | }) 266 | } 267 | -------------------------------------------------------------------------------- /datatypes_json_map_test.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "database/sql/driver" 7 | "encoding/json" 8 | "errors" 9 | "fmt" 10 | "strings" 11 | 12 | "gorm.io/gorm" 13 | "gorm.io/gorm/clause" 14 | "gorm.io/gorm/schema" 15 | ) 16 | 17 | // JSONMap defined JSON data type, need to implement driver.Valuer, sql.Scanner interface 18 | type JSONMap map[string]interface{} 19 | 20 | // Value return json value, implement driver.Valuer interface 21 | // 22 | //goland:noinspection GoMixedReceiverTypes 23 | func (m JSONMap) Value() (driver.Value, error) { 24 | if m == nil { 25 | return nil, nil 26 | } 27 | ba, err := m.MarshalJSON() 28 | return string(ba), err 29 | } 30 | 31 | // Scan value into Jsonb, implements sql.Scanner interface 32 | // 33 | //goland:noinspection GoMixedReceiverTypes 34 | func (m *JSONMap) Scan(val interface{}) error { 35 | if val == nil { 36 | *m = make(JSONMap) 37 | return nil 38 | } 39 | var ba []byte 40 | switch v := val.(type) { 41 | case []byte: 42 | ba = v 43 | case string: 44 | ba = []byte(v) 45 | default: 46 | return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", val)) 47 | } 48 | t := map[string]interface{}{} 49 | rd := bytes.NewReader(ba) 50 | decoder := json.NewDecoder(rd) 51 | decoder.UseNumber() 52 | err := decoder.Decode(&t) 53 | *m = t 54 | return err 55 | } 56 | 57 | // MarshalJSON to output non base64 encoded []byte 58 | // 59 | //goland:noinspection GoMixedReceiverTypes 60 | func (m JSONMap) MarshalJSON() ([]byte, error) { 61 | if m == nil { 62 | return []byte("null"), nil 63 | } 64 | t := (map[string]interface{})(m) 65 | return json.Marshal(t) 66 | } 67 | 68 | // UnmarshalJSON to deserialize []byte 69 | // 70 | //goland:noinspection GoMixedReceiverTypes 71 | func (m *JSONMap) UnmarshalJSON(b []byte) error { 72 | t := map[string]interface{}{} 73 | err := json.Unmarshal(b, &t) 74 | *m = t 75 | return err 76 | } 77 | 78 | // GormDataType gorm common data type 79 | // 80 | //goland:noinspection GoMixedReceiverTypes 81 | func (m JSONMap) GormDataType() string { 82 | return "jsonmap" 83 | } 84 | 85 | // GormDBDataType gorm db data type 86 | // 87 | //goland:noinspection GoMixedReceiverTypes 88 | func (JSONMap) GormDBDataType(db *gorm.DB, field *schema.Field) string { 89 | switch db.Dialector.Name() { 90 | case "sqlite": 91 | return "JSON" 92 | case "mysql": 93 | return "JSON" 94 | case "postgres": 95 | return "JSONB" 96 | case "sqlserver": 97 | return "NVARCHAR(MAX)" 98 | case "oracle": 99 | //return "BLOB" 100 | // BLOB is only supported in Oracle databases version 12c r12.2.0.1.0 and above. 101 | // to support lower versions of Oracle databases, it is recommended to use CLOB. 102 | // see also: 103 | // https://stackoverflow.com/questions/43603905/oracle-12c-error-getting-while-create-blob-column-table-with-json-type 104 | return "CLOB" 105 | default: 106 | return getGormTypeFromTag(field) 107 | } 108 | } 109 | 110 | func getGormTypeFromTag(field *schema.Field) (dataType string) { 111 | if field != nil { 112 | if val, ok := field.TagSettings["TYPE"]; ok { 113 | dataType = strings.ToLower(val) 114 | } 115 | } 116 | return 117 | } 118 | 119 | //goland:noinspection GoMixedReceiverTypes 120 | func (m JSONMap) GormValue(_ context.Context, _ *gorm.DB) clause.Expr { 121 | data, _ := m.MarshalJSON() 122 | return gorm.Expr("?", string(data)) 123 | } 124 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/godoes/gorm-oracle 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/emirpasic/gods v1.18.1 7 | github.com/sijms/go-ora/v2 v2.9.0 8 | gorm.io/gorm v1.30.0 9 | ) 10 | 11 | require ( 12 | github.com/jinzhu/inflection v1.0.0 // indirect 13 | github.com/jinzhu/now v1.1.5 // indirect 14 | golang.org/x/text v0.20.0 // indirect 15 | ) 16 | 17 | exclude ( 18 | github.com/sijms/go-ora/v2 v2.8.8 // ORA-03137: 来自客户机的格式错误的 TTC 包被拒绝: [opiexe: protocol violation] 19 | github.com/sijms/go-ora/v2 v2.8.9 // has bug 20 | ) 21 | 22 | retract ( 23 | v1.5.12 24 | v1.5.1 25 | v1.5.0 26 | ) 27 | -------------------------------------------------------------------------------- /migrator.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | 9 | "gorm.io/gorm" 10 | "gorm.io/gorm/clause" 11 | "gorm.io/gorm/migrator" 12 | "gorm.io/gorm/schema" 13 | ) 14 | 15 | // Migrator implement gorm migrator interface 16 | type Migrator struct { 17 | migrator.Migrator 18 | } 19 | 20 | // AutoMigrate 自动迁移模型为表结构 21 | // 22 | // // 迁移并设置单个表注释 23 | // db.Set("gorm:table_comments", "用户信息表").AutoMigrate(&User{}) 24 | // 25 | // // 迁移并设置多个表注释 26 | // db.Set("gorm:table_comments", []string{"用户信息表", "公司信息表"}).AutoMigrate(&User{}, &Company{}) 27 | func (m Migrator) AutoMigrate(dst ...interface{}) error { 28 | if err := m.Migrator.AutoMigrate(dst...); err != nil { 29 | return err 30 | } 31 | // set table comment 32 | if tableComments, ok := m.DB.Get("gorm:table_comments"); ok { 33 | var comments []string 34 | switch c := tableComments.(type) { 35 | case string: 36 | comments = append(comments, c) 37 | case []string: 38 | comments = c 39 | default: 40 | return nil 41 | } 42 | for i := 0; i < len(dst) && i < len(comments); i++ { 43 | value := dst[i] 44 | tx := m.DB.Session(&gorm.Session{}) 45 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { 46 | return tx.Exec("COMMENT ON TABLE ? IS '?'", m.CurrentTable(stmt), GetStringExpr(comments[i])).Error 47 | }); err != nil { 48 | return err 49 | } 50 | } 51 | } 52 | return nil 53 | } 54 | 55 | // FullDataTypeOf returns field's db full data type 56 | func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { 57 | expr.SQL = m.DataTypeOf(field) 58 | 59 | if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { 60 | if field.DefaultValueInterface != nil { 61 | defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} 62 | m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) 63 | expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) 64 | } else if field.DefaultValue != "(-)" { 65 | expr.SQL += " DEFAULT " + field.DefaultValue 66 | } 67 | } 68 | 69 | if field.NotNull { 70 | expr.SQL += " NOT NULL" 71 | } 72 | 73 | // see https://github.com/go-gorm/gorm/pull/6822 74 | //if field.Unique { 75 | // expr.SQL += " UNIQUE" 76 | //} 77 | 78 | return 79 | } 80 | 81 | // CurrentDatabase returns current database name 82 | func (m Migrator) CurrentDatabase() (name string) { 83 | _ = m.DB.Raw( 84 | fmt.Sprintf(`SELECT ORA_DATABASE_NAME as "Current Database" FROM %s`, m.Dialector.(Dialector).DummyTableName()), 85 | ).Row().Scan(&name) 86 | return 87 | } 88 | 89 | // GetTypeAliases return database type aliases 90 | func (m Migrator) GetTypeAliases(databaseTypeName string) (types []string) { 91 | switch databaseTypeName { 92 | case "blob", "raw", "longraw", "ocibloblocator", "ocifilelocator": 93 | types = append(types, "blob", "raw", "longraw", "ocibloblocator", "ocifilelocator") 94 | case "clob", "nclob", "longvarchar", "ocicloblocator": 95 | types = append(types, "clob", "nclob", "longvarchar", "ocicloblocator") 96 | case "char", "nchar", "varchar", "varchar2", "nvarchar2": 97 | types = append(types, "char", "nchar", "varchar", "varchar2", "nvarchar2") 98 | case "number", "integer", "smallint": 99 | types = append(types, "number", "integer", "smallint") 100 | case "decimal", "numeric", "ibfloat", "ibdouble": 101 | types = append(types, "decimal", "numeric", "ibfloat", "ibdouble") 102 | case "timestampdty", "timestamp", "date": 103 | types = append(types, "timestampdty", "timestamp", "date") 104 | case "timestamptz_dty", "timestamp with time zone": 105 | types = append(types, "timestamptz_dty", "timestamp with time zone") 106 | case "timestampltz_dty", "timestampeltz", "timestamp with local time zone": 107 | types = append(types, "timestampltz_dty", "timestampeltz", "timestamp with local time zone") 108 | default: 109 | return 110 | } 111 | return 112 | } 113 | 114 | // CreateTable create table in database for values 115 | func (m Migrator) CreateTable(values ...interface{}) (err error) { 116 | ignoreCase := !m.Dialector.(Dialector).NamingCaseSensitive 117 | for _, value := range values { 118 | if ignoreCase { 119 | _ = m.TryQuotifyReservedWords(value) 120 | } 121 | _ = m.TryRemoveOnUpdate(value) 122 | } 123 | if err = m.Migrator.CreateTable(values...); err != nil { 124 | return 125 | } 126 | // set column comment 127 | for _, value := range m.ReorderModels(values, false) { 128 | if err = m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { 129 | if stmt.Schema != nil { 130 | for _, fieldName := range stmt.Schema.DBNames { 131 | field := stmt.Schema.FieldsByDBName[fieldName] 132 | if err = m.setCommentForColumn(field, stmt); err != nil { 133 | return 134 | } 135 | } 136 | } 137 | return 138 | }); err != nil { 139 | return 140 | } 141 | } 142 | return 143 | } 144 | 145 | func (m Migrator) setCommentForColumn(field *schema.Field, stmt *gorm.Statement) (err error) { 146 | if field == nil || stmt == nil || field.Comment == "" { 147 | return 148 | } 149 | table := m.CurrentTable(stmt) 150 | column := clause.Column{Name: field.DBName} 151 | comment := GetStringExpr(field.Comment) 152 | err = m.DB.Exec("COMMENT ON COLUMN ?.? IS '?'", table, column, comment).Error 153 | return 154 | } 155 | 156 | // DropTable drop table for values 157 | // 158 | //goland:noinspection SqlNoDataSourceInspection 159 | func (m Migrator) DropTable(values ...interface{}) error { 160 | values = m.ReorderModels(values, false) 161 | for i := len(values) - 1; i >= 0; i-- { 162 | value := values[i] 163 | tx := m.DB.Session(&gorm.Session{}) 164 | if m.HasTable(value) { 165 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { 166 | return tx.Exec("DROP TABLE ? CASCADE CONSTRAINTS", clause.Table{Name: stmt.Table}).Error 167 | }); err != nil { 168 | return err 169 | } 170 | } 171 | } 172 | return nil 173 | } 174 | 175 | // HasTable returns table exists or not for value, value could be a struct or string 176 | func (m Migrator) HasTable(value interface{}) bool { 177 | var count int64 178 | 179 | _ = m.RunWithValue(value, func(stmt *gorm.Statement) error { 180 | if ownerName, tableName := m.getSchemaTable(stmt); ownerName != "" { 181 | return m.DB.Raw("SELECT COUNT(*) FROM ALL_TABLES WHERE OWNER = ? and TABLE_NAME = ?", ownerName, tableName).Row().Scan(&count) 182 | } else { 183 | return m.DB.Raw("SELECT COUNT(*) FROM USER_TABLES WHERE TABLE_NAME = ?", tableName).Row().Scan(&count) 184 | } 185 | }) 186 | 187 | return count > 0 188 | } 189 | 190 | func (m Migrator) getSchemaTable(stmt *gorm.Statement) (ownerName, tableName string) { 191 | if stmt == nil { 192 | return 193 | } 194 | if stmt.Schema == nil { 195 | tableName = stmt.Table 196 | } else { 197 | tableName = stmt.Schema.Table 198 | if strings.Contains(tableName, ".") { 199 | ownerTable := strings.Split(tableName, ".") 200 | ownerName, tableName = ownerTable[0], ownerTable[1] 201 | } 202 | } 203 | return 204 | } 205 | 206 | // ColumnTypes return columnTypes []gorm.ColumnType and execErr error 207 | func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { 208 | columnTypes := make([]gorm.ColumnType, 0) 209 | execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { 210 | _, tableName := m.getSchemaTable(stmt) 211 | rows, err := m.DB.Session(&gorm.Session{}).Table(tableName).Where("ROWNUM = 1").Rows() 212 | if err != nil { 213 | return err 214 | } 215 | 216 | defer func() { 217 | err = rows.Close() 218 | }() 219 | 220 | var rawColumnTypes []*sql.ColumnType 221 | rawColumnTypes, err = rows.ColumnTypes() 222 | if err != nil { 223 | return err 224 | } 225 | 226 | ignoreCase := !m.Dialector.(Dialector).NamingCaseSensitive 227 | for _, c := range rawColumnTypes { 228 | columnType := migrator.ColumnType{SQLColumnType: c} 229 | if ignoreCase && IsReservedWord(c.Name()) { 230 | columnType.NameValue = sql.NullString{ 231 | String: strconv.Quote(c.Name()), 232 | Valid: true, 233 | } 234 | } 235 | columnTypes = append(columnTypes, columnType) 236 | } 237 | 238 | return 239 | }) 240 | 241 | return columnTypes, execErr 242 | } 243 | 244 | // RenameTable rename table from oldName to newName 245 | func (m Migrator) RenameTable(oldName, newName interface{}) (err error) { 246 | resolveTable := func(name interface{}) (result string, err error) { 247 | if v, ok := name.(string); ok { 248 | result = v 249 | } else { 250 | stmt := &gorm.Statement{DB: m.DB} 251 | if err = stmt.Parse(name); err == nil { 252 | result = stmt.Table 253 | } 254 | } 255 | return 256 | } 257 | 258 | var oldTable, newTable string 259 | 260 | if oldTable, err = resolveTable(oldName); err != nil { 261 | return 262 | } 263 | 264 | if newTable, err = resolveTable(newName); err != nil { 265 | return 266 | } 267 | 268 | if !m.HasTable(oldTable) { 269 | return 270 | } 271 | 272 | return m.DB.Exec("RENAME TABLE ? TO ?", 273 | clause.Table{Name: oldTable}, 274 | clause.Table{Name: newTable}, 275 | ).Error 276 | } 277 | 278 | // GetTables returns tables under the current user database 279 | func (m Migrator) GetTables() (tableList []string, err error) { 280 | err = m.DB.Raw(`SELECT TABLE_NAME FROM USER_TABLES 281 | WHERE TABLESPACE_NAME IS NOT NULL AND TABLESPACE_NAME <> 'SYSAUX' 282 | AND TABLE_NAME NOT LIKE 'AQ$%' AND TABLE_NAME NOT LIKE 'MVIEW$%' AND TABLE_NAME NOT LIKE 'ROLLING$%' 283 | AND TABLE_NAME NOT IN ('HELP', 'SQLPLUS_PRODUCT_PROFILE', 'LOGSTDBY$PARAMETERS', 'LOGMNRGGC_GTCS', 'LOGMNRGGC_GTLO', 'LOGMNR_PARAMETER$', 'LOGMNR_SESSION$', 'SCHEDULER_JOB_ARGS_TBL', 'SCHEDULER_PROGRAM_ARGS_TBL') 284 | `).Scan(&tableList).Error 285 | return 286 | } 287 | 288 | // AddColumn create "name" column for value 289 | func (m Migrator) AddColumn(value interface{}, name string) (err error) { 290 | if err = m.Migrator.AddColumn(value, name); err != nil { 291 | return err 292 | } 293 | // set column comment 294 | err = m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { 295 | if field := stmt.Schema.LookUpField(name); field != nil { 296 | if err = m.setCommentForColumn(field, stmt); err != nil { 297 | return 298 | } 299 | } 300 | return 301 | }) 302 | return 303 | } 304 | 305 | // DropColumn drop value's "name" column 306 | func (m Migrator) DropColumn(value interface{}, name string) error { 307 | return m.Migrator.DropColumn(value, name) 308 | } 309 | 310 | // AlterColumn alter value's "field" column's type based on schema definition 311 | // 312 | //goland:noinspection SqlNoDataSourceInspection 313 | func (m Migrator) AlterColumn(value interface{}, field string) error { 314 | if !m.HasColumn(value, field) { 315 | return nil 316 | } 317 | 318 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 319 | if field := stmt.Schema.LookUpField(field); field != nil { 320 | _, tableName := m.getSchemaTable(stmt) 321 | return m.DB.Exec( 322 | "ALTER TABLE ? MODIFY ? ?", 323 | clause.Table{Name: tableName}, 324 | clause.Column{Name: field.DBName}, 325 | m.AlterDataTypeOf(stmt, field), 326 | ).Error 327 | } 328 | return fmt.Errorf("failed to look up field with name: %s", field) 329 | }) 330 | } 331 | 332 | // HasColumn check has column "field" for value or not 333 | func (m Migrator) HasColumn(value interface{}, field string) bool { 334 | var count int64 335 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 336 | if ownerName, tableName := m.getSchemaTable(stmt); ownerName != "" { 337 | return m.DB.Raw("SELECT COUNT(*) FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownerName, tableName, field).Row().Scan(&count) 338 | } else { 339 | return m.DB.Raw("SELECT COUNT(*) FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", tableName, field).Row().Scan(&count) 340 | } 341 | 342 | }) == nil && count > 0 343 | } 344 | 345 | // MigrateColumn migrate column 346 | func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) (err error) { 347 | if err = m.Migrator.MigrateColumn(value, field, columnType); err != nil { 348 | return 349 | } 350 | 351 | return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { 352 | var description string 353 | if ownerName, tableName := m.getSchemaTable(stmt); ownerName != "" { 354 | _ = m.DB.Raw( 355 | "SELECT COMMENTS FROM ALL_COL_COMMENTS WHERE OWNER = ? AND TABLE_NAME = ? AND COLUMN_NAME = ?", 356 | ownerName, tableName, field.DBName, 357 | ).Row().Scan(&description) 358 | } else { 359 | _ = m.DB.Raw( 360 | "SELECT COMMENTS FROM USER_COL_COMMENTS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", 361 | tableName, field.DBName, 362 | ).Row().Scan(&description) 363 | } 364 | if comment := field.Comment; comment != "" && comment != description { 365 | if err = m.setCommentForColumn(field, stmt); err != nil { 366 | return 367 | } 368 | } 369 | return 370 | }) 371 | } 372 | 373 | func (m Migrator) AlterDataTypeOf(stmt *gorm.Statement, field *schema.Field) (expr clause.Expr) { 374 | expr.SQL = m.DataTypeOf(field) 375 | 376 | var nullable = "" 377 | if ownerName, tableName := m.getSchemaTable(stmt); ownerName != "" { 378 | _ = m.DB.Raw("SELECT NULLABLE FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownerName, tableName, field.DBName).Row().Scan(&nullable) 379 | } else { 380 | _ = m.DB.Raw("SELECT NULLABLE FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", tableName, field.DBName).Row().Scan(&nullable) 381 | } 382 | 383 | if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { 384 | if field.DefaultValueInterface != nil { 385 | defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} 386 | m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) 387 | expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) 388 | } else if field.DefaultValue != "(-)" { 389 | expr.SQL += " DEFAULT " + field.DefaultValue 390 | } 391 | } 392 | 393 | if field.NotNull && nullable == "Y" { 394 | expr.SQL += " NOT NULL" 395 | } 396 | if field.Unique { 397 | expr.SQL += " UNIQUE" 398 | } 399 | return 400 | } 401 | 402 | // CreateConstraint create constraint 403 | func (m Migrator) CreateConstraint(value interface{}, name string) error { 404 | _ = m.TryRemoveOnUpdate(value) 405 | return m.Migrator.CreateConstraint(value, name) 406 | } 407 | 408 | // DropConstraint drop constraint 409 | // 410 | //goland:noinspection SqlNoDataSourceInspection 411 | func (m Migrator) DropConstraint(value interface{}, name string) error { 412 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 413 | _, tableName := m.getSchemaTable(stmt) 414 | for _, chk := range stmt.Schema.ParseCheckConstraints() { 415 | if chk.Name == name { 416 | return m.DB.Exec( 417 | "ALTER TABLE ? DROP CHECK ?", 418 | clause.Table{Name: tableName}, clause.Column{Name: name}, 419 | ).Error 420 | } 421 | } 422 | 423 | return m.DB.Exec( 424 | "ALTER TABLE ? DROP CONSTRAINT ?", 425 | clause.Table{Name: tableName}, clause.Column{Name: name}, 426 | ).Error 427 | }) 428 | } 429 | 430 | // HasConstraint check has constraint or not 431 | func (m Migrator) HasConstraint(value interface{}, name string) bool { 432 | var count int64 433 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 434 | return m.DB.Raw( 435 | "SELECT COUNT(*) FROM USER_CONSTRAINTS WHERE TABLE_NAME = ? AND CONSTRAINT_NAME = ?", stmt.Table, name, 436 | ).Row().Scan(&count) 437 | }) == nil && count > 0 438 | } 439 | 440 | // DropIndex drop index "name" 441 | func (m Migrator) DropIndex(value interface{}, name string) error { 442 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 443 | if idx := stmt.Schema.LookIndex(name); idx != nil { 444 | name = idx.Name 445 | } 446 | _, tableName := m.getSchemaTable(stmt) 447 | 448 | return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}, clause.Table{Name: tableName}).Error 449 | }) 450 | } 451 | 452 | // HasIndex check has index "name" or not 453 | func (m Migrator) HasIndex(value interface{}, name string) bool { 454 | var count int64 455 | _ = m.RunWithValue(value, func(stmt *gorm.Statement) error { 456 | if idx := stmt.Schema.LookIndex(name); idx != nil { 457 | name = idx.Name 458 | } 459 | 460 | return m.DB.Raw( 461 | "SELECT COUNT(*) FROM USER_INDEXES WHERE TABLE_NAME = ? AND INDEX_NAME = ?", 462 | m.Migrator.DB.NamingStrategy.TableName(stmt.Table), 463 | name, 464 | ).Row().Scan(&count) 465 | }) 466 | 467 | return count > 0 468 | } 469 | 470 | // RenameIndex rename index from oldName to newName 471 | // 472 | // see also: 473 | // https://docs.oracle.com/database/121/SPATL/alter-index-rename.htm 474 | func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { 475 | return m.RunWithValue(value, func(stmt *gorm.Statement) error { 476 | return m.DB.Exec( 477 | "ALTER INDEX ? RENAME TO ?", 478 | clause.Column{Name: oldName}, clause.Column{Name: newName}, 479 | ).Error 480 | }) 481 | } 482 | 483 | func (m Migrator) TryRemoveOnUpdate(values ...interface{}) error { 484 | for _, value := range values { 485 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { 486 | for _, rel := range stmt.Schema.Relationships.Relations { 487 | constraint := rel.ParseConstraint() 488 | if constraint != nil { 489 | rel.Field.TagSettings["CONSTRAINT"] = strings.ReplaceAll(rel.Field.TagSettings["CONSTRAINT"], fmt.Sprintf("ON UPDATE %s", constraint.OnUpdate), "") 490 | } 491 | } 492 | return nil 493 | }); err != nil { 494 | return err 495 | } 496 | } 497 | return nil 498 | } 499 | 500 | func (m Migrator) TryQuotifyReservedWords(values ...interface{}) error { 501 | for _, value := range values { 502 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { 503 | ignoreCase := !m.Dialector.(Dialector).NamingCaseSensitive 504 | for idx, v := range stmt.Schema.DBNames { 505 | if ignoreCase { 506 | v = strings.ToUpper(v) 507 | } 508 | if IsReservedWord(v) { 509 | v = strconv.Quote(v) 510 | } 511 | stmt.Schema.DBNames[idx] = v 512 | } 513 | 514 | for _, v := range stmt.Schema.Fields { 515 | fieldDBName := v.DBName 516 | if ignoreCase { 517 | v.DBName = strings.ToUpper(v.DBName) 518 | } 519 | if IsReservedWord(v.DBName) { 520 | v.DBName = strconv.Quote(v.DBName) 521 | } 522 | delete(stmt.Schema.FieldsByDBName, fieldDBName) 523 | stmt.Schema.FieldsByDBName[v.DBName] = v 524 | } 525 | return nil 526 | }); err != nil { 527 | return err 528 | } 529 | } 530 | return nil 531 | } 532 | -------------------------------------------------------------------------------- /migrator_test.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "encoding/json" 5 | "reflect" 6 | "testing" 7 | "time" 8 | 9 | "gorm.io/gorm" 10 | ) 11 | 12 | func TestMigrator_AutoMigrate(t *testing.T) { 13 | db, err := dbNamingCase, dbErrors[0] 14 | if err != nil { 15 | t.Fatal(err) 16 | } 17 | if db == nil { 18 | t.Log("db is nil!") 19 | return 20 | } 21 | 22 | type args struct { 23 | drop bool 24 | models []interface{} 25 | comments []string 26 | } 27 | tests := []struct { 28 | name string 29 | args args 30 | wantErr bool 31 | }{ 32 | {name: "TestTableUser", args: args{models: []interface{}{TestTableUser{}}, comments: []string{"用户信息表"}}}, 33 | {name: "TestTableUserDrop", args: args{drop: true, models: []interface{}{TestTableUser{}}, comments: []string{"用户信息表"}}}, 34 | {name: "TestTableUserNoComments", args: args{drop: true, models: []interface{}{TestTableUserNoComments{}}, comments: []string{"用户信息表"}}}, 35 | {name: "TestTableUserAddColumn", args: args{models: []interface{}{TestTableUserAddColumn{}}, comments: []string{"用户信息表"}}}, 36 | {name: "TestTableUserMigrateColumn", args: args{models: []interface{}{TestTableUserMigrateColumn{}}, comments: []string{"用户信息表"}}}, 37 | } 38 | for idx, tt := range tests { 39 | t.Run(tt.name, func(t *testing.T) { 40 | if len(tt.args.models) == 0 { 41 | t.Fatal("models is nil") 42 | } 43 | migrator := db.Set("gorm:table_comments", tt.args.comments).Migrator() 44 | 45 | if tt.args.drop { 46 | for _, model := range tt.args.models { 47 | if !migrator.HasTable(model) { 48 | continue 49 | } 50 | if err = migrator.DropTable(model); err != nil { 51 | t.Fatalf("DropTable() error = %v", err) 52 | } 53 | } 54 | } 55 | 56 | if err = migrator.AutoMigrate(tt.args.models...); (err != nil) != tt.wantErr { 57 | t.Errorf("AutoMigrate() error = %v, wantErr %v", err, tt.wantErr) 58 | } else if err == nil { 59 | t.Log("AutoMigrate() success!") 60 | } 61 | 62 | if idx == len(tests)-1 { 63 | wantUser := TestTableUserMigrateColumn{ 64 | TestTableUser: TestTableUser{ 65 | UID: "U0", 66 | Name: "someone", 67 | Account: "guest", 68 | Password: "MAkOvrJ8JV", 69 | Email: "", 70 | PhoneNumber: "+8618888888888", 71 | Sex: "1", 72 | UserType: 1, 73 | Enabled: true, 74 | Remark: "Ahmad", 75 | }, 76 | AddNewColumn: "AddNewColumnValue", 77 | CommentSingleQuote: "CommentSingleQuoteValue", 78 | } 79 | 80 | result := db.Create(&wantUser) 81 | if err = result.Error; err != nil { 82 | t.Fatal(err) 83 | } 84 | 85 | var gotUser TestTableUserMigrateColumn 86 | result.Where(&TestTableUser{UID: "U0"}).Find(&gotUser) 87 | if err = result.Error; err != nil { 88 | t.Fatal(err) 89 | } 90 | gotUserBytes, _ := json.Marshal(gotUser) 91 | t.Logf("gotUser Result: %s", gotUserBytes) 92 | if !reflect.DeepEqual(gotUser, wantUser) { 93 | wantUserBytes, _ := json.Marshal(wantUser) 94 | t.Errorf("wantUser Info: %s", wantUserBytes) 95 | } 96 | } 97 | }) 98 | } 99 | } 100 | 101 | // TestTableUser 测试用户信息表模型 102 | type TestTableUser struct { 103 | ID uint64 `gorm:"column:id;size:64;not null;autoIncrement:true;autoIncrementIncrement:1;primaryKey;comment:自增 ID" json:"id"` 104 | UID string `gorm:"column:uid;type:varchar(50);comment:用户身份标识" json:"uid"` 105 | Name string `gorm:"column:name;size:50;comment:用户姓名" json:"name"` 106 | 107 | Account string `gorm:"column:account;type:varchar(50);comment:登录账号" json:"account"` 108 | Password string `gorm:"column:password;type:varchar(512);comment:登录密码(密文)" json:"password"` 109 | 110 | Email string `gorm:"column:email;type:varchar(128);comment:邮箱地址" json:"email"` 111 | PhoneNumber string `gorm:"column:phone_number;type:varchar(15);comment:E.164" json:"phoneNumber"` 112 | 113 | Sex string `gorm:"column:sex;type:char(1);comment:性别" json:"sex"` 114 | Birthday *time.Time `gorm:"column:birthday;->:false;<-:create;comment:生日" json:"birthday,omitempty"` 115 | 116 | UserType int `gorm:"column:user_type;size:8;comment:用户类型" json:"userType"` 117 | 118 | Enabled bool `gorm:"column:enabled;comment:是否可用" json:"enabled"` 119 | Remark string `gorm:"column:remark;size:1024;comment:备注信息" json:"remark"` 120 | } 121 | 122 | func (TestTableUser) TableName() string { 123 | return "test_user" 124 | } 125 | 126 | type TestTableUserNoComments struct { 127 | ID uint64 `gorm:"column:id;size:64;not null;autoIncrement:true;autoIncrementIncrement:1;primaryKey" json:"id"` 128 | UID string `gorm:"column:name;type:varchar(50)" json:"uid"` 129 | Name string `gorm:"column:name;size:50" json:"name"` 130 | 131 | Account string `gorm:"column:account;type:varchar(50)" json:"account"` 132 | Password string `gorm:"column:password;type:varchar(512)" json:"password"` 133 | 134 | Email string `gorm:"column:email;type:varchar(128)" json:"email"` 135 | PhoneNumber string `gorm:"column:phone_number;type:varchar(15)" json:"phoneNumber"` 136 | 137 | Sex string `gorm:"column:sex;type:char(1)" json:"sex"` 138 | Birthday time.Time `gorm:"column:birthday" json:"birthday"` 139 | 140 | UserType int `gorm:"column:user_type;size:8" json:"userType"` 141 | 142 | Enabled bool `gorm:"column:enabled" json:"enabled"` 143 | Remark string `gorm:"column:remark;size:1024" json:"remark"` 144 | } 145 | 146 | func (TestTableUserNoComments) TableName() string { 147 | return "test_user" 148 | } 149 | 150 | type TestTableUserAddColumn struct { 151 | TestTableUser 152 | 153 | AddNewColumn string `gorm:"column:add_new_column;type:varchar(100);comment:添加新字段"` 154 | } 155 | 156 | func (TestTableUserAddColumn) TableName() string { 157 | return "test_user" 158 | } 159 | 160 | type TestTableUserMigrateColumn struct { 161 | TestTableUser 162 | 163 | AddNewColumn string `gorm:"column:add_new_column;type:varchar(100);comment:测试添加新字段"` 164 | CommentSingleQuote string `gorm:"column:comment_single_quote;comment:注释中存在单引号'[']'"` 165 | } 166 | 167 | func (TestTableUserMigrateColumn) TableName() string { 168 | return "test_user" 169 | } 170 | 171 | type testTableColumnTypeModel struct { 172 | ID int64 `gorm:"column:id;size:64;not null;autoIncrement:true;autoIncrementIncrement:1;primaryKey"` 173 | Name string `gorm:"column:name;size:50"` 174 | Age uint8 `gorm:"column:age;size:8"` 175 | 176 | Avatar []byte `gorm:"column:avatar;"` 177 | 178 | Balance float64 `gorm:"column:balance;type:decimal(18, 2)"` 179 | Remark string `gorm:"column:remark;size:-1"` 180 | Enabled bool `gorm:"column:enabled;"` 181 | 182 | CreatedAt time.Time 183 | UpdatedAt time.Time 184 | DeletedAt gorm.DeletedAt 185 | } 186 | 187 | func (t testTableColumnTypeModel) TableName() string { 188 | return "test_table_column_type" 189 | } 190 | 191 | func TestMigrator_TableColumnType(t *testing.T) { 192 | db, err := dbNamingCase, dbErrors[0] 193 | if err != nil { 194 | t.Fatal(err) 195 | } 196 | if db == nil { 197 | t.Log("db is nil!") 198 | return 199 | } 200 | testModel := new(testTableColumnTypeModel) 201 | 202 | type args struct { 203 | model interface{} 204 | drop bool 205 | } 206 | tests := []struct { 207 | name string 208 | args args 209 | }{ 210 | {name: "create", args: args{model: testModel}}, 211 | {name: "alter", args: args{model: testModel, drop: true}}, 212 | } 213 | for _, tt := range tests { 214 | t.Run(tt.name, func(t *testing.T) { 215 | if err = db.AutoMigrate(tt.args.model); err != nil { 216 | t.Errorf("AutoMigrate failed:%v", err) 217 | } 218 | if tt.args.drop { 219 | _ = db.Migrator().DropTable(tt.args.model) 220 | } 221 | }) 222 | } 223 | } 224 | 225 | type testFieldNameIsReservedWord struct { 226 | ID int64 `gorm:"column:id;size:64;not null;autoIncrement:true;autoIncrementIncrement:1;primaryKey"` 227 | 228 | FLOAT float64 `gorm:"type:decimal(18, 2)"` 229 | DESC string `gorm:"size:-1"` 230 | ON bool 231 | 232 | Order int 233 | Sort int 234 | 235 | CREATE time.Time 236 | UPDATE time.Time 237 | DELETE gorm.DeletedAt 238 | } 239 | 240 | func (t testFieldNameIsReservedWord) TableName() string { 241 | return "test_name_is_reserved_word" 242 | } 243 | 244 | func TestMigrator_FieldNameIsReservedWord(t *testing.T) { 245 | if err := dbErrors[0]; err != nil { 246 | t.Fatal(err) 247 | } 248 | if dbNamingCase == nil { 249 | t.Log("dbNamingCase is nil!") 250 | return 251 | } 252 | if err := dbErrors[1]; err != nil { 253 | t.Fatal(err) 254 | } 255 | if dbIgnoreCase == nil { 256 | t.Log("dbNamingCase is nil!") 257 | return 258 | } 259 | 260 | testModel := new(testFieldNameIsReservedWord) 261 | _ = dbNamingCase.Migrator().DropTable(testModel) 262 | _ = dbIgnoreCase.Migrator().DropTable(testModel) 263 | 264 | type args struct { 265 | db *gorm.DB 266 | model interface{} 267 | drop bool 268 | } 269 | tests := []struct { 270 | name string 271 | args args 272 | }{ 273 | {name: "createNamingCase", args: args{db: dbNamingCase, model: testModel}}, 274 | {name: "alterNamingCase", args: args{db: dbNamingCase, model: testModel, drop: true}}, 275 | {name: "createIgnoreCase", args: args{db: dbIgnoreCase, model: testModel}}, 276 | {name: "alterIgnoreCase", args: args{db: dbIgnoreCase, model: testModel, drop: true}}, 277 | } 278 | for _, tt := range tests { 279 | t.Run(tt.name, func(t *testing.T) { 280 | db := tt.args.db 281 | if err := db.AutoMigrate(tt.args.model); err != nil { 282 | t.Errorf("AutoMigrate failed:%v", err) 283 | } 284 | if tt.args.drop { 285 | _ = db.Migrator().DropTable(tt.args.model) 286 | } 287 | }) 288 | } 289 | } 290 | 291 | func TestMigrator_DatatypesJsonMapNamingCase(t *testing.T) { 292 | if err := dbErrors[0]; err != nil { 293 | t.Fatal(err) 294 | } 295 | if dbNamingCase == nil { 296 | t.Log("dbNamingCase is nil!") 297 | return 298 | } 299 | 300 | type testJsonMapNamingCase struct { 301 | gorm.Model 302 | 303 | Extras JSONMap `gorm:"check:\"extras\" IS JSON"` 304 | } 305 | testModel := new(testJsonMapNamingCase) 306 | _ = dbNamingCase.Migrator().DropTable(testModel) 307 | 308 | type args struct { 309 | db *gorm.DB 310 | model interface{} 311 | drop bool 312 | } 313 | tests := []struct { 314 | name string 315 | args args 316 | }{ 317 | {name: "createDatatypesJsonMapNamingCase", args: args{db: dbNamingCase, model: testModel}}, 318 | {name: "alterDatatypesJsonMapNamingCase", args: args{db: dbNamingCase, model: testModel, drop: true}}, 319 | } 320 | for _, tt := range tests { 321 | t.Run(tt.name, func(t *testing.T) { 322 | db := tt.args.db 323 | if err := db.AutoMigrate(tt.args.model); err != nil { 324 | t.Errorf("AutoMigrate failed:%v", err) 325 | } 326 | if tt.args.drop { 327 | _ = db.Migrator().DropTable(tt.args.model) 328 | } 329 | }) 330 | } 331 | } 332 | 333 | func TestMigrator_DatatypesJsonMapIgnoreCase(t *testing.T) { 334 | if err := dbErrors[1]; err != nil { 335 | t.Fatal(err) 336 | } 337 | if dbIgnoreCase == nil { 338 | t.Log("dbNamingCase is nil!") 339 | return 340 | } 341 | 342 | type tesJsonMapIgnoreCase struct { 343 | gorm.Model 344 | 345 | Extras JSONMap `gorm:"check:extras IS JSON"` 346 | } 347 | testModel := new(tesJsonMapIgnoreCase) 348 | _ = dbIgnoreCase.Migrator().DropTable(testModel) 349 | 350 | type args struct { 351 | db *gorm.DB 352 | model interface{} 353 | drop bool 354 | } 355 | tests := []struct { 356 | name string 357 | args args 358 | }{ 359 | {name: "createDatatypesJsonMapIgnoreCase", args: args{db: dbIgnoreCase, model: testModel}}, 360 | {name: "alterDatatypesJsonMapIgnoreCase", args: args{db: dbIgnoreCase, model: testModel, drop: true}}, 361 | } 362 | for _, tt := range tests { 363 | t.Run(tt.name, func(t *testing.T) { 364 | db := tt.args.db 365 | if err := db.AutoMigrate(tt.args.model); err != nil { 366 | t.Errorf("AutoMigrate failed:%v", err) 367 | } 368 | if tt.args.drop { 369 | _ = db.Migrator().DropTable(tt.args.model) 370 | } 371 | }) 372 | } 373 | } 374 | -------------------------------------------------------------------------------- /namer.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "strings" 5 | 6 | "gorm.io/gorm/schema" 7 | ) 8 | 9 | // Namer implement gorm schema namer interface 10 | type Namer struct { 11 | // NamingStrategy use custom naming strategy in gorm.Config on initialize 12 | NamingStrategy schema.Namer 13 | // CaseSensitive determines whether naming is case-sensitive 14 | CaseSensitive bool 15 | } 16 | 17 | // Deprecated: As of v1.5.0, use the Namer.ConvertNameToFormat instead. 18 | // 19 | //goland:noinspection GoUnusedExportedFunction 20 | func ConvertNameToFormat(x string) string { 21 | return (Namer{}).ConvertNameToFormat(x) 22 | } 23 | 24 | // ConvertNameToFormat return appropriate capitalization name based on CaseSensitive 25 | func (n Namer) ConvertNameToFormat(x string) string { 26 | if n.CaseSensitive { 27 | return x 28 | } 29 | return strings.ToUpper(x) 30 | } 31 | 32 | // TableName convert string to table name 33 | func (n Namer) TableName(table string) (name string) { 34 | return n.ConvertNameToFormat(n.NamingStrategy.TableName(table)) 35 | } 36 | 37 | // SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName 38 | func (n Namer) SchemaName(table string) string { 39 | return n.ConvertNameToFormat(n.NamingStrategy.SchemaName(table)) 40 | } 41 | 42 | // ColumnName convert string to column name 43 | func (n Namer) ColumnName(table, column string) (name string) { 44 | return n.ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column)) 45 | } 46 | 47 | // JoinTableName convert string to join table name 48 | func (n Namer) JoinTableName(table string) (name string) { 49 | return n.ConvertNameToFormat(n.NamingStrategy.JoinTableName(table)) 50 | } 51 | 52 | // RelationshipFKName generate fk name for relation 53 | func (n Namer) RelationshipFKName(relationship schema.Relationship) (name string) { 54 | return n.ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship)) 55 | } 56 | 57 | // CheckerName generate checker name 58 | func (n Namer) CheckerName(table, column string) (name string) { 59 | return n.ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column)) 60 | } 61 | 62 | // IndexName generate index name 63 | func (n Namer) IndexName(table, column string) (name string) { 64 | return n.ConvertNameToFormat(n.NamingStrategy.IndexName(table, column)) 65 | } 66 | 67 | // UniqueName generate unique constraint name 68 | func (n Namer) UniqueName(table, column string) string { 69 | return n.ConvertNameToFormat(n.NamingStrategy.UniqueName(table, column)) 70 | } 71 | -------------------------------------------------------------------------------- /oracle.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "reflect" 8 | "regexp" 9 | "strconv" 10 | "strings" 11 | "time" 12 | 13 | "github.com/sijms/go-ora/v2" 14 | "gorm.io/gorm" 15 | "gorm.io/gorm/callbacks" 16 | "gorm.io/gorm/clause" 17 | "gorm.io/gorm/logger" 18 | "gorm.io/gorm/migrator" 19 | "gorm.io/gorm/schema" 20 | ) 21 | 22 | type Config struct { 23 | DriverName string 24 | DSN string 25 | Conn gorm.ConnPool //*sql.DB 26 | DefaultStringSize uint 27 | DBVer string 28 | 29 | IgnoreCase bool // warning: may cause performance issues 30 | NamingCaseSensitive bool // whether naming is case-sensitive 31 | // whether VARCHAR type size is character length, defaulting to byte length 32 | VarcharSizeIsCharLength bool 33 | 34 | // RowNumberAliasForOracle11 is the alias for ROW_NUMBER() in Oracle 11g, defaulting to ROW_NUM 35 | RowNumberAliasForOracle11 string 36 | } 37 | 38 | // Dialector implement GORM database dialector 39 | type Dialector struct { 40 | *Config 41 | } 42 | 43 | //goland:noinspection GoUnusedExportedFunction 44 | func Open(dsn string) gorm.Dialector { 45 | return &Dialector{Config: &Config{DSN: dsn}} 46 | } 47 | 48 | //goland:noinspection GoUnusedExportedFunction 49 | func New(config Config) gorm.Dialector { 50 | return &Dialector{Config: &config} 51 | } 52 | 53 | // BuildUrl create databaseURL from server, port, service, user, password, urlOptions 54 | // this function help build a will formed databaseURL and accept any character as it 55 | // convert special charters to corresponding values in URL 56 | // 57 | //goland:noinspection GoUnusedExportedFunction 58 | func BuildUrl(server string, port int, service, user, password string, options map[string]string) string { 59 | return go_ora.BuildUrl(server, port, service, user, password, options) 60 | } 61 | 62 | // GetStringExpr replace single quotes in the string with two single quotes 63 | // and return the expression for the string value 64 | // 65 | // quotes : if the SQL placeholder is ? then pass true, if it is '?' then do not pass or pass false. 66 | func GetStringExpr(value string, quotes ...bool) clause.Expr { 67 | if len(quotes) > 0 && quotes[0] { 68 | if strings.Contains(value, "'") { 69 | // escape single quotes 70 | if !strings.Contains(value, "]'") { 71 | value = fmt.Sprintf("q'[%s]'", value) 72 | } else if !strings.Contains(value, "}'") { 73 | value = fmt.Sprintf("q'{%s}'", value) 74 | } else if !strings.Contains(value, ">'") { 75 | value = fmt.Sprintf("q'<%s>'", value) 76 | } else if !strings.Contains(value, ")'") { 77 | value = fmt.Sprintf("q'(%s)'", value) 78 | } else { 79 | value = fmt.Sprintf("'%s'", strings.ReplaceAll(value, "'", "''")) 80 | } 81 | } else { 82 | value = fmt.Sprintf("'%s'", value) 83 | } 84 | } else { 85 | value = strings.ReplaceAll(value, "'", "''") 86 | } 87 | return gorm.Expr(value) 88 | } 89 | 90 | // AddSessionParams setting database connection session parameters, 91 | // the value is wrapped in single quotes. 92 | // 93 | // If the value doesn't need to be wrapped in single quotes, 94 | // please use the go_ora.AddSessionParam function directly, 95 | // or pass the originals parameter as true. 96 | func AddSessionParams(db *sql.DB, params map[string]string, originals ...bool) (keys []string, err error) { 97 | if db == nil { 98 | return 99 | } 100 | if _, ok := db.Driver().(*go_ora.OracleDriver); !ok { 101 | return 102 | } 103 | var original bool 104 | if len(originals) > 0 { 105 | original = originals[0] 106 | } 107 | 108 | for key, value := range params { 109 | if key == "" || value == "" { 110 | continue 111 | } 112 | if !original { 113 | value = GetStringExpr(value, true).SQL 114 | } 115 | if err = go_ora.AddSessionParam(db, key, value); err != nil { 116 | return 117 | } 118 | keys = append(keys, key) 119 | } 120 | return 121 | } 122 | 123 | // DelSessionParams remove session parameters 124 | func DelSessionParams(db *sql.DB, keys []string) { 125 | if db == nil { 126 | return 127 | } 128 | if _, ok := db.Driver().(*go_ora.OracleDriver); !ok { 129 | return 130 | } 131 | 132 | for _, key := range keys { 133 | if key == "" { 134 | continue 135 | } 136 | go_ora.DelSessionParam(db, key) 137 | } 138 | } 139 | 140 | func convertCustomType(val interface{}) interface{} { 141 | rv := reflect.ValueOf(val) 142 | if !rv.IsValid() || rv.IsZero() { 143 | return val 144 | } 145 | ri := rv.Interface() 146 | typeName := reflect.TypeOf(ri).Name() 147 | if reflect.TypeOf(val).Kind() == reflect.Ptr { 148 | if rv.IsNil() { 149 | typeName = rv.Type().Elem().Name() 150 | } else { 151 | for rv.Kind() == reflect.Ptr { 152 | rv = rv.Elem() 153 | } 154 | ri = rv.Interface() 155 | typeName = reflect.TypeOf(ri).Name() 156 | } 157 | } 158 | if typeName == "DeletedAt" { 159 | // gorm.DeletedAt 160 | if rv.IsZero() { 161 | val = sql.NullTime{} 162 | } else { 163 | val = getTimeValue(ri.(gorm.DeletedAt).Time) 164 | } 165 | } else if m := rv.MethodByName("Time"); m.IsValid() && m.Type().NumIn() == 0 { 166 | // custom time type 167 | for _, result := range m.Call([]reflect.Value{}) { 168 | if reflect.TypeOf(result.Interface()).Name() == "Time" { 169 | val = getTimeValue(result.Interface().(time.Time)) 170 | } 171 | } 172 | } 173 | return val 174 | } 175 | 176 | func ptrDereference(obj interface{}) (value interface{}) { 177 | if obj == nil { 178 | return obj 179 | } 180 | if t := reflect.TypeOf(obj); t.Kind() != reflect.Ptr { 181 | return obj 182 | } 183 | 184 | v := reflect.ValueOf(obj) 185 | for v.Kind() == reflect.Ptr && !v.IsNil() { 186 | v = v.Elem() 187 | } 188 | if !v.IsValid() || v.Kind() == reflect.Ptr && v.IsNil() { 189 | return obj 190 | } 191 | value = v.Interface() 192 | return 193 | } 194 | 195 | func getTimeValue(t time.Time) interface{} { 196 | if t.IsZero() { 197 | return sql.NullTime{} 198 | } 199 | return t 200 | } 201 | 202 | func (d Dialector) DummyTableName() string { 203 | return "DUAL" 204 | } 205 | 206 | func (d Dialector) Name() string { 207 | return "oracle" 208 | } 209 | 210 | func (d Dialector) Initialize(db *gorm.DB) (err error) { 211 | db.NamingStrategy = Namer{ 212 | NamingStrategy: db.NamingStrategy, 213 | CaseSensitive: d.NamingCaseSensitive, 214 | } 215 | d.DefaultStringSize = 1024 216 | 217 | // register callbacks 218 | config := &callbacks.Config{ 219 | CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, 220 | UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, 221 | DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, 222 | } 223 | callbacks.RegisterDefaultCallbacks(db, config) 224 | 225 | d.DriverName = "oracle" 226 | 227 | if d.Conn != nil { 228 | db.ConnPool = d.Conn 229 | } else { 230 | db.ConnPool, err = sql.Open(d.DriverName, d.DSN) 231 | if err != nil { 232 | return 233 | } 234 | } 235 | if d.IgnoreCase { 236 | if sqlDB, ok := db.ConnPool.(*sql.DB); ok { 237 | // warning: may cause performance issues 238 | _ = go_ora.AddSessionParam(sqlDB, "NLS_COMP", "LINGUISTIC") 239 | _ = go_ora.AddSessionParam(sqlDB, "NLS_SORT", "BINARY_CI") 240 | } 241 | } 242 | err = db.ConnPool.QueryRowContext(context.Background(), "select version from product_component_version where rownum = 1").Scan(&d.DBVer) 243 | if err != nil { 244 | return err 245 | } 246 | //log.Println("DBVer:" + d.DBVer) 247 | if err = db.Callback().Create().Replace("gorm:create", Create); err != nil { 248 | return 249 | } 250 | if err = db.Callback().Update().Replace("gorm:update", Update(config)); err != nil { 251 | return 252 | } 253 | 254 | for k, v := range d.ClauseBuilders() { 255 | db.ClauseBuilders[k] = v 256 | } 257 | return 258 | } 259 | 260 | func (d Dialector) ClauseBuilders() (clauseBuilders map[string]clause.ClauseBuilder) { 261 | clauseBuilders = make(map[string]clause.ClauseBuilder) 262 | 263 | if dbVer, _ := strconv.Atoi(strings.Split(d.DBVer, ".")[0]); dbVer > 11 { 264 | clauseBuilders["LIMIT"] = d.RewriteLimit 265 | } else { 266 | clauseBuilders["LIMIT"] = d.RewriteLimit11 267 | } 268 | 269 | clauseBuilders["RETURNING"] = func(c clause.Clause, builder clause.Builder) { 270 | if returning, ok := c.Expression.(clause.Returning); ok { 271 | _, _ = builder.WriteString("/*- -*/") 272 | _, _ = builder.WriteString("RETURNING ") 273 | 274 | if len(returning.Columns) > 0 { 275 | for idx, column := range returning.Columns { 276 | if idx > 0 { 277 | _ = builder.WriteByte(',') 278 | } 279 | 280 | builder.WriteQuoted(column) 281 | } 282 | } else { 283 | _ = builder.WriteByte('*') 284 | } 285 | } 286 | } 287 | return 288 | } 289 | 290 | func (d Dialector) getLimitRows(limit clause.Limit) (limitRows int, hasLimit bool) { 291 | if l := limit.Limit; l != nil { 292 | limitRows = *l 293 | hasLimit = limitRows > 0 294 | } 295 | return 296 | } 297 | 298 | func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) { 299 | if limit, ok := c.Expression.(clause.Limit); ok { 300 | limitRows, hasLimit := d.getLimitRows(limit) 301 | 302 | if stmt, ok := builder.(*gorm.Statement); ok { 303 | if _, hasOrderBy := stmt.Clauses["ORDER BY"]; !hasOrderBy && hasLimit { 304 | s := stmt.Schema 305 | _, _ = builder.WriteString("ORDER BY ") 306 | if s != nil && s.PrioritizedPrimaryField != nil { 307 | builder.WriteQuoted(s.PrioritizedPrimaryField.DBName) 308 | _ = builder.WriteByte(' ') 309 | } else { 310 | _, _ = builder.WriteString("(SELECT NULL FROM ") 311 | _, _ = builder.WriteString(d.DummyTableName()) 312 | _, _ = builder.WriteString(")") 313 | } 314 | } 315 | } 316 | 317 | if offset := limit.Offset; offset > 0 { 318 | _, _ = builder.WriteString(" OFFSET ") 319 | builder.AddVar(builder, offset) 320 | _, _ = builder.WriteString(" ROWS") 321 | } 322 | if hasLimit { 323 | _, _ = builder.WriteString(" FETCH NEXT ") 324 | builder.AddVar(builder, limitRows) 325 | _, _ = builder.WriteString(" ROWS ONLY") 326 | } 327 | } 328 | } 329 | 330 | // RewriteLimit11 rewrite the LIMIT clause in the query to accommodate pagination requirements for Oracle 11g and lower database versions 331 | // 332 | // # Limit and Offset 333 | // 334 | // SELECT * FROM (SELECT T.*, ROW_NUMBER() OVER (ORDER BY column) AS ROW_NUM FROM table_name T) 335 | // WHERE ROW_NUM BETWEEN offset+1 AND offset+limit 336 | // 337 | // # Only Limit 338 | // 339 | // SELECT * FROM table_name WHERE ROWNUM <= limit ORDER BY column 340 | // 341 | // # Only Offset 342 | // 343 | // SELECT * FROM table_name WHERE ROWNUM > offset ORDER BY column 344 | func (d Dialector) RewriteLimit11(c clause.Clause, builder clause.Builder) { 345 | limit, ok := c.Expression.(clause.Limit) 346 | if !ok { 347 | return 348 | } 349 | offsetRows := limit.Offset 350 | hasOffset := offsetRows > 0 351 | limitRows, hasLimit := d.getLimitRows(limit) 352 | if !hasOffset && !hasLimit { 353 | return 354 | } 355 | 356 | var stmt *gorm.Statement 357 | if stmt, ok = builder.(*gorm.Statement); !ok { 358 | return 359 | } 360 | 361 | if hasLimit && hasOffset { 362 | // 使用 ROW_NUMBER() 和子查询实现分页查询 363 | if d.RowNumberAliasForOracle11 == "" { 364 | d.RowNumberAliasForOracle11 = "ROW_NUM" 365 | } 366 | subQuerySQL := fmt.Sprintf( 367 | "SELECT * FROM (SELECT T.*, ROW_NUMBER() OVER (ORDER BY %s) AS %s FROM (%s) T) WHERE %s BETWEEN %d AND %d", 368 | d.getOrderByColumns(stmt), 369 | d.RowNumberAliasForOracle11, 370 | strings.TrimSpace(stmt.SQL.String()), 371 | d.RowNumberAliasForOracle11, 372 | offsetRows+1, 373 | offsetRows+limitRows, 374 | ) 375 | stmt.SQL.Reset() 376 | stmt.SQL.WriteString(subQuerySQL) 377 | } else if hasLimit { 378 | // 只有 Limit 的情况 379 | d.rewriteRownumStmt(stmt, builder, " <= ", limitRows) 380 | } else { 381 | // 只有 Offset 的情况 382 | d.rewriteRownumStmt(stmt, builder, " > ", offsetRows) 383 | } 384 | } 385 | 386 | func (d Dialector) rewriteRownumStmt(stmt *gorm.Statement, builder clause.Builder, operator string, rows int) { 387 | limitSql := strings.Builder{} 388 | if _, ok := stmt.Clauses["WHERE"]; !ok { 389 | limitSql.WriteString(" WHERE ") 390 | } else { 391 | limitSql.WriteString(" AND ") 392 | } 393 | limitSql.WriteString("ROWNUM") 394 | limitSql.WriteString(operator) 395 | limitSql.WriteString(strconv.Itoa(rows)) 396 | 397 | if _, hasOrderBy := stmt.Clauses["ORDER BY"]; !hasOrderBy { 398 | _, _ = builder.WriteString(limitSql.String()) 399 | } else { 400 | // "ORDER BY" before insert 401 | sqlTmp := strings.Builder{} 402 | sqlOld := stmt.SQL.String() 403 | orderIndex := strings.Index(sqlOld, "ORDER BY") - 1 404 | sqlTmp.WriteString(sqlOld[:orderIndex]) 405 | sqlTmp.WriteString(limitSql.String()) 406 | sqlTmp.WriteString(sqlOld[orderIndex:]) 407 | stmt.SQL = sqlTmp 408 | } 409 | } 410 | 411 | func (d Dialector) getOrderByColumns(stmt *gorm.Statement) string { 412 | if orderByClause, ok := stmt.Clauses["ORDER BY"]; ok { 413 | var orderBy clause.OrderBy 414 | if orderBy, ok = orderByClause.Expression.(clause.OrderBy); ok && len(orderBy.Columns) > 0 { 415 | orderByBuilder := strings.Builder{} 416 | for i, column := range orderBy.Columns { 417 | if i > 0 { 418 | orderByBuilder.WriteString(", ") 419 | } 420 | orderByBuilder.WriteString(column.Column.Name) 421 | if column.Desc { 422 | orderByBuilder.WriteString(" DESC") 423 | } 424 | } 425 | return orderByBuilder.String() 426 | } 427 | } 428 | return "NULL" 429 | } 430 | 431 | func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression { 432 | return clause.Expr{SQL: "VALUES (DEFAULT)"} 433 | } 434 | 435 | func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator { 436 | return Migrator{ 437 | Migrator: migrator.Migrator{ 438 | Config: migrator.Config{ 439 | DB: db, 440 | Dialector: d, 441 | CreateIndexAfterCreateTable: true, 442 | }, 443 | }, 444 | } 445 | } 446 | 447 | func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, _ interface{}) { 448 | _, _ = writer.WriteString(":") 449 | _, _ = writer.WriteString(strconv.Itoa(len(stmt.Vars))) 450 | } 451 | 452 | func (d Dialector) QuoteTo(writer clause.Writer, str string) { 453 | if d.NamingCaseSensitive && str != "" { 454 | var ( 455 | underQuoted, selfQuoted bool 456 | continuousBacktick int8 457 | shiftDelimiter int8 458 | ) 459 | 460 | for _, v := range []byte(str) { 461 | switch v { 462 | case '"': 463 | continuousBacktick++ 464 | if continuousBacktick == 2 { 465 | _, _ = writer.WriteString(`""`) 466 | continuousBacktick = 0 467 | } 468 | case '.': 469 | if continuousBacktick > 0 || !selfQuoted { 470 | shiftDelimiter = 0 471 | underQuoted = false 472 | continuousBacktick = 0 473 | _ = writer.WriteByte('"') 474 | } 475 | _ = writer.WriteByte(v) 476 | continue 477 | default: 478 | if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { 479 | _ = writer.WriteByte('"') 480 | underQuoted = true 481 | if selfQuoted = continuousBacktick > 0; selfQuoted { 482 | continuousBacktick -= 1 483 | } 484 | } 485 | 486 | for ; continuousBacktick > 0; continuousBacktick -= 1 { 487 | _, _ = writer.WriteString(`""`) 488 | } 489 | 490 | _ = writer.WriteByte(v) 491 | } 492 | shiftDelimiter++ 493 | } 494 | 495 | if continuousBacktick > 0 && !selfQuoted { 496 | _, _ = writer.WriteString(`""`) 497 | } 498 | _ = writer.WriteByte('"') 499 | } else { 500 | _, _ = writer.WriteString(str) 501 | } 502 | } 503 | 504 | var numericPlaceholder = regexp.MustCompile(`:(\d+)`) 505 | 506 | func (d Dialector) Explain(sql string, vars ...interface{}) string { 507 | for idx, val := range vars { 508 | switch v := ptrDereference(val).(type) { 509 | case bool: 510 | if v { 511 | vars[idx] = 1 512 | } else { 513 | vars[idx] = 0 514 | } 515 | case go_ora.Clob: 516 | vars[idx] = v.String 517 | } 518 | } 519 | return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) 520 | } 521 | 522 | func (d Dialector) DataTypeOf(field *schema.Field) string { 523 | delete(field.TagSettings, "RESTRICT") 524 | 525 | var sqlType string 526 | switch field.DataType { 527 | case schema.Bool: 528 | sqlType = "NUMBER(1)" 529 | case schema.Int, schema.Uint: 530 | sqlType = "INTEGER" 531 | if field.Size > 0 && field.Size <= 8 { 532 | sqlType = "SMALLINT" 533 | } 534 | 535 | if field.AutoIncrement { 536 | sqlType += " GENERATED BY DEFAULT AS IDENTITY" 537 | } 538 | case schema.Float: 539 | sqlType = "FLOAT" 540 | case schema.String, "VARCHAR2": 541 | size := field.Size 542 | defaultSize := d.DefaultStringSize 543 | 544 | if size == 0 { 545 | if defaultSize > 0 { 546 | size = int(defaultSize) 547 | } else { 548 | hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != "" 549 | // TEXT, GEOMETRY or JSON column can't have a default value 550 | if field.PrimaryKey || field.HasDefaultValue || hasIndex { 551 | size = 191 // utf8mb4 552 | } 553 | } 554 | } 555 | 556 | if size > 0 && size <= 4000 { 557 | // 默认情况下 VARCHAR2 可以指定一个不超过 4000 的正整数作为字节长度 558 | if d.VarcharSizeIsCharLength { 559 | if size*3 > 4000 { 560 | sqlType = "CLOB" 561 | } else { 562 | sqlType = fmt.Sprintf("VARCHAR2(%d CHAR)", size) // 字符长度(size * 3) 563 | } 564 | } else { 565 | sqlType = fmt.Sprintf("VARCHAR2(%d)", size) 566 | } 567 | } else { 568 | sqlType = "CLOB" 569 | } 570 | case schema.Time: 571 | sqlType = "TIMESTAMP WITH TIME ZONE" 572 | case schema.Bytes: 573 | sqlType = "BLOB" 574 | default: 575 | sqlType = string(field.DataType) 576 | 577 | if strings.EqualFold(sqlType, "text") { 578 | sqlType = "CLOB" 579 | } 580 | 581 | if sqlType == "" { 582 | panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", field.FieldType.Name(), field.FieldType.String())) 583 | } 584 | } 585 | 586 | return sqlType 587 | } 588 | 589 | func (d Dialector) SavePoint(tx *gorm.DB, name string) error { 590 | tx.Exec("SAVEPOINT " + name) 591 | return tx.Error 592 | } 593 | 594 | func (d Dialector) RollbackTo(tx *gorm.DB, name string) error { 595 | tx.Exec("ROLLBACK TO SAVEPOINT " + name) 596 | return tx.Error 597 | } 598 | -------------------------------------------------------------------------------- /oracle_ora.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import "github.com/sijms/go-ora/v2" 4 | 5 | type ( 6 | RefCursor struct { 7 | go_ora.RefCursor 8 | } 9 | 10 | DataSet struct { 11 | go_ora.DataSet 12 | } 13 | 14 | Out struct { 15 | go_ora.Out 16 | } 17 | ) 18 | 19 | func (cursor *RefCursor) Query() (dataset *DataSet, err error) { 20 | var d *go_ora.DataSet 21 | if d, err = cursor.RefCursor.Query(); err != nil { 22 | return 23 | } 24 | dataset = &DataSet{DataSet: *d} 25 | return 26 | } 27 | -------------------------------------------------------------------------------- /oracle_ora_test.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "log" 11 | "testing" 12 | ) 13 | 14 | const ( 15 | procCreateExamplePagingQuery = `-- example procedure 16 | create or replace PROCEDURE PRO_EXAMPLE_PAGING_QUERY ( 17 | BASIC_SQL IN VARCHAR2, -- 基础查询 SQL 18 | ORDER_FIELD IN VARCHAR2, -- 排序字段 19 | PAGE_NUM IN NUMBER, -- 当前页码 20 | PAGE_SIZE IN NUMBER, -- 每页条数 21 | 22 | TOTAL_NUM OUT NUMBER, -- 返回总条数 23 | RES_CURSOR OUT SYS_REFCURSOR -- 返回结果集 24 | ) 25 | AS 26 | BEGIN 27 | DECLARE 28 | PAGING_SQL VARCHAR2(4000) := ''; -- 分页查询 SQL 29 | TOTAL_SQL VARCHAR2(4000) := ''; -- 总条数查询 SQL 30 | OFFSET NUMBER(10); -- 分页查询偏移量 31 | BEGIN 32 | -- 查询总条数 33 | TOTAL_SQL := 'SELECT TO_NUMBER(COUNT(*)) FROM (' || BASIC_SQL || ') TB'; 34 | EXECUTE IMMEDIATE TOTAL_SQL INTO TOTAL_NUM; 35 | 36 | -- 分页查询 37 | OFFSET := (PAGE_NUM - 1) * PAGE_SIZE; 38 | PAGING_SQL := 'SELECT * FROM (SELECT T.*, ROW_NUMBER() OVER (ORDER BY ' || ORDER_FIELD || 39 | ') AS ROW_NUM FROM (' || BASIC_SQL || ') T) WHERE ROW_NUM BETWEEN ' || 40 | TO_CHAR(OFFSET+1) || ' AND ' || TO_CHAR(OFFSET+PAGE_SIZE); 41 | 42 | OPEN RES_CURSOR FOR PAGING_SQL; 43 | END; 44 | END PRO_EXAMPLE_PAGING_QUERY;` 45 | ) 46 | 47 | func ExampleRefCursor_Query() { 48 | db, err := dbNamingCase, dbErrors[0] 49 | if err != nil || db == nil { 50 | log.Fatal(err) 51 | } 52 | if err = db.Exec(procCreateExamplePagingQuery).Error; err != nil { 53 | log.Fatal(err) 54 | } 55 | var ( 56 | totalNum uint 57 | resCursor RefCursor 58 | 59 | values = []any{ 60 | "SELECT * FROM USER_TABLES", 61 | "TABLE_NAME", 62 | 1, 63 | 10, 64 | sql.Out{Dest: &totalNum}, 65 | sql.Out{Dest: &resCursor.RefCursor}, 66 | } 67 | ) 68 | // 执行存储过程 69 | if err = db.Exec(` 70 | BEGIN 71 | PRO_EXAMPLE_PAGING_QUERY(:BASIC_SQL, :ORDER_FIELD, :PAGE_NUM, :PAGE_SIZE, :TOTAL_NUM, :RES_CURSOR); 72 | END;`, values...).Error; err != nil { 73 | log.Fatal(err) 74 | } 75 | defer func(cursor *RefCursor) { 76 | _ = cursor.Close() 77 | }(&resCursor) 78 | 79 | // 读取游标 80 | var dataset *DataSet 81 | if dataset, err = resCursor.Query(); err != nil { 82 | log.Fatal(err) 83 | } 84 | defer func(dataset *DataSet) { 85 | _ = dataset.Close() 86 | }(dataset) 87 | 88 | var dataRows []map[string]any 89 | columns := dataset.Columns() 90 | dest := make([]driver.Value, len(columns)) 91 | for { 92 | if err = dataset.Next(dest); err != nil { 93 | if errors.Is(err, io.EOF) { 94 | err = nil 95 | } 96 | break 97 | } 98 | dataRow := make(map[string]any, len(columns)) 99 | for i, v := range dest { 100 | dataRow[columns[i]] = v 101 | } 102 | dataRows = append(dataRows, dataRow) 103 | } 104 | if err != nil { 105 | log.Fatal(err) 106 | } 107 | fmt.Println(len(dataRows) > 0) 108 | //Output: true 109 | } 110 | 111 | func TestExecProcedure(t *testing.T) { 112 | db, err := dbNamingCase, dbErrors[0] 113 | if err != nil { 114 | t.Fatal(err) 115 | } 116 | if db == nil { 117 | t.Log("db is nil!") 118 | return 119 | } 120 | if err = db.Exec(procCreateExamplePagingQuery).Error; err != nil { 121 | t.Fatal(err) 122 | } 123 | 124 | var ( 125 | totalNum uint 126 | resCursor RefCursor 127 | 128 | values = []any{ 129 | "SELECT * FROM USER_TABLES", // sql.Named("BASIC_SQL", "SELECT * FROM USER_TABLES"), 130 | "TABLE_NAME", // sql.Named("ORDER_FIELD", "TABLE_NAME"), 131 | 1, // sql.Named("PAGE_NUM", 1), 132 | 10, // sql.Named("PAGE_SIZE", 10), 133 | sql.Out{Dest: &totalNum}, // sql.Named("TOTAL_NUM", sql.Out{Dest: &totalNum}), 134 | sql.Out{Dest: &resCursor.RefCursor}, // sql.Named("RES_CURSOR", sql.Out{Dest: &resCursor.RefCursor}), 135 | } 136 | ) 137 | // 执行存储过程 138 | if err = db.Exec(` 139 | BEGIN 140 | PRO_EXAMPLE_PAGING_QUERY(:BASIC_SQL, :ORDER_FIELD, :PAGE_NUM, :PAGE_SIZE, :TOTAL_NUM, :RES_CURSOR); 141 | END;`, values...).Error; err != nil { 142 | t.Fatal(err) 143 | } 144 | defer func(cursor *RefCursor) { 145 | _ = cursor.Close() 146 | }(&resCursor) 147 | 148 | // 读取游标 149 | var dataset *DataSet 150 | if dataset, err = resCursor.Query(); err != nil { 151 | t.Fatal(err) 152 | } 153 | defer func(dataset *DataSet) { 154 | _ = dataset.Close() 155 | }(dataset) 156 | 157 | var dataRows []map[string]any 158 | columns := dataset.Columns() 159 | dest := make([]driver.Value, len(columns)) 160 | for { 161 | if err = dataset.Next(dest); err != nil { 162 | if errors.Is(err, io.EOF) { 163 | err = nil 164 | } 165 | break 166 | } 167 | dataRow := make(map[string]any, len(columns)) 168 | for i, v := range dest { 169 | dataRow[columns[i]] = v 170 | } 171 | dataRows = append(dataRows, dataRow) 172 | } 173 | if err != nil { 174 | t.Fatal(err) 175 | } 176 | got, _ := json.Marshal(dataRows) 177 | t.Logf("got total: %d, got size: %d, got data:\n%s", totalNum, len(dataRows), got) 178 | } 179 | -------------------------------------------------------------------------------- /oracle_test.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "database/sql" 5 | "encoding/json" 6 | "log" 7 | "os" 8 | "reflect" 9 | "strconv" 10 | "strings" 11 | "testing" 12 | "time" 13 | 14 | "gorm.io/gorm" 15 | "gorm.io/gorm/logger" 16 | "gorm.io/gorm/schema" 17 | ) 18 | 19 | var ( 20 | dbNamingCase *gorm.DB 21 | dbIgnoreCase *gorm.DB 22 | 23 | dbErrors = make([]error, 2) 24 | ) 25 | 26 | func init() { 27 | if wait := os.Getenv("GORM_ORA_WAIT_MIN"); wait != "" { 28 | if min, e := strconv.Atoi(wait); e == nil { 29 | log.Println("wait for oracle database initialization to complete...") 30 | time.Sleep(time.Duration(min) * time.Minute) 31 | } 32 | } 33 | var err error 34 | if dbNamingCase, err = openTestConnection(true, true); err != nil { 35 | dbErrors[0] = err 36 | } 37 | if dbIgnoreCase, err = openTestConnection(true, false); err != nil { 38 | dbErrors[1] = err 39 | } 40 | } 41 | 42 | func openTestConnection(ignoreCase, namingCase bool) (db *gorm.DB, err error) { 43 | dsn := getTestDSN() 44 | 45 | db, err = gorm.Open(New(Config{ 46 | DSN: dsn, 47 | IgnoreCase: ignoreCase, 48 | NamingCaseSensitive: namingCase, 49 | }), getTestGormConfig()) 50 | if db != nil && err == nil { 51 | log.Println("open oracle database connection success!") 52 | } 53 | return 54 | } 55 | 56 | func getTestDSN() (dsn string) { 57 | dsn = os.Getenv("GORM_ORA_DSN") 58 | if dsn == "" { 59 | server := os.Getenv("GORM_ORA_SERVER") 60 | port, _ := strconv.Atoi(os.Getenv("GORM_ORA_PORT")) 61 | if server == "" || port < 1 { 62 | return 63 | } 64 | 65 | language := os.Getenv("GORM_ORA_LANG") 66 | if language == "" { 67 | language = "SIMPLIFIED CHINESE" 68 | } 69 | territory := os.Getenv("GORM_ORA_TERRITORY") 70 | if territory == "" { 71 | territory = "CHINA" 72 | } 73 | 74 | dsn = BuildUrl(server, port, 75 | os.Getenv("GORM_ORA_SID"), 76 | os.Getenv("GORM_ORA_USER"), 77 | os.Getenv("GORM_ORA_PASS"), 78 | map[string]string{ 79 | "CONNECTION TIMEOUT": "90", 80 | "LANGUAGE": language, 81 | "TERRITORY": territory, 82 | "SSL": "false", 83 | }) 84 | } 85 | return 86 | } 87 | 88 | func getTestGormConfig() *gorm.Config { 89 | logWriter := new(log.Logger) 90 | logWriter.SetOutput(os.Stdout) 91 | 92 | return &gorm.Config{ 93 | Logger: logger.New( 94 | logWriter, 95 | logger.Config{LogLevel: logger.Info}, 96 | ), 97 | DisableForeignKeyConstraintWhenMigrating: false, 98 | IgnoreRelationshipsWhenMigrating: false, 99 | NamingStrategy: schema.NamingStrategy{ 100 | IdentifierMaxLength: 30, 101 | }, 102 | } 103 | } 104 | 105 | func TestCountLimit0(t *testing.T) { 106 | db, err := dbNamingCase, dbErrors[0] 107 | if err != nil { 108 | t.Fatal(err) 109 | } 110 | if db == nil { 111 | t.Log("db is nil!") 112 | return 113 | } 114 | 115 | model := TestTableUser{} 116 | migrator := db.Set("gorm:table_comments", "用户信息表").Migrator() 117 | if migrator.HasTable(model) { 118 | if err = migrator.DropTable(model); err != nil { 119 | t.Fatalf("DropTable() error = %v", err) 120 | } 121 | } 122 | if err = migrator.AutoMigrate(model); err != nil { 123 | t.Fatalf("AutoMigrate() error = %v", err) 124 | } 125 | t.Log("AutoMigrate() success!") 126 | 127 | var count int64 128 | result := db.Model(&model).Limit(-1).Count(&count) 129 | if err = result.Error; err != nil { 130 | t.Fatal(err) 131 | } 132 | t.Logf("Limit(-1) count = %d", count) 133 | 134 | if countSql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { 135 | return tx.Model(&model).Limit(-1).Count(&count) 136 | }); strings.Contains(countSql, "ORDER BY") { 137 | t.Error(`The "count(*)" statement contains the "ORDER BY" clause!`) 138 | } 139 | } 140 | 141 | func TestLimit(t *testing.T) { 142 | db, err := dbNamingCase, dbErrors[0] 143 | if err != nil { 144 | t.Fatal(err) 145 | } 146 | if db == nil { 147 | t.Log("db is nil!") 148 | return 149 | } 150 | TestMergeCreate(t) 151 | 152 | type args struct { 153 | offset, limit int 154 | order string 155 | } 156 | tests := []struct { 157 | name string 158 | args args 159 | }{ 160 | {name: "OffsetLimit0", args: args{offset: 0, limit: 0}}, 161 | {name: "Offset10", args: args{offset: 10, limit: 0}}, 162 | {name: "Limit10", args: args{offset: 0, limit: 10}}, 163 | {name: "Offset10Limit10", args: args{offset: 10, limit: 10}}, 164 | {name: "Offset10Limit10Order", args: args{offset: 10, limit: 10, order: `"id"`}}, 165 | {name: "Offset10Limit10OrderDESC", args: args{offset: 10, limit: 10, order: `"id" DESC`}}, 166 | } 167 | 168 | for _, tt := range tests { 169 | t.Run(tt.name, func(t *testing.T) { 170 | var data []TestTableUser 171 | result := db.Model(&TestTableUser{}). 172 | Offset(tt.args.offset). 173 | Limit(tt.args.limit). 174 | Order(tt.args.order). 175 | Find(&data) 176 | if err = result.Error; err != nil { 177 | t.Fatal(err) 178 | } 179 | dataBytes, _ := json.MarshalIndent(data, "", " ") 180 | t.Logf("Offset(%d) Limit(%d) got size = %d, data = %s", 181 | tt.args.offset, tt.args.limit, len(data), dataBytes) 182 | }) 183 | } 184 | } 185 | 186 | func TestAddSessionParams(t *testing.T) { 187 | db, err := dbIgnoreCase, dbErrors[1] 188 | if err != nil { 189 | t.Fatal(err) 190 | } 191 | if db == nil { 192 | t.Log("db is nil!") 193 | return 194 | } 195 | var sqlDB *sql.DB 196 | if sqlDB, err = db.DB(); err != nil { 197 | t.Fatal(err) 198 | } 199 | type args struct { 200 | params map[string]string 201 | } 202 | tests := []struct { 203 | name string 204 | args args 205 | }{ 206 | {name: "TimeParams", args: args{params: map[string]string{ 207 | "TIME_ZONE": "+08:00", // alter session set TIME_ZONE = '+08:00'; 208 | "NLS_DATE_FORMAT": "YYYY-MM-DD", // alter session set NLS_DATE_FORMAT = 'YYYY-MM-DD'; 209 | "NLS_TIME_FORMAT": "HH24:MI:SSXFF", // alter session set NLS_TIME_FORMAT = 'HH24:MI:SS.FF3'; 210 | "NLS_TIMESTAMP_FORMAT": "YYYY-MM-DD HH24:MI:SSXFF", // alter session set NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3'; 211 | "NLS_TIME_TZ_FORMAT": "HH24:MI:SS.FF TZR", // alter session set NLS_TIME_TZ_FORMAT = 'HH24:MI:SS.FF3 TZR'; 212 | "NLS_TIMESTAMP_TZ_FORMAT": "YYYY-MM-DD HH24:MI:SSXFF TZR", // alter session set NLS_TIMESTAMP_TZ_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3 TZR'; 213 | }}}, 214 | } 215 | for _, tt := range tests { 216 | t.Run(tt.name, func(t *testing.T) { 217 | //queryTime := `SELECT SYSDATE FROM DUAL` 218 | queryTime := `SELECT CAST(SYSDATE AS VARCHAR(30)) AS D FROM DUAL` 219 | var timeStr string 220 | if err = db.Raw(queryTime).Row().Scan(&timeStr); err != nil { 221 | t.Fatal(err) 222 | } 223 | t.Logf("SYSDATE 1: %s", timeStr) 224 | 225 | var keys []string 226 | if keys, err = AddSessionParams(sqlDB, tt.args.params); err != nil { 227 | t.Fatalf("AddSessionParams() error = %v", err) 228 | } 229 | if err = db.Raw(queryTime).Row().Scan(&timeStr); err != nil { 230 | t.Fatal(err) 231 | } 232 | defer DelSessionParams(sqlDB, keys) 233 | t.Logf("SYSDATE 2: %s", timeStr) 234 | t.Logf("keys: %#v", keys) 235 | }) 236 | } 237 | } 238 | 239 | func TestGetStringExpr(t *testing.T) { 240 | db, err := dbNamingCase, dbErrors[0] 241 | if err != nil { 242 | t.Fatal(err) 243 | } 244 | if db == nil { 245 | t.Log("db is nil!") 246 | return 247 | } 248 | 249 | type args struct { 250 | prepareSQL string 251 | value string 252 | quote bool 253 | } 254 | tests := []struct { 255 | name string 256 | args args 257 | wantSQL string 258 | }{ 259 | {"1", args{`SELECT ? AS HELLO FROM DUAL`, "Hi!", true}, `SELECT 'Hi!' AS HELLO FROM DUAL`}, 260 | {"2", args{`SELECT '?' AS HELLO FROM DUAL`, "Hi!", false}, `SELECT 'Hi!' AS HELLO FROM DUAL`}, 261 | {"3", args{`SELECT ? AS HELLO FROM DUAL`, "What's your name?", true}, `SELECT q'[What's your name?]' AS HELLO FROM DUAL`}, 262 | {"4", args{`SELECT '?' AS HELLO FROM DUAL`, "What's your name?", false}, `SELECT 'What''s your name?' AS HELLO FROM DUAL`}, 263 | {"5", args{`SELECT ? AS HELLO FROM DUAL`, "What's up]'?", true}, `SELECT q'{What's up]'?}' AS HELLO FROM DUAL`}, 264 | {"6", args{`SELECT ? AS HELLO FROM DUAL`, "What's up]'}'?", true}, `SELECT q'' AS HELLO FROM DUAL`}, 265 | {"7", args{`SELECT ? AS HELLO FROM DUAL`, "What's up]'}'>'?", true}, `SELECT q'(What's up]'}'>'?)' AS HELLO FROM DUAL`}, 266 | {"8", args{`SELECT ? AS HELLO FROM DUAL`, "What's up)'}'>'?", true}, `SELECT q'[What's up)'}'>'?]' AS HELLO FROM DUAL`}, 267 | } 268 | for _, tt := range tests { 269 | t.Run(tt.name, func(t *testing.T) { 270 | gotSQL := db.ToSQL(func(tx *gorm.DB) *gorm.DB { 271 | return tx.Raw(tt.args.prepareSQL, GetStringExpr(tt.args.value, tt.args.quote)) 272 | }) 273 | if !reflect.DeepEqual(gotSQL, tt.wantSQL) { 274 | t.Fatalf("ToSQL = %v, want %v", gotSQL, tt.wantSQL) 275 | } 276 | var results []map[string]interface{} 277 | if err = db.Raw(gotSQL).Find(&results).Error; err != nil { 278 | t.Fatalf("finds all records from raw sql got error: %v", err) 279 | } 280 | t.Log("result:", results) 281 | }) 282 | } 283 | } 284 | 285 | func TestVarcharSizeIsCharLength(t *testing.T) { 286 | dsn := getTestDSN() 287 | 288 | db, err := gorm.Open(New(Config{ 289 | DSN: dsn, 290 | IgnoreCase: true, 291 | NamingCaseSensitive: true, 292 | VarcharSizeIsCharLength: true, 293 | }), getTestGormConfig()) 294 | if db != nil && err == nil { 295 | log.Println("open oracle database connection success!") 296 | } else { 297 | t.Fatal(err) 298 | } 299 | 300 | model := TestTableUserVarcharSize{} 301 | migrator := db.Set("gorm:table_comments", "TestVarcharSizeIsCharLength").Migrator() 302 | if migrator.HasTable(model) { 303 | if err = migrator.DropTable(model); err != nil { 304 | t.Fatalf("DropTable() error = %v", err) 305 | } 306 | } 307 | if err = migrator.AutoMigrate(model); err != nil { 308 | t.Fatalf("AutoMigrate() error = %v", err) 309 | } 310 | t.Log("AutoMigrate() success!") 311 | 312 | type args struct { 313 | value string 314 | } 315 | tests := []struct { 316 | name string 317 | args args 318 | wantErr bool 319 | }{ 320 | {"50", args{strings.Repeat("姓名", 25)}, false}, 321 | {"60", args{strings.Repeat("姓名", 30)}, true}, 322 | } 323 | for _, tt := range tests { 324 | t.Run(tt.name, func(t *testing.T) { 325 | gotErr := db.Create(&TestTableUserVarcharSize{TestTableUser{Name: tt.args.value}}).Error 326 | if (gotErr != nil) != tt.wantErr { 327 | t.Error(gotErr) 328 | } else if gotErr != nil { 329 | t.Log(gotErr) 330 | } 331 | }) 332 | } 333 | } 334 | 335 | type TestTableUserVarcharSize struct { 336 | TestTableUser 337 | } 338 | 339 | func (TestTableUserVarcharSize) TableName() string { 340 | return "test_user_varchar_size" 341 | } 342 | -------------------------------------------------------------------------------- /reserved.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/emirpasic/gods/sets/hashset" 7 | ) 8 | 9 | var ReservedWords = hashset.New(func() []interface{} { 10 | reservedWords := make([]interface{}, len(ReservedWordsList)) 11 | for i, word := range ReservedWordsList { 12 | reservedWords[i] = word 13 | } 14 | return reservedWords 15 | }()...) 16 | 17 | func IsReservedWord(v string) bool { 18 | return ReservedWords.Contains(strings.ToUpper(v)) 19 | } 20 | 21 | var ReservedWordsList = []string{ 22 | "AGGREGATE", "AGGREGATES", "ALL", "ALLOW", "ANALYZE", "ANCESTOR", "AND", "ANY", "AS", "ASC", "AT", "AVG", "BETWEEN", 23 | "BINARY_DOUBLE", "BINARY_FLOAT", "BLOB", "BRANCH", "BUILD", "BY", "BYTE", "CASE", "CAST", "CHAR", "CHILD", "CLEAR", 24 | "CLOB", "COMMIT", "COMPILE", "CONSIDER", "COUNT", "CREATE", "DATATYPE", "DATE", "DATE_MEASURE", "DAY", "DECIMAL", 25 | "DELETE", "DESC", "DESCENDANT", "DIMENSION", "DISALLOW", "DIVISION", "DML", "ELSE", "END", "ESCAPE", "EXECUTE", 26 | "FIRST", "FLOAT", "FOR", "FROM", "HIERARCHIES", "HIERARCHY", "HOUR", "IGNORE", "IN", "INFINITE", "INSERT", 27 | "INTEGER", "INTERVAL", "INTO", "IS", "LAST", "LEAF_DESCENDANT", "LEAVES", "LEVEL", "LIKE", "LIKEC", "LIKE2", 28 | "LIKE4", "LOAD", "LOCAL", "LOG_SPEC", "LONG", "MAINTAIN", "MAX", "MEASURE", "MEASURES", "MEMBER", "MEMBERS", 29 | "MERGE", "MLSLABEL", "MIN", "MINUTE", "MODEL", "MONTH", "NAN", "NCHAR", "NCLOB", "NO", "NONE", "NOT", "NULL", 30 | "NULLS", "NUMBER", "NVARCHAR2", "OF", "OLAP", "OLAP_DML_EXPRESSION", "ON", "ONLY", "OPERATOR", "OR", "ORDER", 31 | "OVER", "OVERFLOW", "PARALLEL", "PARENT", "PLSQL", "PRUNE", "RAW", "RELATIVE", "ROOT_ANCESTOR", "ROWID", "SCN", 32 | "SECOND", "SELF", "SERIAL", "SET", "SOLVE", "SOME", "SORT", "SPEC", "SUM", "SYNCH", "TEXT_MEASURE", "THEN", "TIME", 33 | "TIMESTAMP", "TO", "UNBRANCH", "UPDATE", "USING", "VALIDATE", "VALUES", "VARCHAR2", "WHEN", "WHERE", "WITHIN", 34 | "WITH", "YEAR", "ZERO", "ZONE", 35 | } 36 | -------------------------------------------------------------------------------- /update.go: -------------------------------------------------------------------------------- 1 | package oracle 2 | 3 | import ( 4 | "reflect" 5 | "sort" 6 | 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/callbacks" 9 | "gorm.io/gorm/clause" 10 | "gorm.io/gorm/schema" 11 | "gorm.io/gorm/utils" 12 | ) 13 | 14 | func Update(config *callbacks.Config) func(db *gorm.DB) { 15 | supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") 16 | 17 | return func(db *gorm.DB) { 18 | if db.Error != nil { 19 | return 20 | } 21 | 22 | stmt := db.Statement 23 | if stmt == nil { 24 | return 25 | } 26 | 27 | if stmtSchema := stmt.Schema; stmtSchema != nil { 28 | for _, c := range stmtSchema.UpdateClauses { 29 | stmt.AddClause(c) 30 | } 31 | } 32 | 33 | if stmt.SQL.Len() == 0 { 34 | stmt.SQL.Grow(180) 35 | stmt.AddClauseIfNotExists(clause.Update{}) 36 | if _, ok := stmt.Clauses["SET"]; !ok { 37 | if set := ConvertToAssignments(stmt); len(set) != 0 { 38 | defer delete(stmt.Clauses, "SET") 39 | stmt.AddClause(set) 40 | } else { 41 | return 42 | } 43 | } 44 | 45 | stmt.Build(stmt.BuildClauses...) 46 | } 47 | 48 | checkMissingWhereConditions(db) 49 | 50 | if !db.DryRun && db.Error == nil { 51 | for i, val := range stmt.Vars { 52 | // HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned 53 | stmt.Vars[i] = convertValue(val) 54 | } 55 | if ok, mode := hasReturning(db, supportReturning); ok { 56 | if rows, err := stmt.ConnPool.QueryContext(stmt.Context, stmt.SQL.String(), stmt.Vars...); db.AddError(err) == nil { 57 | dest := stmt.Dest 58 | stmt.Dest = stmt.ReflectValue.Addr().Interface() 59 | gorm.Scan(rows, db, mode) 60 | stmt.Dest = dest 61 | _ = db.AddError(rows.Close()) 62 | } 63 | } else { 64 | result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...) 65 | 66 | if db.AddError(err) == nil { 67 | db.RowsAffected, _ = result.RowsAffected() 68 | } 69 | } 70 | } 71 | } 72 | } 73 | 74 | func checkMissingWhereConditions(db *gorm.DB) { 75 | if !db.AllowGlobalUpdate && db.Error == nil { 76 | where, withCondition := db.Statement.Clauses["WHERE"] 77 | if withCondition { 78 | if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete { 79 | whereClause, _ := where.Expression.(clause.Where) 80 | withCondition = len(whereClause.Exprs) > 1 81 | } 82 | } 83 | if !withCondition { 84 | _ = db.AddError(gorm.ErrMissingWhereClause) 85 | } 86 | return 87 | } 88 | } 89 | 90 | func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { 91 | if supportReturning { 92 | if c, ok := tx.Statement.Clauses["RETURNING"]; ok { 93 | returning, _ := c.Expression.(clause.Returning) 94 | if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") { 95 | return true, 0 96 | } 97 | return true, gorm.ScanUpdate 98 | } 99 | } 100 | return false, 0 101 | } 102 | 103 | // ConvertToAssignments convert to update assignments 104 | func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { 105 | var ( 106 | selectColumns, restricted = stmt.SelectAndOmitColumns(false, true) 107 | assignValue func(field *schema.Field, value interface{}) 108 | ) 109 | 110 | switch stmt.ReflectValue.Kind() { 111 | case reflect.Slice, reflect.Array: 112 | assignValue = func(field *schema.Field, value interface{}) { 113 | for i := 0; i < stmt.ReflectValue.Len(); i++ { 114 | if stmt.ReflectValue.CanAddr() { 115 | _ = field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) 116 | } 117 | } 118 | } 119 | case reflect.Struct: 120 | assignValue = func(field *schema.Field, value interface{}) { 121 | if stmt.ReflectValue.CanAddr() { 122 | _ = field.Set(stmt.Context, stmt.ReflectValue, value) 123 | } 124 | } 125 | default: 126 | assignValue = func(field *schema.Field, value interface{}) { 127 | } 128 | } 129 | 130 | updatingValue := reflect.ValueOf(stmt.Dest) 131 | for updatingValue.Kind() == reflect.Ptr { 132 | updatingValue = updatingValue.Elem() 133 | } 134 | 135 | if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { 136 | switch stmt.ReflectValue.Kind() { 137 | case reflect.Slice, reflect.Array: 138 | if size := stmt.ReflectValue.Len(); size > 0 { 139 | var isZero bool 140 | for i := 0; i < size; i++ { 141 | for _, field := range stmt.Schema.PrimaryFields { 142 | _, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) 143 | if !isZero { 144 | break 145 | } 146 | } 147 | } 148 | 149 | if !isZero { 150 | _, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) 151 | column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues) 152 | stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) 153 | } 154 | } 155 | case reflect.Struct: 156 | for _, field := range stmt.Schema.PrimaryFields { 157 | if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { 158 | stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) 159 | } 160 | } 161 | default: 162 | } 163 | } 164 | 165 | switch value := updatingValue.Interface().(type) { 166 | case map[string]interface{}: 167 | set = make([]clause.Assignment, 0, len(value)) 168 | 169 | keys := make([]string, 0, len(value)) 170 | for k := range value { 171 | keys = append(keys, k) 172 | } 173 | sort.Strings(keys) 174 | 175 | for _, k := range keys { 176 | kv := value[k] 177 | if _, ok := kv.(*gorm.DB); ok { 178 | kv = []interface{}{kv} 179 | } 180 | 181 | if stmt.Schema != nil { 182 | if field := stmt.Schema.LookUpField(k); field != nil { 183 | if field.DBName != "" { 184 | if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { 185 | set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv}) 186 | assignValue(field, value[k]) 187 | } 188 | } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { 189 | assignValue(field, value[k]) 190 | } 191 | continue 192 | } 193 | } 194 | 195 | if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { 196 | set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv}) 197 | } 198 | } 199 | 200 | if !stmt.SkipHooks && stmt.Schema != nil { 201 | for _, dbName := range stmt.Schema.DBNames { 202 | field := stmt.Schema.LookUpField(dbName) 203 | if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { 204 | if v, ok := selectColumns[field.DBName]; (ok && v) || !ok { 205 | now := stmt.DB.NowFunc() 206 | assignValue(field, now) 207 | 208 | if field.AutoUpdateTime == schema.UnixNanosecond { 209 | set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) 210 | } else if field.AutoUpdateTime == schema.UnixMillisecond { 211 | set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) 212 | } else if field.AutoUpdateTime == schema.UnixSecond { 213 | set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) 214 | } else { 215 | set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) 216 | } 217 | } 218 | } 219 | } 220 | } 221 | default: 222 | updatingSchema := stmt.Schema 223 | var isDiffSchema bool 224 | if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { 225 | // different schema 226 | updatingStmt := &gorm.Statement{DB: stmt.DB} 227 | if err := updatingStmt.Parse(stmt.Dest); err == nil { 228 | updatingSchema = updatingStmt.Schema 229 | isDiffSchema = true 230 | } 231 | } 232 | 233 | switch updatingValue.Kind() { 234 | case reflect.Struct: 235 | set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) 236 | for _, dbName := range stmt.Schema.DBNames { 237 | if field := updatingSchema.LookUpField(dbName); field != nil { 238 | if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { 239 | if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { 240 | value, isZero := field.ValueOf(stmt.Context, updatingValue) 241 | if !stmt.SkipHooks && field.AutoUpdateTime > 0 { 242 | if field.AutoUpdateTime == schema.UnixNanosecond { 243 | value = stmt.DB.NowFunc().UnixNano() 244 | } else if field.AutoUpdateTime == schema.UnixMillisecond { 245 | value = stmt.DB.NowFunc().UnixNano() / 1e6 246 | } else if field.AutoUpdateTime == schema.UnixSecond { 247 | value = stmt.DB.NowFunc().Unix() 248 | } else { 249 | value = stmt.DB.NowFunc() 250 | } 251 | isZero = false 252 | } 253 | 254 | if (ok || !isZero) && field.Updatable { 255 | value = convertCustomType(value) 256 | set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) 257 | assignField := field 258 | if isDiffSchema { 259 | if originField := stmt.Schema.LookUpField(dbName); originField != nil { 260 | assignField = originField 261 | } 262 | } 263 | assignValue(assignField, value) 264 | } 265 | } 266 | } else { 267 | if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero { 268 | stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) 269 | } 270 | } 271 | } 272 | } 273 | default: 274 | _ = stmt.AddError(gorm.ErrInvalidData) 275 | } 276 | } 277 | 278 | return 279 | } 280 | --------------------------------------------------------------------------------