├── .github
├── dependabot.yml
└── workflows
│ ├── close_inactive.yml
│ ├── contrib-readme.yml
│ ├── release.yml
│ └── tests.yml
├── .gitignore
├── LICENSE
├── README.md
├── _dev
├── config
│ ├── CHANGELOG.tpl.md
│ └── chglog.config.yml
├── goland
│ ├── go-mod.run.xml
│ ├── go-staticcheck.run.xml
│ └── go-vulncheck.run.xml
└── script
│ ├── done-time-pause.bat
│ ├── go-mod.bat
│ ├── go-staticcheck.bat
│ └── go-vulncheck.bat
├── create.go
├── create_test.go
├── datatypes_json_map_test.go
├── go.mod
├── migrator.go
├── migrator_test.go
├── namer.go
├── oracle.go
├── oracle_ora.go
├── oracle_ora_test.go
├── oracle_test.go
├── reserved.go
└── update.go
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # To get started with Dependabot version updates, you'll need to specify which
2 | # package ecosystems to update and where the package manifests are located.
3 | # Please see the documentation for all configuration options:
4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
5 |
6 | version: 2
7 | updates:
8 | - package-ecosystem: "github-actions"
9 | directory: "/"
10 | schedule:
11 | interval: "daily"
12 | commit-message:
13 | prefix: "⬆️ "
14 | - package-ecosystem: "gomod"
15 | directory: "/"
16 | schedule:
17 | interval: "daily"
18 | commit-message:
19 | prefix: "⬆️ "
20 |
--------------------------------------------------------------------------------
/.github/workflows/close_inactive.yml:
--------------------------------------------------------------------------------
1 | name: Close inactive issues and PRs
2 | on:
3 | schedule:
4 | - cron: "30 1 * * *"
5 |
6 | jobs:
7 | close-issues:
8 | runs-on: ubuntu-latest
9 | permissions:
10 | issues: write
11 | pull-requests: write
12 | steps:
13 | - uses: actions/stale@v9
14 | with:
15 | days-before-issue-stale: 90
16 | days-before-issue-close: 30
17 | stale-issue-label: "stale"
18 | stale-issue-message: "This issue is stale because it has been open for 90 days with no activity. Remove stale label or comment or this will be closed in 30 days."
19 | close-issue-message: "This issue was closed because it has been inactive for 30 days since being marked as stale. Please leave a comment tagging an assigned team member if you need this issue to be reopened, and we will be happy to investigate. Thank you!"
20 | days-before-pr-stale: 90
21 | days-before-pr-close: 30
22 | stale-pr-message: "This pull request is stale because it has been open for 90 days with no activity. Remove stale label or comment or this will be closed in 30 days."
23 | close-pr-message: "This pull request was closed because it has been inactive for 30 days since being marked as stale."
24 | exempt-assignees: iTanken
25 | operations-per-run: 500
26 | repo-token: ${{ secrets.GITHUB_TOKEN }}
27 |
--------------------------------------------------------------------------------
/.github/workflows/contrib-readme.yml:
--------------------------------------------------------------------------------
1 | name: A job to automate contrib in readme
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | contrib-readme-job:
10 | runs-on: ubuntu-latest
11 | name: A job to automate contribute list in readme.md
12 | steps:
13 | - name: Contribute List
14 | uses: akhilmhdh/contributors-readme-action@v2.3.10
15 | with:
16 | # Size of square images in the stack
17 | image_size: 100
18 | # Path of the readme file you want to update
19 | readme_path: "README.md "
20 | # To use github-id instead of profile name
21 | use_username: true
22 | # Number of columns in a row
23 | columns_per_row: 6
24 | # Type of collaborators options: all/direct/outside
25 | collaborators: all
26 | # Commit message of the GitHub action
27 | commit_message: "👥 自动更新 README 中的贡献者信息"
28 | # Username on commit
29 | committer_username: iTanken
30 | # email id of committer
31 | committer_email: zixizixi@vip.qq.com
32 | # check if branch is protected
33 | auto_detect_branch_protection: true
34 | env:
35 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
36 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: release
2 |
3 | on:
4 | push:
5 | tags:
6 | - "v*"
7 |
8 | permissions:
9 | contents: write
10 |
11 | jobs:
12 | releaser:
13 | name: release
14 | runs-on: ubuntu-latest
15 | steps:
16 | - name: Checkout
17 | uses: actions/checkout@v4
18 | with:
19 | fetch-depth: 0
20 |
21 | - name: Set up Go
22 | uses: actions/setup-go@v5
23 | with:
24 | go-version: 'stable'
25 | cache: false
26 |
27 | - name: Get Tag Name
28 | id: tag
29 | run: echo "tagName=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT
30 |
31 | - name: Generate Changelog
32 | id: gen_changelog
33 | run: |
34 | config=./_dev/config/chglog.config.yml
35 | tagName=v${{ steps.tag.outputs.tagName }}
36 | go run github.com/git-chglog/git-chglog/cmd/git-chglog@latest --config $config $tagName
37 | go run github.com/git-chglog/git-chglog/cmd/git-chglog@latest --config $config $tagName > changelog.md
38 |
39 | - name: Create Release
40 | uses: ncipollo/release-action@v1
41 | with:
42 | bodyFile: changelog.md
43 | token: ${{ secrets.GITHUB_TOKEN }}
44 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | # This workflow will build a golang project
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
3 |
4 | name: tests
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 |
12 | jobs:
13 |
14 | build:
15 | runs-on: ubuntu-latest
16 | timeout-minutes: 30
17 |
18 | services:
19 | oracle:
20 | image: truevoly/oracle-12c
21 | env:
22 | TZ: Asia/Shanghai
23 | WEB_CONSOLE: false
24 | ports:
25 | - 30256:1521
26 | # volumes:
27 | # - /data/oracle_test:/u01/app/oracle
28 | options: >-
29 | --restart=always
30 | --privileged=true
31 |
32 | steps:
33 | - uses: actions/checkout@v4
34 |
35 | - name: Set up Go
36 | uses: actions/setup-go@v5
37 | with:
38 | go-version: 'stable'
39 | cache: false
40 |
41 | - name: Tidy
42 | run: go mod tidy
43 |
44 | - name: Build
45 | run: go build -v ./...
46 |
47 | - name: Check oracle port
48 | run: |
49 | if ss -tln | grep -q ":30256 "; then
50 | echo "oracle 服务端口号正常!"
51 | else
52 | echo "oracle 服务端口号异常!"
53 | fi
54 | go run github.com/cloverstd/tcping@latest 127.0.0.1:30256
55 |
56 | - name: Test
57 | env:
58 | GORM_ORA_DSN: "oracle://system:oracle@localhost:30256/xe?LANGUAGE=SIMPLIFIED+CHINESE&TERRITORY=CHINA"
59 | GORM_ORA_WAIT_MIN: 5 # wait for the database initialization to complete
60 | run: go test -timeout 20m -v ./...
61 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | vendor/
3 | go.sum
4 | CHANGELOG.md
5 | /test_local/
6 | /go.work*
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 iTanken
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GORM Oracle Driver
2 |
3 | ## Description
4 |
5 | GORM Oracle driver for connect Oracle DB and Manage Oracle DB, Based on [CengSin/oracle](https://github.com/CengSin/oracle)
6 | and [sijms/go-ora](https://github.com/sijms/go-ora) (pure go oracle client),*not recommended for use in a production environment*.
7 |
8 | ## Required dependency Install
9 |
10 | - Oracle `11g` + (*`v1.6.3` and earlier versions support only `12c` +*)
11 | - Golang
12 | - `v1.6.1`: `go1.16` +
13 | - `v1.6.2`: `go1.18` +
14 | - gorm `1.24.0` +
15 |
16 | ## Quick Start
17 |
18 | ### How to install
19 |
20 | ```bash
21 | go get -d github.com/godoes/gorm-oracle
22 | ```
23 |
24 | ### Usage
25 |
26 | ```go
27 | package main
28 |
29 | import (
30 | oracle "github.com/godoes/gorm-oracle"
31 | "gorm.io/gorm"
32 | )
33 |
34 | func main() {
35 | options := map[string]string{
36 | "CONNECTION TIMEOUT": "90",
37 | "LANGUAGE": "SIMPLIFIED CHINESE",
38 | "TERRITORY": "CHINA",
39 | "SSL": "false",
40 | }
41 | // oracle://user:password@127.0.0.1:1521/service
42 | url := oracle.BuildUrl("127.0.0.1", "1521", "service", "user", "password", options)
43 | dialector := oracle.New(oracle.Config{
44 | DSN: url,
45 | IgnoreCase: false, // query conditions are not case-sensitive
46 | NamingCaseSensitive: true, // whether naming is case-sensitive
47 | VarcharSizeIsCharLength: true, // whether VARCHAR type size is character length, defaulting to byte length
48 |
49 | // RowNumberAliasForOracle11 is the alias for ROW_NUMBER() in Oracle 11g, defaulting to ROW_NUM
50 | RowNumberAliasForOracle11: "ROW_NUM",
51 | })
52 | db, err := gorm.Open(dialector, &gorm.Config{
53 | SkipDefaultTransaction: true, // 是否禁用默认在事务中执行单次创建、更新、删除操作
54 | DisableForeignKeyConstraintWhenMigrating: true, // 是否禁止在自动迁移或创建表时自动创建外键约束
55 | // 自定义命名策略
56 | NamingStrategy: schema.NamingStrategy{
57 | NoLowerCase: true, // 是否不自动转换小写表名
58 | IdentifierMaxLength: 30, // Oracle: 30, PostgreSQL:63, MySQL: 64, SQL Server、SQLite、DM: 128
59 | },
60 | PrepareStmt: false, // 创建并缓存预编译语句,启用后可能会报 ORA-01002 错误
61 | CreateBatchSize: 50, // 插入数据默认批处理大小
62 | })
63 | if err != nil {
64 | // panic error or log error info
65 | }
66 |
67 | // set session parameters
68 | if sqlDB, err := db.DB(); err == nil {
69 | _, _ = oracle.AddSessionParams(sqlDB, map[string]string{
70 | "TIME_ZONE": "+08:00", // ALTER SESSION SET TIME_ZONE = '+08:00';
71 | "NLS_DATE_FORMAT": "YYYY-MM-DD", // ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD';
72 | "NLS_TIME_FORMAT": "HH24:MI:SSXFF", // ALTER SESSION SET NLS_TIME_FORMAT = 'HH24:MI:SS.FF3';
73 | "NLS_TIMESTAMP_FORMAT": "YYYY-MM-DD HH24:MI:SSXFF", // ALTER SESSION SET NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3';
74 | "NLS_TIME_TZ_FORMAT": "HH24:MI:SS.FF TZR", // ALTER SESSION SET NLS_TIME_TZ_FORMAT = 'HH24:MI:SS.FF3 TZR';
75 | "NLS_TIMESTAMP_TZ_FORMAT": "YYYY-MM-DD HH24:MI:SSXFF TZR", // ALTER SESSION SET NLS_TIMESTAMP_TZ_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3 TZR';
76 | })
77 | }
78 |
79 | // do somethings
80 | }
81 |
82 | ```
83 |
84 | ## Questions
85 |
86 |
87 |
88 | ORA-01000: 超出打开游标的最大数
89 |
90 | > ORA-00604: 递归 SQL 级别 1 出现错误
91 | >
92 | > ORA-01000: 超出打开游标的最大数
93 |
94 | ```shell
95 | show parameter OPEN_CURSORS;
96 | ```
97 |
98 | ```sql
99 | alter system set OPEN_CURSORS = 1000; -- or bigger
100 | commit;
101 | ```
102 |
103 |
104 |
105 |
106 | ORA-01002: 提取违反顺序
107 |
108 | > 如果重复执行同一查询,第一次查询成功,第二次报 `ORA-01002` 错误,可能是因为启用了 `PrepareStmt`,关闭此配置即可。
109 |
110 | 推荐配置:
111 |
112 | ```go
113 | &gorm.Config{
114 | SkipDefaultTransaction: true, // 是否禁用默认在事务中执行单次创建、更新、删除操作
115 | DisableForeignKeyConstraintWhenMigrating: true, // 是否禁止在自动迁移或创建表时自动创建外键约束
116 | // 自定义命名策略
117 | NamingStrategy: schema.NamingStrategy{
118 | NoLowerCase: true, // 是否不自动转换小写表名
119 | IdentifierMaxLength: 30, // Oracle: 30, PostgreSQL:63, MySQL: 64, SQL Server、SQLite、DM: 128
120 | },
121 | PrepareStmt: false, // 创建并缓存预编译语句,启用后可能会报 ORA-01002 错误
122 | CreateBatchSize: 50, // 插入数据默认批处理大小
123 | }
124 | ```
125 |
126 |
127 |
128 | ## Contributors
129 |
130 |
131 |
195 |
196 |
197 | ## LICENSE
198 |
199 | [MIT license](./LICENSE)
200 |
201 | - Copyright (c) 2020 [Jinzhu](https://github.com/jinzhu)
202 | - Copyright (c) 2020 [Steve Fan](https://github.com/stevefan1999-personal)
203 | - Copyright (c) 2020 [CengSin](https://github.com/CengSin)
204 | - Copyright (c) 2022 [dzwvip](https://github.com/dzwvip)
205 | - Copyright (c) 2022-present [iTanken](https://github.com/iTanken)
206 |
--------------------------------------------------------------------------------
/_dev/config/CHANGELOG.tpl.md:
--------------------------------------------------------------------------------
1 | {{ if .Versions -}}
2 | {{ if .Unreleased.CommitGroups -}}
3 |
4 | ## ⭐ [最新变更]({{ .Info.RepositoryURL }}/compare/{{ $latest := index .Versions 0 }}{{ $latest.Tag.Name }}...main)
5 |
6 | {{ range .Unreleased.CommitGroups -}}
7 | ### {{ .RawTitle }} {{ .Title }}
8 |
9 | {{ range .Commits -}}
10 | {{/* SKIPPING RULES - START */ -}}
11 | {{- if not (contains .Subject " CHANGELOG") -}}
12 | {{- if not (contains .Subject "[ci skip]") -}}
13 | {{- if not (contains .Subject "[skip ci]") -}}
14 | {{- if not (hasPrefix .Subject "Merge pull request ") -}}
15 | {{- if not (hasPrefix .Subject "Merge remote-tracking ") -}}
16 | {{- /* SKIPPING RULES - END */ -}}
17 | - [{{ if .Type }}`{{ .Type }}`{{ end }}{{ .Subject }}]({{ $.Info.RepositoryURL }}/commit/{{ .Hash.Short }}) - `{{ datetime "2006-01-02 15:04" .Committer.Date }}`
18 | {{- if .TrimmedBody }}
19 |
20 |
21 | {{ indent .TrimmedBody 2 }}
22 |
23 |
24 | {{ end -}}
25 | {{/* SKIPPING RULES - START */ -}}
26 | {{ end -}}
27 | {{ end -}}
28 | {{ end -}}
29 | {{ end -}}
30 | {{ end -}}
31 | {{/* SKIPPING RULES - END */ -}}
32 | {{ end -}}
33 | {{ end -}}
34 | {{ else }}
35 | {{- range .Unreleased.Commits -}}
36 | {{/* SKIPPING RULES - START */ -}}
37 | {{- if not (contains .Subject " CHANGELOG") -}}
38 | {{- if not (contains .Subject "[ci skip]") -}}
39 | {{- if not (contains .Subject "[skip ci]") -}}
40 | {{- if not (hasPrefix .Subject "Merge pull request ") -}}
41 | {{- if not (hasPrefix .Subject "Merge remote-tracking ") -}}
42 | {{- /* SKIPPING RULES - END */ -}}
43 | - [{{ if .Type }}`{{ .Type }}`{{ end }}{{ .Subject }}]({{ $.Info.RepositoryURL }}/commit/{{ .Hash.Short }})
44 | {{- if .TrimmedBody }}
45 |
46 |
47 | {{ indent .TrimmedBody 2 }}
48 |
49 |
50 | {{ end -}}
51 | {{/* SKIPPING RULES - START */ -}}
52 | {{ end -}}
53 | {{ end -}}
54 | {{ end -}}
55 | {{ end -}}
56 | {{ end -}}
57 | {{/* SKIPPING RULES - END */ -}}
58 | {{ end -}}
59 | {{ end -}}
60 | {{ end -}}
61 |
62 | {{ range .Versions -}}
63 | ## 🔖 {{ if .Tag.Previous -}}
64 | [`{{ .Tag.Name }}`]({{ $.Info.RepositoryURL }}/compare/{{ .Tag.Previous.Name }}...{{ .Tag.Name }})
65 | {{- else }}`{{ .Tag.Name }}`{{ end }} - `{{ datetime "2006-01-02" .Tag.Date }}`
66 | {{ if .CommitGroups -}}
67 | {{ range .CommitGroups }}
68 | ### {{ .RawTitle }} {{ .Title }}
69 |
70 | {{ range .Commits -}}
71 | {{/* SKIPPING RULES - START */ -}}
72 | {{- if not (contains .Subject " CHANGELOG") -}}
73 | {{- if not (contains .Subject "[ci skip]") -}}
74 | {{- if not (contains .Subject "[skip ci]") -}}
75 | {{- if not (hasPrefix .Subject "Merge pull request ") -}}
76 | {{- if not (hasPrefix .Subject "Merge remote-tracking ") -}}
77 | {{- /* SKIPPING RULES - END */ -}}
78 | - [{{ if .Type }}`{{ .Type }}`{{ end }}{{ .Subject }}]({{ $.Info.RepositoryURL }}/commit/{{ .Hash.Short }})
79 | {{- if .TrimmedBody }}
80 |
81 |
82 | {{ indent .TrimmedBody 2 }}
83 |
84 | {{- end }}
85 | {{/* SKIPPING RULES - START */ -}}
86 | {{ end -}}
87 | {{ end -}}
88 | {{ end -}}
89 | {{ end -}}
90 | {{ end -}}
91 | {{/* SKIPPING RULES - END */ -}}
92 | {{ end -}}
93 | {{ end -}}
94 | {{ else }}{{ range .Commits -}}
95 | {{/* SKIPPING RULES - START */ -}}
96 | {{- if not (contains .Subject " CHANGELOG") -}}
97 | {{- if not (contains .Subject "[ci skip]") -}}
98 | {{- if not (contains .Subject "[skip ci]") -}}
99 | {{- if not (hasPrefix .Subject "Merge pull request ") -}}
100 | {{- if not (hasPrefix .Subject "Merge remote-tracking ") }}
101 | {{/* SKIPPING RULES - END */ -}}
102 | - [{{ if .Type }}`{{ .Type }}`{{ end }}{{ .Subject }}]({{ $.Info.RepositoryURL }}/commit/{{ .Hash.Short }})
103 | {{- if .TrimmedBody }}
104 |
105 |
106 | {{ indent .TrimmedBody 2 }}
107 |
108 |
109 | {{ end -}}
110 | {{/* SKIPPING RULES - START */ -}}
111 | {{ end -}}
112 | {{ end -}}
113 | {{ end -}}
114 | {{ end -}}
115 | {{ end -}}
116 | {{/* SKIPPING RULES - END */ -}}
117 | {{ end -}}
118 | {{ end -}}
119 | {{- if .NoteGroups -}}
120 | {{ range .NoteGroups -}}
121 |
122 | ### {{ .Title }}
123 |
124 | {{ range .Notes }}
125 | {{ .Body }}
126 | {{ end }}
127 | {{ end -}}
128 | {{ end -}}
129 | {{ end -}}
130 |
--------------------------------------------------------------------------------
/_dev/config/chglog.config.yml:
--------------------------------------------------------------------------------
1 | style: gitlab
2 | template: CHANGELOG.tpl.md
3 | info:
4 | title: 🕔 变更记录
5 | repository_url: https://github.com/godoes/gorm-oracle
6 | options:
7 | tag_filter_pattern: '^v'
8 | sort: "date"
9 | commits:
10 | filters:
11 | Type:
12 | - 🎉
13 | - ✨
14 | - 🚩
15 | - 🔖
16 | - 🩹
17 | - 🐛
18 | - 🚑️
19 | - 🔒️
20 | - 🔐
21 | - 🚨
22 | - ♻️
23 | - 👔
24 | - 🍱
25 | - 🔨
26 | - 🔧
27 | - 🙈
28 | - 💡
29 | - 📝
30 | - ⬆️
31 | - ⬇️
32 | - 🚚
33 | - 🏗️
34 | - 👽️
35 | - 💥
36 | - ⚡️
37 | - 🚸
38 | - 🔊
39 | - 🔇
40 | - 💄
41 | - 🎨
42 | - 🧱
43 | - 👷
44 | - 💚
45 | - 🚀
46 | - 🧵
47 | - 🛂
48 | - 🦺
49 | - 🔥
50 | - 🏁
51 | - 🐧
52 | - 🗃️
53 | - 👥
54 | - 📄
55 | - 🌐
56 | - ✅
57 | - 🧪
58 | - "Merge"
59 | commit_groups:
60 | sort_by: RawTitle
61 | title_maps:
62 | 🎉: 初始化
63 | ✨: 新特性
64 | 🚩: 功能标识
65 | 🔖: 版本发布
66 | 🩹: 修改优化
67 | 🐛: 问题修复
68 | 🚑️: 紧急修复
69 | 🔒️: 安全修复
70 | 🔐: 加密相关
71 | 🚨: 修复警告
72 | ♻️: 代码重构
73 | 👔: 业务逻辑
74 | 🍱: 资源文件
75 | 🔨: 开发配置
76 | 🔧: 程序配置
77 | 🙈: 忽略配置
78 | 💡: 注释文档
79 | 📝: 说明文档
80 | ⬆️: 依赖升级
81 | ⬇️: 依赖降级
82 | 🚚: 移动文件
83 | 🏗️: 架构变更
84 | 👽️: 外部变更
85 | 💥: 突破变更
86 | ⚡️: 性能优化
87 | 🚸: 用户体验
88 | 🔊: 更新日志
89 | 🔇: 移除日志
90 | 💄: 界面样式
91 | 🎨: 代码格式
92 | 🧱: 基础设施
93 | 👷: 持续集成
94 | 💚: 修复构建
95 | 🚀: 部署
96 | 🧵: 并发
97 | 🛂: 权限
98 | 🦺: 验证
99 | 🔥: 删除
100 | 🏁: Windows
101 | 🐧: Unix
102 | 🗃️: 数据库
103 | 👥: 贡献者
104 | 📄: 许可证
105 | 🌐: 国际化
106 | ✅: 通过的测试
107 | 🧪: 失败的测试
108 | "Merge": 合并
109 | header:
110 | # pattern: "^(.*)$"
111 | # https://regex101.com/r/wEipOM/1
112 | # https://emojipedia.org/emoji/%E2%9C%A8/
113 | pattern: "^(Merge|[\\x{1F300}-\\x{1F5FF}\\x{1F600}-\\x{1F64F}\\x{1F680}-\\x{1F6FF}\\x{1F700}-\\x{1F77F}\\x{1F780}-\\x{1F7FF}\\x{1F800}-\\x{1F8FF}\\x{1F900}-\\x{1F9FF}\\x{1FA00}-\\x{1FA6F}\\x{1FA70}-\\x{1FAFF}\\x{2600}-\\x{26FF}\\x{2700}-\\x{27BF}]|[\\x{2B05}\\x{2B06}\\x{2B07}\\x{23EA}\\x{23EB}])(.*)$"
114 | pattern_maps:
115 | - Type
116 | - Subject
117 | notes:
118 | keywords:
119 | - 💥
120 |
--------------------------------------------------------------------------------
/_dev/goland/go-mod.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/_dev/goland/go-staticcheck.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/_dev/goland/go-vulncheck.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/_dev/script/done-time-pause.bat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/godoes/gorm-oracle/bb5f5c6dc7f5d2c17c42608fc33182ef20a53e38/_dev/script/done-time-pause.bat
--------------------------------------------------------------------------------
/_dev/script/go-mod.bat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/godoes/gorm-oracle/bb5f5c6dc7f5d2c17c42608fc33182ef20a53e38/_dev/script/go-mod.bat
--------------------------------------------------------------------------------
/_dev/script/go-staticcheck.bat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/godoes/gorm-oracle/bb5f5c6dc7f5d2c17c42608fc33182ef20a53e38/_dev/script/go-staticcheck.bat
--------------------------------------------------------------------------------
/_dev/script/go-vulncheck.bat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/godoes/gorm-oracle/bb5f5c6dc7f5d2c17c42608fc33182ef20a53e38/_dev/script/go-vulncheck.bat
--------------------------------------------------------------------------------
/create.go:
--------------------------------------------------------------------------------
1 | package oracle
2 |
3 | import (
4 | "reflect"
5 |
6 | "github.com/sijms/go-ora/v2"
7 | "gorm.io/gorm"
8 | "gorm.io/gorm/callbacks"
9 | "gorm.io/gorm/clause"
10 | )
11 |
12 | func Create(db *gorm.DB) {
13 | if db.Error != nil || db.Statement == nil {
14 | return
15 | }
16 |
17 | stmt := db.Statement
18 | stmtSchema := stmt.Schema
19 | if stmtSchema != nil && !stmt.Unscoped {
20 | for _, c := range stmtSchema.CreateClauses {
21 | stmt.AddClause(c)
22 | }
23 | }
24 |
25 | if stmt.SQL.Len() == 0 {
26 | var (
27 | createValues = callbacks.ConvertToCreateValues(stmt)
28 | onConflict, hasConflict = stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
29 | )
30 |
31 | if hasConflict {
32 | if stmtSchema != nil && len(stmtSchema.PrimaryFields) > 0 {
33 | columnsMap := map[string]bool{}
34 | for _, column := range createValues.Columns {
35 | columnsMap[column.Name] = true
36 | }
37 |
38 | for _, field := range stmtSchema.PrimaryFields {
39 | if _, ok := columnsMap[field.DBName]; !ok {
40 | hasConflict = false
41 | }
42 | }
43 | } else {
44 | hasConflict = false
45 | }
46 | }
47 |
48 | if hasConflict {
49 | MergeCreate(db, onConflict, createValues)
50 | } else {
51 | stmt.AddClauseIfNotExists(clause.Insert{})
52 | stmt.AddClause(clause.Values{Columns: createValues.Columns, Values: [][]interface{}{createValues.Values[0]}})
53 |
54 | stmt.Build("INSERT", "VALUES")
55 | _ = outputInserted(db)
56 | }
57 |
58 | if !db.DryRun && db.Error == nil {
59 | if hasConflict {
60 | for i, val := range stmt.Vars {
61 | // HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
62 | stmt.Vars[i] = convertValue(val)
63 | }
64 |
65 | result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
66 | if db.AddError(err) == nil {
67 | db.RowsAffected, _ = result.RowsAffected()
68 | // TODO: get merged returning
69 | }
70 | } else {
71 | for idx, values := range createValues.Values {
72 | for i, val := range values {
73 | // HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
74 | stmt.Vars[i] = convertValue(val)
75 | }
76 |
77 | result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
78 | if db.AddError(err) == nil {
79 | rowsAffected, _ := result.RowsAffected()
80 | db.RowsAffected += rowsAffected
81 |
82 | if stmtSchema != nil && len(stmtSchema.FieldsWithDefaultDBValue) > 0 {
83 | getDefaultValues(db, idx)
84 | }
85 | }
86 | }
87 | }
88 | }
89 | }
90 | }
91 |
92 | func outputInserted(db *gorm.DB) (lenDefaultValue int) {
93 | stmtSchema := db.Statement.Schema
94 | if stmtSchema == nil {
95 | return
96 | }
97 | lenDefaultValue = len(stmtSchema.FieldsWithDefaultDBValue)
98 | if lenDefaultValue > 0 {
99 | columns := make([]clause.Column, lenDefaultValue)
100 | for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
101 | columns[idx] = clause.Column{Name: field.DBName}
102 | }
103 | db.Statement.AddClauseIfNotExists(clause.Returning{Columns: columns})
104 | }
105 | db.Statement.Build("RETURNING")
106 |
107 | _, _ = db.Statement.WriteString(" INTO ")
108 | for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
109 | if idx > 0 {
110 | _ = db.Statement.WriteByte(',')
111 | }
112 |
113 | outVar := go_ora.Out{Dest: reflect.New(field.FieldType).Interface()}
114 | if field.Size > 0 {
115 | outVar.Size = field.Size
116 | }
117 | db.Statement.AddVar(db.Statement, outVar)
118 | }
119 | _, _ = db.Statement.WriteString(" /*-go_ora.Out{}-*/")
120 | return
121 | }
122 |
123 | func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) {
124 | dummyTable := getDummyTable(db)
125 |
126 | _, _ = db.Statement.WriteString("MERGE INTO ")
127 | db.Statement.WriteQuoted(db.Statement.Table)
128 | _, _ = db.Statement.WriteString(" USING (")
129 |
130 | for idx, value := range values.Values {
131 | if idx > 0 {
132 | _, _ = db.Statement.WriteString(" UNION ALL ")
133 | }
134 |
135 | _, _ = db.Statement.WriteString("SELECT ")
136 | for i, v := range value {
137 | if i > 0 {
138 | _ = db.Statement.WriteByte(',')
139 | }
140 | column := values.Columns[i]
141 | db.Statement.AddVar(db.Statement, v)
142 | _, _ = db.Statement.WriteString(" AS ")
143 | db.Statement.WriteQuoted(column.Name)
144 | }
145 | _, _ = db.Statement.WriteString(" FROM ")
146 | _, _ = db.Statement.WriteString(dummyTable)
147 | }
148 |
149 | _, _ = db.Statement.WriteString(`) `)
150 | db.Statement.WriteQuoted("excluded")
151 | _, _ = db.Statement.WriteString(" ON (")
152 |
153 | var where clause.Where
154 | for _, field := range db.Statement.Schema.PrimaryFields {
155 | where.Exprs = append(where.Exprs, clause.Eq{
156 | Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
157 | Value: clause.Column{Table: "excluded", Name: field.DBName},
158 | })
159 | }
160 | where.Build(db.Statement)
161 | _ = db.Statement.WriteByte(')')
162 |
163 | if len(onConflict.DoUpdates) > 0 {
164 | _, _ = db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
165 | onConflict.DoUpdates.Build(db.Statement)
166 | }
167 |
168 | _, _ = db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")
169 |
170 | written := false
171 | for _, column := range values.Columns {
172 | if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
173 | if written {
174 | _ = db.Statement.WriteByte(',')
175 | }
176 | written = true
177 | db.Statement.WriteQuoted(column.Name)
178 | }
179 | }
180 |
181 | _, _ = db.Statement.WriteString(") VALUES (")
182 |
183 | written = false
184 | for _, column := range values.Columns {
185 | if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
186 | if written {
187 | _ = db.Statement.WriteByte(',')
188 | }
189 | written = true
190 | db.Statement.WriteQuoted(clause.Column{
191 | Table: "excluded",
192 | Name: column.Name,
193 | })
194 | }
195 | }
196 | _, _ = db.Statement.WriteString(")")
197 | }
198 |
199 | func convertValue(val interface{}) interface{} {
200 | val = ptrDereference(val)
201 | switch v := val.(type) {
202 | case bool:
203 | if v {
204 | val = 1
205 | } else {
206 | val = 0
207 | }
208 | case string:
209 | if len(v) > 2000 {
210 | val = go_ora.Clob{String: v, Valid: true}
211 | }
212 | default:
213 | val = convertCustomType(val)
214 | }
215 | return val
216 | }
217 |
218 | func getDummyTable(db *gorm.DB) (dummyTable string) {
219 | switch d := ptrDereference(db.Dialector).(type) {
220 | case Dialector:
221 | dummyTable = d.DummyTableName()
222 | default:
223 | dummyTable = "DUAL"
224 | }
225 | return
226 | }
227 |
228 | func getDefaultValues(db *gorm.DB, idx int) {
229 | if db.Statement.Schema == nil || len(db.Statement.Schema.FieldsWithDefaultDBValue) == 0 {
230 | return
231 | }
232 | insertTo := db.Statement.ReflectValue
233 | switch insertTo.Kind() {
234 | case reflect.Slice, reflect.Array:
235 | insertTo = insertTo.Index(idx)
236 | default:
237 | }
238 | if insertTo.Kind() == reflect.Pointer {
239 | insertTo = insertTo.Elem()
240 | }
241 |
242 | for _, val := range db.Statement.Vars {
243 | switch v := val.(type) {
244 | case go_ora.Out:
245 | switch insertTo.Kind() {
246 | case reflect.Slice, reflect.Array:
247 | for i := insertTo.Len() - 1; i >= 0; i-- {
248 | rv := insertTo.Index(i)
249 | switch reflect.Indirect(rv).Kind() {
250 | case reflect.Struct:
251 | setStructFieldValue(db, rv, v)
252 | default:
253 | }
254 | }
255 | case reflect.Struct:
256 | setStructFieldValue(db, insertTo, v)
257 | default:
258 | }
259 | default:
260 | }
261 | }
262 | }
263 |
264 | func setStructFieldValue(db *gorm.DB, insertTo reflect.Value, out go_ora.Out) {
265 | if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo); !isZero {
266 | return
267 | }
268 | _ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, out.Dest))
269 | }
270 |
--------------------------------------------------------------------------------
/create_test.go:
--------------------------------------------------------------------------------
1 | package oracle
2 |
3 | import (
4 | "encoding/json"
5 | "strings"
6 | "testing"
7 | "time"
8 | )
9 |
10 | func TestMergeCreate(t *testing.T) {
11 | db, err := dbNamingCase, dbErrors[0]
12 | if err != nil {
13 | t.Fatal(err)
14 | }
15 | if db == nil {
16 | t.Log("db is nil!")
17 | return
18 | }
19 |
20 | model := TestTableUser{}
21 | migrator := db.Set("gorm:table_comments", "用户信息表").Migrator()
22 | if migrator.HasTable(model) {
23 | if err = migrator.DropTable(model); err != nil {
24 | t.Fatalf("DropTable() error = %v", err)
25 | }
26 | }
27 | if err = migrator.AutoMigrate(model); err != nil {
28 | t.Fatalf("AutoMigrate() error = %v", err)
29 | } else {
30 | t.Log("AutoMigrate() success!")
31 | }
32 |
33 | data := []TestTableUser{
34 | {
35 | UID: "U1",
36 | Name: "Lisa",
37 | Account: "lisa",
38 | Password: "H6aLDNr",
39 | PhoneNumber: "+8616666666666",
40 | Sex: "0",
41 | UserType: 1,
42 | Enabled: true,
43 | },
44 | {
45 | UID: "U1",
46 | Name: "Lisa",
47 | Account: "lisa",
48 | Password: "H6aLDNr",
49 | PhoneNumber: "+8616666666666",
50 | Sex: "0",
51 | UserType: 1,
52 | Enabled: true,
53 | },
54 | {
55 | UID: "U2",
56 | Name: "Daniela",
57 | Account: "daniela",
58 | Password: "Si7l1sRIC79",
59 | PhoneNumber: "+8619999999999",
60 | Sex: "1",
61 | UserType: 1,
62 | Enabled: true,
63 | },
64 | }
65 | t.Run("MergeCreate", func(t *testing.T) {
66 | tx := db.Create(&data)
67 | if err = tx.Error; err != nil {
68 | t.Fatal(err)
69 | }
70 | dataJsonBytes, _ := json.MarshalIndent(data, "", " ")
71 | t.Logf("result: %s", dataJsonBytes)
72 | })
73 | }
74 |
75 | type TestTableUserUnique struct {
76 | ID uint64 `gorm:"column:id;size:64;not null;autoIncrement:true;autoIncrementIncrement:1;primaryKey;comment:自增 ID" json:"id"`
77 | UID string `gorm:"column:uid;type:varchar(50);comment:用户身份标识;unique" json:"uid"`
78 | Name string `gorm:"column:name;size:50;comment:用户姓名" json:"name"`
79 | Account string `gorm:"column:account;type:varchar(50);comment:登录账号" json:"account"`
80 | Password string `gorm:"column:password;type:varchar(512);comment:登录密码(密文)" json:"password"`
81 | Email string `gorm:"column:email;type:varchar(128);comment:邮箱地址" json:"email"`
82 | PhoneNumber string `gorm:"column:phone_number;type:varchar(15);comment:E.164" json:"phoneNumber"`
83 | Sex string `gorm:"column:sex;type:char(1);comment:性别" json:"sex"`
84 | Birthday *time.Time `gorm:"column:birthday;->:false;<-:create;comment:生日" json:"birthday,omitempty"`
85 | UserType int `gorm:"column:user_type;size:8;comment:用户类型" json:"userType"`
86 | Enabled bool `gorm:"column:enabled;comment:是否可用" json:"enabled"`
87 | Remark string `gorm:"column:remark;size:1024;comment:备注信息" json:"remark"`
88 | }
89 |
90 | func (TestTableUserUnique) TableName() string {
91 | return "test_user_unique"
92 | }
93 |
94 | func TestMergeCreateUnique(t *testing.T) {
95 | db, err := dbNamingCase, dbErrors[0]
96 | if err != nil {
97 | t.Fatal(err)
98 | }
99 | if db == nil {
100 | t.Log("db is nil!")
101 | return
102 | }
103 |
104 | model := TestTableUserUnique{}
105 | migrator := db.Set("gorm:table_comments", "用户信息表").Migrator()
106 | if migrator.HasTable(model) {
107 | if err = migrator.DropTable(model); err != nil {
108 | t.Fatalf("DropTable() error = %v", err)
109 | }
110 | }
111 | if err = migrator.AutoMigrate(model); err != nil {
112 | t.Fatalf("AutoMigrate() error = %v", err)
113 | } else {
114 | t.Log("AutoMigrate() success!")
115 | }
116 |
117 | data := []TestTableUserUnique{
118 | {
119 | UID: "U1",
120 | Name: "Lisa",
121 | Account: "lisa",
122 | Password: "H6aLDNr",
123 | PhoneNumber: "+8616666666666",
124 | Sex: "0",
125 | UserType: 1,
126 | Enabled: true,
127 | },
128 | {
129 | UID: "U2",
130 | Name: "Daniela",
131 | Account: "daniela",
132 | Password: "Si7l1sRIC79",
133 | PhoneNumber: "+8619999999999",
134 | Sex: "1",
135 | UserType: 1,
136 | Enabled: true,
137 | },
138 | {
139 | UID: "U2",
140 | Name: "Daniela",
141 | Account: "daniela",
142 | Password: "Si7l1sRIC79",
143 | PhoneNumber: "+8619999999999",
144 | Sex: "1",
145 | UserType: 1,
146 | Enabled: true,
147 | },
148 | }
149 | t.Run("MergeCreateUnique", func(t *testing.T) {
150 | tx := db.Create(&data)
151 | if err = tx.Error; err != nil {
152 | if strings.Contains(err.Error(), "ORA-00001") {
153 | t.Log(err) // ORA-00001: 违反唯一约束条件
154 | var gotData []TestTableUserUnique
155 | tx = db.Where(map[string]interface{}{"uid": []string{"U1", "U2"}}).Find(&gotData)
156 | if err = tx.Error; err != nil {
157 | t.Fatal(err)
158 | } else {
159 | if len(gotData) > 0 {
160 | t.Error("Unique constraint violation, but some data was inserted!")
161 | } else {
162 | t.Log("Unique constraint violation, rolled back!")
163 | }
164 | }
165 | } else {
166 | t.Fatal(err)
167 | }
168 | return
169 | }
170 | dataJsonBytes, _ := json.MarshalIndent(data, "", " ")
171 | t.Logf("result: %s", dataJsonBytes)
172 | })
173 | }
174 |
175 | type testModelOra03146TTC struct {
176 | Id int64 `gorm:"primaryKey;autoIncrement:false;column:SL_ID;type:uint;size:20;default:0;comment:id" json:"SL_ID"`
177 | ApiName string `gorm:"column:SL_API_NAME;type:VARCHAR2;size:100;default:null;comment:接口名称" json:"SL_API_NAME"`
178 | RawReceive string `gorm:"column:SL_RAW_RECEIVE_JSON;type:VARCHAR2;size:4000;default:null;comment:原始请求参数" json:"SL_RAW_RECEIVE_JSON"`
179 | RawSend string `gorm:"column:SL_RAW_SEND_JSON;type:VARCHAR2;size:4000;default:null;comment:原始响应参数" json:"SL_RAW_SEND_JSON"`
180 | DealReceive string `gorm:"column:SL_DEAL_RECEIVE_JSON;type:VARCHAR2;size:4000;default:null;comment:处理请求参数" json:"SL_DEAL_RECEIVE_JSON"`
181 | DealSend string `gorm:"column:SL_DEAL_SEND_JSON;type:VARCHAR2;size:4000;default:null;comment:处理响应参数" json:"SL_DEAL_SEND_JSON"`
182 | Code string `gorm:"column:SL_CODE;type:VARCHAR2;size:16;default:null;comment:http状态" json:"SL_CODE"`
183 | CreatedTime time.Time `gorm:"column:SL_CREATED_TIME;type:date;default:null;comment:创建时间" json:"SL_CREATED_TIME"`
184 | }
185 |
186 | func TestOra03146TTC(t *testing.T) {
187 | db, err := dbNamingCase, dbErrors[0]
188 | if err != nil {
189 | t.Fatal(err)
190 | }
191 | if db == nil {
192 | t.Log("db is nil!")
193 | return
194 | }
195 |
196 | model := testModelOra03146TTC{}
197 | migrator := db.Set("gorm:table_comments", "TTC 字段的缓冲区长度无效问题测试表").Migrator()
198 | if migrator.HasTable(model) {
199 | if err = migrator.DropTable(model); err != nil {
200 | t.Fatalf("DropTable() error = %v", err)
201 | }
202 | }
203 | if err = migrator.AutoMigrate(model); err != nil {
204 | t.Fatalf("AutoMigrate() error = %v", err)
205 | } else {
206 | t.Log("AutoMigrate() success!")
207 | }
208 |
209 | // INSERT INTO "T100_SCPTOAPI_LOG" ("SL_ID","SL_API_NAME","SL_RAW_RECEIVE_JSON","SL_RAW_SEND_JSON","SL_DEAL_RECEIVE_JSON","SL_DEAL_SEND_JSON","SL_CODE","SL_CREATED_TIME")
210 | // VALUES (9578529926701056,'/v1/t100/packingNum','11111','11111','11111','11111','111','2024-08-27 18:21:39.495')
211 | data := testModelOra03146TTC{
212 | Id: 9578529926701056,
213 | ApiName: "/v1/t100/packingNum",
214 | RawReceive: "11111",
215 | RawSend: "11111",
216 | DealReceive: "11111",
217 | DealSend: "11111",
218 | Code: "111",
219 | CreatedTime: time.Now(),
220 | }
221 | result := db.Create(&data)
222 | if err = result.Error; err != nil {
223 | t.Fatalf("执行失败:%v", err)
224 | }
225 | t.Log("执行成功,影响行数:", result.RowsAffected)
226 | }
227 |
228 | func TestCreateInBatches(t *testing.T) {
229 | db, err := dbNamingCase, dbErrors[0]
230 | if err != nil {
231 | t.Fatal(err)
232 | }
233 | if db == nil {
234 | t.Log("db is nil!")
235 | return
236 | }
237 |
238 | model := TestTableUser{}
239 | migrator := db.Set("gorm:table_comments", "用户信息表").Migrator()
240 | if migrator.HasTable(model) {
241 | if err = migrator.DropTable(model); err != nil {
242 | t.Fatalf("DropTable() error = %v", err)
243 | }
244 | }
245 | if err = migrator.AutoMigrate(model); err != nil {
246 | t.Fatalf("AutoMigrate() error = %v", err)
247 | } else {
248 | t.Log("AutoMigrate() success!")
249 | }
250 |
251 | data := []TestTableUser{
252 | {UID: "U1", Name: "Lisa", Account: "lisa", Password: "H6aLDNr", PhoneNumber: "+8616666666666", Sex: "0", UserType: 1, Enabled: true},
253 | {UID: "U2", Name: "Daniela", Account: "daniela", Password: "Si7l1sRIC79", PhoneNumber: "+8619999999999", Sex: "1", UserType: 1, Enabled: true},
254 | {UID: "U3", Name: "Tom", Account: "tom", Password: "********", PhoneNumber: "+8618888888888", Sex: "1", UserType: 1, Enabled: true},
255 | {UID: "U4", Name: "James", Account: "james", Password: "********", PhoneNumber: "+8617777777777", Sex: "1", UserType: 2, Enabled: true},
256 | {UID: "U5", Name: "John", Account: "john", Password: "********", PhoneNumber: "+8615555555555", Sex: "1", UserType: 1, Enabled: true},
257 | }
258 | t.Run("CreateInBatches", func(t *testing.T) {
259 | tx := db.CreateInBatches(&data, 2)
260 | if err = tx.Error; err != nil {
261 | t.Fatal(err)
262 | }
263 | dataJsonBytes, _ := json.MarshalIndent(data, "", " ")
264 | t.Logf("result: %s", dataJsonBytes)
265 | })
266 | }
267 |
--------------------------------------------------------------------------------
/datatypes_json_map_test.go:
--------------------------------------------------------------------------------
1 | package oracle
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "database/sql/driver"
7 | "encoding/json"
8 | "errors"
9 | "fmt"
10 | "strings"
11 |
12 | "gorm.io/gorm"
13 | "gorm.io/gorm/clause"
14 | "gorm.io/gorm/schema"
15 | )
16 |
17 | // JSONMap defined JSON data type, need to implement driver.Valuer, sql.Scanner interface
18 | type JSONMap map[string]interface{}
19 |
20 | // Value return json value, implement driver.Valuer interface
21 | //
22 | //goland:noinspection GoMixedReceiverTypes
23 | func (m JSONMap) Value() (driver.Value, error) {
24 | if m == nil {
25 | return nil, nil
26 | }
27 | ba, err := m.MarshalJSON()
28 | return string(ba), err
29 | }
30 |
31 | // Scan value into Jsonb, implements sql.Scanner interface
32 | //
33 | //goland:noinspection GoMixedReceiverTypes
34 | func (m *JSONMap) Scan(val interface{}) error {
35 | if val == nil {
36 | *m = make(JSONMap)
37 | return nil
38 | }
39 | var ba []byte
40 | switch v := val.(type) {
41 | case []byte:
42 | ba = v
43 | case string:
44 | ba = []byte(v)
45 | default:
46 | return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", val))
47 | }
48 | t := map[string]interface{}{}
49 | rd := bytes.NewReader(ba)
50 | decoder := json.NewDecoder(rd)
51 | decoder.UseNumber()
52 | err := decoder.Decode(&t)
53 | *m = t
54 | return err
55 | }
56 |
57 | // MarshalJSON to output non base64 encoded []byte
58 | //
59 | //goland:noinspection GoMixedReceiverTypes
60 | func (m JSONMap) MarshalJSON() ([]byte, error) {
61 | if m == nil {
62 | return []byte("null"), nil
63 | }
64 | t := (map[string]interface{})(m)
65 | return json.Marshal(t)
66 | }
67 |
68 | // UnmarshalJSON to deserialize []byte
69 | //
70 | //goland:noinspection GoMixedReceiverTypes
71 | func (m *JSONMap) UnmarshalJSON(b []byte) error {
72 | t := map[string]interface{}{}
73 | err := json.Unmarshal(b, &t)
74 | *m = t
75 | return err
76 | }
77 |
78 | // GormDataType gorm common data type
79 | //
80 | //goland:noinspection GoMixedReceiverTypes
81 | func (m JSONMap) GormDataType() string {
82 | return "jsonmap"
83 | }
84 |
85 | // GormDBDataType gorm db data type
86 | //
87 | //goland:noinspection GoMixedReceiverTypes
88 | func (JSONMap) GormDBDataType(db *gorm.DB, field *schema.Field) string {
89 | switch db.Dialector.Name() {
90 | case "sqlite":
91 | return "JSON"
92 | case "mysql":
93 | return "JSON"
94 | case "postgres":
95 | return "JSONB"
96 | case "sqlserver":
97 | return "NVARCHAR(MAX)"
98 | case "oracle":
99 | //return "BLOB"
100 | // BLOB is only supported in Oracle databases version 12c r12.2.0.1.0 and above.
101 | // to support lower versions of Oracle databases, it is recommended to use CLOB.
102 | // see also:
103 | // https://stackoverflow.com/questions/43603905/oracle-12c-error-getting-while-create-blob-column-table-with-json-type
104 | return "CLOB"
105 | default:
106 | return getGormTypeFromTag(field)
107 | }
108 | }
109 |
110 | func getGormTypeFromTag(field *schema.Field) (dataType string) {
111 | if field != nil {
112 | if val, ok := field.TagSettings["TYPE"]; ok {
113 | dataType = strings.ToLower(val)
114 | }
115 | }
116 | return
117 | }
118 |
119 | //goland:noinspection GoMixedReceiverTypes
120 | func (m JSONMap) GormValue(_ context.Context, _ *gorm.DB) clause.Expr {
121 | data, _ := m.MarshalJSON()
122 | return gorm.Expr("?", string(data))
123 | }
124 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/godoes/gorm-oracle
2 |
3 | go 1.18
4 |
5 | require (
6 | github.com/emirpasic/gods v1.18.1
7 | github.com/sijms/go-ora/v2 v2.9.0
8 | gorm.io/gorm v1.30.0
9 | )
10 |
11 | require (
12 | github.com/jinzhu/inflection v1.0.0 // indirect
13 | github.com/jinzhu/now v1.1.5 // indirect
14 | golang.org/x/text v0.20.0 // indirect
15 | )
16 |
17 | exclude (
18 | github.com/sijms/go-ora/v2 v2.8.8 // ORA-03137: 来自客户机的格式错误的 TTC 包被拒绝: [opiexe: protocol violation]
19 | github.com/sijms/go-ora/v2 v2.8.9 // has bug
20 | )
21 |
22 | retract (
23 | v1.5.12
24 | v1.5.1
25 | v1.5.0
26 | )
27 |
--------------------------------------------------------------------------------
/migrator.go:
--------------------------------------------------------------------------------
1 | package oracle
2 |
3 | import (
4 | "database/sql"
5 | "fmt"
6 | "strconv"
7 | "strings"
8 |
9 | "gorm.io/gorm"
10 | "gorm.io/gorm/clause"
11 | "gorm.io/gorm/migrator"
12 | "gorm.io/gorm/schema"
13 | )
14 |
15 | // Migrator implement gorm migrator interface
16 | type Migrator struct {
17 | migrator.Migrator
18 | }
19 |
20 | // AutoMigrate 自动迁移模型为表结构
21 | //
22 | // // 迁移并设置单个表注释
23 | // db.Set("gorm:table_comments", "用户信息表").AutoMigrate(&User{})
24 | //
25 | // // 迁移并设置多个表注释
26 | // db.Set("gorm:table_comments", []string{"用户信息表", "公司信息表"}).AutoMigrate(&User{}, &Company{})
27 | func (m Migrator) AutoMigrate(dst ...interface{}) error {
28 | if err := m.Migrator.AutoMigrate(dst...); err != nil {
29 | return err
30 | }
31 | // set table comment
32 | if tableComments, ok := m.DB.Get("gorm:table_comments"); ok {
33 | var comments []string
34 | switch c := tableComments.(type) {
35 | case string:
36 | comments = append(comments, c)
37 | case []string:
38 | comments = c
39 | default:
40 | return nil
41 | }
42 | for i := 0; i < len(dst) && i < len(comments); i++ {
43 | value := dst[i]
44 | tx := m.DB.Session(&gorm.Session{})
45 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
46 | return tx.Exec("COMMENT ON TABLE ? IS '?'", m.CurrentTable(stmt), GetStringExpr(comments[i])).Error
47 | }); err != nil {
48 | return err
49 | }
50 | }
51 | }
52 | return nil
53 | }
54 |
55 | // FullDataTypeOf returns field's db full data type
56 | func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
57 | expr.SQL = m.DataTypeOf(field)
58 |
59 | if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
60 | if field.DefaultValueInterface != nil {
61 | defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
62 | m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
63 | expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
64 | } else if field.DefaultValue != "(-)" {
65 | expr.SQL += " DEFAULT " + field.DefaultValue
66 | }
67 | }
68 |
69 | if field.NotNull {
70 | expr.SQL += " NOT NULL"
71 | }
72 |
73 | // see https://github.com/go-gorm/gorm/pull/6822
74 | //if field.Unique {
75 | // expr.SQL += " UNIQUE"
76 | //}
77 |
78 | return
79 | }
80 |
81 | // CurrentDatabase returns current database name
82 | func (m Migrator) CurrentDatabase() (name string) {
83 | _ = m.DB.Raw(
84 | fmt.Sprintf(`SELECT ORA_DATABASE_NAME as "Current Database" FROM %s`, m.Dialector.(Dialector).DummyTableName()),
85 | ).Row().Scan(&name)
86 | return
87 | }
88 |
89 | // GetTypeAliases return database type aliases
90 | func (m Migrator) GetTypeAliases(databaseTypeName string) (types []string) {
91 | switch databaseTypeName {
92 | case "blob", "raw", "longraw", "ocibloblocator", "ocifilelocator":
93 | types = append(types, "blob", "raw", "longraw", "ocibloblocator", "ocifilelocator")
94 | case "clob", "nclob", "longvarchar", "ocicloblocator":
95 | types = append(types, "clob", "nclob", "longvarchar", "ocicloblocator")
96 | case "char", "nchar", "varchar", "varchar2", "nvarchar2":
97 | types = append(types, "char", "nchar", "varchar", "varchar2", "nvarchar2")
98 | case "number", "integer", "smallint":
99 | types = append(types, "number", "integer", "smallint")
100 | case "decimal", "numeric", "ibfloat", "ibdouble":
101 | types = append(types, "decimal", "numeric", "ibfloat", "ibdouble")
102 | case "timestampdty", "timestamp", "date":
103 | types = append(types, "timestampdty", "timestamp", "date")
104 | case "timestamptz_dty", "timestamp with time zone":
105 | types = append(types, "timestamptz_dty", "timestamp with time zone")
106 | case "timestampltz_dty", "timestampeltz", "timestamp with local time zone":
107 | types = append(types, "timestampltz_dty", "timestampeltz", "timestamp with local time zone")
108 | default:
109 | return
110 | }
111 | return
112 | }
113 |
114 | // CreateTable create table in database for values
115 | func (m Migrator) CreateTable(values ...interface{}) (err error) {
116 | ignoreCase := !m.Dialector.(Dialector).NamingCaseSensitive
117 | for _, value := range values {
118 | if ignoreCase {
119 | _ = m.TryQuotifyReservedWords(value)
120 | }
121 | _ = m.TryRemoveOnUpdate(value)
122 | }
123 | if err = m.Migrator.CreateTable(values...); err != nil {
124 | return
125 | }
126 | // set column comment
127 | for _, value := range m.ReorderModels(values, false) {
128 | if err = m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
129 | if stmt.Schema != nil {
130 | for _, fieldName := range stmt.Schema.DBNames {
131 | field := stmt.Schema.FieldsByDBName[fieldName]
132 | if err = m.setCommentForColumn(field, stmt); err != nil {
133 | return
134 | }
135 | }
136 | }
137 | return
138 | }); err != nil {
139 | return
140 | }
141 | }
142 | return
143 | }
144 |
145 | func (m Migrator) setCommentForColumn(field *schema.Field, stmt *gorm.Statement) (err error) {
146 | if field == nil || stmt == nil || field.Comment == "" {
147 | return
148 | }
149 | table := m.CurrentTable(stmt)
150 | column := clause.Column{Name: field.DBName}
151 | comment := GetStringExpr(field.Comment)
152 | err = m.DB.Exec("COMMENT ON COLUMN ?.? IS '?'", table, column, comment).Error
153 | return
154 | }
155 |
156 | // DropTable drop table for values
157 | //
158 | //goland:noinspection SqlNoDataSourceInspection
159 | func (m Migrator) DropTable(values ...interface{}) error {
160 | values = m.ReorderModels(values, false)
161 | for i := len(values) - 1; i >= 0; i-- {
162 | value := values[i]
163 | tx := m.DB.Session(&gorm.Session{})
164 | if m.HasTable(value) {
165 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
166 | return tx.Exec("DROP TABLE ? CASCADE CONSTRAINTS", clause.Table{Name: stmt.Table}).Error
167 | }); err != nil {
168 | return err
169 | }
170 | }
171 | }
172 | return nil
173 | }
174 |
175 | // HasTable returns table exists or not for value, value could be a struct or string
176 | func (m Migrator) HasTable(value interface{}) bool {
177 | var count int64
178 |
179 | _ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
180 | if ownerName, tableName := m.getSchemaTable(stmt); ownerName != "" {
181 | return m.DB.Raw("SELECT COUNT(*) FROM ALL_TABLES WHERE OWNER = ? and TABLE_NAME = ?", ownerName, tableName).Row().Scan(&count)
182 | } else {
183 | return m.DB.Raw("SELECT COUNT(*) FROM USER_TABLES WHERE TABLE_NAME = ?", tableName).Row().Scan(&count)
184 | }
185 | })
186 |
187 | return count > 0
188 | }
189 |
190 | func (m Migrator) getSchemaTable(stmt *gorm.Statement) (ownerName, tableName string) {
191 | if stmt == nil {
192 | return
193 | }
194 | if stmt.Schema == nil {
195 | tableName = stmt.Table
196 | } else {
197 | tableName = stmt.Schema.Table
198 | if strings.Contains(tableName, ".") {
199 | ownerTable := strings.Split(tableName, ".")
200 | ownerName, tableName = ownerTable[0], ownerTable[1]
201 | }
202 | }
203 | return
204 | }
205 |
206 | // ColumnTypes return columnTypes []gorm.ColumnType and execErr error
207 | func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
208 | columnTypes := make([]gorm.ColumnType, 0)
209 | execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
210 | _, tableName := m.getSchemaTable(stmt)
211 | rows, err := m.DB.Session(&gorm.Session{}).Table(tableName).Where("ROWNUM = 1").Rows()
212 | if err != nil {
213 | return err
214 | }
215 |
216 | defer func() {
217 | err = rows.Close()
218 | }()
219 |
220 | var rawColumnTypes []*sql.ColumnType
221 | rawColumnTypes, err = rows.ColumnTypes()
222 | if err != nil {
223 | return err
224 | }
225 |
226 | ignoreCase := !m.Dialector.(Dialector).NamingCaseSensitive
227 | for _, c := range rawColumnTypes {
228 | columnType := migrator.ColumnType{SQLColumnType: c}
229 | if ignoreCase && IsReservedWord(c.Name()) {
230 | columnType.NameValue = sql.NullString{
231 | String: strconv.Quote(c.Name()),
232 | Valid: true,
233 | }
234 | }
235 | columnTypes = append(columnTypes, columnType)
236 | }
237 |
238 | return
239 | })
240 |
241 | return columnTypes, execErr
242 | }
243 |
244 | // RenameTable rename table from oldName to newName
245 | func (m Migrator) RenameTable(oldName, newName interface{}) (err error) {
246 | resolveTable := func(name interface{}) (result string, err error) {
247 | if v, ok := name.(string); ok {
248 | result = v
249 | } else {
250 | stmt := &gorm.Statement{DB: m.DB}
251 | if err = stmt.Parse(name); err == nil {
252 | result = stmt.Table
253 | }
254 | }
255 | return
256 | }
257 |
258 | var oldTable, newTable string
259 |
260 | if oldTable, err = resolveTable(oldName); err != nil {
261 | return
262 | }
263 |
264 | if newTable, err = resolveTable(newName); err != nil {
265 | return
266 | }
267 |
268 | if !m.HasTable(oldTable) {
269 | return
270 | }
271 |
272 | return m.DB.Exec("RENAME TABLE ? TO ?",
273 | clause.Table{Name: oldTable},
274 | clause.Table{Name: newTable},
275 | ).Error
276 | }
277 |
278 | // GetTables returns tables under the current user database
279 | func (m Migrator) GetTables() (tableList []string, err error) {
280 | err = m.DB.Raw(`SELECT TABLE_NAME FROM USER_TABLES
281 | WHERE TABLESPACE_NAME IS NOT NULL AND TABLESPACE_NAME <> 'SYSAUX'
282 | AND TABLE_NAME NOT LIKE 'AQ$%' AND TABLE_NAME NOT LIKE 'MVIEW$%' AND TABLE_NAME NOT LIKE 'ROLLING$%'
283 | AND TABLE_NAME NOT IN ('HELP', 'SQLPLUS_PRODUCT_PROFILE', 'LOGSTDBY$PARAMETERS', 'LOGMNRGGC_GTCS', 'LOGMNRGGC_GTLO', 'LOGMNR_PARAMETER$', 'LOGMNR_SESSION$', 'SCHEDULER_JOB_ARGS_TBL', 'SCHEDULER_PROGRAM_ARGS_TBL')
284 | `).Scan(&tableList).Error
285 | return
286 | }
287 |
288 | // AddColumn create "name" column for value
289 | func (m Migrator) AddColumn(value interface{}, name string) (err error) {
290 | if err = m.Migrator.AddColumn(value, name); err != nil {
291 | return err
292 | }
293 | // set column comment
294 | err = m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
295 | if field := stmt.Schema.LookUpField(name); field != nil {
296 | if err = m.setCommentForColumn(field, stmt); err != nil {
297 | return
298 | }
299 | }
300 | return
301 | })
302 | return
303 | }
304 |
305 | // DropColumn drop value's "name" column
306 | func (m Migrator) DropColumn(value interface{}, name string) error {
307 | return m.Migrator.DropColumn(value, name)
308 | }
309 |
310 | // AlterColumn alter value's "field" column's type based on schema definition
311 | //
312 | //goland:noinspection SqlNoDataSourceInspection
313 | func (m Migrator) AlterColumn(value interface{}, field string) error {
314 | if !m.HasColumn(value, field) {
315 | return nil
316 | }
317 |
318 | return m.RunWithValue(value, func(stmt *gorm.Statement) error {
319 | if field := stmt.Schema.LookUpField(field); field != nil {
320 | _, tableName := m.getSchemaTable(stmt)
321 | return m.DB.Exec(
322 | "ALTER TABLE ? MODIFY ? ?",
323 | clause.Table{Name: tableName},
324 | clause.Column{Name: field.DBName},
325 | m.AlterDataTypeOf(stmt, field),
326 | ).Error
327 | }
328 | return fmt.Errorf("failed to look up field with name: %s", field)
329 | })
330 | }
331 |
332 | // HasColumn check has column "field" for value or not
333 | func (m Migrator) HasColumn(value interface{}, field string) bool {
334 | var count int64
335 | return m.RunWithValue(value, func(stmt *gorm.Statement) error {
336 | if ownerName, tableName := m.getSchemaTable(stmt); ownerName != "" {
337 | return m.DB.Raw("SELECT COUNT(*) FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownerName, tableName, field).Row().Scan(&count)
338 | } else {
339 | return m.DB.Raw("SELECT COUNT(*) FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", tableName, field).Row().Scan(&count)
340 | }
341 |
342 | }) == nil && count > 0
343 | }
344 |
345 | // MigrateColumn migrate column
346 | func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) (err error) {
347 | if err = m.Migrator.MigrateColumn(value, field, columnType); err != nil {
348 | return
349 | }
350 |
351 | return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
352 | var description string
353 | if ownerName, tableName := m.getSchemaTable(stmt); ownerName != "" {
354 | _ = m.DB.Raw(
355 | "SELECT COMMENTS FROM ALL_COL_COMMENTS WHERE OWNER = ? AND TABLE_NAME = ? AND COLUMN_NAME = ?",
356 | ownerName, tableName, field.DBName,
357 | ).Row().Scan(&description)
358 | } else {
359 | _ = m.DB.Raw(
360 | "SELECT COMMENTS FROM USER_COL_COMMENTS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?",
361 | tableName, field.DBName,
362 | ).Row().Scan(&description)
363 | }
364 | if comment := field.Comment; comment != "" && comment != description {
365 | if err = m.setCommentForColumn(field, stmt); err != nil {
366 | return
367 | }
368 | }
369 | return
370 | })
371 | }
372 |
373 | func (m Migrator) AlterDataTypeOf(stmt *gorm.Statement, field *schema.Field) (expr clause.Expr) {
374 | expr.SQL = m.DataTypeOf(field)
375 |
376 | var nullable = ""
377 | if ownerName, tableName := m.getSchemaTable(stmt); ownerName != "" {
378 | _ = m.DB.Raw("SELECT NULLABLE FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownerName, tableName, field.DBName).Row().Scan(&nullable)
379 | } else {
380 | _ = m.DB.Raw("SELECT NULLABLE FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", tableName, field.DBName).Row().Scan(&nullable)
381 | }
382 |
383 | if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
384 | if field.DefaultValueInterface != nil {
385 | defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
386 | m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
387 | expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
388 | } else if field.DefaultValue != "(-)" {
389 | expr.SQL += " DEFAULT " + field.DefaultValue
390 | }
391 | }
392 |
393 | if field.NotNull && nullable == "Y" {
394 | expr.SQL += " NOT NULL"
395 | }
396 | if field.Unique {
397 | expr.SQL += " UNIQUE"
398 | }
399 | return
400 | }
401 |
402 | // CreateConstraint create constraint
403 | func (m Migrator) CreateConstraint(value interface{}, name string) error {
404 | _ = m.TryRemoveOnUpdate(value)
405 | return m.Migrator.CreateConstraint(value, name)
406 | }
407 |
408 | // DropConstraint drop constraint
409 | //
410 | //goland:noinspection SqlNoDataSourceInspection
411 | func (m Migrator) DropConstraint(value interface{}, name string) error {
412 | return m.RunWithValue(value, func(stmt *gorm.Statement) error {
413 | _, tableName := m.getSchemaTable(stmt)
414 | for _, chk := range stmt.Schema.ParseCheckConstraints() {
415 | if chk.Name == name {
416 | return m.DB.Exec(
417 | "ALTER TABLE ? DROP CHECK ?",
418 | clause.Table{Name: tableName}, clause.Column{Name: name},
419 | ).Error
420 | }
421 | }
422 |
423 | return m.DB.Exec(
424 | "ALTER TABLE ? DROP CONSTRAINT ?",
425 | clause.Table{Name: tableName}, clause.Column{Name: name},
426 | ).Error
427 | })
428 | }
429 |
430 | // HasConstraint check has constraint or not
431 | func (m Migrator) HasConstraint(value interface{}, name string) bool {
432 | var count int64
433 | return m.RunWithValue(value, func(stmt *gorm.Statement) error {
434 | return m.DB.Raw(
435 | "SELECT COUNT(*) FROM USER_CONSTRAINTS WHERE TABLE_NAME = ? AND CONSTRAINT_NAME = ?", stmt.Table, name,
436 | ).Row().Scan(&count)
437 | }) == nil && count > 0
438 | }
439 |
440 | // DropIndex drop index "name"
441 | func (m Migrator) DropIndex(value interface{}, name string) error {
442 | return m.RunWithValue(value, func(stmt *gorm.Statement) error {
443 | if idx := stmt.Schema.LookIndex(name); idx != nil {
444 | name = idx.Name
445 | }
446 | _, tableName := m.getSchemaTable(stmt)
447 |
448 | return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}, clause.Table{Name: tableName}).Error
449 | })
450 | }
451 |
452 | // HasIndex check has index "name" or not
453 | func (m Migrator) HasIndex(value interface{}, name string) bool {
454 | var count int64
455 | _ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
456 | if idx := stmt.Schema.LookIndex(name); idx != nil {
457 | name = idx.Name
458 | }
459 |
460 | return m.DB.Raw(
461 | "SELECT COUNT(*) FROM USER_INDEXES WHERE TABLE_NAME = ? AND INDEX_NAME = ?",
462 | m.Migrator.DB.NamingStrategy.TableName(stmt.Table),
463 | name,
464 | ).Row().Scan(&count)
465 | })
466 |
467 | return count > 0
468 | }
469 |
470 | // RenameIndex rename index from oldName to newName
471 | //
472 | // see also:
473 | // https://docs.oracle.com/database/121/SPATL/alter-index-rename.htm
474 | func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
475 | return m.RunWithValue(value, func(stmt *gorm.Statement) error {
476 | return m.DB.Exec(
477 | "ALTER INDEX ? RENAME TO ?",
478 | clause.Column{Name: oldName}, clause.Column{Name: newName},
479 | ).Error
480 | })
481 | }
482 |
483 | func (m Migrator) TryRemoveOnUpdate(values ...interface{}) error {
484 | for _, value := range values {
485 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
486 | for _, rel := range stmt.Schema.Relationships.Relations {
487 | constraint := rel.ParseConstraint()
488 | if constraint != nil {
489 | rel.Field.TagSettings["CONSTRAINT"] = strings.ReplaceAll(rel.Field.TagSettings["CONSTRAINT"], fmt.Sprintf("ON UPDATE %s", constraint.OnUpdate), "")
490 | }
491 | }
492 | return nil
493 | }); err != nil {
494 | return err
495 | }
496 | }
497 | return nil
498 | }
499 |
500 | func (m Migrator) TryQuotifyReservedWords(values ...interface{}) error {
501 | for _, value := range values {
502 | if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
503 | ignoreCase := !m.Dialector.(Dialector).NamingCaseSensitive
504 | for idx, v := range stmt.Schema.DBNames {
505 | if ignoreCase {
506 | v = strings.ToUpper(v)
507 | }
508 | if IsReservedWord(v) {
509 | v = strconv.Quote(v)
510 | }
511 | stmt.Schema.DBNames[idx] = v
512 | }
513 |
514 | for _, v := range stmt.Schema.Fields {
515 | fieldDBName := v.DBName
516 | if ignoreCase {
517 | v.DBName = strings.ToUpper(v.DBName)
518 | }
519 | if IsReservedWord(v.DBName) {
520 | v.DBName = strconv.Quote(v.DBName)
521 | }
522 | delete(stmt.Schema.FieldsByDBName, fieldDBName)
523 | stmt.Schema.FieldsByDBName[v.DBName] = v
524 | }
525 | return nil
526 | }); err != nil {
527 | return err
528 | }
529 | }
530 | return nil
531 | }
532 |
--------------------------------------------------------------------------------
/migrator_test.go:
--------------------------------------------------------------------------------
1 | package oracle
2 |
3 | import (
4 | "encoding/json"
5 | "reflect"
6 | "testing"
7 | "time"
8 |
9 | "gorm.io/gorm"
10 | )
11 |
12 | func TestMigrator_AutoMigrate(t *testing.T) {
13 | db, err := dbNamingCase, dbErrors[0]
14 | if err != nil {
15 | t.Fatal(err)
16 | }
17 | if db == nil {
18 | t.Log("db is nil!")
19 | return
20 | }
21 |
22 | type args struct {
23 | drop bool
24 | models []interface{}
25 | comments []string
26 | }
27 | tests := []struct {
28 | name string
29 | args args
30 | wantErr bool
31 | }{
32 | {name: "TestTableUser", args: args{models: []interface{}{TestTableUser{}}, comments: []string{"用户信息表"}}},
33 | {name: "TestTableUserDrop", args: args{drop: true, models: []interface{}{TestTableUser{}}, comments: []string{"用户信息表"}}},
34 | {name: "TestTableUserNoComments", args: args{drop: true, models: []interface{}{TestTableUserNoComments{}}, comments: []string{"用户信息表"}}},
35 | {name: "TestTableUserAddColumn", args: args{models: []interface{}{TestTableUserAddColumn{}}, comments: []string{"用户信息表"}}},
36 | {name: "TestTableUserMigrateColumn", args: args{models: []interface{}{TestTableUserMigrateColumn{}}, comments: []string{"用户信息表"}}},
37 | }
38 | for idx, tt := range tests {
39 | t.Run(tt.name, func(t *testing.T) {
40 | if len(tt.args.models) == 0 {
41 | t.Fatal("models is nil")
42 | }
43 | migrator := db.Set("gorm:table_comments", tt.args.comments).Migrator()
44 |
45 | if tt.args.drop {
46 | for _, model := range tt.args.models {
47 | if !migrator.HasTable(model) {
48 | continue
49 | }
50 | if err = migrator.DropTable(model); err != nil {
51 | t.Fatalf("DropTable() error = %v", err)
52 | }
53 | }
54 | }
55 |
56 | if err = migrator.AutoMigrate(tt.args.models...); (err != nil) != tt.wantErr {
57 | t.Errorf("AutoMigrate() error = %v, wantErr %v", err, tt.wantErr)
58 | } else if err == nil {
59 | t.Log("AutoMigrate() success!")
60 | }
61 |
62 | if idx == len(tests)-1 {
63 | wantUser := TestTableUserMigrateColumn{
64 | TestTableUser: TestTableUser{
65 | UID: "U0",
66 | Name: "someone",
67 | Account: "guest",
68 | Password: "MAkOvrJ8JV",
69 | Email: "",
70 | PhoneNumber: "+8618888888888",
71 | Sex: "1",
72 | UserType: 1,
73 | Enabled: true,
74 | Remark: "Ahmad",
75 | },
76 | AddNewColumn: "AddNewColumnValue",
77 | CommentSingleQuote: "CommentSingleQuoteValue",
78 | }
79 |
80 | result := db.Create(&wantUser)
81 | if err = result.Error; err != nil {
82 | t.Fatal(err)
83 | }
84 |
85 | var gotUser TestTableUserMigrateColumn
86 | result.Where(&TestTableUser{UID: "U0"}).Find(&gotUser)
87 | if err = result.Error; err != nil {
88 | t.Fatal(err)
89 | }
90 | gotUserBytes, _ := json.Marshal(gotUser)
91 | t.Logf("gotUser Result: %s", gotUserBytes)
92 | if !reflect.DeepEqual(gotUser, wantUser) {
93 | wantUserBytes, _ := json.Marshal(wantUser)
94 | t.Errorf("wantUser Info: %s", wantUserBytes)
95 | }
96 | }
97 | })
98 | }
99 | }
100 |
101 | // TestTableUser 测试用户信息表模型
102 | type TestTableUser struct {
103 | ID uint64 `gorm:"column:id;size:64;not null;autoIncrement:true;autoIncrementIncrement:1;primaryKey;comment:自增 ID" json:"id"`
104 | UID string `gorm:"column:uid;type:varchar(50);comment:用户身份标识" json:"uid"`
105 | Name string `gorm:"column:name;size:50;comment:用户姓名" json:"name"`
106 |
107 | Account string `gorm:"column:account;type:varchar(50);comment:登录账号" json:"account"`
108 | Password string `gorm:"column:password;type:varchar(512);comment:登录密码(密文)" json:"password"`
109 |
110 | Email string `gorm:"column:email;type:varchar(128);comment:邮箱地址" json:"email"`
111 | PhoneNumber string `gorm:"column:phone_number;type:varchar(15);comment:E.164" json:"phoneNumber"`
112 |
113 | Sex string `gorm:"column:sex;type:char(1);comment:性别" json:"sex"`
114 | Birthday *time.Time `gorm:"column:birthday;->:false;<-:create;comment:生日" json:"birthday,omitempty"`
115 |
116 | UserType int `gorm:"column:user_type;size:8;comment:用户类型" json:"userType"`
117 |
118 | Enabled bool `gorm:"column:enabled;comment:是否可用" json:"enabled"`
119 | Remark string `gorm:"column:remark;size:1024;comment:备注信息" json:"remark"`
120 | }
121 |
122 | func (TestTableUser) TableName() string {
123 | return "test_user"
124 | }
125 |
126 | type TestTableUserNoComments struct {
127 | ID uint64 `gorm:"column:id;size:64;not null;autoIncrement:true;autoIncrementIncrement:1;primaryKey" json:"id"`
128 | UID string `gorm:"column:name;type:varchar(50)" json:"uid"`
129 | Name string `gorm:"column:name;size:50" json:"name"`
130 |
131 | Account string `gorm:"column:account;type:varchar(50)" json:"account"`
132 | Password string `gorm:"column:password;type:varchar(512)" json:"password"`
133 |
134 | Email string `gorm:"column:email;type:varchar(128)" json:"email"`
135 | PhoneNumber string `gorm:"column:phone_number;type:varchar(15)" json:"phoneNumber"`
136 |
137 | Sex string `gorm:"column:sex;type:char(1)" json:"sex"`
138 | Birthday time.Time `gorm:"column:birthday" json:"birthday"`
139 |
140 | UserType int `gorm:"column:user_type;size:8" json:"userType"`
141 |
142 | Enabled bool `gorm:"column:enabled" json:"enabled"`
143 | Remark string `gorm:"column:remark;size:1024" json:"remark"`
144 | }
145 |
146 | func (TestTableUserNoComments) TableName() string {
147 | return "test_user"
148 | }
149 |
150 | type TestTableUserAddColumn struct {
151 | TestTableUser
152 |
153 | AddNewColumn string `gorm:"column:add_new_column;type:varchar(100);comment:添加新字段"`
154 | }
155 |
156 | func (TestTableUserAddColumn) TableName() string {
157 | return "test_user"
158 | }
159 |
160 | type TestTableUserMigrateColumn struct {
161 | TestTableUser
162 |
163 | AddNewColumn string `gorm:"column:add_new_column;type:varchar(100);comment:测试添加新字段"`
164 | CommentSingleQuote string `gorm:"column:comment_single_quote;comment:注释中存在单引号'[']'"`
165 | }
166 |
167 | func (TestTableUserMigrateColumn) TableName() string {
168 | return "test_user"
169 | }
170 |
171 | type testTableColumnTypeModel struct {
172 | ID int64 `gorm:"column:id;size:64;not null;autoIncrement:true;autoIncrementIncrement:1;primaryKey"`
173 | Name string `gorm:"column:name;size:50"`
174 | Age uint8 `gorm:"column:age;size:8"`
175 |
176 | Avatar []byte `gorm:"column:avatar;"`
177 |
178 | Balance float64 `gorm:"column:balance;type:decimal(18, 2)"`
179 | Remark string `gorm:"column:remark;size:-1"`
180 | Enabled bool `gorm:"column:enabled;"`
181 |
182 | CreatedAt time.Time
183 | UpdatedAt time.Time
184 | DeletedAt gorm.DeletedAt
185 | }
186 |
187 | func (t testTableColumnTypeModel) TableName() string {
188 | return "test_table_column_type"
189 | }
190 |
191 | func TestMigrator_TableColumnType(t *testing.T) {
192 | db, err := dbNamingCase, dbErrors[0]
193 | if err != nil {
194 | t.Fatal(err)
195 | }
196 | if db == nil {
197 | t.Log("db is nil!")
198 | return
199 | }
200 | testModel := new(testTableColumnTypeModel)
201 |
202 | type args struct {
203 | model interface{}
204 | drop bool
205 | }
206 | tests := []struct {
207 | name string
208 | args args
209 | }{
210 | {name: "create", args: args{model: testModel}},
211 | {name: "alter", args: args{model: testModel, drop: true}},
212 | }
213 | for _, tt := range tests {
214 | t.Run(tt.name, func(t *testing.T) {
215 | if err = db.AutoMigrate(tt.args.model); err != nil {
216 | t.Errorf("AutoMigrate failed:%v", err)
217 | }
218 | if tt.args.drop {
219 | _ = db.Migrator().DropTable(tt.args.model)
220 | }
221 | })
222 | }
223 | }
224 |
225 | type testFieldNameIsReservedWord struct {
226 | ID int64 `gorm:"column:id;size:64;not null;autoIncrement:true;autoIncrementIncrement:1;primaryKey"`
227 |
228 | FLOAT float64 `gorm:"type:decimal(18, 2)"`
229 | DESC string `gorm:"size:-1"`
230 | ON bool
231 |
232 | Order int
233 | Sort int
234 |
235 | CREATE time.Time
236 | UPDATE time.Time
237 | DELETE gorm.DeletedAt
238 | }
239 |
240 | func (t testFieldNameIsReservedWord) TableName() string {
241 | return "test_name_is_reserved_word"
242 | }
243 |
244 | func TestMigrator_FieldNameIsReservedWord(t *testing.T) {
245 | if err := dbErrors[0]; err != nil {
246 | t.Fatal(err)
247 | }
248 | if dbNamingCase == nil {
249 | t.Log("dbNamingCase is nil!")
250 | return
251 | }
252 | if err := dbErrors[1]; err != nil {
253 | t.Fatal(err)
254 | }
255 | if dbIgnoreCase == nil {
256 | t.Log("dbNamingCase is nil!")
257 | return
258 | }
259 |
260 | testModel := new(testFieldNameIsReservedWord)
261 | _ = dbNamingCase.Migrator().DropTable(testModel)
262 | _ = dbIgnoreCase.Migrator().DropTable(testModel)
263 |
264 | type args struct {
265 | db *gorm.DB
266 | model interface{}
267 | drop bool
268 | }
269 | tests := []struct {
270 | name string
271 | args args
272 | }{
273 | {name: "createNamingCase", args: args{db: dbNamingCase, model: testModel}},
274 | {name: "alterNamingCase", args: args{db: dbNamingCase, model: testModel, drop: true}},
275 | {name: "createIgnoreCase", args: args{db: dbIgnoreCase, model: testModel}},
276 | {name: "alterIgnoreCase", args: args{db: dbIgnoreCase, model: testModel, drop: true}},
277 | }
278 | for _, tt := range tests {
279 | t.Run(tt.name, func(t *testing.T) {
280 | db := tt.args.db
281 | if err := db.AutoMigrate(tt.args.model); err != nil {
282 | t.Errorf("AutoMigrate failed:%v", err)
283 | }
284 | if tt.args.drop {
285 | _ = db.Migrator().DropTable(tt.args.model)
286 | }
287 | })
288 | }
289 | }
290 |
291 | func TestMigrator_DatatypesJsonMapNamingCase(t *testing.T) {
292 | if err := dbErrors[0]; err != nil {
293 | t.Fatal(err)
294 | }
295 | if dbNamingCase == nil {
296 | t.Log("dbNamingCase is nil!")
297 | return
298 | }
299 |
300 | type testJsonMapNamingCase struct {
301 | gorm.Model
302 |
303 | Extras JSONMap `gorm:"check:\"extras\" IS JSON"`
304 | }
305 | testModel := new(testJsonMapNamingCase)
306 | _ = dbNamingCase.Migrator().DropTable(testModel)
307 |
308 | type args struct {
309 | db *gorm.DB
310 | model interface{}
311 | drop bool
312 | }
313 | tests := []struct {
314 | name string
315 | args args
316 | }{
317 | {name: "createDatatypesJsonMapNamingCase", args: args{db: dbNamingCase, model: testModel}},
318 | {name: "alterDatatypesJsonMapNamingCase", args: args{db: dbNamingCase, model: testModel, drop: true}},
319 | }
320 | for _, tt := range tests {
321 | t.Run(tt.name, func(t *testing.T) {
322 | db := tt.args.db
323 | if err := db.AutoMigrate(tt.args.model); err != nil {
324 | t.Errorf("AutoMigrate failed:%v", err)
325 | }
326 | if tt.args.drop {
327 | _ = db.Migrator().DropTable(tt.args.model)
328 | }
329 | })
330 | }
331 | }
332 |
333 | func TestMigrator_DatatypesJsonMapIgnoreCase(t *testing.T) {
334 | if err := dbErrors[1]; err != nil {
335 | t.Fatal(err)
336 | }
337 | if dbIgnoreCase == nil {
338 | t.Log("dbNamingCase is nil!")
339 | return
340 | }
341 |
342 | type tesJsonMapIgnoreCase struct {
343 | gorm.Model
344 |
345 | Extras JSONMap `gorm:"check:extras IS JSON"`
346 | }
347 | testModel := new(tesJsonMapIgnoreCase)
348 | _ = dbIgnoreCase.Migrator().DropTable(testModel)
349 |
350 | type args struct {
351 | db *gorm.DB
352 | model interface{}
353 | drop bool
354 | }
355 | tests := []struct {
356 | name string
357 | args args
358 | }{
359 | {name: "createDatatypesJsonMapIgnoreCase", args: args{db: dbIgnoreCase, model: testModel}},
360 | {name: "alterDatatypesJsonMapIgnoreCase", args: args{db: dbIgnoreCase, model: testModel, drop: true}},
361 | }
362 | for _, tt := range tests {
363 | t.Run(tt.name, func(t *testing.T) {
364 | db := tt.args.db
365 | if err := db.AutoMigrate(tt.args.model); err != nil {
366 | t.Errorf("AutoMigrate failed:%v", err)
367 | }
368 | if tt.args.drop {
369 | _ = db.Migrator().DropTable(tt.args.model)
370 | }
371 | })
372 | }
373 | }
374 |
--------------------------------------------------------------------------------
/namer.go:
--------------------------------------------------------------------------------
1 | package oracle
2 |
3 | import (
4 | "strings"
5 |
6 | "gorm.io/gorm/schema"
7 | )
8 |
9 | // Namer implement gorm schema namer interface
10 | type Namer struct {
11 | // NamingStrategy use custom naming strategy in gorm.Config on initialize
12 | NamingStrategy schema.Namer
13 | // CaseSensitive determines whether naming is case-sensitive
14 | CaseSensitive bool
15 | }
16 |
17 | // Deprecated: As of v1.5.0, use the Namer.ConvertNameToFormat instead.
18 | //
19 | //goland:noinspection GoUnusedExportedFunction
20 | func ConvertNameToFormat(x string) string {
21 | return (Namer{}).ConvertNameToFormat(x)
22 | }
23 |
24 | // ConvertNameToFormat return appropriate capitalization name based on CaseSensitive
25 | func (n Namer) ConvertNameToFormat(x string) string {
26 | if n.CaseSensitive {
27 | return x
28 | }
29 | return strings.ToUpper(x)
30 | }
31 |
32 | // TableName convert string to table name
33 | func (n Namer) TableName(table string) (name string) {
34 | return n.ConvertNameToFormat(n.NamingStrategy.TableName(table))
35 | }
36 |
37 | // SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName
38 | func (n Namer) SchemaName(table string) string {
39 | return n.ConvertNameToFormat(n.NamingStrategy.SchemaName(table))
40 | }
41 |
42 | // ColumnName convert string to column name
43 | func (n Namer) ColumnName(table, column string) (name string) {
44 | return n.ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column))
45 | }
46 |
47 | // JoinTableName convert string to join table name
48 | func (n Namer) JoinTableName(table string) (name string) {
49 | return n.ConvertNameToFormat(n.NamingStrategy.JoinTableName(table))
50 | }
51 |
52 | // RelationshipFKName generate fk name for relation
53 | func (n Namer) RelationshipFKName(relationship schema.Relationship) (name string) {
54 | return n.ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship))
55 | }
56 |
57 | // CheckerName generate checker name
58 | func (n Namer) CheckerName(table, column string) (name string) {
59 | return n.ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column))
60 | }
61 |
62 | // IndexName generate index name
63 | func (n Namer) IndexName(table, column string) (name string) {
64 | return n.ConvertNameToFormat(n.NamingStrategy.IndexName(table, column))
65 | }
66 |
67 | // UniqueName generate unique constraint name
68 | func (n Namer) UniqueName(table, column string) string {
69 | return n.ConvertNameToFormat(n.NamingStrategy.UniqueName(table, column))
70 | }
71 |
--------------------------------------------------------------------------------
/oracle.go:
--------------------------------------------------------------------------------
1 | package oracle
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "fmt"
7 | "reflect"
8 | "regexp"
9 | "strconv"
10 | "strings"
11 | "time"
12 |
13 | "github.com/sijms/go-ora/v2"
14 | "gorm.io/gorm"
15 | "gorm.io/gorm/callbacks"
16 | "gorm.io/gorm/clause"
17 | "gorm.io/gorm/logger"
18 | "gorm.io/gorm/migrator"
19 | "gorm.io/gorm/schema"
20 | )
21 |
22 | type Config struct {
23 | DriverName string
24 | DSN string
25 | Conn gorm.ConnPool //*sql.DB
26 | DefaultStringSize uint
27 | DBVer string
28 |
29 | IgnoreCase bool // warning: may cause performance issues
30 | NamingCaseSensitive bool // whether naming is case-sensitive
31 | // whether VARCHAR type size is character length, defaulting to byte length
32 | VarcharSizeIsCharLength bool
33 |
34 | // RowNumberAliasForOracle11 is the alias for ROW_NUMBER() in Oracle 11g, defaulting to ROW_NUM
35 | RowNumberAliasForOracle11 string
36 | }
37 |
38 | // Dialector implement GORM database dialector
39 | type Dialector struct {
40 | *Config
41 | }
42 |
43 | //goland:noinspection GoUnusedExportedFunction
44 | func Open(dsn string) gorm.Dialector {
45 | return &Dialector{Config: &Config{DSN: dsn}}
46 | }
47 |
48 | //goland:noinspection GoUnusedExportedFunction
49 | func New(config Config) gorm.Dialector {
50 | return &Dialector{Config: &config}
51 | }
52 |
53 | // BuildUrl create databaseURL from server, port, service, user, password, urlOptions
54 | // this function help build a will formed databaseURL and accept any character as it
55 | // convert special charters to corresponding values in URL
56 | //
57 | //goland:noinspection GoUnusedExportedFunction
58 | func BuildUrl(server string, port int, service, user, password string, options map[string]string) string {
59 | return go_ora.BuildUrl(server, port, service, user, password, options)
60 | }
61 |
62 | // GetStringExpr replace single quotes in the string with two single quotes
63 | // and return the expression for the string value
64 | //
65 | // quotes : if the SQL placeholder is ? then pass true, if it is '?' then do not pass or pass false.
66 | func GetStringExpr(value string, quotes ...bool) clause.Expr {
67 | if len(quotes) > 0 && quotes[0] {
68 | if strings.Contains(value, "'") {
69 | // escape single quotes
70 | if !strings.Contains(value, "]'") {
71 | value = fmt.Sprintf("q'[%s]'", value)
72 | } else if !strings.Contains(value, "}'") {
73 | value = fmt.Sprintf("q'{%s}'", value)
74 | } else if !strings.Contains(value, ">'") {
75 | value = fmt.Sprintf("q'<%s>'", value)
76 | } else if !strings.Contains(value, ")'") {
77 | value = fmt.Sprintf("q'(%s)'", value)
78 | } else {
79 | value = fmt.Sprintf("'%s'", strings.ReplaceAll(value, "'", "''"))
80 | }
81 | } else {
82 | value = fmt.Sprintf("'%s'", value)
83 | }
84 | } else {
85 | value = strings.ReplaceAll(value, "'", "''")
86 | }
87 | return gorm.Expr(value)
88 | }
89 |
90 | // AddSessionParams setting database connection session parameters,
91 | // the value is wrapped in single quotes.
92 | //
93 | // If the value doesn't need to be wrapped in single quotes,
94 | // please use the go_ora.AddSessionParam function directly,
95 | // or pass the originals parameter as true.
96 | func AddSessionParams(db *sql.DB, params map[string]string, originals ...bool) (keys []string, err error) {
97 | if db == nil {
98 | return
99 | }
100 | if _, ok := db.Driver().(*go_ora.OracleDriver); !ok {
101 | return
102 | }
103 | var original bool
104 | if len(originals) > 0 {
105 | original = originals[0]
106 | }
107 |
108 | for key, value := range params {
109 | if key == "" || value == "" {
110 | continue
111 | }
112 | if !original {
113 | value = GetStringExpr(value, true).SQL
114 | }
115 | if err = go_ora.AddSessionParam(db, key, value); err != nil {
116 | return
117 | }
118 | keys = append(keys, key)
119 | }
120 | return
121 | }
122 |
123 | // DelSessionParams remove session parameters
124 | func DelSessionParams(db *sql.DB, keys []string) {
125 | if db == nil {
126 | return
127 | }
128 | if _, ok := db.Driver().(*go_ora.OracleDriver); !ok {
129 | return
130 | }
131 |
132 | for _, key := range keys {
133 | if key == "" {
134 | continue
135 | }
136 | go_ora.DelSessionParam(db, key)
137 | }
138 | }
139 |
140 | func convertCustomType(val interface{}) interface{} {
141 | rv := reflect.ValueOf(val)
142 | if !rv.IsValid() || rv.IsZero() {
143 | return val
144 | }
145 | ri := rv.Interface()
146 | typeName := reflect.TypeOf(ri).Name()
147 | if reflect.TypeOf(val).Kind() == reflect.Ptr {
148 | if rv.IsNil() {
149 | typeName = rv.Type().Elem().Name()
150 | } else {
151 | for rv.Kind() == reflect.Ptr {
152 | rv = rv.Elem()
153 | }
154 | ri = rv.Interface()
155 | typeName = reflect.TypeOf(ri).Name()
156 | }
157 | }
158 | if typeName == "DeletedAt" {
159 | // gorm.DeletedAt
160 | if rv.IsZero() {
161 | val = sql.NullTime{}
162 | } else {
163 | val = getTimeValue(ri.(gorm.DeletedAt).Time)
164 | }
165 | } else if m := rv.MethodByName("Time"); m.IsValid() && m.Type().NumIn() == 0 {
166 | // custom time type
167 | for _, result := range m.Call([]reflect.Value{}) {
168 | if reflect.TypeOf(result.Interface()).Name() == "Time" {
169 | val = getTimeValue(result.Interface().(time.Time))
170 | }
171 | }
172 | }
173 | return val
174 | }
175 |
176 | func ptrDereference(obj interface{}) (value interface{}) {
177 | if obj == nil {
178 | return obj
179 | }
180 | if t := reflect.TypeOf(obj); t.Kind() != reflect.Ptr {
181 | return obj
182 | }
183 |
184 | v := reflect.ValueOf(obj)
185 | for v.Kind() == reflect.Ptr && !v.IsNil() {
186 | v = v.Elem()
187 | }
188 | if !v.IsValid() || v.Kind() == reflect.Ptr && v.IsNil() {
189 | return obj
190 | }
191 | value = v.Interface()
192 | return
193 | }
194 |
195 | func getTimeValue(t time.Time) interface{} {
196 | if t.IsZero() {
197 | return sql.NullTime{}
198 | }
199 | return t
200 | }
201 |
202 | func (d Dialector) DummyTableName() string {
203 | return "DUAL"
204 | }
205 |
206 | func (d Dialector) Name() string {
207 | return "oracle"
208 | }
209 |
210 | func (d Dialector) Initialize(db *gorm.DB) (err error) {
211 | db.NamingStrategy = Namer{
212 | NamingStrategy: db.NamingStrategy,
213 | CaseSensitive: d.NamingCaseSensitive,
214 | }
215 | d.DefaultStringSize = 1024
216 |
217 | // register callbacks
218 | config := &callbacks.Config{
219 | CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
220 | UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
221 | DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
222 | }
223 | callbacks.RegisterDefaultCallbacks(db, config)
224 |
225 | d.DriverName = "oracle"
226 |
227 | if d.Conn != nil {
228 | db.ConnPool = d.Conn
229 | } else {
230 | db.ConnPool, err = sql.Open(d.DriverName, d.DSN)
231 | if err != nil {
232 | return
233 | }
234 | }
235 | if d.IgnoreCase {
236 | if sqlDB, ok := db.ConnPool.(*sql.DB); ok {
237 | // warning: may cause performance issues
238 | _ = go_ora.AddSessionParam(sqlDB, "NLS_COMP", "LINGUISTIC")
239 | _ = go_ora.AddSessionParam(sqlDB, "NLS_SORT", "BINARY_CI")
240 | }
241 | }
242 | err = db.ConnPool.QueryRowContext(context.Background(), "select version from product_component_version where rownum = 1").Scan(&d.DBVer)
243 | if err != nil {
244 | return err
245 | }
246 | //log.Println("DBVer:" + d.DBVer)
247 | if err = db.Callback().Create().Replace("gorm:create", Create); err != nil {
248 | return
249 | }
250 | if err = db.Callback().Update().Replace("gorm:update", Update(config)); err != nil {
251 | return
252 | }
253 |
254 | for k, v := range d.ClauseBuilders() {
255 | db.ClauseBuilders[k] = v
256 | }
257 | return
258 | }
259 |
260 | func (d Dialector) ClauseBuilders() (clauseBuilders map[string]clause.ClauseBuilder) {
261 | clauseBuilders = make(map[string]clause.ClauseBuilder)
262 |
263 | if dbVer, _ := strconv.Atoi(strings.Split(d.DBVer, ".")[0]); dbVer > 11 {
264 | clauseBuilders["LIMIT"] = d.RewriteLimit
265 | } else {
266 | clauseBuilders["LIMIT"] = d.RewriteLimit11
267 | }
268 |
269 | clauseBuilders["RETURNING"] = func(c clause.Clause, builder clause.Builder) {
270 | if returning, ok := c.Expression.(clause.Returning); ok {
271 | _, _ = builder.WriteString("/*- -*/")
272 | _, _ = builder.WriteString("RETURNING ")
273 |
274 | if len(returning.Columns) > 0 {
275 | for idx, column := range returning.Columns {
276 | if idx > 0 {
277 | _ = builder.WriteByte(',')
278 | }
279 |
280 | builder.WriteQuoted(column)
281 | }
282 | } else {
283 | _ = builder.WriteByte('*')
284 | }
285 | }
286 | }
287 | return
288 | }
289 |
290 | func (d Dialector) getLimitRows(limit clause.Limit) (limitRows int, hasLimit bool) {
291 | if l := limit.Limit; l != nil {
292 | limitRows = *l
293 | hasLimit = limitRows > 0
294 | }
295 | return
296 | }
297 |
298 | func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) {
299 | if limit, ok := c.Expression.(clause.Limit); ok {
300 | limitRows, hasLimit := d.getLimitRows(limit)
301 |
302 | if stmt, ok := builder.(*gorm.Statement); ok {
303 | if _, hasOrderBy := stmt.Clauses["ORDER BY"]; !hasOrderBy && hasLimit {
304 | s := stmt.Schema
305 | _, _ = builder.WriteString("ORDER BY ")
306 | if s != nil && s.PrioritizedPrimaryField != nil {
307 | builder.WriteQuoted(s.PrioritizedPrimaryField.DBName)
308 | _ = builder.WriteByte(' ')
309 | } else {
310 | _, _ = builder.WriteString("(SELECT NULL FROM ")
311 | _, _ = builder.WriteString(d.DummyTableName())
312 | _, _ = builder.WriteString(")")
313 | }
314 | }
315 | }
316 |
317 | if offset := limit.Offset; offset > 0 {
318 | _, _ = builder.WriteString(" OFFSET ")
319 | builder.AddVar(builder, offset)
320 | _, _ = builder.WriteString(" ROWS")
321 | }
322 | if hasLimit {
323 | _, _ = builder.WriteString(" FETCH NEXT ")
324 | builder.AddVar(builder, limitRows)
325 | _, _ = builder.WriteString(" ROWS ONLY")
326 | }
327 | }
328 | }
329 |
330 | // RewriteLimit11 rewrite the LIMIT clause in the query to accommodate pagination requirements for Oracle 11g and lower database versions
331 | //
332 | // # Limit and Offset
333 | //
334 | // SELECT * FROM (SELECT T.*, ROW_NUMBER() OVER (ORDER BY column) AS ROW_NUM FROM table_name T)
335 | // WHERE ROW_NUM BETWEEN offset+1 AND offset+limit
336 | //
337 | // # Only Limit
338 | //
339 | // SELECT * FROM table_name WHERE ROWNUM <= limit ORDER BY column
340 | //
341 | // # Only Offset
342 | //
343 | // SELECT * FROM table_name WHERE ROWNUM > offset ORDER BY column
344 | func (d Dialector) RewriteLimit11(c clause.Clause, builder clause.Builder) {
345 | limit, ok := c.Expression.(clause.Limit)
346 | if !ok {
347 | return
348 | }
349 | offsetRows := limit.Offset
350 | hasOffset := offsetRows > 0
351 | limitRows, hasLimit := d.getLimitRows(limit)
352 | if !hasOffset && !hasLimit {
353 | return
354 | }
355 |
356 | var stmt *gorm.Statement
357 | if stmt, ok = builder.(*gorm.Statement); !ok {
358 | return
359 | }
360 |
361 | if hasLimit && hasOffset {
362 | // 使用 ROW_NUMBER() 和子查询实现分页查询
363 | if d.RowNumberAliasForOracle11 == "" {
364 | d.RowNumberAliasForOracle11 = "ROW_NUM"
365 | }
366 | subQuerySQL := fmt.Sprintf(
367 | "SELECT * FROM (SELECT T.*, ROW_NUMBER() OVER (ORDER BY %s) AS %s FROM (%s) T) WHERE %s BETWEEN %d AND %d",
368 | d.getOrderByColumns(stmt),
369 | d.RowNumberAliasForOracle11,
370 | strings.TrimSpace(stmt.SQL.String()),
371 | d.RowNumberAliasForOracle11,
372 | offsetRows+1,
373 | offsetRows+limitRows,
374 | )
375 | stmt.SQL.Reset()
376 | stmt.SQL.WriteString(subQuerySQL)
377 | } else if hasLimit {
378 | // 只有 Limit 的情况
379 | d.rewriteRownumStmt(stmt, builder, " <= ", limitRows)
380 | } else {
381 | // 只有 Offset 的情况
382 | d.rewriteRownumStmt(stmt, builder, " > ", offsetRows)
383 | }
384 | }
385 |
386 | func (d Dialector) rewriteRownumStmt(stmt *gorm.Statement, builder clause.Builder, operator string, rows int) {
387 | limitSql := strings.Builder{}
388 | if _, ok := stmt.Clauses["WHERE"]; !ok {
389 | limitSql.WriteString(" WHERE ")
390 | } else {
391 | limitSql.WriteString(" AND ")
392 | }
393 | limitSql.WriteString("ROWNUM")
394 | limitSql.WriteString(operator)
395 | limitSql.WriteString(strconv.Itoa(rows))
396 |
397 | if _, hasOrderBy := stmt.Clauses["ORDER BY"]; !hasOrderBy {
398 | _, _ = builder.WriteString(limitSql.String())
399 | } else {
400 | // "ORDER BY" before insert
401 | sqlTmp := strings.Builder{}
402 | sqlOld := stmt.SQL.String()
403 | orderIndex := strings.Index(sqlOld, "ORDER BY") - 1
404 | sqlTmp.WriteString(sqlOld[:orderIndex])
405 | sqlTmp.WriteString(limitSql.String())
406 | sqlTmp.WriteString(sqlOld[orderIndex:])
407 | stmt.SQL = sqlTmp
408 | }
409 | }
410 |
411 | func (d Dialector) getOrderByColumns(stmt *gorm.Statement) string {
412 | if orderByClause, ok := stmt.Clauses["ORDER BY"]; ok {
413 | var orderBy clause.OrderBy
414 | if orderBy, ok = orderByClause.Expression.(clause.OrderBy); ok && len(orderBy.Columns) > 0 {
415 | orderByBuilder := strings.Builder{}
416 | for i, column := range orderBy.Columns {
417 | if i > 0 {
418 | orderByBuilder.WriteString(", ")
419 | }
420 | orderByBuilder.WriteString(column.Column.Name)
421 | if column.Desc {
422 | orderByBuilder.WriteString(" DESC")
423 | }
424 | }
425 | return orderByBuilder.String()
426 | }
427 | }
428 | return "NULL"
429 | }
430 |
431 | func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression {
432 | return clause.Expr{SQL: "VALUES (DEFAULT)"}
433 | }
434 |
435 | func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator {
436 | return Migrator{
437 | Migrator: migrator.Migrator{
438 | Config: migrator.Config{
439 | DB: db,
440 | Dialector: d,
441 | CreateIndexAfterCreateTable: true,
442 | },
443 | },
444 | }
445 | }
446 |
447 | func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, _ interface{}) {
448 | _, _ = writer.WriteString(":")
449 | _, _ = writer.WriteString(strconv.Itoa(len(stmt.Vars)))
450 | }
451 |
452 | func (d Dialector) QuoteTo(writer clause.Writer, str string) {
453 | if d.NamingCaseSensitive && str != "" {
454 | var (
455 | underQuoted, selfQuoted bool
456 | continuousBacktick int8
457 | shiftDelimiter int8
458 | )
459 |
460 | for _, v := range []byte(str) {
461 | switch v {
462 | case '"':
463 | continuousBacktick++
464 | if continuousBacktick == 2 {
465 | _, _ = writer.WriteString(`""`)
466 | continuousBacktick = 0
467 | }
468 | case '.':
469 | if continuousBacktick > 0 || !selfQuoted {
470 | shiftDelimiter = 0
471 | underQuoted = false
472 | continuousBacktick = 0
473 | _ = writer.WriteByte('"')
474 | }
475 | _ = writer.WriteByte(v)
476 | continue
477 | default:
478 | if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
479 | _ = writer.WriteByte('"')
480 | underQuoted = true
481 | if selfQuoted = continuousBacktick > 0; selfQuoted {
482 | continuousBacktick -= 1
483 | }
484 | }
485 |
486 | for ; continuousBacktick > 0; continuousBacktick -= 1 {
487 | _, _ = writer.WriteString(`""`)
488 | }
489 |
490 | _ = writer.WriteByte(v)
491 | }
492 | shiftDelimiter++
493 | }
494 |
495 | if continuousBacktick > 0 && !selfQuoted {
496 | _, _ = writer.WriteString(`""`)
497 | }
498 | _ = writer.WriteByte('"')
499 | } else {
500 | _, _ = writer.WriteString(str)
501 | }
502 | }
503 |
504 | var numericPlaceholder = regexp.MustCompile(`:(\d+)`)
505 |
506 | func (d Dialector) Explain(sql string, vars ...interface{}) string {
507 | for idx, val := range vars {
508 | switch v := ptrDereference(val).(type) {
509 | case bool:
510 | if v {
511 | vars[idx] = 1
512 | } else {
513 | vars[idx] = 0
514 | }
515 | case go_ora.Clob:
516 | vars[idx] = v.String
517 | }
518 | }
519 | return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
520 | }
521 |
522 | func (d Dialector) DataTypeOf(field *schema.Field) string {
523 | delete(field.TagSettings, "RESTRICT")
524 |
525 | var sqlType string
526 | switch field.DataType {
527 | case schema.Bool:
528 | sqlType = "NUMBER(1)"
529 | case schema.Int, schema.Uint:
530 | sqlType = "INTEGER"
531 | if field.Size > 0 && field.Size <= 8 {
532 | sqlType = "SMALLINT"
533 | }
534 |
535 | if field.AutoIncrement {
536 | sqlType += " GENERATED BY DEFAULT AS IDENTITY"
537 | }
538 | case schema.Float:
539 | sqlType = "FLOAT"
540 | case schema.String, "VARCHAR2":
541 | size := field.Size
542 | defaultSize := d.DefaultStringSize
543 |
544 | if size == 0 {
545 | if defaultSize > 0 {
546 | size = int(defaultSize)
547 | } else {
548 | hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != ""
549 | // TEXT, GEOMETRY or JSON column can't have a default value
550 | if field.PrimaryKey || field.HasDefaultValue || hasIndex {
551 | size = 191 // utf8mb4
552 | }
553 | }
554 | }
555 |
556 | if size > 0 && size <= 4000 {
557 | // 默认情况下 VARCHAR2 可以指定一个不超过 4000 的正整数作为字节长度
558 | if d.VarcharSizeIsCharLength {
559 | if size*3 > 4000 {
560 | sqlType = "CLOB"
561 | } else {
562 | sqlType = fmt.Sprintf("VARCHAR2(%d CHAR)", size) // 字符长度(size * 3)
563 | }
564 | } else {
565 | sqlType = fmt.Sprintf("VARCHAR2(%d)", size)
566 | }
567 | } else {
568 | sqlType = "CLOB"
569 | }
570 | case schema.Time:
571 | sqlType = "TIMESTAMP WITH TIME ZONE"
572 | case schema.Bytes:
573 | sqlType = "BLOB"
574 | default:
575 | sqlType = string(field.DataType)
576 |
577 | if strings.EqualFold(sqlType, "text") {
578 | sqlType = "CLOB"
579 | }
580 |
581 | if sqlType == "" {
582 | panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", field.FieldType.Name(), field.FieldType.String()))
583 | }
584 | }
585 |
586 | return sqlType
587 | }
588 |
589 | func (d Dialector) SavePoint(tx *gorm.DB, name string) error {
590 | tx.Exec("SAVEPOINT " + name)
591 | return tx.Error
592 | }
593 |
594 | func (d Dialector) RollbackTo(tx *gorm.DB, name string) error {
595 | tx.Exec("ROLLBACK TO SAVEPOINT " + name)
596 | return tx.Error
597 | }
598 |
--------------------------------------------------------------------------------
/oracle_ora.go:
--------------------------------------------------------------------------------
1 | package oracle
2 |
3 | import "github.com/sijms/go-ora/v2"
4 |
5 | type (
6 | RefCursor struct {
7 | go_ora.RefCursor
8 | }
9 |
10 | DataSet struct {
11 | go_ora.DataSet
12 | }
13 |
14 | Out struct {
15 | go_ora.Out
16 | }
17 | )
18 |
19 | func (cursor *RefCursor) Query() (dataset *DataSet, err error) {
20 | var d *go_ora.DataSet
21 | if d, err = cursor.RefCursor.Query(); err != nil {
22 | return
23 | }
24 | dataset = &DataSet{DataSet: *d}
25 | return
26 | }
27 |
--------------------------------------------------------------------------------
/oracle_ora_test.go:
--------------------------------------------------------------------------------
1 | package oracle
2 |
3 | import (
4 | "database/sql"
5 | "database/sql/driver"
6 | "encoding/json"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "log"
11 | "testing"
12 | )
13 |
14 | const (
15 | procCreateExamplePagingQuery = `-- example procedure
16 | create or replace PROCEDURE PRO_EXAMPLE_PAGING_QUERY (
17 | BASIC_SQL IN VARCHAR2, -- 基础查询 SQL
18 | ORDER_FIELD IN VARCHAR2, -- 排序字段
19 | PAGE_NUM IN NUMBER, -- 当前页码
20 | PAGE_SIZE IN NUMBER, -- 每页条数
21 |
22 | TOTAL_NUM OUT NUMBER, -- 返回总条数
23 | RES_CURSOR OUT SYS_REFCURSOR -- 返回结果集
24 | )
25 | AS
26 | BEGIN
27 | DECLARE
28 | PAGING_SQL VARCHAR2(4000) := ''; -- 分页查询 SQL
29 | TOTAL_SQL VARCHAR2(4000) := ''; -- 总条数查询 SQL
30 | OFFSET NUMBER(10); -- 分页查询偏移量
31 | BEGIN
32 | -- 查询总条数
33 | TOTAL_SQL := 'SELECT TO_NUMBER(COUNT(*)) FROM (' || BASIC_SQL || ') TB';
34 | EXECUTE IMMEDIATE TOTAL_SQL INTO TOTAL_NUM;
35 |
36 | -- 分页查询
37 | OFFSET := (PAGE_NUM - 1) * PAGE_SIZE;
38 | PAGING_SQL := 'SELECT * FROM (SELECT T.*, ROW_NUMBER() OVER (ORDER BY ' || ORDER_FIELD ||
39 | ') AS ROW_NUM FROM (' || BASIC_SQL || ') T) WHERE ROW_NUM BETWEEN ' ||
40 | TO_CHAR(OFFSET+1) || ' AND ' || TO_CHAR(OFFSET+PAGE_SIZE);
41 |
42 | OPEN RES_CURSOR FOR PAGING_SQL;
43 | END;
44 | END PRO_EXAMPLE_PAGING_QUERY;`
45 | )
46 |
47 | func ExampleRefCursor_Query() {
48 | db, err := dbNamingCase, dbErrors[0]
49 | if err != nil || db == nil {
50 | log.Fatal(err)
51 | }
52 | if err = db.Exec(procCreateExamplePagingQuery).Error; err != nil {
53 | log.Fatal(err)
54 | }
55 | var (
56 | totalNum uint
57 | resCursor RefCursor
58 |
59 | values = []any{
60 | "SELECT * FROM USER_TABLES",
61 | "TABLE_NAME",
62 | 1,
63 | 10,
64 | sql.Out{Dest: &totalNum},
65 | sql.Out{Dest: &resCursor.RefCursor},
66 | }
67 | )
68 | // 执行存储过程
69 | if err = db.Exec(`
70 | BEGIN
71 | PRO_EXAMPLE_PAGING_QUERY(:BASIC_SQL, :ORDER_FIELD, :PAGE_NUM, :PAGE_SIZE, :TOTAL_NUM, :RES_CURSOR);
72 | END;`, values...).Error; err != nil {
73 | log.Fatal(err)
74 | }
75 | defer func(cursor *RefCursor) {
76 | _ = cursor.Close()
77 | }(&resCursor)
78 |
79 | // 读取游标
80 | var dataset *DataSet
81 | if dataset, err = resCursor.Query(); err != nil {
82 | log.Fatal(err)
83 | }
84 | defer func(dataset *DataSet) {
85 | _ = dataset.Close()
86 | }(dataset)
87 |
88 | var dataRows []map[string]any
89 | columns := dataset.Columns()
90 | dest := make([]driver.Value, len(columns))
91 | for {
92 | if err = dataset.Next(dest); err != nil {
93 | if errors.Is(err, io.EOF) {
94 | err = nil
95 | }
96 | break
97 | }
98 | dataRow := make(map[string]any, len(columns))
99 | for i, v := range dest {
100 | dataRow[columns[i]] = v
101 | }
102 | dataRows = append(dataRows, dataRow)
103 | }
104 | if err != nil {
105 | log.Fatal(err)
106 | }
107 | fmt.Println(len(dataRows) > 0)
108 | //Output: true
109 | }
110 |
111 | func TestExecProcedure(t *testing.T) {
112 | db, err := dbNamingCase, dbErrors[0]
113 | if err != nil {
114 | t.Fatal(err)
115 | }
116 | if db == nil {
117 | t.Log("db is nil!")
118 | return
119 | }
120 | if err = db.Exec(procCreateExamplePagingQuery).Error; err != nil {
121 | t.Fatal(err)
122 | }
123 |
124 | var (
125 | totalNum uint
126 | resCursor RefCursor
127 |
128 | values = []any{
129 | "SELECT * FROM USER_TABLES", // sql.Named("BASIC_SQL", "SELECT * FROM USER_TABLES"),
130 | "TABLE_NAME", // sql.Named("ORDER_FIELD", "TABLE_NAME"),
131 | 1, // sql.Named("PAGE_NUM", 1),
132 | 10, // sql.Named("PAGE_SIZE", 10),
133 | sql.Out{Dest: &totalNum}, // sql.Named("TOTAL_NUM", sql.Out{Dest: &totalNum}),
134 | sql.Out{Dest: &resCursor.RefCursor}, // sql.Named("RES_CURSOR", sql.Out{Dest: &resCursor.RefCursor}),
135 | }
136 | )
137 | // 执行存储过程
138 | if err = db.Exec(`
139 | BEGIN
140 | PRO_EXAMPLE_PAGING_QUERY(:BASIC_SQL, :ORDER_FIELD, :PAGE_NUM, :PAGE_SIZE, :TOTAL_NUM, :RES_CURSOR);
141 | END;`, values...).Error; err != nil {
142 | t.Fatal(err)
143 | }
144 | defer func(cursor *RefCursor) {
145 | _ = cursor.Close()
146 | }(&resCursor)
147 |
148 | // 读取游标
149 | var dataset *DataSet
150 | if dataset, err = resCursor.Query(); err != nil {
151 | t.Fatal(err)
152 | }
153 | defer func(dataset *DataSet) {
154 | _ = dataset.Close()
155 | }(dataset)
156 |
157 | var dataRows []map[string]any
158 | columns := dataset.Columns()
159 | dest := make([]driver.Value, len(columns))
160 | for {
161 | if err = dataset.Next(dest); err != nil {
162 | if errors.Is(err, io.EOF) {
163 | err = nil
164 | }
165 | break
166 | }
167 | dataRow := make(map[string]any, len(columns))
168 | for i, v := range dest {
169 | dataRow[columns[i]] = v
170 | }
171 | dataRows = append(dataRows, dataRow)
172 | }
173 | if err != nil {
174 | t.Fatal(err)
175 | }
176 | got, _ := json.Marshal(dataRows)
177 | t.Logf("got total: %d, got size: %d, got data:\n%s", totalNum, len(dataRows), got)
178 | }
179 |
--------------------------------------------------------------------------------
/oracle_test.go:
--------------------------------------------------------------------------------
1 | package oracle
2 |
3 | import (
4 | "database/sql"
5 | "encoding/json"
6 | "log"
7 | "os"
8 | "reflect"
9 | "strconv"
10 | "strings"
11 | "testing"
12 | "time"
13 |
14 | "gorm.io/gorm"
15 | "gorm.io/gorm/logger"
16 | "gorm.io/gorm/schema"
17 | )
18 |
19 | var (
20 | dbNamingCase *gorm.DB
21 | dbIgnoreCase *gorm.DB
22 |
23 | dbErrors = make([]error, 2)
24 | )
25 |
26 | func init() {
27 | if wait := os.Getenv("GORM_ORA_WAIT_MIN"); wait != "" {
28 | if min, e := strconv.Atoi(wait); e == nil {
29 | log.Println("wait for oracle database initialization to complete...")
30 | time.Sleep(time.Duration(min) * time.Minute)
31 | }
32 | }
33 | var err error
34 | if dbNamingCase, err = openTestConnection(true, true); err != nil {
35 | dbErrors[0] = err
36 | }
37 | if dbIgnoreCase, err = openTestConnection(true, false); err != nil {
38 | dbErrors[1] = err
39 | }
40 | }
41 |
42 | func openTestConnection(ignoreCase, namingCase bool) (db *gorm.DB, err error) {
43 | dsn := getTestDSN()
44 |
45 | db, err = gorm.Open(New(Config{
46 | DSN: dsn,
47 | IgnoreCase: ignoreCase,
48 | NamingCaseSensitive: namingCase,
49 | }), getTestGormConfig())
50 | if db != nil && err == nil {
51 | log.Println("open oracle database connection success!")
52 | }
53 | return
54 | }
55 |
56 | func getTestDSN() (dsn string) {
57 | dsn = os.Getenv("GORM_ORA_DSN")
58 | if dsn == "" {
59 | server := os.Getenv("GORM_ORA_SERVER")
60 | port, _ := strconv.Atoi(os.Getenv("GORM_ORA_PORT"))
61 | if server == "" || port < 1 {
62 | return
63 | }
64 |
65 | language := os.Getenv("GORM_ORA_LANG")
66 | if language == "" {
67 | language = "SIMPLIFIED CHINESE"
68 | }
69 | territory := os.Getenv("GORM_ORA_TERRITORY")
70 | if territory == "" {
71 | territory = "CHINA"
72 | }
73 |
74 | dsn = BuildUrl(server, port,
75 | os.Getenv("GORM_ORA_SID"),
76 | os.Getenv("GORM_ORA_USER"),
77 | os.Getenv("GORM_ORA_PASS"),
78 | map[string]string{
79 | "CONNECTION TIMEOUT": "90",
80 | "LANGUAGE": language,
81 | "TERRITORY": territory,
82 | "SSL": "false",
83 | })
84 | }
85 | return
86 | }
87 |
88 | func getTestGormConfig() *gorm.Config {
89 | logWriter := new(log.Logger)
90 | logWriter.SetOutput(os.Stdout)
91 |
92 | return &gorm.Config{
93 | Logger: logger.New(
94 | logWriter,
95 | logger.Config{LogLevel: logger.Info},
96 | ),
97 | DisableForeignKeyConstraintWhenMigrating: false,
98 | IgnoreRelationshipsWhenMigrating: false,
99 | NamingStrategy: schema.NamingStrategy{
100 | IdentifierMaxLength: 30,
101 | },
102 | }
103 | }
104 |
105 | func TestCountLimit0(t *testing.T) {
106 | db, err := dbNamingCase, dbErrors[0]
107 | if err != nil {
108 | t.Fatal(err)
109 | }
110 | if db == nil {
111 | t.Log("db is nil!")
112 | return
113 | }
114 |
115 | model := TestTableUser{}
116 | migrator := db.Set("gorm:table_comments", "用户信息表").Migrator()
117 | if migrator.HasTable(model) {
118 | if err = migrator.DropTable(model); err != nil {
119 | t.Fatalf("DropTable() error = %v", err)
120 | }
121 | }
122 | if err = migrator.AutoMigrate(model); err != nil {
123 | t.Fatalf("AutoMigrate() error = %v", err)
124 | }
125 | t.Log("AutoMigrate() success!")
126 |
127 | var count int64
128 | result := db.Model(&model).Limit(-1).Count(&count)
129 | if err = result.Error; err != nil {
130 | t.Fatal(err)
131 | }
132 | t.Logf("Limit(-1) count = %d", count)
133 |
134 | if countSql := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
135 | return tx.Model(&model).Limit(-1).Count(&count)
136 | }); strings.Contains(countSql, "ORDER BY") {
137 | t.Error(`The "count(*)" statement contains the "ORDER BY" clause!`)
138 | }
139 | }
140 |
141 | func TestLimit(t *testing.T) {
142 | db, err := dbNamingCase, dbErrors[0]
143 | if err != nil {
144 | t.Fatal(err)
145 | }
146 | if db == nil {
147 | t.Log("db is nil!")
148 | return
149 | }
150 | TestMergeCreate(t)
151 |
152 | type args struct {
153 | offset, limit int
154 | order string
155 | }
156 | tests := []struct {
157 | name string
158 | args args
159 | }{
160 | {name: "OffsetLimit0", args: args{offset: 0, limit: 0}},
161 | {name: "Offset10", args: args{offset: 10, limit: 0}},
162 | {name: "Limit10", args: args{offset: 0, limit: 10}},
163 | {name: "Offset10Limit10", args: args{offset: 10, limit: 10}},
164 | {name: "Offset10Limit10Order", args: args{offset: 10, limit: 10, order: `"id"`}},
165 | {name: "Offset10Limit10OrderDESC", args: args{offset: 10, limit: 10, order: `"id" DESC`}},
166 | }
167 |
168 | for _, tt := range tests {
169 | t.Run(tt.name, func(t *testing.T) {
170 | var data []TestTableUser
171 | result := db.Model(&TestTableUser{}).
172 | Offset(tt.args.offset).
173 | Limit(tt.args.limit).
174 | Order(tt.args.order).
175 | Find(&data)
176 | if err = result.Error; err != nil {
177 | t.Fatal(err)
178 | }
179 | dataBytes, _ := json.MarshalIndent(data, "", " ")
180 | t.Logf("Offset(%d) Limit(%d) got size = %d, data = %s",
181 | tt.args.offset, tt.args.limit, len(data), dataBytes)
182 | })
183 | }
184 | }
185 |
186 | func TestAddSessionParams(t *testing.T) {
187 | db, err := dbIgnoreCase, dbErrors[1]
188 | if err != nil {
189 | t.Fatal(err)
190 | }
191 | if db == nil {
192 | t.Log("db is nil!")
193 | return
194 | }
195 | var sqlDB *sql.DB
196 | if sqlDB, err = db.DB(); err != nil {
197 | t.Fatal(err)
198 | }
199 | type args struct {
200 | params map[string]string
201 | }
202 | tests := []struct {
203 | name string
204 | args args
205 | }{
206 | {name: "TimeParams", args: args{params: map[string]string{
207 | "TIME_ZONE": "+08:00", // alter session set TIME_ZONE = '+08:00';
208 | "NLS_DATE_FORMAT": "YYYY-MM-DD", // alter session set NLS_DATE_FORMAT = 'YYYY-MM-DD';
209 | "NLS_TIME_FORMAT": "HH24:MI:SSXFF", // alter session set NLS_TIME_FORMAT = 'HH24:MI:SS.FF3';
210 | "NLS_TIMESTAMP_FORMAT": "YYYY-MM-DD HH24:MI:SSXFF", // alter session set NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3';
211 | "NLS_TIME_TZ_FORMAT": "HH24:MI:SS.FF TZR", // alter session set NLS_TIME_TZ_FORMAT = 'HH24:MI:SS.FF3 TZR';
212 | "NLS_TIMESTAMP_TZ_FORMAT": "YYYY-MM-DD HH24:MI:SSXFF TZR", // alter session set NLS_TIMESTAMP_TZ_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3 TZR';
213 | }}},
214 | }
215 | for _, tt := range tests {
216 | t.Run(tt.name, func(t *testing.T) {
217 | //queryTime := `SELECT SYSDATE FROM DUAL`
218 | queryTime := `SELECT CAST(SYSDATE AS VARCHAR(30)) AS D FROM DUAL`
219 | var timeStr string
220 | if err = db.Raw(queryTime).Row().Scan(&timeStr); err != nil {
221 | t.Fatal(err)
222 | }
223 | t.Logf("SYSDATE 1: %s", timeStr)
224 |
225 | var keys []string
226 | if keys, err = AddSessionParams(sqlDB, tt.args.params); err != nil {
227 | t.Fatalf("AddSessionParams() error = %v", err)
228 | }
229 | if err = db.Raw(queryTime).Row().Scan(&timeStr); err != nil {
230 | t.Fatal(err)
231 | }
232 | defer DelSessionParams(sqlDB, keys)
233 | t.Logf("SYSDATE 2: %s", timeStr)
234 | t.Logf("keys: %#v", keys)
235 | })
236 | }
237 | }
238 |
239 | func TestGetStringExpr(t *testing.T) {
240 | db, err := dbNamingCase, dbErrors[0]
241 | if err != nil {
242 | t.Fatal(err)
243 | }
244 | if db == nil {
245 | t.Log("db is nil!")
246 | return
247 | }
248 |
249 | type args struct {
250 | prepareSQL string
251 | value string
252 | quote bool
253 | }
254 | tests := []struct {
255 | name string
256 | args args
257 | wantSQL string
258 | }{
259 | {"1", args{`SELECT ? AS HELLO FROM DUAL`, "Hi!", true}, `SELECT 'Hi!' AS HELLO FROM DUAL`},
260 | {"2", args{`SELECT '?' AS HELLO FROM DUAL`, "Hi!", false}, `SELECT 'Hi!' AS HELLO FROM DUAL`},
261 | {"3", args{`SELECT ? AS HELLO FROM DUAL`, "What's your name?", true}, `SELECT q'[What's your name?]' AS HELLO FROM DUAL`},
262 | {"4", args{`SELECT '?' AS HELLO FROM DUAL`, "What's your name?", false}, `SELECT 'What''s your name?' AS HELLO FROM DUAL`},
263 | {"5", args{`SELECT ? AS HELLO FROM DUAL`, "What's up]'?", true}, `SELECT q'{What's up]'?}' AS HELLO FROM DUAL`},
264 | {"6", args{`SELECT ? AS HELLO FROM DUAL`, "What's up]'}'?", true}, `SELECT q'' AS HELLO FROM DUAL`},
265 | {"7", args{`SELECT ? AS HELLO FROM DUAL`, "What's up]'}'>'?", true}, `SELECT q'(What's up]'}'>'?)' AS HELLO FROM DUAL`},
266 | {"8", args{`SELECT ? AS HELLO FROM DUAL`, "What's up)'}'>'?", true}, `SELECT q'[What's up)'}'>'?]' AS HELLO FROM DUAL`},
267 | }
268 | for _, tt := range tests {
269 | t.Run(tt.name, func(t *testing.T) {
270 | gotSQL := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
271 | return tx.Raw(tt.args.prepareSQL, GetStringExpr(tt.args.value, tt.args.quote))
272 | })
273 | if !reflect.DeepEqual(gotSQL, tt.wantSQL) {
274 | t.Fatalf("ToSQL = %v, want %v", gotSQL, tt.wantSQL)
275 | }
276 | var results []map[string]interface{}
277 | if err = db.Raw(gotSQL).Find(&results).Error; err != nil {
278 | t.Fatalf("finds all records from raw sql got error: %v", err)
279 | }
280 | t.Log("result:", results)
281 | })
282 | }
283 | }
284 |
285 | func TestVarcharSizeIsCharLength(t *testing.T) {
286 | dsn := getTestDSN()
287 |
288 | db, err := gorm.Open(New(Config{
289 | DSN: dsn,
290 | IgnoreCase: true,
291 | NamingCaseSensitive: true,
292 | VarcharSizeIsCharLength: true,
293 | }), getTestGormConfig())
294 | if db != nil && err == nil {
295 | log.Println("open oracle database connection success!")
296 | } else {
297 | t.Fatal(err)
298 | }
299 |
300 | model := TestTableUserVarcharSize{}
301 | migrator := db.Set("gorm:table_comments", "TestVarcharSizeIsCharLength").Migrator()
302 | if migrator.HasTable(model) {
303 | if err = migrator.DropTable(model); err != nil {
304 | t.Fatalf("DropTable() error = %v", err)
305 | }
306 | }
307 | if err = migrator.AutoMigrate(model); err != nil {
308 | t.Fatalf("AutoMigrate() error = %v", err)
309 | }
310 | t.Log("AutoMigrate() success!")
311 |
312 | type args struct {
313 | value string
314 | }
315 | tests := []struct {
316 | name string
317 | args args
318 | wantErr bool
319 | }{
320 | {"50", args{strings.Repeat("姓名", 25)}, false},
321 | {"60", args{strings.Repeat("姓名", 30)}, true},
322 | }
323 | for _, tt := range tests {
324 | t.Run(tt.name, func(t *testing.T) {
325 | gotErr := db.Create(&TestTableUserVarcharSize{TestTableUser{Name: tt.args.value}}).Error
326 | if (gotErr != nil) != tt.wantErr {
327 | t.Error(gotErr)
328 | } else if gotErr != nil {
329 | t.Log(gotErr)
330 | }
331 | })
332 | }
333 | }
334 |
335 | type TestTableUserVarcharSize struct {
336 | TestTableUser
337 | }
338 |
339 | func (TestTableUserVarcharSize) TableName() string {
340 | return "test_user_varchar_size"
341 | }
342 |
--------------------------------------------------------------------------------
/reserved.go:
--------------------------------------------------------------------------------
1 | package oracle
2 |
3 | import (
4 | "strings"
5 |
6 | "github.com/emirpasic/gods/sets/hashset"
7 | )
8 |
9 | var ReservedWords = hashset.New(func() []interface{} {
10 | reservedWords := make([]interface{}, len(ReservedWordsList))
11 | for i, word := range ReservedWordsList {
12 | reservedWords[i] = word
13 | }
14 | return reservedWords
15 | }()...)
16 |
17 | func IsReservedWord(v string) bool {
18 | return ReservedWords.Contains(strings.ToUpper(v))
19 | }
20 |
21 | var ReservedWordsList = []string{
22 | "AGGREGATE", "AGGREGATES", "ALL", "ALLOW", "ANALYZE", "ANCESTOR", "AND", "ANY", "AS", "ASC", "AT", "AVG", "BETWEEN",
23 | "BINARY_DOUBLE", "BINARY_FLOAT", "BLOB", "BRANCH", "BUILD", "BY", "BYTE", "CASE", "CAST", "CHAR", "CHILD", "CLEAR",
24 | "CLOB", "COMMIT", "COMPILE", "CONSIDER", "COUNT", "CREATE", "DATATYPE", "DATE", "DATE_MEASURE", "DAY", "DECIMAL",
25 | "DELETE", "DESC", "DESCENDANT", "DIMENSION", "DISALLOW", "DIVISION", "DML", "ELSE", "END", "ESCAPE", "EXECUTE",
26 | "FIRST", "FLOAT", "FOR", "FROM", "HIERARCHIES", "HIERARCHY", "HOUR", "IGNORE", "IN", "INFINITE", "INSERT",
27 | "INTEGER", "INTERVAL", "INTO", "IS", "LAST", "LEAF_DESCENDANT", "LEAVES", "LEVEL", "LIKE", "LIKEC", "LIKE2",
28 | "LIKE4", "LOAD", "LOCAL", "LOG_SPEC", "LONG", "MAINTAIN", "MAX", "MEASURE", "MEASURES", "MEMBER", "MEMBERS",
29 | "MERGE", "MLSLABEL", "MIN", "MINUTE", "MODEL", "MONTH", "NAN", "NCHAR", "NCLOB", "NO", "NONE", "NOT", "NULL",
30 | "NULLS", "NUMBER", "NVARCHAR2", "OF", "OLAP", "OLAP_DML_EXPRESSION", "ON", "ONLY", "OPERATOR", "OR", "ORDER",
31 | "OVER", "OVERFLOW", "PARALLEL", "PARENT", "PLSQL", "PRUNE", "RAW", "RELATIVE", "ROOT_ANCESTOR", "ROWID", "SCN",
32 | "SECOND", "SELF", "SERIAL", "SET", "SOLVE", "SOME", "SORT", "SPEC", "SUM", "SYNCH", "TEXT_MEASURE", "THEN", "TIME",
33 | "TIMESTAMP", "TO", "UNBRANCH", "UPDATE", "USING", "VALIDATE", "VALUES", "VARCHAR2", "WHEN", "WHERE", "WITHIN",
34 | "WITH", "YEAR", "ZERO", "ZONE",
35 | }
36 |
--------------------------------------------------------------------------------
/update.go:
--------------------------------------------------------------------------------
1 | package oracle
2 |
3 | import (
4 | "reflect"
5 | "sort"
6 |
7 | "gorm.io/gorm"
8 | "gorm.io/gorm/callbacks"
9 | "gorm.io/gorm/clause"
10 | "gorm.io/gorm/schema"
11 | "gorm.io/gorm/utils"
12 | )
13 |
14 | func Update(config *callbacks.Config) func(db *gorm.DB) {
15 | supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
16 |
17 | return func(db *gorm.DB) {
18 | if db.Error != nil {
19 | return
20 | }
21 |
22 | stmt := db.Statement
23 | if stmt == nil {
24 | return
25 | }
26 |
27 | if stmtSchema := stmt.Schema; stmtSchema != nil {
28 | for _, c := range stmtSchema.UpdateClauses {
29 | stmt.AddClause(c)
30 | }
31 | }
32 |
33 | if stmt.SQL.Len() == 0 {
34 | stmt.SQL.Grow(180)
35 | stmt.AddClauseIfNotExists(clause.Update{})
36 | if _, ok := stmt.Clauses["SET"]; !ok {
37 | if set := ConvertToAssignments(stmt); len(set) != 0 {
38 | defer delete(stmt.Clauses, "SET")
39 | stmt.AddClause(set)
40 | } else {
41 | return
42 | }
43 | }
44 |
45 | stmt.Build(stmt.BuildClauses...)
46 | }
47 |
48 | checkMissingWhereConditions(db)
49 |
50 | if !db.DryRun && db.Error == nil {
51 | for i, val := range stmt.Vars {
52 | // HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
53 | stmt.Vars[i] = convertValue(val)
54 | }
55 | if ok, mode := hasReturning(db, supportReturning); ok {
56 | if rows, err := stmt.ConnPool.QueryContext(stmt.Context, stmt.SQL.String(), stmt.Vars...); db.AddError(err) == nil {
57 | dest := stmt.Dest
58 | stmt.Dest = stmt.ReflectValue.Addr().Interface()
59 | gorm.Scan(rows, db, mode)
60 | stmt.Dest = dest
61 | _ = db.AddError(rows.Close())
62 | }
63 | } else {
64 | result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
65 |
66 | if db.AddError(err) == nil {
67 | db.RowsAffected, _ = result.RowsAffected()
68 | }
69 | }
70 | }
71 | }
72 | }
73 |
74 | func checkMissingWhereConditions(db *gorm.DB) {
75 | if !db.AllowGlobalUpdate && db.Error == nil {
76 | where, withCondition := db.Statement.Clauses["WHERE"]
77 | if withCondition {
78 | if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
79 | whereClause, _ := where.Expression.(clause.Where)
80 | withCondition = len(whereClause.Exprs) > 1
81 | }
82 | }
83 | if !withCondition {
84 | _ = db.AddError(gorm.ErrMissingWhereClause)
85 | }
86 | return
87 | }
88 | }
89 |
90 | func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
91 | if supportReturning {
92 | if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
93 | returning, _ := c.Expression.(clause.Returning)
94 | if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
95 | return true, 0
96 | }
97 | return true, gorm.ScanUpdate
98 | }
99 | }
100 | return false, 0
101 | }
102 |
103 | // ConvertToAssignments convert to update assignments
104 | func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
105 | var (
106 | selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
107 | assignValue func(field *schema.Field, value interface{})
108 | )
109 |
110 | switch stmt.ReflectValue.Kind() {
111 | case reflect.Slice, reflect.Array:
112 | assignValue = func(field *schema.Field, value interface{}) {
113 | for i := 0; i < stmt.ReflectValue.Len(); i++ {
114 | if stmt.ReflectValue.CanAddr() {
115 | _ = field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
116 | }
117 | }
118 | }
119 | case reflect.Struct:
120 | assignValue = func(field *schema.Field, value interface{}) {
121 | if stmt.ReflectValue.CanAddr() {
122 | _ = field.Set(stmt.Context, stmt.ReflectValue, value)
123 | }
124 | }
125 | default:
126 | assignValue = func(field *schema.Field, value interface{}) {
127 | }
128 | }
129 |
130 | updatingValue := reflect.ValueOf(stmt.Dest)
131 | for updatingValue.Kind() == reflect.Ptr {
132 | updatingValue = updatingValue.Elem()
133 | }
134 |
135 | if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
136 | switch stmt.ReflectValue.Kind() {
137 | case reflect.Slice, reflect.Array:
138 | if size := stmt.ReflectValue.Len(); size > 0 {
139 | var isZero bool
140 | for i := 0; i < size; i++ {
141 | for _, field := range stmt.Schema.PrimaryFields {
142 | _, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
143 | if !isZero {
144 | break
145 | }
146 | }
147 | }
148 |
149 | if !isZero {
150 | _, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
151 | column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
152 | stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
153 | }
154 | }
155 | case reflect.Struct:
156 | for _, field := range stmt.Schema.PrimaryFields {
157 | if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
158 | stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
159 | }
160 | }
161 | default:
162 | }
163 | }
164 |
165 | switch value := updatingValue.Interface().(type) {
166 | case map[string]interface{}:
167 | set = make([]clause.Assignment, 0, len(value))
168 |
169 | keys := make([]string, 0, len(value))
170 | for k := range value {
171 | keys = append(keys, k)
172 | }
173 | sort.Strings(keys)
174 |
175 | for _, k := range keys {
176 | kv := value[k]
177 | if _, ok := kv.(*gorm.DB); ok {
178 | kv = []interface{}{kv}
179 | }
180 |
181 | if stmt.Schema != nil {
182 | if field := stmt.Schema.LookUpField(k); field != nil {
183 | if field.DBName != "" {
184 | if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
185 | set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
186 | assignValue(field, value[k])
187 | }
188 | } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
189 | assignValue(field, value[k])
190 | }
191 | continue
192 | }
193 | }
194 |
195 | if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
196 | set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
197 | }
198 | }
199 |
200 | if !stmt.SkipHooks && stmt.Schema != nil {
201 | for _, dbName := range stmt.Schema.DBNames {
202 | field := stmt.Schema.LookUpField(dbName)
203 | if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
204 | if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
205 | now := stmt.DB.NowFunc()
206 | assignValue(field, now)
207 |
208 | if field.AutoUpdateTime == schema.UnixNanosecond {
209 | set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
210 | } else if field.AutoUpdateTime == schema.UnixMillisecond {
211 | set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
212 | } else if field.AutoUpdateTime == schema.UnixSecond {
213 | set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
214 | } else {
215 | set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
216 | }
217 | }
218 | }
219 | }
220 | }
221 | default:
222 | updatingSchema := stmt.Schema
223 | var isDiffSchema bool
224 | if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
225 | // different schema
226 | updatingStmt := &gorm.Statement{DB: stmt.DB}
227 | if err := updatingStmt.Parse(stmt.Dest); err == nil {
228 | updatingSchema = updatingStmt.Schema
229 | isDiffSchema = true
230 | }
231 | }
232 |
233 | switch updatingValue.Kind() {
234 | case reflect.Struct:
235 | set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
236 | for _, dbName := range stmt.Schema.DBNames {
237 | if field := updatingSchema.LookUpField(dbName); field != nil {
238 | if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
239 | if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
240 | value, isZero := field.ValueOf(stmt.Context, updatingValue)
241 | if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
242 | if field.AutoUpdateTime == schema.UnixNanosecond {
243 | value = stmt.DB.NowFunc().UnixNano()
244 | } else if field.AutoUpdateTime == schema.UnixMillisecond {
245 | value = stmt.DB.NowFunc().UnixNano() / 1e6
246 | } else if field.AutoUpdateTime == schema.UnixSecond {
247 | value = stmt.DB.NowFunc().Unix()
248 | } else {
249 | value = stmt.DB.NowFunc()
250 | }
251 | isZero = false
252 | }
253 |
254 | if (ok || !isZero) && field.Updatable {
255 | value = convertCustomType(value)
256 | set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
257 | assignField := field
258 | if isDiffSchema {
259 | if originField := stmt.Schema.LookUpField(dbName); originField != nil {
260 | assignField = originField
261 | }
262 | }
263 | assignValue(assignField, value)
264 | }
265 | }
266 | } else {
267 | if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
268 | stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
269 | }
270 | }
271 | }
272 | }
273 | default:
274 | _ = stmt.AddError(gorm.ErrInvalidData)
275 | }
276 | }
277 |
278 | return
279 | }
280 |
--------------------------------------------------------------------------------