├── .editorconfig ├── .github └── workflows │ ├── codeql-analysis.yml │ ├── mariadb.yml │ ├── mysql.yml │ ├── postgres.yml │ ├── sqlite.yml │ └── sqlite3.yml ├── .gitignore ├── LICENSE ├── README.md ├── bench_test.go ├── core ├── builder.go ├── builder_test.go ├── column.go ├── column_test.go ├── core.go ├── core_test.go ├── model.go ├── model_test.go ├── primitive.go ├── primitive_test.go ├── stmt.go └── stmt_test.go ├── db.go ├── db_test.go ├── dialect ├── base.go ├── base_test.go ├── common_test.go ├── dialect.go ├── dialect_test.go ├── mysql.go ├── mysql_test.go ├── postgres.go ├── postgres_test.go ├── sqlite3.go ├── sqlite3_test.go └── testdata │ ├── .gitignore │ └── .gitkeep ├── doc.go ├── docs ├── advance.md ├── curd.md ├── dialect.md ├── index.md ├── model.md ├── quick-start.md ├── sqlbuilder.md └── upgrade.md ├── fetch ├── bench_test.go ├── column.go ├── column_test.go ├── fetch.go ├── fetch_test.go ├── map.go ├── map_test.go ├── object.go └── object_test.go ├── go.mod ├── go.sum ├── internal ├── createtable │ ├── createtable.go │ ├── createtable_test.go │ ├── sqlite3.go │ └── sqlite3_test.go ├── model │ ├── bench_test.go │ ├── column.go │ ├── column_test.go │ ├── engine.go │ ├── model.go │ ├── model_test.go │ ├── models.go │ ├── models_test.go │ └── testdata │ │ └── user.go ├── sqltest │ ├── sqltest.go │ └── sqltest_test.go ├── tags │ ├── flag_test.go │ ├── tags.go │ └── tags_test.go └── test │ ├── flag.go │ ├── test.go │ └── test_test.go ├── model_test.go ├── sqlbuilder.go ├── sqlbuilder ├── base.go ├── base_test.go ├── column.go ├── column_test.go ├── constraint.go ├── constraint_test.go ├── delete.go ├── delete_test.go ├── index.go ├── index_test.go ├── insert.go ├── insert_test.go ├── parser.go ├── parser_test.go ├── select.go ├── select_test.go ├── sqlbuilder.go ├── sqlbuilder_test.go ├── table.go ├── table_test.go ├── update.go ├── update_test.go ├── version.go ├── version_test.go ├── view.go ├── view_test.go ├── where.go └── where_test.go ├── tx.go ├── tx_test.go ├── types.go ├── types ├── decimal.go ├── decimal_test.go ├── rat.go ├── rat_test.go ├── slices.go ├── slices_test.go ├── types.go ├── types_test.go ├── unix.go └── unix_test.go ├── types_test.go ├── upgrade.go ├── upgrade_test.go ├── where.go └── where_test.go /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig is awesome: http://EditorConfig.org 2 | 3 | # top-most EditorConfig file 4 | root = true 5 | 6 | # Unix-style newlines with a newline ending every file 7 | [*] 8 | end_of_line = lf 9 | insert_final_newline = true 10 | charset = utf-8 11 | 12 | # html 13 | [*.{htm,html,js,css}] 14 | indent_style = space 15 | indent_size = 4 16 | 17 | # 配置文件 18 | [*.{yml,yaml,json}] 19 | indent_style = space 20 | indent_size = 2 21 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | name: "CodeQL" 7 | 8 | on: 9 | push: 10 | branches: [master] 11 | pull_request: 12 | # The branches below must be a subset of the branches above 13 | branches: [master] 14 | schedule: 15 | - cron: '0 4 * * 1' 16 | 17 | jobs: 18 | analyze: 19 | name: Analyze 20 | runs-on: ubuntu-latest 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | # Override automatic language detection by changing the below list 26 | # Supported options are ['csharp', 'cpp', 'go', 'java', 'javascript', 'python'] 27 | language: ['go'] 28 | # Learn more... 29 | # https://docs.github.com/en/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#overriding-automatic-language-detection 30 | 31 | steps: 32 | - name: Checkout repository 33 | uses: actions/checkout@v4 34 | with: 35 | # We must fetch at least the immediate parents so that if this is 36 | # a pull request then we can checkout the head. 37 | fetch-depth: 2 38 | 39 | # Initializes the CodeQL tools for scanning. 40 | - name: Initialize CodeQL 41 | uses: github/codeql-action/init@v3 42 | with: 43 | languages: ${{ matrix.language }} 44 | # If you wish to specify custom queries, you can do so here or in a config file. 45 | # By default, queries listed here will override any specified in a config file. 46 | # Prefix the list here with "+" to use these queries and those in the config file. 47 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 48 | 49 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 50 | # If this step fails, then you should remove it and run the build manually (see below) 51 | - name: Autobuild 52 | uses: github/codeql-action/autobuild@v3 53 | 54 | # ℹ️ Command-line programs to run using the OS shell. 55 | # 📚 https://git.io/JvXDl 56 | 57 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 58 | # and modify them (or add more) to build your code if your project 59 | # uses a compiled language 60 | 61 | #- run: | 62 | # make bootstrap 63 | # make release 64 | 65 | - name: Perform CodeQL Analysis 66 | uses: github/codeql-action/analyze@v3 67 | -------------------------------------------------------------------------------- /.github/workflows/mariadb.yml: -------------------------------------------------------------------------------- 1 | name: Mariadb 2 | on: [push, pull_request] 3 | jobs: 4 | test: 5 | name: Test 6 | runs-on: ${{ matrix.os }} 7 | 8 | strategy: 9 | matrix: 10 | os: [ubuntu-latest] # action 不支持非 linux 下的容器, windows-latest, macOS-latest 11 | go: ["1.23.x", "1.24.x"] 12 | 13 | services: 14 | mariadb: 15 | image: mariadb:latest 16 | env: 17 | MYSQL_ROOT_PASSWORD: root 18 | ports: 19 | - 3306:3306 20 | options: >- 21 | --health-cmd="healthcheck.sh --connect --innodb_initialized" 22 | --health-interval=10s 23 | --health-timeout=5s 24 | --health-retries=3 25 | 26 | steps: 27 | - name: 创建数据库 28 | run: | 29 | mysql -u root -proot -h 127.0.0.1 -e 'CREATE DATABASE IF NOT EXISTS orm_test;' 30 | 31 | - name: 安装 Go ${{ matrix.go }} 32 | uses: actions/setup-go@v5 33 | with: 34 | go-version: ${{ matrix.go }} 35 | id: go 36 | 37 | - name: Check out code into the Go module directory 38 | uses: actions/checkout@v4 39 | 40 | - name: Vet 41 | run: go vet -v ./... 42 | 43 | - name: Test 44 | run: go test ./... -test.coverprofile=coverage.txt -covermode=atomic -dbs=mariadb,mysql -p=1 -parallel=1 45 | 46 | - name: Upload coverage to Codecov 47 | uses: codecov/codecov-action@v5 48 | with: 49 | token: ${{secrets.CODECOV_TOKEN}} 50 | files: ./coverage.txt 51 | -------------------------------------------------------------------------------- /.github/workflows/mysql.yml: -------------------------------------------------------------------------------- 1 | name: Mysql 2 | on: [push, pull_request] 3 | jobs: 4 | test: 5 | name: Test 6 | runs-on: ${{ matrix.os }} 7 | 8 | strategy: 9 | matrix: 10 | os: [ubuntu-latest] # action 不支持非 linux 下的容器, windows-latest, macOS-latest 11 | go: ["1.23.x", "1.24.x"] 12 | 13 | services: 14 | mysql: 15 | image: mysql:latest 16 | env: 17 | MYSQL_ROOT_PASSWORD: root 18 | ports: 19 | - 3306:3306 20 | options: >- 21 | --health-cmd="mysqladmin ping" 22 | --health-interval=10s 23 | --health-timeout=5s 24 | --health-retries=3 25 | 26 | steps: 27 | - name: 创建数据库 28 | run: | 29 | mysql -u root -proot -h 127.0.0.1 -e 'CREATE DATABASE IF NOT EXISTS orm_test;' 30 | 31 | - name: 安装 Go ${{ matrix.go }} 32 | uses: actions/setup-go@v5 33 | with: 34 | go-version: ${{ matrix.go }} 35 | id: go 36 | 37 | - name: Check out code into the Go module directory 38 | uses: actions/checkout@v4 39 | 40 | - name: Vet 41 | run: go vet -v ./... 42 | 43 | - name: Test 44 | run: go test ./... -test.coverprofile=coverage.txt -covermode=atomic -dbs=mysql,mysql -p=1 -parallel=1 45 | 46 | - name: Upload coverage to Codecov 47 | uses: codecov/codecov-action@v5 48 | with: 49 | token: ${{secrets.CODECOV_TOKEN}} 50 | files: ./coverage.txt 51 | -------------------------------------------------------------------------------- /.github/workflows/postgres.yml: -------------------------------------------------------------------------------- 1 | name: Postgres 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | test: 6 | name: Test 7 | runs-on: ${{ matrix.os }} 8 | 9 | strategy: 10 | matrix: 11 | os: [ubuntu-latest] # action 不支持非 linux 下的容器, windows-latest, macOS-latest 12 | go: ["1.23.x", "1.24.x"] 13 | 14 | services: 15 | postgres: 16 | image: postgres:17 17 | env: 18 | POSTGRES_USER: postgres 19 | POSTGRES_PASSWORD: postgres 20 | ports: 21 | - 5432:5432 22 | options: >- 23 | --health-cmd pg_isready 24 | --health-interval=10s 25 | --health-timeout=5s 26 | --health-retries=3 27 | 28 | steps: 29 | - name: 安装客户端 30 | run: | 31 | sudo apt-get update 32 | sudo apt-get install -y wget ca-certificates 33 | sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' 34 | wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - 35 | sudo apt-get update 36 | sudo apt-get install -y postgresql-client-17 37 | 38 | - name: 创建数据库 39 | run: | 40 | PGPASSWORD=postgres psql -U postgres -h 127.0.0.1 -c 'CREATE DATABASE orm_test;' 41 | 42 | - name: 安装 Go ${{ matrix.go }} 43 | uses: actions/setup-go@v5 44 | with: 45 | go-version: ${{ matrix.go }} 46 | id: go 47 | 48 | - name: Check out code into the Go module directory 49 | uses: actions/checkout@v4 50 | 51 | - name: Vet 52 | run: go vet -v ./... 53 | 54 | - name: Test 55 | run: go test ./... -test.coverprofile=coverage.txt -covermode=atomic -dbs=postgres,postgres -p=1 -parallel=1 56 | 57 | - name: Upload coverage to Codecov 58 | uses: codecov/codecov-action@v5 59 | with: 60 | token: ${{secrets.CODECOV_TOKEN}} 61 | files: ./coverage.txt 62 | -------------------------------------------------------------------------------- /.github/workflows/sqlite.yml: -------------------------------------------------------------------------------- 1 | name: Sqlite 2 | on: [push, pull_request] 3 | jobs: 4 | test: 5 | name: Test 6 | runs-on: ${{ matrix.os }} 7 | 8 | strategy: 9 | matrix: 10 | os: [ubuntu-latest] # action 不支持非 linux 下的容器, windows-latest, macOS-latest 11 | go: ["1.23.x", "1.24.x"] 12 | 13 | steps: 14 | - name: 安装 Go ${{ matrix.go }} 15 | uses: actions/setup-go@v5 16 | with: 17 | go-version: ${{ matrix.go }} 18 | id: go 19 | 20 | - name: Check out code into the Go module directory 21 | uses: actions/checkout@v4 22 | 23 | - name: Vet 24 | run: go vet -v ./... 25 | 26 | - name: Test 27 | run: go test ./... -test.coverprofile=coverage.txt -covermode=atomic -dbs=sqlite3,sqlite -p=1 -parallel=1 28 | 29 | - name: Upload coverage to Codecov 30 | uses: codecov/codecov-action@v5 31 | with: 32 | token: ${{secrets.CODECOV_TOKEN}} 33 | files: ./coverage.txt 34 | -------------------------------------------------------------------------------- /.github/workflows/sqlite3.yml: -------------------------------------------------------------------------------- 1 | name: Sqlite3 2 | on: [push, pull_request] 3 | jobs: 4 | test: 5 | name: Test 6 | runs-on: ${{ matrix.os }} 7 | 8 | strategy: 9 | matrix: 10 | os: [ubuntu-latest] # action 不支持非 linux 下的容器, windows-latest, macOS-latest 11 | go: ["1.23.x", "1.24.x"] 12 | 13 | steps: 14 | - name: 安装 Go ${{ matrix.go }} 15 | uses: actions/setup-go@v5 16 | with: 17 | go-version: ${{ matrix.go }} 18 | id: go 19 | 20 | - name: Check out code into the Go module directory 21 | uses: actions/checkout@v4 22 | 23 | - name: Vet 24 | run: go vet -v ./... 25 | 26 | - name: Test 27 | run: go test ./... -test.coverprofile=coverage.txt -covermode=atomic -dbs=sqlite3,sqlite3 -p=1 -parallel=1 28 | 29 | - name: Upload coverage to Codecov 30 | uses: codecov/codecov-action@v5 31 | with: 32 | token: ${{secrets.CODECOV_TOKEN}} 33 | files: ./coverage.txt 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | *.exe 7 | *.test 8 | *.prof 9 | 10 | # 测试文件 11 | orm_test.db 12 | coverage.txt 13 | 14 | # mac os x 15 | .DS_Store 16 | 17 | # ide 18 | .idea 19 | .vscode 20 | *.swp 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 caixw 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /bench_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package orm_test 6 | 7 | import ( 8 | "database/sql" 9 | "testing" 10 | "time" 11 | 12 | "github.com/issue9/assert/v4" 13 | 14 | "github.com/issue9/orm/v6/internal/test" 15 | "github.com/issue9/orm/v6/sqlbuilder" 16 | ) 17 | 18 | func BenchmarkDB_Insert(b *testing.B) { 19 | a := assert.New(b, false) 20 | 21 | m := &Group{ 22 | Name: "name", 23 | Created: time.Now(), 24 | Any: 5, 25 | } 26 | 27 | suite := test.NewSuite(a, "") 28 | 29 | suite.Run(func(t *test.Driver) { 30 | t.NotError(t.DB.Create(&Group{})) 31 | defer func() { 32 | t.NotError(t.DB.Drop(&Group{})) 33 | }() 34 | 35 | for i := 0; i < b.N; i++ { 36 | _, err := t.DB.Insert(m) 37 | t.NotError(err) 38 | } 39 | }) 40 | } 41 | 42 | func BenchmarkDB_Update(b *testing.B) { 43 | a := assert.New(b, false) 44 | 45 | m := &Group{ 46 | Name: "name", 47 | Created: time.Now(), 48 | Any: 5, 49 | } 50 | 51 | suite := test.NewSuite(a, "") 52 | 53 | suite.Run(func(t *test.Driver) { 54 | t.NotError(t.DB.Create(&Group{})) 55 | defer func() { 56 | t.NotError(t.DB.Drop(&Group{})) 57 | }() 58 | 59 | // 构造数据 60 | for i := 0; i < 10000; i++ { 61 | _, err := t.DB.Insert(m) 62 | t.NotError(err) 63 | } 64 | 65 | m.ID = sql.NullInt64{Int64: 1, Valid: true} // 自增,从 1 开始 66 | for i := 0; i < b.N; i++ { 67 | _, err := t.DB.Update(m) 68 | t.NotError(err) 69 | } 70 | }) 71 | } 72 | 73 | func BenchmarkDB_Select(b *testing.B) { 74 | a := assert.New(b, false) 75 | 76 | m := &Group{ 77 | Name: "name", 78 | Created: time.Now(), 79 | Any: 5, 80 | } 81 | 82 | suite := test.NewSuite(a, "") 83 | 84 | suite.Run(func(t *test.Driver) { 85 | t.NotError(t.DB.Create(&Group{})) 86 | defer func() { 87 | t.NotError(t.DB.Drop(&Group{})) 88 | }() 89 | 90 | _, err := t.DB.Insert(m) 91 | t.NotError(err) 92 | 93 | m.ID = sql.NullInt64{Int64: 1, Valid: true} 94 | for i := 0; i < b.N; i++ { 95 | found, err := t.DB.Select(m) 96 | t.NotError(err).True(found) 97 | } 98 | }) 99 | } 100 | 101 | func BenchmarkDB_WhereUpdate(b *testing.B) { 102 | a := assert.New(b, false) 103 | 104 | m := &Group{ 105 | Name: "name", 106 | Created: time.Now(), 107 | Any: 5, 108 | } 109 | 110 | suite := test.NewSuite(a, "") 111 | 112 | suite.Run(func(t *test.Driver) { 113 | t.NotError(t.DB.Create(&Group{})) 114 | defer func() { 115 | t.NotError(t.DB.Drop(&Group{})) 116 | }() 117 | 118 | // 构造数据 119 | for i := 0; i < 10000; i++ { 120 | _, err := t.DB.Insert(m) 121 | t.NotError(err) 122 | } 123 | 124 | for i := 0; i < b.N; i++ { 125 | _, err := sqlbuilder. 126 | Update(t.DB).Table("{groups}"). 127 | Set("name", "n1"). 128 | Increase("created", 1). 129 | Where("{id}=?", i+1). 130 | Exec() 131 | t.NotError(err) 132 | } 133 | }) 134 | } 135 | -------------------------------------------------------------------------------- /core/builder.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package core 6 | 7 | import ( 8 | "strings" 9 | 10 | "github.com/issue9/errwrap" 11 | ) 12 | 13 | // 作用于表名,列名等非关键字上的引号占位符。 14 | // 在执行会自动替换成该数据库相应的符号。 15 | const ( 16 | QuoteLeft = '{' 17 | QuoteRight = '}' 18 | ) 19 | 20 | // Builder 用于构建 SQL 语句 21 | // 22 | // 出错时,错误信息会缓存,并在 [Builder.String] 或 [Builder.Bytes] 时返回, 23 | // 或是通过 [Builder.Err] 查看是否存在错误。 24 | type Builder struct { 25 | buffer errwrap.Buffer 26 | } 27 | 28 | // NewBuilder 声明一个新的 [Builder] 实例 29 | func NewBuilder(str ...string) *Builder { 30 | b := &Builder{} 31 | 32 | for _, s := range str { 33 | b.WString(s) 34 | } 35 | 36 | return b 37 | } 38 | 39 | // WString 写入一字符串 40 | func (b *Builder) WString(str string) *Builder { 41 | b.buffer.WString(str) 42 | return b 43 | } 44 | 45 | // WBytes 写入多个字符 46 | func (b *Builder) WBytes(c ...byte) *Builder { 47 | b.buffer.WBytes(c) 48 | return b 49 | } 50 | 51 | // WRunes 写入多个字符 52 | func (b *Builder) WRunes(r ...rune) *Builder { 53 | b.buffer.WRunes(r) 54 | return b 55 | } 56 | 57 | // Quote 给 str 左右添加 l 和 r 两个字符 58 | func (b *Builder) Quote(str string, l, r byte) *Builder { return b.WBytes(l).WString(str).WBytes(r) } 59 | 60 | // QuoteKey 给 str 左右添加 [QuoteLeft] 和 [QuoteRight] 两个字符 61 | func (b *Builder) QuoteKey(str string) *Builder { return b.Quote(str, QuoteLeft, QuoteRight) } 62 | 63 | // QuoteColumn 为列名添加 [QuoteLeft] 和 [QuoteRight] 两个字符 64 | // 65 | // NOTE: 列名可能包含表名或是表名别名:table.col 66 | func (b *Builder) QuoteColumn(col string) *Builder { 67 | if index := strings.IndexByte(col, '.'); index > 0 { 68 | return b.QuoteKey(col[:index]).WBytes('.').QuoteKey(col[index+1:]) 69 | } 70 | return b.Quote(col, QuoteLeft, QuoteRight) 71 | } 72 | 73 | // Reset 重置内容,同时也会将 err 设置为 nil 74 | func (b *Builder) Reset() *Builder { 75 | b.buffer.Reset() 76 | return b 77 | } 78 | 79 | // TruncateLast 去掉最后几个字符 80 | func (b *Builder) TruncateLast(n int) *Builder { 81 | b.buffer.Truncate(b.Len() - n) 82 | return b 83 | } 84 | 85 | // Err 返回错误内容 86 | func (b *Builder) Err() error { return b.buffer.Err } 87 | 88 | // String 获取表示的字符串 89 | func (b *Builder) String() (string, error) { 90 | if b.Err() != nil { 91 | return "", b.Err() 92 | } 93 | return b.buffer.String(), nil 94 | } 95 | 96 | // Bytes 获取表示的字符串 97 | func (b *Builder) Bytes() ([]byte, error) { 98 | if b.Err() != nil { 99 | return nil, b.Err() 100 | } 101 | return b.buffer.Bytes(), nil 102 | } 103 | 104 | // Len 获取长度 105 | func (b *Builder) Len() int { return b.buffer.Len() } 106 | 107 | // Append 追加加一个 [Builder] 的内容 108 | func (b *Builder) Append(v *Builder) *Builder { 109 | if b.Err() != nil { 110 | return b 111 | } 112 | 113 | str, err := v.String() 114 | if err == nil { 115 | return b.WString(str) 116 | } 117 | b.buffer.Err = err 118 | 119 | return b 120 | } 121 | -------------------------------------------------------------------------------- /core/builder_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package core 6 | 7 | import ( 8 | "errors" 9 | "testing" 10 | 11 | "github.com/issue9/assert/v4" 12 | ) 13 | 14 | func TestSQLBuilder(t *testing.T) { 15 | a := assert.New(t, false) 16 | 17 | b := NewBuilder() 18 | b.WBytes('1') 19 | b.WString("23") 20 | 21 | str, err := b.String() 22 | a.NotError(err).Equal("123", str) 23 | a.Equal(3, b.Len()) 24 | 25 | b.Reset() 26 | str, err = b.String() 27 | a.NotError(err).Equal(str, "") 28 | a.Equal(b.Len(), 0) 29 | 30 | b.WBytes('B', 'y', 't', 'e'). 31 | WRunes('R', 'u', 'n', 'e'). 32 | WString("String") 33 | str, err = b.String() 34 | a.NotError(err).Equal(str, "ByteRuneString") 35 | 36 | b.WBytes('1', '2') 37 | b.TruncateLast(2) 38 | str, err = b.String() 39 | a.NotError(err).Equal(str, "ByteRuneString").Equal(14, b.Len()) 40 | 41 | b.Reset() 42 | b.QuoteKey("key") 43 | bs, err := b.Bytes() 44 | a.NotError(err).Equal(bs, []byte("{key}")) 45 | 46 | buf := NewBuilder("buf-") 47 | buf.Append(b) 48 | bs, err = buf.Bytes() 49 | a.NotError(err).Equal(bs, "buf-{key}") 50 | 51 | // 带错误信息 52 | buf = NewBuilder() 53 | buf.buffer.Err = errors.New("test") 54 | buf.WString("str1") 55 | str, err = buf.String() 56 | a.ErrorString(err, "test"). 57 | Empty(str) 58 | } 59 | -------------------------------------------------------------------------------- /core/column.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package core 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | ) 11 | 12 | var errInvalidColumnType = errors.New("无效的列类型") 13 | 14 | // ErrInvalidColumnType 无效的列类型 15 | // 16 | // 当一个数据无法表达为数据库中的对应的字段类型时返回此错误。 17 | func ErrInvalidColumnType() error { return errInvalidColumnType } 18 | 19 | // ErrColumnNotFound 返回列不存在的错误 20 | func ErrColumnNotFound(s string) error { return fmt.Errorf("列 %s 未找到", s) } 21 | 22 | // Column 列结构 23 | type Column struct { 24 | Name string // 数据库的字段名 25 | AI bool 26 | Nullable bool 27 | HasDefault bool 28 | Default any 29 | Length []int 30 | 31 | PrimitiveType PrimitiveType 32 | GoName string // Go 中的字段名 33 | } 34 | 35 | // NewColumn 从 Go 类型中生成 [Column] 36 | func NewColumn(p PrimitiveType) (*Column, error) { 37 | if p <= Auto || p >= maxPrimitiveType { 38 | return nil, ErrInvalidColumnType() 39 | } 40 | 41 | return &Column{ 42 | PrimitiveType: p, 43 | }, nil 44 | } 45 | 46 | // Clone 复制 Column 47 | func (c *Column) Clone() *Column { 48 | cc := &Column{} 49 | *cc = *c 50 | return cc 51 | } 52 | 53 | // Check 检测 [Column] 内容是否合法 54 | func (c *Column) Check() error { 55 | if c.AI && c.HasDefault { 56 | return fmt.Errorf("AutoIncrement 列 %s 不能同时包含默认值", c.Name) 57 | } 58 | 59 | if c.AI && c.Nullable { 60 | return fmt.Errorf("AutoIncrement 列 %s 不能同时带 NULL 约束", c.Name) 61 | } 62 | 63 | if c.PrimitiveType == String || c.PrimitiveType == Bytes { 64 | if len(c.Length) > 0 && (c.Length[0] < -1 || c.Length[0] == 0) { 65 | return fmt.Errorf("列 %s 的长度只能是 -1 或是 >0", c.Name) 66 | } 67 | } else { 68 | for _, v := range c.Length { 69 | if v < 0 { 70 | return fmt.Errorf("列 %s 的长度不能小于 0", c.Name) 71 | } 72 | } 73 | } 74 | 75 | return nil 76 | } 77 | 78 | // AddColumns 添加新列 79 | func (m *Model) AddColumns(col ...*Column) error { 80 | for _, c := range col { 81 | if err := m.AddColumn(c); err != nil { 82 | return err 83 | } 84 | } 85 | 86 | return nil 87 | } 88 | 89 | // AddColumn 添加新列 90 | // 91 | // 按添加顺序确定位置,越早添加的越在前。 92 | func (m *Model) AddColumn(col *Column) error { 93 | if col.Name == "" { 94 | return errors.New("列必须存在名称") 95 | } 96 | 97 | if m.FindColumn(col.Name) != nil { 98 | return fmt.Errorf("列 %s 已经存在", col.Name) 99 | } 100 | 101 | m.Columns = append(m.Columns, col) 102 | 103 | if col.AI { 104 | return m.SetAutoIncrement(col) 105 | } 106 | return nil 107 | } 108 | 109 | // FindColumn 查找指定名称的列 110 | // 111 | // 不存在该列则返回 nil 112 | func (m *Model) FindColumn(name string) *Column { 113 | for _, col := range m.Columns { 114 | if col.Name == name { 115 | return col 116 | } 117 | } 118 | return nil 119 | } 120 | 121 | func (m *Model) columnExists(col *Column) bool { 122 | for _, c := range m.Columns { 123 | if c == col { 124 | return true 125 | } 126 | } 127 | 128 | return false 129 | } 130 | -------------------------------------------------------------------------------- /core/column_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package core 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | ) 12 | 13 | func TestNewColumn(t *testing.T) { 14 | a := assert.New(t, false) 15 | 16 | col, err := NewColumn(Int) 17 | a.NotError(err).NotNil(col).Equal(col.PrimitiveType, Int) 18 | 19 | col, err = NewColumn(Bool) 20 | a.NotError(err).NotNil(col).Equal(col.PrimitiveType, Bool) 21 | 22 | col, err = NewColumn(Auto) 23 | a.ErrorIs(err, ErrInvalidColumnType()).Nil(col) 24 | 25 | col, err = NewColumn(maxPrimitiveType) 26 | a.ErrorIs(err, ErrInvalidColumnType()).Nil(col) 27 | } 28 | 29 | func TestModel_AddColumns(t *testing.T) { 30 | a := assert.New(t, false) 31 | m := NewModel(Table, "m1", 10) 32 | a.NotNil(m) 33 | 34 | ai, err := NewColumn(Int) 35 | a.NotError(err).NotNil(ai) 36 | ai.AI = true 37 | a.Error(m.AddColumns(ai)) // 没有名称 38 | 39 | col, err := NewColumn(Int) 40 | a.NotError(err).NotNil(col) 41 | 42 | // 同名 43 | ai.Name = "ai" 44 | col.Name = "ai" 45 | a.Error(m.AddColumns(ai, col)) 46 | 47 | // 正常 48 | m.Reset() 49 | col.Name = "col" 50 | a.NotError(m.AddColumns(ai, col)) 51 | } 52 | 53 | func TestColumn_Clone(t *testing.T) { 54 | a := assert.New(t, false) 55 | 56 | col, err := NewColumn(Int) 57 | a.NotError(err).NotNil(col) 58 | col.Nullable = true 59 | 60 | cc := col.Clone() 61 | a.Equal(cc, col) // 值相同 62 | a.True(cc != col) // 但不是同一个实例 63 | } 64 | 65 | func TestColumn_Check(t *testing.T) { 66 | a := assert.New(t, false) 67 | 68 | col, err := NewColumn(String) 69 | a.NotError(err).NotNil(col) 70 | col.Length = []int{-1} 71 | 72 | a.NotError(col.Check()) 73 | 74 | col.Length[0] = 0 75 | a.Error(col.Check()) 76 | 77 | col.Length[0] = -2 78 | a.Error(col.Check()) 79 | 80 | col, err = NewColumn(Int) 81 | a.NotError(err).NotNil(col) 82 | col.Length = []int{-2} 83 | a.Error(col.Check()) 84 | 85 | col.Length[0] = -1 86 | a.Error(col.Check()) 87 | 88 | col.Length[0] = 0 89 | a.NotError(col.Check()) 90 | 91 | col.AI = true 92 | col.HasDefault = true 93 | a.Error(col.Check()) 94 | 95 | col.AI = true 96 | col.HasDefault = false 97 | col.Nullable = true 98 | a.Error(col.Check()) 99 | } 100 | -------------------------------------------------------------------------------- /core/core.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | // Package core 核心功能 6 | package core 7 | 8 | import ( 9 | "context" 10 | "database/sql" 11 | "fmt" 12 | ) 13 | 14 | // 索引的类型 15 | const ( 16 | IndexDefault IndexType = iota // 普通的索引 17 | IndexUnique // 唯一索引 18 | ) 19 | 20 | // 约束类型 21 | // 22 | // 以下定义了一些常用的约束类型,但是并不是所有的数据都支持这些约束类型, 23 | // 比如 mysql<8.0.16 和 mariadb<10.2.1 不支持 check 约束。 24 | const ( 25 | ConstraintNone ConstraintType = iota 26 | ConstraintUnique // 唯一约束 27 | ConstraintFK // 外键约束 28 | ConstraintCheck // Check 约束 29 | ConstraintPK // 主键约束 30 | ) 31 | 32 | type IndexType int8 33 | 34 | type ConstraintType int8 35 | 36 | // Engine 数据库执行的基本接口 37 | // 38 | // Engine 对查询语句作了以下处理: 39 | // - {} 符号会被替换为 [Dialect.Quotes] 对应的符号; 40 | // - # 会被替换为 [Engine.TablePrefix] 的返回值; 41 | type Engine interface { 42 | Query(query string, args ...any) (*sql.Rows, error) 43 | QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) 44 | 45 | QueryRow(query string, args ...any) *sql.Row 46 | QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row 47 | 48 | Exec(query string, args ...any) (sql.Result, error) 49 | ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) 50 | 51 | Prepare(query string) (*Stmt, error) 52 | PrepareContext(ctx context.Context, query string) (*Stmt, error) 53 | 54 | Dialect() Dialect 55 | 56 | // Debug 启用调试输出 57 | // 58 | // 如果传递了一个非空值,那么会将生成的 SQL 输出到 l。 59 | Debug(l func(string)) 60 | } 61 | 62 | // Dialect 用于描述与数据库和驱动相关的一些特性 63 | // 64 | // Dialect 的实现者除了要实现 Dialect 之外, 65 | // 还需要根据数据库的支持情况实现 sqlbuilder 下的部分 *Hooker 接口。 66 | type Dialect interface { 67 | // Name 当前关联实例的名称 68 | // 69 | // 一般为该数据库的官方名称。 70 | // 实例名称和驱动名未必相同。比如 mysql 和 mariadb 可能采用相同的驱动名; 71 | Name() string 72 | 73 | // DriverName 与当前实例关联的驱动名称 74 | // 75 | // 原则上驱动名和 [Dialect] 应该是一一对应的,但是也会有例外,比如: 76 | // github.com/lib/pq 和 github.com/jackc/pgx/v5/stdlib 功能上是相同的, 77 | // 仅注册的名称的不同。 78 | DriverName() string 79 | 80 | Quotes() (left, right byte) 81 | 82 | // SQLType 将列转换成数据支持的类型表达式 83 | // 84 | // 必须实现对所有 [PrimitiveType] 类型的转换。 85 | SQLType(*Column) (string, error) 86 | 87 | // TransactionalDDL 是否允许在事务中执行 DDL 88 | // 89 | // 比如在 postgresql 中,如果创建一个带索引的表,会采用在事务中, 90 | // 分多条语句创建表。 91 | // 而像 mysql 等不支持事务内 DDL 的数据库,则会采用普通的方式, 92 | // 依次提交语句。 93 | TransactionalDDL() bool 94 | 95 | // VersionSQL 查询服务器版本号的 SQL 语句 96 | VersionSQL() string 97 | 98 | // ExistsSQL 查询数据库中是否存在指定名称的表或是视图 SQL 语句 99 | // 100 | // 返回的 SQL语句中,其执行结果如果存在,则应该返回 name 字段表示表名,否则返回空。 101 | ExistsSQL(name string, view bool) (string, []any) 102 | 103 | // LimitSQL 生成 `LIMIT N OFFSET M` 或是相同的语意的语句片段 104 | // 105 | // offset 值为一个可选参数,若不指定,则表示 `LIMIT N` 语句。 106 | // 返回的是对应数据库的 limit 语句以及语句中占位符对应的值。 107 | // 108 | // limit 和 offset 可以是 SQL.NamedArg 类型。 109 | LimitSQL(limit any, offset ...any) (string, []any) 110 | 111 | // LastInsertIDSQL 自定义获取 LastInsertID 的获取方式 112 | // 113 | // 类似于 postgresql 等都需要额外定义。 114 | // 115 | // sql 表示额外的语句,如果为空,则执行的是标准的 SQL 插入语句; 116 | // append 表示在 sql 不为空的情况下,sql 与现有的插入语句的结合方式, 117 | // 如果为 true 表示直接添加在插入语句之后,否则为一条新的语句。 118 | LastInsertIDSQL(table, col string) (sql string, append bool) 119 | 120 | // CreateTableOptionsSQL 创建表时根据附加信息返回的部分 SQL 语句 121 | CreateTableOptionsSQL(sql *Builder, options map[string][]string) error 122 | 123 | // TruncateTableSQL 生成清空数据表并重置自增列的语句 124 | // 125 | // ai 表示自增列的名称,可以为空,表示没有自去列。 126 | TruncateTableSQL(table, ai string) ([]string, error) 127 | 128 | // CreateViewSQL 生成创建视图的 SQL 语句 129 | CreateViewSQL(replace, temporary bool, name, selectQuery string, cols []string) ([]string, error) 130 | 131 | // DropIndexSQL 生成删除索引的语句 132 | // 133 | // table 为表名,部分数据库需要; 134 | // index 表示索引名; 135 | DropIndexSQL(table, index string) (string, error) 136 | 137 | // Fix 对 sql 语句作调整 138 | // 139 | // 比如处理 [sql.NamedArgs],postgresql 需要将 ? 改成 $1 等形式。 140 | // 以及对 args 的参数作校正,比如 lib/pq 对 [time.Time] 处理有问题,也可以在此处作调整。 141 | // 142 | // NOTE: query 中不能同时存在 ? 和命名参数。因为如果是命名参数,则 args 的顺序可以是随意的。 143 | Fix(query string, args []any) (string, []any, error) 144 | 145 | // Prepare 对预编译的内容进行处理 146 | // 147 | // 目前大部分驱动都不支持 [sql.NamedArgs],为了支持该功能, 148 | // 需要在预编译之前,对语句进行如下处理: 149 | // 1. 将 sql 中的 @xx 替换成 ? 150 | // 2. 将 sql 中的 @xx 在 sql 中的位置进行记录,并通过 orders 返回。 151 | // query 为处理后的 SQL 语句; 152 | // orders 为参数名在 query 中对应的位置,第一个位置为 0,依次增加。 153 | // 154 | // NOTE: query 中不能同时存在 ? 和命名参数。因为如果是命名参数,则 Exec 等的参数顺序可以是随意的。 155 | Prepare(sql string) (query string, orders map[string]int, err error) 156 | 157 | // Backup 备份数据库 158 | // 159 | // dsn 初始化数据库的参数,主要从其中获取数据库名称等参数; 160 | // dest 备份的文件名,格式由实现者决定; 161 | Backup(dsn, dest string) error 162 | } 163 | 164 | // ErrConstraintExists 返回约束名已经存在的错误 165 | func ErrConstraintExists(c string) error { return fmt.Errorf("约束 %s 已经存在", c) } 166 | -------------------------------------------------------------------------------- /core/core_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package core_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/orm/v6/internal/test" 11 | ) 12 | 13 | func TestMain(m *testing.M) { 14 | test.Main(m) 15 | } 16 | -------------------------------------------------------------------------------- /core/primitive.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2025 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package core 6 | 7 | import ( 8 | "database/sql" 9 | "reflect" 10 | "time" 11 | ) 12 | 13 | // TimeFormatLayout 时间如果需要转换成字符串采用此格式 14 | const TimeFormatLayout = time.RFC3339 15 | 16 | // 当前支持的 [PrimitiveType] 值 17 | // 18 | // 其中的 [String] 被设计成可以保存部分类型为 [reflect.Interface] 的数据结构, 19 | // 但是一个有限的集合,比如将一个 any 字段赋予 slice 类型,在保存时可能不被支持。 20 | // 且在读取时,各个数据库略有不同,比如 mysql 返回 []byte,而其它数据一般返回 string。 21 | const ( 22 | Auto PrimitiveType = iota 23 | Bool 24 | Int 25 | Int8 26 | Int16 27 | Int32 28 | Int64 29 | Uint 30 | Uint8 31 | Uint16 32 | Uint32 33 | Uint64 34 | Float32 35 | Float64 36 | String 37 | Bytes 38 | Time 39 | Decimal 40 | maxPrimitiveType 41 | ) 42 | 43 | var ( 44 | typeStrings = map[PrimitiveType]string{ 45 | Auto: "auto", 46 | Bool: "bool", 47 | Int: "int", 48 | Int8: "int8", 49 | Int16: "int16", 50 | Int32: "int32", 51 | Int64: "int64", 52 | Uint: "uint", 53 | Uint8: "uint8", 54 | Uint16: "uint16", 55 | Uint32: "uint32", 56 | Uint64: "uint64", 57 | Float32: "float32", 58 | Float64: "float64", 59 | String: "string", 60 | Bytes: "bytes", 61 | Time: "time", 62 | Decimal: "decimal", 63 | } 64 | 65 | types = map[reflect.Type]PrimitiveType{ 66 | reflect.TypeFor[bool](): Bool, 67 | reflect.TypeFor[int](): Int, 68 | reflect.TypeFor[int8](): Int8, 69 | reflect.TypeFor[int16](): Int16, 70 | reflect.TypeFor[int32](): Int32, 71 | reflect.TypeFor[int64](): Int64, 72 | reflect.TypeFor[uint](): Uint, 73 | reflect.TypeFor[uint8](): Uint8, 74 | reflect.TypeFor[uint16](): Uint16, 75 | reflect.TypeFor[uint32](): Uint32, 76 | reflect.TypeFor[uint64](): Uint64, 77 | reflect.TypeFor[float32](): Float32, 78 | reflect.TypeFor[float64](): Float64, 79 | reflect.TypeFor[string](): String, 80 | reflect.TypeFor[[]byte](): Bytes, 81 | reflect.TypeFor[sql.RawBytes](): Bytes, 82 | reflect.TypeFor[time.Time](): Time, 83 | 84 | reflect.TypeFor[sql.NullString](): String, 85 | reflect.TypeFor[sql.NullByte](): Bytes, 86 | reflect.TypeFor[sql.NullInt64](): Int64, 87 | reflect.TypeFor[sql.NullInt32](): Int32, 88 | reflect.TypeFor[sql.NullInt16](): Int16, 89 | reflect.TypeFor[sql.NullBool](): Bool, 90 | reflect.TypeFor[sql.NullFloat64](): Float64, 91 | reflect.TypeFor[sql.NullTime](): Time, 92 | } 93 | 94 | kinds = map[reflect.Kind]PrimitiveType{ 95 | reflect.Bool: Bool, 96 | reflect.Int: Int, 97 | reflect.Int8: Int8, 98 | reflect.Int16: Int16, 99 | reflect.Int32: Int32, 100 | reflect.Int64: Int64, 101 | reflect.Uint: Uint, 102 | reflect.Uint8: Uint8, 103 | reflect.Uint16: Uint16, 104 | reflect.Uint32: Uint32, 105 | reflect.Uint64: Uint64, 106 | reflect.Float32: Float32, 107 | reflect.Float64: Float64, 108 | reflect.String: String, 109 | reflect.Interface: String, 110 | } 111 | 112 | primitiveTyperType = reflect.TypeFor[PrimitiveTyper]() 113 | ) 114 | 115 | type PrimitiveTyper interface { 116 | // NOTE: 最简单的方法是复用 [driver.Valuer] 接口,从其返回值中获取类型信息, 117 | // 但是该接口有可能返回 nil 值,无法确定类型。 118 | 119 | // PrimitiveType 返回当前对象所表示的 [PrimitiveType] 值 120 | // 121 | // NOTE: 每个对象在任何时间返回的值应该都是固定的。 122 | PrimitiveType() PrimitiveType 123 | } 124 | 125 | // PrimitiveType 表示 Go 对象在数据库中实际的存储方式 126 | // 127 | // PrimitiveType 由 [Dialect.SQLType] 转换成相应数据的实际类型。 128 | type PrimitiveType int 129 | 130 | // GetPrimitiveType 获取 t 所关联的 [PrimitiveType] 值 131 | // 132 | // t.Kind 不能为 [reflect.Ptr] 否则将返回 [Auto]。 133 | func GetPrimitiveType(t reflect.Type) PrimitiveType { 134 | primitiveType, found := kinds[t.Kind()] 135 | if found { 136 | return primitiveType 137 | } 138 | 139 | primitiveType, found = types[t] 140 | if !found { 141 | v := reflect.New(t).Elem() 142 | if t.Implements(primitiveTyperType) { 143 | primitiveType = v.Interface().(PrimitiveTyper).PrimitiveType() 144 | } else if v.Addr().Type().Implements(primitiveTyperType) { 145 | primitiveType = v.Addr().Interface().(PrimitiveTyper).PrimitiveType() 146 | } 147 | } 148 | 149 | return primitiveType 150 | } 151 | 152 | func (t PrimitiveType) String() string { return typeStrings[t] } 153 | -------------------------------------------------------------------------------- /core/primitive_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package core 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | 11 | "github.com/issue9/assert/v4" 12 | ) 13 | 14 | func TestPrimitiveType(t *testing.T) { 15 | a := assert.New(t, false) 16 | 17 | // 保证 PrimitiveType.String() 拥有所有的值。 18 | for i := Auto; i < maxPrimitiveType; i++ { 19 | a.NotEmpty(i.String()) 20 | } 21 | a.Length(typeStrings, int(maxPrimitiveType)) 22 | } 23 | 24 | func TestGetPrimitiveType(t *testing.T) { 25 | a := assert.New(t, false) 26 | 27 | a.Equal(GetPrimitiveType(reflect.TypeOf(1)), Int) 28 | a.Equal(GetPrimitiveType(reflect.TypeOf([]byte{1, 2})), Bytes) 29 | a.Equal(GetPrimitiveType(reflect.TypeOf("string")), String) 30 | a.Equal(GetPrimitiveType(reflect.TypeOf(any(5))), Int) 31 | 32 | // 指针的 PrimitiveType 33 | x := 5 34 | a.Equal(GetPrimitiveType(reflect.TypeOf(&x)), Auto) 35 | 36 | // 自定义类型,但是未实现 PrimitiveTyper 接口 37 | type T int16 38 | a.Equal(GetPrimitiveType(reflect.TypeOf(T(1))), Int16) 39 | 40 | type obj struct{} 41 | a.Equal(GetPrimitiveType(reflect.TypeOf(obj{})), Auto) 42 | 43 | type obj2 struct { 44 | Any any 45 | } 46 | o2 := obj2{} 47 | field, _ := reflect.ValueOf(o2).Type().FieldByName("Any") 48 | a.Equal(GetPrimitiveType(field.Type), String) 49 | } 50 | -------------------------------------------------------------------------------- /core/stmt.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package core 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | "fmt" 11 | "slices" 12 | ) 13 | 14 | // Stmt 实现自定义的 Stmt 实例 15 | // 16 | // 功能与 [sql.Stmt] 完全相同,但是实现了对 [sql.NamedArgs] 的支持。 17 | type Stmt struct { 18 | *sql.Stmt 19 | orders map[string]int 20 | } 21 | 22 | // NewStmt 声明 [Stmt] 实例 23 | // 24 | // 如果 orders 为空,则 Stmt 的表现和 [sql.Stmt] 是完全相同的, 25 | // 如果不为空,则可以处理 [sql.NamedArg] 类型的参数。 26 | func NewStmt(stmt *sql.Stmt, orders map[string]int) *Stmt { 27 | if len(orders) > 0 { 28 | vals := make([]int, 0, len(orders)) 29 | for _, v := range orders { 30 | vals = append(vals, v) 31 | } 32 | slices.Sort(vals) 33 | 34 | for k, v := range vals { 35 | if k != v { 36 | panic(fmt.Sprintf("orders 并不是连续的参数,缺少了 %d", k)) 37 | } 38 | } 39 | } 40 | 41 | return &Stmt{ 42 | Stmt: stmt, 43 | orders: orders, 44 | } 45 | } 46 | 47 | // Close 关闭 Stmt 实例 48 | func (stmt *Stmt) Close() error { 49 | stmt.orders = nil 50 | return stmt.Stmt.Close() 51 | } 52 | 53 | // Exec 以指定的参数执行预编译的语句 54 | func (stmt *Stmt) Exec(args ...any) (sql.Result, error) { 55 | return stmt.ExecContext(context.Background(), args...) 56 | } 57 | 58 | // ExecContext 以指定的参数执行预编译的语句 59 | func (stmt *Stmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { 60 | args, err := stmt.buildArgs(args) 61 | if err != nil { 62 | return nil, err 63 | } 64 | return stmt.Stmt.ExecContext(ctx, args...) 65 | } 66 | 67 | // Query 以指定的参数执行预编译的语句 68 | func (stmt *Stmt) Query(args ...any) (*sql.Rows, error) { 69 | return stmt.QueryContext(context.Background(), args...) 70 | } 71 | 72 | // QueryContext 以指定的参数执行预编译的语句 73 | func (stmt *Stmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) { 74 | args, err := stmt.buildArgs(args) 75 | if err != nil { 76 | return nil, err 77 | } 78 | return stmt.Stmt.QueryContext(ctx, args...) 79 | } 80 | 81 | // QueryRow 以指定的参数执行预编译的语句 82 | func (stmt *Stmt) QueryRow(args ...any) *sql.Row { 83 | return stmt.QueryRowContext(context.Background(), args...) 84 | } 85 | 86 | // QueryRowContext 以指定的参数执行预编译的语句 87 | func (stmt *Stmt) QueryRowContext(ctx context.Context, args ...any) *sql.Row { 88 | args, err := stmt.buildArgs(args) 89 | if err != nil { 90 | panic(err) 91 | } 92 | return stmt.Stmt.QueryRowContext(ctx, args...) 93 | } 94 | 95 | func (stmt *Stmt) buildArgs(args []any) ([]any, error) { 96 | if len(stmt.orders) == 0 { 97 | return args, nil 98 | } 99 | 100 | if len(args) != len(stmt.orders) { 101 | return nil, fmt.Errorf("给定的参数数量 %d 与预编译的参数数量 %d 不相等", len(args), len(stmt.orders)) 102 | } 103 | 104 | ret := make([]any, len(args)) 105 | 106 | for index, arg := range args { 107 | named, ok := arg.(sql.NamedArg) 108 | if !ok { 109 | return nil, fmt.Errorf("第 %d 个参数并非是 sql.NamedArg 类型", index) 110 | } 111 | 112 | i, found := stmt.orders[named.Name] 113 | if !found { 114 | return nil, fmt.Errorf("参数 %s 并不存在于预编译内容中", named.Name) 115 | } 116 | ret[i] = named.Value 117 | } 118 | 119 | return ret, nil 120 | } 121 | -------------------------------------------------------------------------------- /core/stmt_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package core 6 | 7 | import ( 8 | "database/sql" 9 | "testing" 10 | 11 | "github.com/issue9/assert/v4" 12 | ) 13 | 14 | func TestNewStmt(t *testing.T) { 15 | a := assert.New(t, false) 16 | 17 | a.NotPanic(func() { 18 | stmt := NewStmt(nil, nil) 19 | a.NotNil(stmt) 20 | }) 21 | 22 | a.NotPanic(func() { 23 | stmt := NewStmt(nil, map[string]int{"id": 0, "name": 1, "test": 2}) 24 | a.NotNil(stmt) 25 | }) 26 | 27 | a.NotPanic(func() { 28 | stmt := NewStmt(nil, map[string]int{"id": 2, "name": 1, "test": 0}) 29 | a.NotNil(stmt) 30 | }) 31 | 32 | a.Panic(func() { 33 | NewStmt(nil, map[string]int{"id": 1}) 34 | }) 35 | 36 | a.Panic(func() { 37 | NewStmt(nil, map[string]int{"id": 1, "name": 0, "test": 10}) 38 | }) 39 | } 40 | 41 | func TestStmt_buildArgs(t *testing.T) { 42 | a := assert.New(t, false) 43 | 44 | data := []*struct { 45 | orders map[string]int 46 | input []any 47 | output []any 48 | err bool 49 | }{ 50 | {}, 51 | { // orders 为空,则原样返回内容 52 | input: []any{1, 2, 3}, 53 | output: []any{1, 2, 3}, 54 | }, 55 | { // orders 为空,则原样返回内容 56 | orders: map[string]int{}, 57 | input: []any{1, 2, 3}, 58 | output: []any{1, 2, 3}, 59 | }, 60 | { // 参数数量不匹配 61 | orders: map[string]int{"id": 0}, 62 | input: []any{sql.Named("id", 1), 1, 2}, 63 | err: true, 64 | }, 65 | { // 输入参数有非 sql.Named 类型 66 | orders: map[string]int{"id": 0, "name": 1}, 67 | input: []any{sql.Named("id", 1), 1}, 68 | err: true, 69 | }, 70 | { // 参数并不在 orders 中 71 | orders: map[string]int{"id": 0, "name": 1}, 72 | input: []any{sql.Named("id", 1), sql.Named("not-exists-arg", "test")}, 73 | err: true, 74 | }, 75 | { 76 | orders: map[string]int{"name": 1, "id": 0}, 77 | input: []any{sql.Named("id", 1), sql.Named("name", "test")}, 78 | output: []any{1, "test"}, 79 | }, 80 | } 81 | 82 | for k, v := range data { 83 | stmt := NewStmt(nil, v.orders) 84 | output, err := stmt.buildArgs(v.input) 85 | if v.err { 86 | a.Error(err, "not error @ %d", k). 87 | Nil(output) 88 | } else { 89 | a.Equal(output, v.output, "not equal @%d,v1:%s,v2:%s", k, output, v.output) 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /dialect/base.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package dialect 6 | 7 | import ( 8 | "os/exec" 9 | "strings" 10 | 11 | "github.com/issue9/sliceutil" 12 | ) 13 | 14 | var ( 15 | quoteApostrophe = strings.NewReplacer("'", "''") // 标准 SQL 用法 16 | escapeApostrophe = strings.NewReplacer("'", "\\'") // mysql 用法 17 | ) 18 | 19 | type base struct { 20 | driverName string 21 | name string 22 | quoteL, quoteR byte 23 | } 24 | 25 | func newBase(name, driverName string, quoteLeft, quoteRight byte) base { 26 | return base{ 27 | name: name, 28 | driverName: driverName, 29 | quoteL: quoteLeft, 30 | quoteR: quoteRight, 31 | } 32 | } 33 | 34 | func (b *base) Name() string { return b.name } 35 | 36 | func (b *base) DriverName() string { return b.driverName } 37 | 38 | func (b *base) Quotes() (byte, byte) { return b.quoteL, b.quoteR } 39 | 40 | func buildCmdArgs(k, v string) string { 41 | if v == "" { 42 | return "" 43 | } 44 | return k + "=" + v 45 | } 46 | 47 | func newCommand(name string, env, kv []string) *exec.Cmd { 48 | env = sliceutil.Filter(env, func(i string, _ int) bool { return i != "" }) 49 | kv = sliceutil.Filter(kv, func(i string, _ int) bool { return i != "" }) 50 | cmd := exec.Command(name, kv...) 51 | cmd.Env = append(cmd.Env, env...) 52 | return cmd 53 | } 54 | -------------------------------------------------------------------------------- /dialect/base_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package dialect 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | ) 12 | 13 | func TestBuildCmdArg(t *testing.T) { 14 | a := assert.New(t, false) 15 | a.Equal(buildCmdArgs("-p", ""), ""). 16 | Equal(buildCmdArgs("-p", "123"), "-p=123") 17 | } 18 | -------------------------------------------------------------------------------- /dialect/dialect.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | // Package dialect 提供了部分数据库对 [core.Dialect] 接口的实现 6 | package dialect 7 | 8 | import ( 9 | "database/sql" 10 | "errors" 11 | "fmt" 12 | "unicode" 13 | 14 | "github.com/issue9/orm/v6/core" 15 | "github.com/issue9/orm/v6/sqlbuilder" 16 | ) 17 | 18 | var ( 19 | errColIsNil = errors.New("参数 col 参数是个空值") 20 | 21 | datetimeLayouts = []string{ 22 | "2006-01-02 15:04:05", 23 | "2006-01-02 15:04:05.9", 24 | "2006-01-02 15:04:05.99", 25 | "2006-01-02 15:04:05.999", 26 | "2006-01-02 15:04:05.9999", 27 | "2006-01-02 15:04:05.99999", 28 | "2006-01-02 15:04:05.999999", 29 | } 30 | ) 31 | 32 | func missLength(col *core.Column) error { 33 | return fmt.Errorf("列 %s 缺少长度数据", col.Name) 34 | } 35 | 36 | func invalidTimeFractional(col *core.Column) error { 37 | return fmt.Errorf("列 %s 时间精度只能介于 [0,6] 之间", col.Name) 38 | } 39 | 40 | func errUncovert(col *core.Column) error { 41 | return fmt.Errorf("不支持的列类型: %s", col.Name) 42 | } 43 | 44 | // mysqlLimitSQL mysql 系列数据库分页语法的实现 45 | // 46 | // 支持以下数据库:MySQL, H2, HSQLDB, Postgres, SQLite3 47 | func mysqlLimitSQL(limit any, offset ...any) (string, []any) { 48 | query := " LIMIT " 49 | 50 | if named, ok := limit.(sql.NamedArg); ok && named.Name != "" { 51 | query += "@" + named.Name 52 | } else { 53 | query += "?" 54 | } 55 | 56 | if len(offset) == 0 { 57 | return query + " ", []any{limit} 58 | } 59 | 60 | query += " OFFSET " 61 | o := offset[0] 62 | if named, ok := o.(sql.NamedArg); ok && named.Name != "" { 63 | query += "@" + named.Name 64 | } else { 65 | query += "?" 66 | } 67 | 68 | return query + " ", []any{limit, offset[0]} 69 | } 70 | 71 | // oracleLimitSQL oracle 系列数据库分页语法的实现 72 | // 73 | // 支持以下数据库:Derby, SQL Server 2012, Oracle 12c, the SQL 2008 standard 74 | func oracleLimitSQL(limit any, offset ...any) (string, []any) { 75 | query := "FETCH NEXT " 76 | 77 | if named, ok := limit.(sql.NamedArg); ok && named.Name != "" { 78 | query += "@" + named.Name 79 | } else { 80 | query += "?" 81 | } 82 | query += " ROWS ONLY " 83 | 84 | if len(offset) == 0 { 85 | return query, []any{limit} 86 | } 87 | 88 | o := offset[0] 89 | if named, ok := o.(sql.NamedArg); ok && named.Name != "" { 90 | query = "OFFSET @" + named.Name + " ROWS " + query 91 | } else { 92 | query = "OFFSET ? ROWS " + query 93 | } 94 | 95 | return query, []any{offset[0], limit} 96 | } 97 | 98 | // PrepareNamedArgs 对命名参数进行预处理 99 | // 100 | // 命名参数替换成 ?,并返回参数名称对应在语句的位置。 101 | // query 中不能同时包含命名参数和 ?,否则将 panic。 102 | func PrepareNamedArgs(query string) (string, map[string]int, error) { 103 | orders := map[string]int{} 104 | builder := core.NewBuilder("") 105 | start := -1 106 | cnt := 0 107 | 108 | write := func(name string) { 109 | if _, found := orders[name]; found { 110 | panic("存在相同的参数名:" + name) 111 | } 112 | 113 | builder.WString(" ? ") 114 | orders[name] = cnt 115 | } 116 | 117 | for index, c := range query { 118 | switch { 119 | case c == '@': 120 | start = index + 1 121 | case start != -1 && !(unicode.IsLetter(c) || unicode.IsDigit(c)): 122 | write(query[start:index]) 123 | builder.WRunes(c) // 当前的字符不能丢 124 | cnt++ 125 | start = -1 126 | case start == -1: 127 | builder.WRunes(c) 128 | if c == '?' && cnt > 0 { 129 | panic("不能同时存在 ? 和命名参数") 130 | } 131 | } 132 | } 133 | 134 | if start > -1 { 135 | write(query[start:]) 136 | } 137 | 138 | q, err := builder.String() 139 | if err != nil { 140 | return "", nil, err 141 | } 142 | return q, orders, nil 143 | } 144 | 145 | func stdDropIndex(index string) (string, error) { 146 | if index == "" { 147 | return "", sqlbuilder.SyntaxError("DROP INDEX", "未指定列") 148 | } 149 | 150 | return core.NewBuilder("DROP INDEX ").QuoteKey(index).String() 151 | } 152 | 153 | func appendViewBody(builder *core.Builder, name, selectQuery string, cols []string) (string, error) { 154 | builder.WString(" VIEW ").QuoteKey(name) 155 | 156 | if len(cols) > 0 { 157 | builder.WBytes('(') 158 | for _, col := range cols { 159 | builder.QuoteKey(col).WBytes(',') 160 | } 161 | builder.TruncateLast(1).WBytes(')') 162 | } 163 | 164 | return builder.WString(" AS ").WString(selectQuery).String() 165 | } 166 | 167 | // 修正查询语句和查询参数的位置 168 | func fixQueryAndArgs(query string, args []any) (string, []any, error) { 169 | query, orders, err := PrepareNamedArgs(query) 170 | if err != nil { 171 | return "", nil, err 172 | } 173 | 174 | // 整理返回参数 175 | named := make(map[int]any, len(orders)) 176 | for _, arg := range args { 177 | if n, ok := arg.(sql.NamedArg); ok { 178 | i, found := orders[n.Name] 179 | if !found { 180 | panic(fmt.Sprintf("不存在指定名称的参数 %s", n.Name)) 181 | } 182 | delete(orders, n.Name) 183 | named[i] = n.Value 184 | continue 185 | } 186 | } 187 | 188 | if len(orders) > 0 { 189 | panic("占位符与命名参数的数量不相同") 190 | } 191 | 192 | for index, val := range named { 193 | args[index] = val 194 | } 195 | 196 | return query, args, nil 197 | } 198 | -------------------------------------------------------------------------------- /dialect/dialect_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package dialect 6 | 7 | import ( 8 | "database/sql" 9 | "testing" 10 | 11 | "github.com/issue9/assert/v4" 12 | 13 | "github.com/issue9/orm/v6/internal/sqltest" 14 | ) 15 | 16 | func TestMysqlLimitSQL(t *testing.T) { 17 | a := assert.New(t, false) 18 | 19 | query, ret := mysqlLimitSQL(5, 0) 20 | a.Equal(ret, []int{5, 0}) 21 | sqltest.Equal(a, query, " LIMIT ? OFFSET ? ") 22 | 23 | query, ret = mysqlLimitSQL(5) 24 | a.Equal(ret, []int{5}) 25 | sqltest.Equal(a, query, "LIMIT ?") 26 | 27 | // 带 sql.namedArg 28 | query, ret = mysqlLimitSQL(sql.Named("limit", 1), 2) 29 | a.Equal(ret, []any{sql.Named("limit", 1), 2}) 30 | sqltest.Equal(a, query, "LIMIT @limit offset ?") 31 | } 32 | 33 | func TestOracleLimitSQL(t *testing.T) { 34 | a := assert.New(t, false) 35 | 36 | query, ret := oracleLimitSQL(5, 0) 37 | a.Equal(ret, []int{0, 5}) 38 | sqltest.Equal(a, query, " OFFSET ? ROWS FETCH NEXT ? ROWS ONLY ") 39 | 40 | query, ret = oracleLimitSQL(5) 41 | a.Equal(ret, []int{5}) 42 | sqltest.Equal(a, query, "FETCH NEXT ? ROWS ONLY ") 43 | 44 | // 带 sql.namedArg 45 | query, ret = oracleLimitSQL(sql.Named("limit", 1), 2) 46 | a.Equal(ret, []any{2, sql.Named("limit", 1)}) 47 | sqltest.Equal(a, query, "offset ? rows fetch next @limit rows only") 48 | } 49 | 50 | func TestPrepareNamedArgs(t *testing.T) { 51 | a := assert.New(t, false) 52 | 53 | var data = []*struct { 54 | input string 55 | query string 56 | orders map[string]int 57 | err bool 58 | }{ 59 | { 60 | input: "select * from table", 61 | query: "select * from table", 62 | orders: map[string]int{}, 63 | }, 64 | { 65 | input: "select * from table where id=@id", 66 | query: "select * from table where id=?", 67 | orders: map[string]int{"id": 0}, 68 | }, 69 | { 70 | input: "select * from table where id=@id and name like @name", 71 | query: "select * from table where id=? and name like ?", 72 | orders: map[string]int{"id": 0, "name": 1}, 73 | }, 74 | { 75 | input: "select * from table where {id}=@id and {name} like @name", 76 | query: "select * from table where {id}=? and {name} like ?", 77 | orders: map[string]int{"id": 0, "name": 1}, 78 | }, 79 | { 80 | input: "select * from table where {编号}=@编号 and {name} like @name", 81 | query: "select * from table where {编号}=? and {name} like ?", 82 | orders: map[string]int{"编号": 0, "name": 1}, 83 | }, 84 | { 85 | input: "INSERT INTO users({id},{name}) VALUES (@id,@name)", 86 | query: "INSERT INTO users({id},{name}) VALUES (?,?)", 87 | orders: map[string]int{"id": 0, "name": 1}, 88 | }, 89 | { // 没有命名参数 90 | input: "INSERT INTO users({id},{name}) VALUES (?,?)", 91 | query: "INSERT INTO users({id},{name}) VALUES (?,?)", 92 | orders: map[string]int{}, 93 | }, 94 | { // 参数名称是另一个参数名称的一部分 95 | input: "select * from table where id=@id and id=1 and id=@id2", 96 | query: "select * from table where id=? and id=1 and id=?", 97 | orders: map[string]int{"id": 0, "id2": 1}, 98 | }, 99 | } 100 | 101 | for _, item := range data { 102 | q, o, err := PrepareNamedArgs(item.input) 103 | 104 | if item.err { 105 | a.Error(err).Nil(o).Empty(q) 106 | continue 107 | } 108 | 109 | a.NotError(err). 110 | Equal(o, item.orders) 111 | sqltest.Equal(a, q, item.query) 112 | } 113 | 114 | a.PanicString(func() { 115 | PrepareNamedArgs("INSERT INTO users({id},{name}) VALUES (@id,@id)") 116 | }, "存在相同的参数名") 117 | 118 | a.PanicString(func() { 119 | PrepareNamedArgs("INSERT INTO users({id},{name}) VALUES (@id,?)") 120 | }, "不能同时存在 ? 和命名参数") 121 | } 122 | 123 | func TestFixQueryAndArgs(t *testing.T) { 124 | a := assert.New(t, false) 125 | 126 | data := []*struct { 127 | query string 128 | args []any 129 | err bool 130 | outputQuery string 131 | outputArgs []any 132 | }{ 133 | { 134 | query: "select * from table where id=1 and id=@id", 135 | args: []any{sql.Named("id", 2)}, 136 | outputQuery: "select * from table where id=1 and id=?", 137 | outputArgs: []any{2}, 138 | }, 139 | { 140 | query: "select * from table where id=@id and id=1 and id=@id2", 141 | args: []any{sql.Named("id2", 1), sql.Named("id", 2)}, 142 | outputQuery: "select * from table where id=? and id=1 and id=?", 143 | outputArgs: []any{2, 1}, 144 | }, 145 | } 146 | 147 | for _, item := range data { 148 | query, args, err := fixQueryAndArgs(item.query, item.args) 149 | a.NotError(err). 150 | Equal(args, item.outputArgs) 151 | sqltest.Equal(a, query, item.outputQuery) 152 | } 153 | 154 | a.Panic(func() { 155 | fixQueryAndArgs("select * from table where id=@id and id=@id2", []any{sql.Named("id2", 1), sql.Named("id3", 2)}) 156 | }) 157 | 158 | a.Panic(func() { 159 | fixQueryAndArgs("select * from table where id=@id and id=@id2", []any{sql.Named("id2", 1), sql.Named("not-exists", 2)}) 160 | }) 161 | 162 | a.Panic(func() { 163 | fixQueryAndArgs("select * from table where id=@id and id=@id", []any{sql.Named("id", 1), sql.Named("id2", 2)}) 164 | }) 165 | 166 | a.Panic(func() { 167 | fixQueryAndArgs("select * from table where id=@id and id=?", []any{sql.Named("id", 1)}) 168 | }) 169 | } 170 | -------------------------------------------------------------------------------- /dialect/testdata/.gitignore: -------------------------------------------------------------------------------- 1 | *.db 2 | *.sql 3 | -------------------------------------------------------------------------------- /dialect/testdata/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/issue9/orm/96127c63a57daa06d2a001c1970ec362af781cae/dialect/testdata/.gitkeep -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | // Package orm 一个简单小巧的 orm 实现方案 6 | // 7 | // 目前内置了对以下数据库的支持: 8 | // - sqlite3: github.com/mattn/go-sqlite3 9 | // - mysql/mariadb: github.com/go-sql-driver/mysql 10 | // - postgres: github.com/lib/pq 11 | // - sqlite modernc.org/sqlite 纯 Go 的 sqlite3,无须 CGO。 12 | // 13 | // 其它数据库,用户可以通过实现 [Dialect] 接口,来实现相应的支持。 14 | // 15 | // 初始化: 16 | // 17 | // 默认情况下,orm 包并不会加载任何数据库的实例。所以想要用哪个数据库,需要手动初始化: 18 | // 19 | // import ( 20 | // _ github.com/mattn/go-sqlite3 // 加载数据库驱动 21 | // _ github.com/issue9/orm/v6/dialect // sqlite3 的 dialect 声明在此处 22 | // ) 23 | // 24 | // // 初始化一个 DB 25 | // db1 := orm.NewDB("./db1", dialect.Sqlite3("sqlite3")) 26 | // 27 | // // 另一个 DB 实例 28 | // db2 := orm.NewDB("./db2", dialect.Sqlite3("sqlite3")) 29 | // 30 | // 占位符 31 | // 32 | // SQL 中可以使用以下占位符: 33 | // - {} 包含一个关键字,使其它成为普通列名; 34 | // - # 表名前缀,为采用 core.Engine.TablePrefix 进行替换; 35 | // 36 | // 如: 37 | // 38 | // select * from user where {#group}=1 39 | // 40 | // 在实际执行时,相关的占位符就会被替换成与当前环境想容的实例, 41 | // 如在数据库为 mysql 时,会被替换成以下语句,然后再执行: 42 | // 43 | // select * from p_user where `group`=1 44 | // 45 | // [DB.Query]、[DB.Exec]、[DB.Prepare]、[DB.Where] 及 [Tx] 与之对应的函数都可以使用占位符。 46 | // 47 | // Model 不能指定占位符,它们默认总会使用占位符,且无法取消。 48 | // 49 | // Model: 50 | // 51 | // orm 包通过 struct tag 来描述 model 在数据库中的结构。大概格式如下: 52 | // 53 | // type User struct { 54 | // Id int64 `orm:"name(id);ai;"` 55 | // FirstName string `orm:"name(first_name);index(index_name)"` 56 | // LastName string `orm:"name(first_name);index(index_name)"` 57 | // 58 | // // 此处group会自动加上引号,无须担心是否为关键字 59 | // Group string `orm:"name(group)"` 60 | // } 61 | // 62 | // // 通过 [ApplyModeler] 接口,指定表的额外数据。若不需要,可不用实现该接口 63 | // func(u *User) ApplyModel(m *core.Model) error { 64 | // m.Name = "user" 65 | // m.Options["engine"] = "innodb" 66 | // m.Options["charset"] = "utf8" 67 | // return nil 68 | // } 69 | // 70 | // 目前支持以下的 struct tag: 71 | // 72 | // name(fieldName): 指定当前字段在数据表中的名称,如果未指定, 73 | // 则和字段名相同。只有可导出的字段才有效果。 74 | // 75 | // len(l1, l2): 指定字段的长度。比如 mysql 中的int(5),varchar(255),double(1,2), 76 | // 不支持该特性的数据,将会忽略该标签的内容,比如 sqlite3。 77 | // NOTE:字符串类型必须指定长度,若长度过大或是将长度设置了 -1, 78 | // 想使用类似于 TEXT 等不定长的形式表达。 79 | // 如果是日期类型,则第一个可选参数表示日期精度。 80 | // 81 | // nullable(true|false): 相当于定义表结构时的 NULL。 82 | // 83 | // pk: 主键,支持联合主键,给多个字段加上 pk 的 struct tag 即可。 84 | // 85 | // ai: 自增,若指定了自增列,则将自动取消其它的 pk 设置。无法指定起始值和步长。 86 | // 可手动设置一个非零值来更改某条数据的 AI 行为。 87 | // 88 | // unique(index_name): 唯一索引,支持联合索引,index_name 为约束名, 89 | // 会将 index_name 为一样的字段定义为一个联合索引。 90 | // 91 | // index(index_name): 普通的关键字索引,同 unique 一样会将名称相同的索引定义为一个联合索引。 92 | // 93 | // occ(true|false) 当前列作为乐观锁字段。 94 | // 95 | // 作为乐观锁的字段,其值表示的是线上数据的值,在更新时,会自动给线上的值加 1。 96 | // 97 | // default(value): 指定默认值。相当于定义表结构时的 DEFAULT。 98 | // 当一个字段如果是个零值(reflect.Zero())时,将会使用它的默认值, 99 | // 但是系统无法判断该零值是人为指定,还是未指定被默认初始化零值的, 100 | // 所以在需要用到零值的字段,最好不要用 default 的 struct tag。 101 | // 102 | // fk(fk_name,refTable,refColName,updateRule,deleteRule): 103 | // 定义物理外键,最少需要指定 fk_name,refTable,refColName 三个值。 104 | // 分别对应约束名,引用的表和引用的字段,updateRule,deleteRule, 105 | // 在不指定的情况下,使用数据库的默认值。 106 | // 107 | // ApplyModeler: 108 | // 109 | // 用于将一个对象转换成 Model 对象时执行的函数,给予用户修改 Model 的机会, 110 | // 在 [ApplyModeler] 中可以修改任意模型的内容,所以也可以由 ApplyModeler 代替 struct tag 的操作。 111 | // 112 | // 约束名: 113 | // 114 | // index、unique、check 和 fk 都是可以指定约束名的,在表中,约束名必须是唯一的, 115 | // 即便是不同类型的约束,比如已经有一个 unique 的约束名叫作 name,那么其它类 116 | // 型的约束,就不能再取这个名称了。 117 | // 部分数据库(比如 postgres)可能要求约束名是整个数据库唯一的, 118 | // 为了统一,在 ORM 中所有数据都强制约束名全局(数据库)唯一。 119 | // 120 | // 如何使用: 121 | // 122 | // Create: 123 | // 可以通过 DB.Create() 或是 Tx.Create() 创建一张表。 124 | // 125 | // // 创建表 126 | // db.Create(&User{}) 127 | // 128 | // Update: 129 | // 130 | // // 将 id 为 1 的记录的 FirstName 更改为 abc;对象中的零值不会被提交。 131 | // db.Update(&User{Id:1,FirstName:"abc"}) 132 | // sqlbuilder.Update(db).Table("table").Where("id=?",1).Set("FirstName", "abc").Exec() 133 | // 134 | // Delete: 135 | // 136 | // // 删除 id 为 1 的记录 137 | // e.Delete(&User{Id:1}) 138 | // sqlbuilder.Delete(e).Table("table").Where("id=?",1).Exec() 139 | // 140 | // Insert: 141 | // 142 | // // 插入一条数据 143 | // db.Insert(&User{Id:1,FirstName:"abc"}) 144 | // // 一次性插入多条数据 145 | // tx.InsertMany(&User{Id:1,FirstName:"abc"},&User{Id:1,FirstName:"abc"}) 146 | // 147 | // Select: 148 | // 149 | // // 导出 id=1 的数据 150 | // _,err := sqlbuilder.Select(e, e.Dialect()).Select("*").From("{table}").Where("id=1").QueryObj(obj) 151 | // // 导出 id 为 1 的数据,并回填到 user 实例中 152 | // user := &User{Id:1} 153 | // err := e.Select(u) 154 | // 155 | // Query/Exec: 156 | // 157 | // // Query 返回参数与 sql.Query 是相同的 158 | // sql := "select * from tbl_name where id=?" 159 | // rows, err := e.Query(sql, []interface{}{5}) 160 | // // Exec 返回参数与 sql.Exec 是相同的 161 | // sql = "update tbl_name set name=? where id=?" 162 | // r, err := e.Exec(sql, []interface{}{"name1", 5}) 163 | // 164 | // 事务: 165 | // 166 | // 默认的 [DB] 是不支持事务的,若需要事务支持,则需要调用 [DB.Begin] 167 | // 返回事务对象 [Tx],当然并不是所有的数据库都支持事务操作的。 168 | // [Tx] 拥有一组与 [DB] 相同的接口。 169 | package orm 170 | -------------------------------------------------------------------------------- /docs/advance.md: -------------------------------------------------------------------------------- 1 | ### 时区 2 | 3 | sqlite3 可以通过 _loc 的方式指定时区; 4 | mysql 可以通过 loc 参数指定时区; 5 | postgres 无法指定时间,都直接当作时区 0 进行了处理; 6 | 7 | 如果你的代码后期需要在不同的数据库之间迁移,那么建议将时区统一设置为 UTC。 8 | 9 | ### 时间精度 10 | 11 | 各个数据库驱动对精度处理方式并不相同,mysql 将未设置精度等同于精度为 0,而 12 | postgres 和 sqlite3 则会将其赞同于最大精度 6。 13 | 14 | ### 自定义类型 15 | 16 | ORM 支持对自定义类型的存储和读取,需要实现以下几个接口: 17 | 18 | - sql.Scan/driver.Valuer 这两个接口为标准库本身要求必须实现的; 19 | - core.PrimitiveTyper 指定了底层的 Go 类型,该值会在创建表时用于判断应该创建的数据库类型; 20 | - core.TableNamer 指定表名; 21 | 22 | 可以参考 types 下的各个自定义类型的实现。 23 | -------------------------------------------------------------------------------- /docs/curd.md: -------------------------------------------------------------------------------- 1 | DB 和 Tx 对象都提供了一套基于数据模型的基本操作, 2 | 功能上比较单一,如果需要复杂的 SQL 操作,则需要使用 sqlbuilder 3 | 下的内容。 4 | 5 | 像 Update、Delete 和 Select 等操作,需要指定查询条件的, 6 | 在 DB 和 Tx 中会从当前提交的对象中查找可用的查询条件。 7 | 可用的查询条件是指 AI、PK 和唯一约束中,所有值都不为零值的那一个约束。 8 | 所以 Update、Delete 和 Select 操作都是单一对象。 9 | 10 | 以下展示了 CRUD 的一些基本操作。假设操作对象为以下结构: 11 | 12 | ```go 13 | type User struct { 14 | ID int64 `orm:"name(id);ai"` 15 | Name string `orm:"name(name);len(20);index(i_user_name)"` 16 | Age int `orm:"name(age)"` 17 | Username string `orm:"name(username);unique(u_unique_username)"` 18 | } 19 | ``` 20 | 21 | ### TransactionalDDL 22 | 23 | `core.Dialect.TransactionalDDL()` 指定了当前数据是否支持在事务中执行 DDL 语句。 24 | 25 | 像 `db.Create()` 可能存在执行多条语句,比如: 26 | 27 | ```sql 28 | CREATE TABLE users ( 29 | id INT NOT NULL, 30 | name VARCHAR(20) NOT NULL, 31 | ); 32 | CREATE INDEX i_user_index ON users (name); 33 | ``` 34 | 35 | 两条 create 才组成一个完整的创建表的操作。 36 | 37 | 如果不支持 TransactionalDDL 的,那么这些语句会分开执行,中断出错了,也没法回滚; 38 | 而支持 TransactionalDDL 的,这些步骤只出错,都会被撤消。 39 | 40 | 所在以 TransactionalDDL 值不同的数据库中,执行某些操作,其行为可能会有稍微的差别。 41 | 42 | ### create 43 | 44 | 用于创建数据表。 45 | 46 | ```go 47 | err := db.Create(&User{}) 48 | ``` 49 | 50 | 会创建一张 User 表,表名由 User.TableName() 方法指定。 51 | 52 | 创建表属于 DDL,如果数据不支持事务中执行 DDL,那么即使在事务中, 53 | 创建表也依然是逐条提交的。 54 | 55 | ### insert 56 | 57 | ```go 58 | result, err := db.Insert(&User{ 59 | Name: "name", 60 | }) 61 | ``` 62 | 63 | 插入 User{} 对象到数据库,不需要指定自增列的值,会自动生成。 64 | 其 name 字段的值为 name,其它字段都采用默认值。 65 | 66 | #### lastInsertID 67 | 68 | ```go 69 | // id 为当前插入数据的自增 ID 70 | id, err := db.LastInsertID(&User{ 71 | Name: "name", 72 | }) 73 | ``` 74 | 75 | 如果需要获得 Last Insert ID 的值,建议采用 `LastInsertID` 方法获取, 76 | 而不是通过 `sql.Result.LastInsertID`,部分数据库(比如 postgres) 77 | 无法通过 `sql.Result.LastInertID` 获取 ID,但是 `db.LastInsertID` 78 | 会处理这种情况。 79 | 80 | 必须要有自增列,否则会出错! 81 | 82 | ### update 83 | 84 | ```go 85 | result, err := db.Update(&User{ 86 | ID: 1, 87 | Name: "test", 88 | Age: 0, // 零值,不会更新到数据库 89 | }) 90 | 91 | result, err := db.Update(&User{ 92 | ID: 1, 93 | Name: "test", 94 | Age: 0, 95 | }, "age") // 指定了 age 必须更新,即使是零值 96 | ``` 97 | 98 | update 会根据当前传递对象的非零值字段中查找 AI、PK 和唯一约束, 99 | 只要找到了就符根据这些约束作为查询条件,其它值作为更新内容进行更新。 100 | 101 | 默认情况下,零值不会被更新,这在大对象中,会节省不少操作。 102 | 当然如果需要更新零值到数据库,则需要在 `Update()` 的第二个参中指定列名。 103 | 104 | 如果需要更新 AI、PK 和唯一约束本身的内容,可以通过 sqlbuilder 105 | 进行一些高级的操作。 106 | 107 | ### delete 108 | 109 | delete 和 update 一样,通过唯一查询条件确定需要删除的列,并执行删除操作。 110 | 111 | ```go 112 | // 删除 ID 为 1 的行。 113 | result, err := db.Delete(&User{ID: 1}) 114 | 115 | // 删除 username 值为 example 的行 116 | result, err = db.Delete(&User{Username: "example"}) 117 | 118 | 119 | // 同时指这了 AI 和唯一约束,则优先 AI 作查询。 120 | result, err = db.Delete(&User{ 121 | ID: 1, 122 | Username: "example", 123 | }) 124 | 125 | // 返回错误,查询条件必须要有表达唯一性。 126 | result, err = db.Delete(&User{ 127 | Age: 18, 128 | }) 129 | ``` 130 | 131 | ### truncate 132 | 133 | truncate 会清空表内容,同时将该的自增计数重置为从 1 开始。 134 | 135 | ```go 136 | err :=db.Truncate(&User{}) 137 | ``` 138 | 139 | ### select 140 | 141 | ```go 142 | // 查找 ID 为 1 的 User 数据。会将 u 的其它字段填上。 143 | u := &User{ID: 1} 144 | err := db.Select(u) 145 | ``` 146 | 147 | ### count 148 | 149 | count 用于统计符合指定条件的所有数据。所有非零值都参与计算, 150 | 以 `AND` 作为各个查询条件的连接。 151 | 152 | ```go 153 | // 相当于 SELECT count(*) FROM users WHERE name='name' AND age=18 154 | count, err := db.Count(&User{ 155 | Name: "name", 156 | Age: 18, 157 | }) 158 | ``` 159 | -------------------------------------------------------------------------------- /docs/dialect.md: -------------------------------------------------------------------------------- 1 | 2 | 目前 orm 包本身定义了 Postgres、Sqlite3 和 Mysql 三个类型数据库的支持。 3 | 如果用户需要其它类型的数据库操作,可以自己实现 `core.Dialect` 接口。 4 | 5 | Dialect 需要实现两个部分的内容:其中 core.Dialect 是必须要实现的接口, 6 | 另外,在 sqlbuilder 包中,还提供了一部分 **Hooker 的接定义, 7 | 如果你当前的数据库实现与 sqlbuilder 中的默认实现不一样,可需要自行实现该接口。 8 | 9 | 比如 `sqlbuilder.InsertDefaultValueHooker` 接口,默认实现,采用了比较常用的方法: 10 | 11 | ```sql 12 | INSERT INTO table DEFAULT VALUES; 13 | ``` 14 | 15 | 但是 mysql 没有对应的实现,需要自定义该口,而 postgres 和 sqlite3 不需要。 16 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/issue9/orm.svg?branch=master)](https://travis-ci.org/issue9/orm) 2 | [![Go version](https://img.shields.io/badge/Go-1.10-brightgreen.svg?style=flat)](https://golang.org) 3 | [![Go Report Card](https://goreportcard.com/badge/github.com/issue9/orm)](https://goreportcard.com/report/github.com/issue9/orm) 4 | [![codecov](https://codecov.io/gh/issue9/orm/branch/master/graph/badge.svg)](https://codecov.io/gh/issue9/orm) 5 | [![GoDoc](https://godoc.org/github.com/issue9/orm?status.svg)](https://godoc.org/github.com/issue9/orm) 6 | [![license](https://img.shields.io/badge/license-MIT-brightgreen.svg?style=flat)](https://opensource.org/licenses/MIT) 7 | 8 | 9 | ## [快速开始](quick-start.md) 10 | 11 | ## [模型](model.md) 12 | 13 | ## [CURD](curd.md) 14 | 15 | ## [SQL Builder](sqlbuilder.md) 16 | 17 | ## [升级](upgrade.md) 18 | 19 | ## [开发](dialect.md) 20 | 21 | ## [高级应用](advance.md) 22 | 23 | -------------------------------------------------------------------------------- /docs/model.md: -------------------------------------------------------------------------------- 1 | 每一个数据模型都可以定义为 Go 结构体。当通过 DB 实例第一次接触到该对象时 2 | (比如 `Insert`、`Create` 等),会生成模型数据。 3 | 4 | ```go 5 | type User struct { 6 | ID int64 `orm:"name(id);ai"` 7 | Name string `orm:"name(name);len(20);index(user_index_name)"` 8 | Username string `orm:"name(username);len(20);unique(user_unique_username)"` 9 | Nickname sql.NullString `orm:"name(nickname);len(20);nullable"` 10 | Last *Last `orm:"name(last);len(-1);default(192.168.1.1,2019-07-29T17:11:01)"` 11 | } 12 | 13 | func(u *User) TableName() string { return "users" } 14 | 15 | func(u *User) ApplyModel(m*core.Model) error { 16 | m.Options["mysql_charset"] = []string{"utf8"} 17 | return m.NewCheck("id_great_zero", "id>0") 18 | } 19 | ``` 20 | 21 | 结构体中的字段与数据表中列的关联通过名为 orm 的 struct tag 进行设置。 22 | struct tag 中的格式为 `key(val);key(v1,v2)`,其中 key 属性名,val 等为该属性对应的值列表。 23 | 24 | 目前支持在 struct tag 支持以下属性: 25 | 26 | ### 属性 27 | 28 | #### name(fieldName) 29 | 30 | 指定当前字段在数据表中的名称,如果未指定,则和字段名相同。 31 | 只有可导出的字段才有效果。 32 | 33 | #### len(l1, l2) 34 | 35 | 指定字段的长度。比如 mysql 中的int(5),varchar(255),double(1,2), 36 | 不支持该特性的数据,将会忽略该标签的内容,比如 sqlite3。 37 | 38 | NOTE:字符串类型必须指定长度,若长度过大或是将长度设置了 -1, 39 | 会使用类似于 TEXT 等不定长的形式表达。 40 | 41 | 如果是日期类型,则第一个可选参数表示日期精度。 42 | 43 | #### nullable(true|false) 44 | 45 | 相当于定义表结构时的 NULL,建议尽量少用该属性。 46 | 47 | #### pk 48 | 49 | 主键,支持联合主键,给多个字段加上 pk 的 struct tag 即可。 50 | 51 | 主键约束不能自定义约束名。 52 | 53 | #### ai 54 | 55 | 自增,若指定了自增列,则将自动取消其它的 pk 设置。无法指定起始值和步长。 56 | 可手动设置一个非零值来更改某条数据的 AI 行为。 57 | 58 | #### unique(index_name) 59 | 60 | 唯一约束,支持联合索引,index_name 为约束名,会将 index_name 61 | 一样的字段定义为一个联合唯一约束。 62 | 63 | #### index(index_name) 64 | 65 | 普通的关键字索引,同 unique 一样会将名称相同的索引定义为一个联合索引。 66 | 67 | #### occ(true|false) 68 | 69 | 当前列作为乐观锁字段。 70 | 71 | 作为乐观锁的字段,其值表示的是线上数据的值,在更新时,会自动给线上的值加 1。 72 | 73 | #### default(value) 74 | 75 | 指定默认值。相当于定义表结构时的 DEFAULT。 76 | 77 | 内置类型的格式,Bool 为 true 和 false,time 为 time.RFC3339 78 | 79 | 自定义类型会尝试采用 sql.Scanner 作为解析方式。 80 | 81 | #### fk(fk_name,refTable,refColName,updateRule,deleteRule) 82 | 83 | 定义物理外键,最少需要指定 fk_name、refTable 和 refColName 三个值。分别对应约束名, 84 | 引用的表和引用的字段,updateRule,deleteRule,在不指定的情况下,使用数据库的默认值。 85 | 86 | ### 接口 87 | 88 | #### TableNamer 89 | 90 | 指定表名,视图和数据表都需要实现此接口。 91 | 92 | #### ApplyModeler 93 | 94 | 通过 ApplyModeler 接口可以指定一些表级别的属性值。 95 | 96 | #### Viewer 97 | 98 | 如果需要将模型定义为视图,则需要实现此接口, 99 | Viewer 接口返回一条 `SELECT` 语句,用于指定创建视图时的 `SELECT` 部分语句。 100 | 实现都需要保证接口中返回的列与模型中列的定义要对应。 101 | 102 | 在视图模式下,部分功能会不可用,比如 check 约束、索引等。 103 | 但是 AI、PK 和唯一索引,仍然在查询时,被用来当作唯一查询条件。 104 | 105 | #### BeforeUpdater/BeforeInserter/AfterFetcher 106 | 107 | 分别用于在更新和插入数据之前和从数据库获取数据之后被执行的方法。 108 | 一般用于特定内容的生成,比如: 109 | 110 | ```go 111 | type User struct { 112 | Created time.Time `orm:"name(created)"` // 创建时间 113 | Modified time.Time `orm:"name(modified)"` // 修改时间 114 | Avatar string `orm:"name(avatar);len(1024)"` 115 | } 116 | 117 | // 每次插入数据,都将 created 和 modified 设置为当前时间 118 | func(u *User) BeforeCreate() error { 119 | u.Created = time.Now() 120 | u.Modified = u.Created 121 | return nil 122 | } 123 | 124 | // 每次更新前,都修改 modified 的值为当前时间 125 | func(u *User) BeforeUpdate() error { 126 | u.Modified = time.Now() 127 | return nil 128 | } 129 | 130 | // 如果不存在头像信息,则给定一个默认图片地址 131 | func(u *User) AfterFetch() error { 132 | if u.Avatar == "" { 133 | u.Avatar = "/assets/default-avatar.png" 134 | } 135 | 136 | return nil 137 | } 138 | 139 | ``` 140 | -------------------------------------------------------------------------------- /docs/quick-start.md: -------------------------------------------------------------------------------- 1 | ### 快速开始 2 | 3 | ```go 4 | import ( 5 | "github.com/issue9/orm/v6" 6 | "github.com/issue9/orm/v6/dialect" 7 | 8 | _ "github.com/mattn/go-sqlite3" 9 | ) 10 | 11 | type User struct { 12 | ID int64 `orm:"name(id);ai" json:"id"` 13 | Name string `orm:"name(name);len(500)" json:"name"` 14 | Age int `orm:"name(age)" json:"age"` 15 | } 16 | 17 | func (u *User) TableName() string { return "users" } 18 | 19 | // 指定了表名,以及其它一些表属性 20 | func (u *User) ApplyModel(m*core.Model) error { 21 | m.Options["mysql_charset"] = []string{"utf8"} 22 | return nil 23 | } 24 | 25 | func main() { 26 | db, err := orm.NewDB("./test.db", dialect.Sqlite3("sqlite3", "test_")) 27 | if err !=nil { 28 | panic(err) 29 | } 30 | defer db.Close() 31 | 32 | // 创建表 33 | if err = db.Create(&User{});err != nil { 34 | panic(err) 35 | } 36 | 37 | // 插入一条数据,ID 自增为 1 38 | rslt, err := db.Insert(&User{ 39 | Name: "test", 40 | Age: 18, 41 | }) 42 | 43 | // 读取 ID 值为 1 的数据到 u 中 44 | u := &User{ID: 1} 45 | err = db.Select(u) 46 | 47 | // 更新,根据自增列 ID 查找需要更新列 48 | u = &User{ID: 1, Name: "name", Age: 100} 49 | rslt, err = db.Update(u) 50 | 51 | // 删除,根据自增 ID 查找唯一数据删除 52 | u = &User{ ID: 1} 53 | rslt, err = db.Delete(u) 54 | 55 | // 删除表 56 | db.Drop(&User{}) 57 | } 58 | ``` 59 | 60 | 61 | ### 安装 62 | 63 | 在项目的 go.mod 中引用项目即可,当前版本为 v6: 64 | 65 | ```go.mod 66 | require ( 67 | github.com/issue9/orm/v6 v6.x.x 68 | ) 69 | 70 | go 1.18 71 | ``` 72 | 73 | ### 数据库 74 | 75 | 目前支持以下数据库以及对应的驱动: 76 | 77 | 1. sqlite3: github.com/mattn/go-sqlite3 78 | 1. sqlite: modernc.org/sqlite 纯 go 写的驱动 79 | 1. mysql: github.com/go-sql-driver/mysql 80 | 1. mariadb: github.com/go-sql-driver/mysql 81 | 1. postgres: github.com/lib/pq 82 | 83 | 在初始化时,需要用到什么数据库,只需要引入该驱动即可。 84 | 如果用到了 check 约束,则需要 mysql > 8.0.16、mariadb > 10.2.1。 85 | 86 | ```go 87 | import ( 88 | "github.com/issue9/orm/v6" 89 | "github.com/issue9/orm/v6/dialect" 90 | 91 | _ "github.com/mattn/go-sql-driver/mysql" 92 | _ "github.com/mattn/mattn/go-sqlite3" 93 | ) 94 | ``` 95 | 96 | 之后就可以直接使用 `orm.NewDB` 初始化实例。 97 | 98 | 或者在已经初始化 `sql.DB` 的情况下,直接使用 `sql.DB` 实例初始化: 99 | 100 | ```go 101 | // 初始化 sqlite3 的实例 102 | sqlite, err := orm.NewDB("" ,"./orm.db", dialect.Sqlite3()) 103 | 104 | // 初始化 mysql 的实例 105 | my, err := orm.NewDB("", "root@/orm", dialect.Mysql()) 106 | ``` 107 | 108 | 后续代码中可以同时使用 my 和 sqlite 两个实例操纵不同的数据库数据。 109 | 110 | ### 调试 111 | 112 | 系统提供了 `DB.Debug` 用于调试,调用时,如果指定了非空参数, 113 | 那么在下次 SQL 调用时,会将 SQL 输出到 l 指定的日志上;如果指定了空值, 114 | 则表示关闭调试输出。 115 | 116 | ```go 117 | l := log.New(os.Stdout, "[SQL]", 0) 118 | db.Debug(l) 119 | db.Query("select * from users") // 输出到日志 120 | db.Debug(nil) 121 | db.Query("select * from groups") // 不会输出到日志 122 | ``` 123 | -------------------------------------------------------------------------------- /docs/upgrade.md: -------------------------------------------------------------------------------- 1 | 2 | upgrade 可以当作是数据库升级的小助手,在大部分情况下,都能胜任。 3 | 如果数据库本身是支持事务中执行 DDL 的,那么在执行失败时,会回滚。 4 | 5 | ```go 6 | type User { 7 | ID int64 `orm:"name(id)"` 8 | Name string `orm:"name(name);len(20)"` // 新加的列 9 | } 10 | err := db.Upgrade(&User{}). 11 | AddColumn("name"). // 将该列添加到数据库 12 | DropColumn("username"). // 删除数据库中的 username 列 13 | Do() // 执行以上操作 14 | ``` 15 | -------------------------------------------------------------------------------- /fetch/bench_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package fetch_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/fetch" 13 | "github.com/issue9/orm/v6/internal/test" 14 | ) 15 | 16 | func BenchmarkObject(b *testing.B) { 17 | a := assert.New(b, false) 18 | suite := test.NewSuite(a, "") 19 | 20 | suite.Run(func(t *test.Driver) { 21 | initDB(t) 22 | defer clearDB(t) 23 | 24 | sql := `SELECT id,Email FROM fetch_users WHERE id<2 ORDER BY id` 25 | objs := []*FetchUser{ 26 | {}, 27 | {}, 28 | } 29 | 30 | for i := 0; i < b.N; i++ { 31 | rows, err := t.DB.Query(sql) 32 | t.NotError(err) 33 | 34 | cnt, err := fetch.Object(true, rows, &objs) 35 | t.NotError(err).NotEmpty(cnt) 36 | t.NotError(rows.Close()) 37 | } 38 | }) 39 | } 40 | 41 | func BenchmarkMap(b *testing.B) { 42 | a := assert.New(b, false) 43 | suite := test.NewSuite(a, "") 44 | 45 | suite.Run(func(t *test.Driver) { 46 | initDB(t) 47 | defer clearDB(t) 48 | 49 | // 正常匹配数据,读取多行 50 | sql := `SELECT id,Email FROM fetch_users WHERE id<2 ORDER BY id` 51 | 52 | for i := 0; i < b.N; i++ { 53 | rows, err := t.DB.Query(sql) 54 | t.NotError(err) 55 | 56 | mapped, err := fetch.Map(false, rows) 57 | t.NotError(err).NotNil(mapped) 58 | t.NotError(rows.Close()) 59 | } 60 | }) 61 | } 62 | -------------------------------------------------------------------------------- /fetch/column.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package fetch 6 | 7 | import ( 8 | "database/sql" 9 | 10 | "github.com/issue9/orm/v6/core" 11 | ) 12 | 13 | // Column 导出 rows 中某列的所有或一行数据 14 | // 15 | // once 若为 true,则只导出第一条数据。 16 | // colName 指定需要导出的列名,若指定了不存在的名称,返回 error。 17 | // 18 | // NOTE: 要求 T 的类型必须符合 [sql.Row.Scan] 的参数要求; 19 | func Column[T any](once bool, colName string, rows *sql.Rows) ([]T, error) { 20 | // TODO: 应该约束 T 为 sql.Rows.Scan 允许的类型,但是以目前 Go 的语法无法做到。 21 | 22 | cols, err := rows.Columns() 23 | if err != nil { 24 | return nil, err 25 | } 26 | 27 | index := -1 // colName 列在 rows.Columns() 中的索引号 28 | buff := make([]any, len(cols)) 29 | for i, v := range cols { 30 | if colName == v { // 获取 index 的值 31 | index = i 32 | var zero T 33 | buff[i] = &zero 34 | } else { 35 | var value any 36 | buff[i] = &value 37 | } 38 | } 39 | 40 | if index == -1 { 41 | return nil, core.ErrColumnNotFound(colName) 42 | } 43 | 44 | var data []T 45 | for rows.Next() { 46 | if err := rows.Scan(buff...); err != nil { 47 | return nil, err 48 | } 49 | data = append(data, *buff[index].(*T)) 50 | if once { 51 | return data, nil 52 | } 53 | } 54 | 55 | return data, nil 56 | } 57 | -------------------------------------------------------------------------------- /fetch/column_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package fetch_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/fetch" 13 | "github.com/issue9/orm/v6/internal/test" 14 | ) 15 | 16 | func TestColumn(t *testing.T) { 17 | a := assert.New(t, false) 18 | suite := test.NewSuite(a, "") 19 | 20 | suite.Run(func(t *test.Driver) { 21 | initDB(t) 22 | defer clearDB(t) 23 | db := t.DB 24 | 25 | // 正常数据匹配,读取多行 26 | sql := `SELECT id,email FROM fetch_users WHERE id<3 ORDER BY id ASC` 27 | rows, err := db.Query(sql) 28 | t.NotError(err).NotNil(rows) 29 | 30 | cols, err := fetch.Column[int64](false, "id", rows) 31 | t.NotError(err).NotNil(cols) 32 | 33 | t.Equal(cols, []int64{int64(1), int64(2)}) 34 | t.NotError(rows.Close()) 35 | 36 | // 正常数据匹配,读取一行 37 | rows, err = db.Query(sql) 38 | t.NotError(err).NotNil(rows) 39 | 40 | cols, err = fetch.Column[int64](true, "id", rows) 41 | t.NotError(err).NotNil(cols) 42 | 43 | t.Equal(cols, []int64{int64(1)}) 44 | t.NotError(rows.Close()) 45 | 46 | // 没有数据匹配,读取多行 47 | sql = `SELECT id,email FROM fetch_users WHERE id<0 ORDER BY id ASC` 48 | rows, err = db.Query(sql) 49 | t.NotError(err).NotNil(rows) 50 | 51 | cols, err = fetch.Column[int64](false, "id", rows) 52 | t.NotError(err) 53 | 54 | t.Empty(cols) 55 | t.NotError(rows.Close()) 56 | 57 | // 没有数据匹配,读取一行 58 | rows, err = db.Query(sql) 59 | t.NotError(err).NotNil(rows) 60 | 61 | cols, err = fetch.Column[int64](true, "id", rows) 62 | t.NotError(err) 63 | 64 | t.Empty(cols) 65 | t.NotError(rows.Close()) 66 | 67 | // 指定错误的列名 68 | rows, err = db.Query(sql) 69 | t.NotError(err).NotNil(rows) 70 | 71 | cols, err = fetch.Column[int64](true, "not-exists", rows) 72 | t.Error(err) 73 | 74 | t.Empty(cols) 75 | t.NotError(rows.Close()) 76 | }) 77 | } 78 | 79 | func TestColumnString(t *testing.T) { 80 | a := assert.New(t, false) 81 | suite := test.NewSuite(a, "") 82 | 83 | suite.Run(func(t *test.Driver) { 84 | initDB(t) 85 | defer clearDB(t) 86 | db := t.DB 87 | 88 | // 正常数据匹配,读取多行 89 | sql := `SELECT id,email FROM fetch_users WHERE id<3 ORDER BY id` 90 | rows, err := db.Query(sql) 91 | t.NotError(err).NotNil(rows) 92 | 93 | cols, err := fetch.Column[string](false, "id", rows) 94 | t.NotError(err).NotNil(cols) 95 | 96 | t.Equal([]string{"1", "2"}, cols) 97 | t.NotError(rows.Close()) 98 | 99 | // 正常数据匹配,读取一行 100 | rows, err = db.Query(sql) 101 | t.NotError(err).NotNil(rows) 102 | 103 | cols, err = fetch.Column[string](true, "id", rows) 104 | t.NotError(err).NotNil(cols) 105 | 106 | t.Equal([]string{"1"}, cols) 107 | t.NotError(rows.Close()) 108 | 109 | // 没有数据匹配,读取多行 110 | sql = `SELECT id FROM fetch_users WHERE id<0 ORDER BY id` 111 | rows, err = db.Query(sql) 112 | t.NotError(err).NotNil(rows) 113 | 114 | cols, err = fetch.Column[string](false, "id", rows) 115 | t.NotError(err) 116 | 117 | t.Empty(cols) 118 | t.NotError(rows.Close()) 119 | 120 | // 没有数据匹配,读取一行 121 | rows, err = db.Query(sql) 122 | t.NotError(err).NotNil(rows) 123 | 124 | cols, err = fetch.Column[string](true, "id", rows) 125 | t.NotError(err) 126 | 127 | t.Empty(cols) 128 | t.NotError(rows.Close()) 129 | 130 | // 指定错误的列名 131 | rows, err = db.Query(sql) 132 | t.NotError(err).NotNil(rows) 133 | 134 | cols, err = fetch.Column[string](true, "not-exists", rows) 135 | t.Error(err) 136 | 137 | t.Empty(cols) 138 | t.NotError(rows.Close()) 139 | }) 140 | } 141 | -------------------------------------------------------------------------------- /fetch/fetch.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | // Package fetch 提供了将 [sql.Rows] 导出为几种常用数据格式的方法 6 | package fetch 7 | 8 | // Tag 表示 struct tag 的名称 9 | const Tag = "orm" 10 | -------------------------------------------------------------------------------- /fetch/fetch_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package fetch 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | 11 | "github.com/issue9/assert/v4" 12 | ) 13 | 14 | type FetchEmail struct { 15 | Email string `orm:"unique(unique_index);nullable;pk;len(100)"` 16 | 17 | Regdate int64 `orm:"-"` 18 | } 19 | 20 | type FetchUser struct { 21 | FetchEmail 22 | ID int `orm:"name(id);ai;"` 23 | Username string `orm:"index(username_index);len(20)"` 24 | Group int `orm:"name(group);fk(fk_group,group,id)"` 25 | } 26 | 27 | type Log struct { 28 | ID int `orm:"name(id);ai"` 29 | Content string `orm:"name(content);len(1024)"` 30 | Created int `orm:"name(created)"` 31 | UID int `orm:"name(uid)"` 32 | User *FetchUser `orm:"name(user)"` 33 | } 34 | 35 | func TestParseObject(t *testing.T) { 36 | a := assert.New(t, false) 37 | obj := &Log{ID: 5} 38 | mapped := map[string]reflect.Value{} 39 | 40 | v := reflect.ValueOf(obj).Elem() 41 | a.True(v.IsValid()) 42 | 43 | a.NotError(parseObject(v, &mapped)) 44 | a.Equal(8, len(mapped), "长度不相等,导出元素为:[%v]", mapped) 45 | 46 | // 忽略的字段 47 | _, found := mapped["user.Regdate"] 48 | a.False(found) 49 | 50 | // 判断字段是否存在 51 | vi, found := mapped["id"] 52 | a.True(found).True(vi.IsValid()) 53 | 54 | // 设置字段的值 55 | mapped["user.id"].Set(reflect.ValueOf(36)) 56 | a.Equal(36, obj.User.ID) 57 | mapped["user.Email"].SetString("email") 58 | a.Equal("email", obj.User.Email) 59 | mapped["user.Username"].SetString("username") 60 | a.Equal("username", obj.User.Username) 61 | mapped["user.group"].SetInt(1) 62 | a.Equal(1, obj.User.Group) 63 | 64 | type m struct { 65 | *FetchEmail 66 | ID int 67 | } 68 | o := &m{ID: 5} 69 | mapped = map[string]reflect.Value{} 70 | v = reflect.ValueOf(o).Elem() 71 | a.NotError(parseObject(v, &mapped)) 72 | a.Equal(2, len(mapped), "长度不相等,导出元素为:[%v]", mapped) 73 | 74 | type mm struct { 75 | FetchEmail 76 | ID int 77 | } 78 | oo := &mm{ID: 5} 79 | mapped = map[string]reflect.Value{} 80 | v = reflect.ValueOf(oo).Elem() 81 | a.NotError(parseObject(v, &mapped)) 82 | a.Equal(2, len(mapped), "长度不相等,导出元素为:[%v]", mapped) 83 | } 84 | 85 | func TestGetColumns(t *testing.T) { 86 | a := assert.New(t, false) 87 | obj := &FetchUser{} 88 | 89 | cols, err := getColumns(reflect.ValueOf(obj), []string{"id"}) 90 | a.NotError(err).NotNil(cols) 91 | a.Equal(len(cols), 1) 92 | 93 | // 当列不存在数据模型时 94 | cols, err = getColumns(reflect.ValueOf(obj), []string{"id", "not-exists"}) 95 | a.NotError(err).NotNil(cols) 96 | a.Equal(len(cols), 2) 97 | } 98 | -------------------------------------------------------------------------------- /fetch/map.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package fetch 6 | 7 | import ( 8 | "database/sql" 9 | "reflect" 10 | ) 11 | 12 | // Map 将 rows 中的所有或一行数据导出到 map[string]any 中 13 | // 14 | // 若 once 值为 true,则只导出第一条数据。 15 | // 16 | // NOTE: 17 | // 每个数据库对数据的处理方式是不一样的,比如如下语句 18 | // 19 | // SELECT x FROM tbl1 20 | // 21 | // 使用 [Map] 导出到 []map[string]any 中时, 22 | // 在 mysql 中,x 如果是字符串,有可能被处理成一个 []byte (打印输出时,像一个数组,容易造成困惑), 23 | // 而在 sqlite3 就有可能是个 int。 24 | func Map(once bool, rows *sql.Rows) ([]map[string]any, error) { 25 | cols, err := rows.Columns() 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | // 临时缓存,用于保存从 rows 中读取出来的一行。 31 | buff := make([]any, len(cols)) 32 | for i := range cols { 33 | var value any 34 | buff[i] = &value 35 | } 36 | 37 | var data []map[string]any 38 | for rows.Next() { 39 | if err := rows.Scan(buff...); err != nil { 40 | return nil, err 41 | } 42 | 43 | line := make(map[string]any, len(cols)) 44 | for i, v := range cols { 45 | if buff[i] == nil { 46 | continue 47 | } 48 | value := reflect.Indirect(reflect.ValueOf(buff[i])) 49 | line[v] = value.Interface() 50 | } 51 | 52 | data = append(data, line) 53 | if once { 54 | return data, nil 55 | } 56 | } 57 | 58 | return data, nil 59 | } 60 | 61 | // MapString 将 rows 中的数据导出到一个 map[string]string 中 62 | // 63 | // 功能上与 [Map] 上一样,但 map 的键值固定为 string。 64 | func MapString(once bool, rows *sql.Rows) (data []map[string]string, err error) { 65 | cols, err := rows.Columns() 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | buf := make([]any, len(cols)) 71 | for k := range buf { 72 | var val string 73 | buf[k] = &val 74 | } 75 | 76 | for rows.Next() { 77 | if err = rows.Scan(buf...); err != nil { 78 | return nil, err 79 | } 80 | 81 | line := make(map[string]string, len(cols)) 82 | for i, v := range cols { 83 | if buf[i] == nil { 84 | continue 85 | } 86 | line[v] = *(buf[i].(*string)) 87 | } 88 | 89 | data = append(data, line) 90 | 91 | if once { 92 | return data, nil 93 | } 94 | } 95 | 96 | return data, nil 97 | } 98 | -------------------------------------------------------------------------------- /fetch/map_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package fetch_test 6 | 7 | import ( 8 | "reflect" 9 | "testing" 10 | 11 | "github.com/issue9/assert/v4" 12 | 13 | "github.com/issue9/orm/v6/fetch" 14 | "github.com/issue9/orm/v6/internal/test" 15 | ) 16 | 17 | func TestMap(t *testing.T) { 18 | a := assert.New(t, false) 19 | suite := test.NewSuite(a, "") 20 | 21 | eq := func(m1, m2 []map[string]any) bool { 22 | if len(m1) != len(m2) { 23 | return false 24 | } 25 | 26 | for i, s1 := range m1 { 27 | s2 := m2[i] 28 | if len(s2) != len(s1) { 29 | return false 30 | } 31 | 32 | for k, v := range s1 { 33 | if !reflect.DeepEqual(s2[k], v) { 34 | return false 35 | } 36 | } 37 | } 38 | return true 39 | } 40 | 41 | suite.Run(func(t *test.Driver) { 42 | initDB(t) 43 | defer clearDB(t) 44 | 45 | db := t.DB 46 | 47 | // 正常匹配数据,读取多行 48 | sql := `SELECT id,email FROM fetch_users WHERE id<3 ORDER BY id` 49 | rows, err := db.Query(sql) 50 | t.NotError(err).NotNil(rows) 51 | 52 | mapped, err := fetch.Map(false, rows) 53 | t.NotError(err).NotNil(mapped) 54 | 55 | ok := eq([]map[string]any{ 56 | {"id": int64(1), "email": "email-1"}, 57 | {"id": int64(2), "email": "email-2"}, 58 | }, mapped) || 59 | eq([]map[string]any{ 60 | {"id": int64(1), "email": []byte("email-1")}, 61 | {"id": int64(2), "email": []byte("email-2")}, 62 | }, mapped) 63 | t.True(ok, "%+v") 64 | t.NotError(rows.Close()) 65 | 66 | // 正常匹配数据,读取一行 67 | rows, err = db.Query(sql) 68 | t.NotError(err).NotNil(rows) 69 | 70 | mapped, err = fetch.Map(true, rows) 71 | t.NotError(err).NotNil(mapped) 72 | 73 | ok = eq([]map[string]any{ 74 | {"id": int64(1), "email": "email-1"}, 75 | }, mapped) || 76 | eq([]map[string]any{ 77 | {"id": int64(1), "email": []byte("email-1")}, 78 | }, mapped) 79 | t.True(ok) 80 | t.NotError(rows.Close()) 81 | 82 | // 没有匹配的数据,读取多行 83 | sql = `SELECT id,email FROM fetch_users WHERE id<0 ORDER BY id` 84 | rows, err = db.Query(sql) 85 | t.NotError(err).NotNil(rows) 86 | 87 | mapped, err = fetch.Map(false, rows) 88 | t.NotError(err) 89 | 90 | t.Equal([]map[string]any{}, mapped) 91 | t.NotError(rows.Close()) 92 | 93 | // 没有匹配的数据,读取一行 94 | rows, err = db.Query(sql) 95 | t.NotError(err).NotNil(rows) 96 | 97 | mapped, err = fetch.Map(true, rows) 98 | t.NotError(err) 99 | 100 | t.Equal([]map[string]any{}, mapped) 101 | t.NotError(rows.Close()) 102 | }) 103 | } 104 | 105 | func TestMapString(t *testing.T) { 106 | a := assert.New(t, false) 107 | suite := test.NewSuite(a, "") 108 | 109 | suite.Run(func(t *test.Driver) { 110 | initDB(t) 111 | defer clearDB(t) 112 | 113 | db := t.DB 114 | 115 | // 正常数据匹配,读取多行 116 | sql := `SELECT id,email FROM fetch_users WHERE id<3 ORDER BY id` 117 | rows, err := db.Query(sql) 118 | t.NotError(err).NotNil(rows) 119 | 120 | mapped, err := fetch.MapString(false, rows) 121 | t.NotError(err).NotNil(mapped) 122 | 123 | t.Equal(mapped, []map[string]string{ 124 | {"id": "1", "email": "email-1"}, 125 | {"id": "2", "email": "email-2"}, 126 | }) 127 | t.NotError(rows.Close()) 128 | 129 | // 正常数据匹配,读取一行 130 | rows, err = db.Query(sql) 131 | t.NotError(err).NotNil(rows) 132 | 133 | mapped, err = fetch.MapString(true, rows) 134 | t.NotError(err).NotNil(mapped) 135 | 136 | t.Equal(mapped, []map[string]string{ 137 | {"id": "1", "email": "email-1"}, 138 | }) 139 | t.NotError(rows.Close()) 140 | 141 | // 没有数据匹配,读取多行 142 | sql = `SELECT id,email FROM fetch_users WHERE id<0 ORDER BY id` 143 | rows, err = db.Query(sql) 144 | t.NotError(err).NotNil(rows) 145 | 146 | mapped, err = fetch.MapString(false, rows) 147 | t.NotError(err) 148 | 149 | t.Equal(mapped, []map[string]string{}) 150 | t.NotError(rows.Close()) 151 | 152 | // 没有数据匹配,读取一行 153 | rows, err = db.Query(sql) 154 | t.NotError(err).NotNil(rows) 155 | 156 | mapped, err = fetch.MapString(true, rows) 157 | t.NotError(err) 158 | 159 | t.Equal(mapped, []map[string]string{}) 160 | t.NotError(rows.Close()) 161 | }) 162 | } 163 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/issue9/orm/v6 2 | 3 | require ( 4 | github.com/go-sql-driver/mysql v1.9.2 5 | github.com/issue9/assert/v4 v4.3.1 6 | github.com/issue9/conv v1.3.5 7 | github.com/issue9/errwrap v0.3.2 8 | github.com/issue9/sliceutil v0.17.0 9 | github.com/lib/pq v1.10.9 10 | github.com/mattn/go-sqlite3 v1.14.27 11 | github.com/shopspring/decimal v1.4.0 12 | modernc.org/sqlite v1.37.0 13 | ) 14 | 15 | require ( 16 | filippo.io/edwards25519 v1.1.0 // indirect 17 | github.com/dustin/go-humanize v1.0.1 // indirect 18 | github.com/google/uuid v1.6.0 // indirect 19 | github.com/mattn/go-isatty v0.0.20 // indirect 20 | github.com/ncruces/go-strftime v0.1.9 // indirect 21 | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect 22 | golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect 23 | golang.org/x/sys v0.31.0 // indirect 24 | modernc.org/libc v1.62.1 // indirect 25 | modernc.org/mathutil v1.7.1 // indirect 26 | modernc.org/memory v1.9.1 // indirect 27 | ) 28 | 29 | go 1.23.0 30 | -------------------------------------------------------------------------------- /internal/createtable/createtable.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | // Package createtable 分析 create table 语句的内容 6 | package createtable 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "strings" 12 | "unicode" 13 | 14 | "github.com/issue9/orm/v6/core" 15 | ) 16 | 17 | var backQuoteReplacer = strings.NewReplacer("`", "") 18 | 19 | func lines(sql string) []string { 20 | sql = backQuoteReplacer.Replace(sql) 21 | var deep, start int 22 | var lines []string 23 | 24 | LOOP: 25 | for index, c := range sql { 26 | switch c { 27 | case ',': 28 | if deep == 1 && index > start { 29 | lines = append(lines, strings.TrimSpace(sql[start:index])) 30 | start = index + 1 // 不包含 ( 本身 31 | } 32 | case '(': 33 | deep++ 34 | if deep == 1 { 35 | start = index + 1 // 不包含 ( 本身 36 | } 37 | case ')': 38 | deep-- 39 | if deep == 0 { // 不需要 create table xx() 之后的内容 40 | if start != index { 41 | lines = append(lines, strings.TrimSpace(sql[start:index])) 42 | } 43 | break LOOP 44 | } 45 | } // end switch 46 | } // end for 47 | 48 | return lines 49 | } 50 | 51 | func fields(line string) []string { 52 | return strings.FieldsFunc(line, func(r rune) bool { 53 | return unicode.IsSpace(r) || r == '(' || r == ')' 54 | }) 55 | } 56 | 57 | // 获取 create table 的内容 58 | // 59 | // query 查询 create table 的语句; 60 | // val 从查询语句中获取的值。 61 | func scanCreateTable(engine core.Engine, table, query string, val ...any) error { 62 | rows, err := engine.Query(query) 63 | if err != nil { 64 | return err 65 | } 66 | 67 | defer func() { err = errors.Join(err, rows.Close()) }() 68 | 69 | if !rows.Next() { 70 | err = fmt.Errorf("未找到任何与 %s 相关的 CREATE TABLE 数据", table) 71 | return err 72 | } 73 | 74 | err = rows.Scan(val...) 75 | return err 76 | } 77 | -------------------------------------------------------------------------------- /internal/createtable/createtable_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package createtable 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | ) 12 | 13 | func TestLines(t *testing.T) { 14 | a := assert.New(t, false) 15 | query := `create table tb1( 16 | id int not null primary key, 17 | name string not null, 18 | constraint fk foreign key (name) references tab2(col1) 19 | );charset=utf-8` 20 | a.Equal(lines(query), []string{ 21 | "id int not null primary key", 22 | "name string not null", 23 | "constraint fk foreign key (name) references tab2(col1)", 24 | }) 25 | 26 | query = "create table `tb1`(`id` int,`name` string,unique `fk`(`id`,`name`))" 27 | a.Equal(lines(query), []string{ 28 | "id int", 29 | "name string", 30 | "unique fk(id,name)", 31 | }) 32 | } 33 | -------------------------------------------------------------------------------- /internal/createtable/sqlite3.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package createtable 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "strings" 11 | 12 | "github.com/issue9/orm/v6/core" 13 | ) 14 | 15 | // Sqlite3Table 包含从 sqlite_master 中获取的与当前表相关的信息 16 | type Sqlite3Table struct { 17 | Columns map[string]string // 列信息,名称 => SQL 语句 18 | Constraints map[string]*Sqlite3Constraint 19 | Indexes map[string]*Sqlite3Index 20 | } 21 | 22 | // Sqlite3Index 表的索引信息 23 | // 24 | // 在 sqlite 中,索引是在创建表之后,另外提交的。 25 | // 在修改表结构时,需要保存索引,方便之后重建。 26 | type Sqlite3Index struct { 27 | Type core.IndexType 28 | SQL string // 创建索引的语句 29 | } 30 | 31 | // Sqlite3Constraint 从 create table 语句解析出来的约束信息 32 | type Sqlite3Constraint struct { 33 | Type core.ConstraintType 34 | SQL string // 在 Create Sqlite3Table 中的语句 35 | } 36 | 37 | // CreateTableSQL 生成 create table 语句 38 | func (t Sqlite3Table) CreateTableSQL(name string) (string, error) { 39 | builder := core.NewBuilder("CREATE TABLE "). 40 | WString(name). 41 | WBytes('(') 42 | 43 | for _, col := range t.Columns { 44 | builder.WString(col).WBytes(',') 45 | } 46 | 47 | for _, cont := range t.Constraints { 48 | builder.WString(cont.SQL).WBytes(',') 49 | } 50 | 51 | builder.TruncateLast(1).WBytes(')') 52 | 53 | return builder.String() 54 | } 55 | 56 | // ParseSqlite3CreateTable 从 sqlite_master 中获取 create table 并分析其内容 57 | func ParseSqlite3CreateTable(table string, engine core.Engine) (*Sqlite3Table, error) { 58 | tbl := &Sqlite3Table{ 59 | Columns: make(map[string]string, 10), 60 | Constraints: make(map[string]*Sqlite3Constraint, 5), 61 | Indexes: make(map[string]*Sqlite3Index, 2), 62 | } 63 | 64 | if err := parseSqlite3CreateTable(tbl, table, engine); err != nil { 65 | return nil, err 66 | } 67 | 68 | if err := parseSqlite3Indexes(tbl, table, engine); err != nil { 69 | return nil, err 70 | } 71 | 72 | return tbl, nil 73 | } 74 | 75 | // https://www.sqlite.org/draft/lang_createtable.html 76 | func parseSqlite3CreateTable(table *Sqlite3Table, tableName string, engine core.Engine) error { 77 | query := "SELECT sql FROM sqlite_master WHERE `type`='table' and tbl_name='" + tableName + "'" 78 | var sql string 79 | if err := scanCreateTable(engine, tableName, query, &sql); err != nil { 80 | return err 81 | } 82 | 83 | lines := lines(sql) 84 | for _, line := range lines { 85 | index := strings.IndexByte(line, ' ') 86 | if index <= 0 { 87 | return fmt.Errorf("语法错误:%s", line) 88 | } 89 | first := line[:index] 90 | 91 | switch strings.ToUpper(first) { 92 | case "CONSTRAINT": // 约束 93 | words := fields(line[index+1:]) 94 | if len(words) < 2 { 95 | return fmt.Errorf("语法错误:%s", line) 96 | } 97 | 98 | cont := &Sqlite3Constraint{SQL: line} 99 | switch words[1] { 100 | case "PRIMARY": 101 | cont.Type = core.ConstraintPK 102 | case "UNIQUE": 103 | cont.Type = core.ConstraintUnique 104 | case "CHECK": 105 | cont.Type = core.ConstraintCheck 106 | case "FOREIGN": 107 | cont.Type = core.ConstraintFK 108 | default: 109 | return fmt.Errorf("未知的约束名:%s", line) 110 | } 111 | 112 | table.Constraints[words[0]] = cont 113 | default: // 普通列定义,第一个字符串即为列名 114 | table.Columns[first] = line 115 | } 116 | } 117 | 118 | return nil 119 | } 120 | 121 | func parseSqlite3Indexes(table *Sqlite3Table, tableName string, engine core.Engine) error { 122 | // 通过 sql IS NOT NULL 过滤掉自动生成的索引值 123 | query := "SELECT name,sql FROM sqlite_master WHERE `type`='index' AND sql IS NOT NULL AND tbl_name='" + tableName + "'" 124 | rows, err := engine.Query(query) 125 | if err != nil { 126 | return err 127 | } 128 | 129 | defer func() { err = errors.Join(err, rows.Close()) }() 130 | 131 | for rows.Next() { 132 | var name, sql string 133 | if err = rows.Scan(&name, &sql); err != nil { 134 | return err 135 | } 136 | table.Indexes[name] = &Sqlite3Index{ 137 | SQL: sql, 138 | Type: core.IndexDefault, 139 | } 140 | } 141 | 142 | return nil 143 | } 144 | -------------------------------------------------------------------------------- /internal/createtable/sqlite3_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package createtable_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/core" 13 | "github.com/issue9/orm/v6/internal/createtable" 14 | "github.com/issue9/orm/v6/internal/sqltest" 15 | "github.com/issue9/orm/v6/internal/test" 16 | ) 17 | 18 | func TestMain(m *testing.M) { 19 | test.Main(m) 20 | } 21 | 22 | var sqlite3CreateTable = []string{`CREATE TABLE fk_table( 23 | id integer NOT NULL, 24 | PRIMARY KEY(id) 25 | )`, 26 | `CREATE TABLE usr ( 27 | id integer NOT NULL, 28 | created integer NOT NULL, 29 | nickname text NOT NULL, 30 | state integer NOT NULL, 31 | username text NOT NULL, 32 | mobile text NOT NULL, 33 | email text NOT NULL, 34 | pwd text NOT NULL, 35 | CONSTRAINT users_pk PRIMARY KEY (id), 36 | CONSTRAINT u_user_xx1 UNIQUE (mobile,username), 37 | CONSTRAINT u_user_email1 UNIQUE (email,username), 38 | CONSTRAINT unique_id UNIQUE (id), 39 | CONSTRAINT xxx_fk FOREIGN KEY (id) REFERENCES fk_table (id), 40 | CONSTRAINT xxx CHECK(created > 0) 41 | )`, 42 | `create index index_user_mobile on usr(mobile)`, 43 | `create unique index index_user_unique_email_id on usr(email,id)`, 44 | } 45 | 46 | func TestTable_CreateTableSQL(t *testing.T) { 47 | a := assert.New(t, false) 48 | 49 | tbl := &createtable.Sqlite3Table{ 50 | Columns: map[string]string{ 51 | "id": "id integer not null", 52 | }, 53 | Constraints: map[string]*createtable.Sqlite3Constraint{ 54 | "users_pk": { 55 | Type: core.ConstraintPK, 56 | SQL: "constraint users_pk primary key(id)", 57 | }, 58 | }, 59 | } 60 | 61 | query, err := tbl.CreateTableSQL("test") 62 | a.NotError(err) 63 | sqltest.Equal(a, query, `create table test( id integer not null,constraint users_pk primary key(id))`) 64 | } 65 | 66 | func TestParseSqlite3CreateTable(t *testing.T) { 67 | a := assert.New(t, false) 68 | 69 | suite := test.NewSuite(a, "", test.Sqlite3) 70 | 71 | suite.Run(func(t *test.Driver) { 72 | db := t.DB 73 | 74 | for _, query := range sqlite3CreateTable { 75 | _, err := db.Exec(query) 76 | t.NotError(err) 77 | } 78 | 79 | defer func() { 80 | _, err := db.Exec("DROP TABLE `usr`") 81 | t.NotError(err) 82 | 83 | _, err = db.Exec("DROP TABLE `fk_table`") 84 | t.NotError(err) 85 | }() 86 | 87 | table, err := createtable.ParseSqlite3CreateTable("usr", db) 88 | t.NotError(err).NotNil(table) 89 | 90 | t.Equal(len(table.Columns), 8) 91 | sqltest.Equal(a, table.Columns["id"], "id integer NOT NULL") 92 | sqltest.Equal(a, table.Columns["created"], "created integer NOT NULL") 93 | sqltest.Equal(a, table.Columns["nickname"], "nickname text NOT NULL") 94 | sqltest.Equal(a, table.Columns["state"], "state integer NOT NULL") 95 | sqltest.Equal(a, table.Columns["username"], "username text NOT NULL") 96 | sqltest.Equal(a, table.Columns["mobile"], "mobile text NOT NULL") 97 | sqltest.Equal(a, table.Columns["email"], "email text NOT NULL") 98 | sqltest.Equal(a, table.Columns["pwd"], "pwd text NOT NULL") 99 | t.Equal(len(table.Constraints), 6). 100 | Equal(table.Constraints["u_user_xx1"], &createtable.Sqlite3Constraint{ 101 | Type: core.ConstraintUnique, 102 | SQL: "CONSTRAINT u_user_xx1 UNIQUE (mobile,username)", 103 | }). 104 | Equal(table.Constraints["u_user_email1"], &createtable.Sqlite3Constraint{ 105 | Type: core.ConstraintUnique, 106 | SQL: "CONSTRAINT u_user_email1 UNIQUE (email,username)", 107 | }). 108 | Equal(table.Constraints["unique_id"], &createtable.Sqlite3Constraint{ 109 | Type: core.ConstraintUnique, 110 | SQL: "CONSTRAINT unique_id UNIQUE (id)", 111 | }). 112 | Equal(table.Constraints["xxx_fk"], &createtable.Sqlite3Constraint{ 113 | Type: core.ConstraintFK, 114 | SQL: "CONSTRAINT xxx_fk FOREIGN KEY (id) REFERENCES fk_table (id)", 115 | }). 116 | Equal(table.Constraints["xxx"], &createtable.Sqlite3Constraint{ 117 | Type: core.ConstraintCheck, 118 | SQL: "CONSTRAINT xxx CHECK(created > 0)", 119 | }). 120 | Equal(table.Constraints["users_pk"], &createtable.Sqlite3Constraint{ 121 | Type: core.ConstraintPK, 122 | SQL: "CONSTRAINT users_pk PRIMARY KEY (id)", 123 | }) // 主键约束名为固定值 124 | t.Equal(len(table.Indexes), 2). 125 | Equal(table.Indexes["index_user_mobile"], &createtable.Sqlite3Index{ 126 | Type: core.IndexDefault, 127 | SQL: "CREATE INDEX index_user_mobile on usr(mobile)", 128 | }). 129 | Equal(table.Indexes["index_user_unique_email_id"], &createtable.Sqlite3Index{ 130 | Type: core.IndexDefault, 131 | SQL: "CREATE UNIQUE INDEX index_user_unique_email_id on usr(email,id)", 132 | }) // sqlite 没有 unique 133 | }) 134 | } 135 | -------------------------------------------------------------------------------- /internal/model/bench_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package model_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | ) 12 | 13 | func BenchmarkNewModelNoCached(b *testing.B) { 14 | a := assert.New(b, false) 15 | ms := newModules(a) 16 | 17 | for i := 0; i < b.N; i++ { 18 | m, err := ms.New(&User{}) 19 | a.NotError(err).NotNil(m) 20 | ms.Close() 21 | } 22 | } 23 | 24 | func BenchmarkNewModelCached(b *testing.B) { 25 | a := assert.New(b, false) 26 | ms := newModules(a) 27 | 28 | for i := 0; i < b.N; i++ { 29 | m, err := ms.New(&User{}) 30 | a.NotError(err).NotNil(m) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /internal/model/column.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package model 6 | 7 | import ( 8 | "database/sql" 9 | "reflect" 10 | "strconv" 11 | "time" 12 | 13 | "github.com/issue9/conv" 14 | 15 | "github.com/issue9/orm/v6/core" 16 | "github.com/issue9/orm/v6/internal/tags" 17 | ) 18 | 19 | type Column struct { 20 | *core.Column 21 | GoType reflect.Type 22 | } 23 | 24 | func NewColumn(field reflect.StructField) (*Column, error) { 25 | t := field.Type 26 | for t.Kind() == reflect.Ptr { 27 | t = t.Elem() 28 | } 29 | 30 | col, err := core.NewColumn(core.GetPrimitiveType(t)) 31 | if err != nil { 32 | return nil, err 33 | } 34 | 35 | col.Name = field.Name 36 | col.GoName = field.Name 37 | return &Column{ 38 | Column: col, 39 | GoType: t, 40 | }, nil 41 | } 42 | 43 | func (col *Column) ParseTags(m *core.Model, tag string) (err error) { 44 | if err = m.AddColumn(col.Column); err != nil { 45 | return err 46 | } 47 | 48 | ts := tags.Parse(tag) 49 | for _, tag := range ts { 50 | switch tag.Name { 51 | case "name": // name(col) 52 | if len(tag.Args) != 1 { 53 | return propertyError(col.Name, "name", "过多的参数值") 54 | } 55 | col.Name = tag.Args[0] 56 | case "index": 57 | err = setIndex(m, col, tag.Args) 58 | case "pk": 59 | err = SetPK(m, col, tag.Args) 60 | case "unique": 61 | err = setUnique(m, col, tag.Args) 62 | case "nullable": 63 | err = col.SetNullable(tag.Args) 64 | case "ai": 65 | err = col.SetAI(m, tag.Args) 66 | case "len": 67 | err = col.SetLen(tag.Args) 68 | case "fk": 69 | err = setFK(m, col, tag.Args) 70 | case "default": 71 | err = col.SetDefault(tag.Args) 72 | case "occ": 73 | err = SetOCC(m, col, tag.Args) 74 | default: 75 | err = propertyError(col.Name, tag.Name, "未知的属性") 76 | } 77 | 78 | if err != nil { 79 | return err 80 | } 81 | } 82 | 83 | return nil 84 | } 85 | 86 | // 从参数中获取 Column 的 len1 和 len2 变量 87 | // 88 | // len(len1,len2) 89 | func (col *Column) SetLen(vals []string) (err error) { 90 | l := len(vals) 91 | switch l { 92 | case 1: 93 | case 2: 94 | case 0: 95 | return nil 96 | default: 97 | return propertyError(col.Name, "len", "过多的参数") 98 | } 99 | 100 | col.Length = make([]int, 0, l) 101 | for _, val := range vals { 102 | v, err := strconv.Atoi(val) 103 | if err != nil { 104 | return err 105 | } 106 | col.Length = append(col.Length, v) 107 | } 108 | 109 | return nil 110 | } 111 | 112 | // nullable 113 | func (col *Column) SetNullable(vals []string) (err error) { 114 | if len(vals) > 0 { 115 | return propertyError(col.Name, "nullable", "指定了太多的值") 116 | } 117 | 118 | if col.AI { 119 | return propertyError(col.Name, "nullable", "自增列不能设置此值") 120 | } 121 | 122 | col.Nullable = true 123 | return nil 124 | } 125 | 126 | // default(5) 127 | func (col *Column) SetDefault(vals []string) error { 128 | if len(vals) != 1 { 129 | return propertyError(col.Name, "default", "太多的值") 130 | } 131 | col.HasDefault = true 132 | 133 | rval := reflect.New(col.GoType) 134 | 135 | v := rval.Interface() 136 | if p, ok := v.(sql.Scanner); ok { 137 | if err := p.Scan(vals[0]); err != nil { 138 | return err 139 | } 140 | col.Default = v 141 | return nil 142 | } 143 | v = rval.Elem().Interface() 144 | if p, ok := v.(sql.Scanner); ok { 145 | if err := p.Scan(vals[0]); err != nil { 146 | return err 147 | } 148 | col.Default = v 149 | return nil 150 | } 151 | 152 | switch col.PrimitiveType { 153 | case core.Time: 154 | v, err := time.Parse(core.TimeFormatLayout, vals[0]) 155 | if err != nil { 156 | return err 157 | } 158 | col.Default = v 159 | default: 160 | for rval.Kind() == reflect.Ptr { 161 | rval = rval.Elem() 162 | } 163 | 164 | if err := conv.Value(vals[0], rval); err != nil { 165 | return err 166 | } 167 | col.Default = rval.Interface() 168 | } 169 | 170 | return nil 171 | } 172 | 173 | // ai 174 | func (col *Column) SetAI(m *core.Model, vals []string) error { 175 | if len(vals) != 0 { 176 | return propertyError(col.Name, "ai", "太多的值") 177 | } 178 | 179 | return m.SetAutoIncrement(col.Column) 180 | } 181 | -------------------------------------------------------------------------------- /internal/model/engine.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package model 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | "strings" 11 | 12 | "github.com/issue9/orm/v6/core" 13 | ) 14 | 15 | type coreEngine struct { 16 | ms *Models 17 | engine stdEngine 18 | replacer *strings.Replacer 19 | sqlLogger func(string) 20 | } 21 | 22 | // [sql.DB] 与 [sql.Tx] 的最小接口 23 | type stdEngine interface { 24 | ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) 25 | PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) 26 | QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) 27 | QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row 28 | } 29 | 30 | func defaultSQLLogger(string) {} 31 | 32 | // NewEngine 声明实现 [core.Engine] 接口的实例 33 | func (ms *Models) NewEngine(e stdEngine, tablePrefix string) core.Engine { 34 | l, r := ms.dialect.Quotes() 35 | 36 | return &coreEngine{ 37 | ms: ms, 38 | engine: e, 39 | sqlLogger: defaultSQLLogger, 40 | replacer: strings.NewReplacer( 41 | string(core.QuoteLeft), string(l), 42 | string(core.QuoteRight), string(r), 43 | "#", tablePrefix, 44 | ), 45 | } 46 | } 47 | 48 | // Debug 指定调输出调试内容通道 49 | // 50 | // 如果 l 不为 nil,则每次 SQL 调用都会输出 SQL 语句,预编译的语句,仅在预编译时输出; 51 | // 如果为 nil,则表示关闭调试。 52 | func (db *coreEngine) Debug(l func(string)) { 53 | if l == nil { 54 | l = defaultSQLLogger 55 | } 56 | db.sqlLogger = l 57 | } 58 | 59 | func (db *coreEngine) Dialect() core.Dialect { return db.ms.dialect } 60 | 61 | func (db *coreEngine) QueryRow(query string, args ...any) *sql.Row { 62 | return db.QueryRowContext(context.Background(), query, args...) 63 | } 64 | 65 | func (db *coreEngine) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { 66 | db.sqlLogger(query) 67 | query, args, err := db.Dialect().Fix(query, args) 68 | if err != nil { 69 | panic(err) 70 | } 71 | 72 | query = db.replacer.Replace(query) 73 | return db.engine.QueryRowContext(ctx, query, args...) 74 | } 75 | 76 | func (db *coreEngine) Query(query string, args ...any) (*sql.Rows, error) { 77 | return db.QueryContext(context.Background(), query, args...) 78 | } 79 | 80 | func (db *coreEngine) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { 81 | db.sqlLogger(query) 82 | query, args, err := db.Dialect().Fix(query, args) 83 | if err != nil { 84 | return nil, err 85 | } 86 | 87 | query = db.replacer.Replace(query) 88 | return db.engine.QueryContext(ctx, query, args...) 89 | } 90 | 91 | func (db *coreEngine) Exec(query string, args ...any) (sql.Result, error) { 92 | return db.ExecContext(context.Background(), query, args...) 93 | } 94 | 95 | func (db *coreEngine) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { 96 | db.sqlLogger(query) 97 | query, args, err := db.Dialect().Fix(query, args) 98 | if err != nil { 99 | return nil, err 100 | } 101 | 102 | query = db.replacer.Replace(query) 103 | return db.engine.ExecContext(ctx, query, args...) 104 | } 105 | 106 | func (db *coreEngine) Prepare(query string) (*core.Stmt, error) { 107 | return db.PrepareContext(context.Background(), query) 108 | } 109 | 110 | func (db *coreEngine) PrepareContext(ctx context.Context, query string) (*core.Stmt, error) { 111 | db.sqlLogger(query) 112 | query, orders, err := db.Dialect().Prepare(query) 113 | if err != nil { 114 | return nil, err 115 | } 116 | 117 | query = db.replacer.Replace(query) 118 | s, err := db.engine.PrepareContext(ctx, query) 119 | if err != nil { 120 | return nil, err 121 | } 122 | return core.NewStmt(s, orders), nil 123 | } 124 | -------------------------------------------------------------------------------- /internal/model/model.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | // Package model 管理数据模型 6 | package model 7 | 8 | import ( 9 | "fmt" 10 | "reflect" 11 | "strings" 12 | "unicode" 13 | 14 | "github.com/issue9/orm/v6/core" 15 | "github.com/issue9/orm/v6/fetch" 16 | ) 17 | 18 | func propertyError(field, name, message string) error { 19 | return fmt.Errorf("%s 的 %s 属性发生以下错误: %s", field, name, message) 20 | } 21 | 22 | // New 从一个 obj 声明 [core.Model] 实例 23 | // 24 | // obj 可以是一个结构体或是指针。 25 | func (ms *Models) New(obj core.TableNamer) (*core.Model, error) { 26 | rtype := reflect.TypeOf(obj) 27 | for rtype.Kind() == reflect.Ptr { 28 | rtype = rtype.Elem() 29 | } 30 | 31 | if rtype.Kind() != reflect.Struct { 32 | return nil, fetch.ErrUnsupportedKind() 33 | } 34 | 35 | if m, found := ms.models.Load(rtype); found { 36 | return m.(*core.Model), nil 37 | } 38 | 39 | m := core.NewModel(core.Table, "#"+obj.TableName(), rtype.NumField()) 40 | m.GoType = rtype 41 | 42 | if err := parseColumns(m, rtype); err != nil { 43 | return nil, err 44 | } 45 | 46 | if am, ok := obj.(core.ApplyModeler); ok { 47 | if err := am.ApplyModel(m); err != nil { 48 | return nil, err 49 | } 50 | } 51 | 52 | if view, ok := obj.(core.Viewer); ok { 53 | m.Type = core.View 54 | sql, err := view.ViewAs() 55 | if err != nil { 56 | return nil, err 57 | } 58 | m.ViewAs = sql 59 | } 60 | 61 | if err := m.Sanitize(); err != nil { 62 | return nil, err 63 | } 64 | 65 | // 在构建完 core.Model 时在其它地方写入了相同名称的 core.Model, 66 | // 相当于在函数的开始阶段判断是否存在同名的对象,返回已经存在的对象。 67 | if m, found := ms.models.Load(rtype); found { 68 | return m.(*core.Model), nil 69 | } 70 | ms.models.Store(rtype, m) 71 | return m, nil 72 | } 73 | 74 | // 将 rval 中的结构解析到 m 中,支持匿名字段。 75 | func parseColumns(m *core.Model, rtype reflect.Type) error { 76 | num := rtype.NumField() 77 | for i := 0; i < num; i++ { 78 | field := rtype.Field(i) 79 | 80 | if field.Anonymous { 81 | if err := parseColumns(m, field.Type); err != nil { 82 | return err 83 | } 84 | continue 85 | } 86 | 87 | if unicode.IsLower(rune(field.Name[0])) { // 忽略以小写字母开头的字段 88 | continue 89 | } 90 | 91 | tag := field.Tag.Get(fetch.Tag) 92 | if tag == "-" { 93 | continue 94 | } 95 | 96 | col, err := NewColumn(field) 97 | if err != nil { 98 | return err 99 | } 100 | 101 | // 这属于代码级别的错误,直接 panic 了。 102 | if err := col.ParseTags(m, tag); err != nil { 103 | panic(err) 104 | } 105 | } 106 | 107 | return nil 108 | } 109 | 110 | // occ 111 | func SetOCC(m *core.Model, c *Column, vals []string) error { 112 | if len(vals) > 0 { 113 | return propertyError(c.Name, "occ", "指定了太多的值") 114 | } 115 | return m.SetOCC(c.Column) 116 | } 117 | 118 | // index(idx_name) 119 | func setIndex(m *core.Model, col *Column, vals []string) error { 120 | if len(vals) != 1 { 121 | return propertyError(col.Name, "index", "太多的值") 122 | } 123 | return m.AddIndex(core.IndexDefault, strings.ToLower(vals[0]), col.Column) 124 | } 125 | 126 | // pk 127 | func SetPK(m *core.Model, col *Column, vals []string) error { 128 | if len(vals) != 0 { 129 | return propertyError(col.Name, "pk", "太多的值") 130 | } 131 | return m.AddPrimaryKey(col.Column) 132 | } 133 | 134 | // unique(unique_name) 135 | func setUnique(m *core.Model, col *Column, vals []string) error { 136 | if len(vals) != 1 { 137 | return propertyError(col.Name, "unique", "只能带一个参数") 138 | } 139 | return m.AddUnique(strings.ToLower(vals[0]), col.Column) 140 | } 141 | 142 | // fk(fk_name,refTable,refColName,updateRule,deleteRule) 143 | func setFK(m *core.Model, col *Column, vals []string) error { 144 | if len(vals) < 3 || len(vals) > 5 { 145 | return propertyError(col.Name, "fk", "参数数量不正确") 146 | } 147 | 148 | fk := &core.ForeignKey{ 149 | Name: strings.ToLower(vals[0]), 150 | Column: col.Column, 151 | RefTableName: "#" + vals[1], 152 | RefColName: vals[2], 153 | } 154 | 155 | if len(vals) > 3 { // 存在 updateRule 156 | fk.UpdateRule = vals[3] 157 | } 158 | if len(vals) > 4 { // 存在 deleteRule 159 | fk.DeleteRule = vals[4] 160 | } 161 | 162 | return m.NewForeignKey(fk) 163 | } 164 | -------------------------------------------------------------------------------- /internal/model/models.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package model 6 | 7 | import ( 8 | "database/sql" 9 | "sync" 10 | 11 | "github.com/issue9/orm/v6/core" 12 | ) 13 | 14 | // Models 数据模型管理 15 | type Models struct { 16 | db *sql.DB 17 | dialect core.Dialect 18 | models *sync.Map 19 | version string 20 | } 21 | 22 | // NewModels 声明 [Models] 变量 23 | // 24 | // 返回对象中除了 [Models] 之外,还包含了一个 [core.Engine] 对象, 25 | // 该对象的表名前缀由参数 tablePrefix 指定。 26 | func NewModels(db *sql.DB, d core.Dialect, tablePrefix string) (*Models, core.Engine, error) { 27 | ms := &Models{ 28 | db: db, 29 | dialect: d, 30 | models: &sync.Map{}, 31 | } 32 | 33 | e := ms.NewEngine(db, tablePrefix) 34 | if err := e.QueryRow(d.VersionSQL()).Scan(&ms.version); err != nil { 35 | return nil, nil, err 36 | } 37 | return ms, e, nil 38 | } 39 | 40 | // Close 清除所有的 [core.Model] 缓存 41 | func (ms *Models) Close() error { 42 | ms.models.Range(func(key, _ any) bool { 43 | ms.models.Delete(key) 44 | return true 45 | }) 46 | 47 | return ms.DB().Close() 48 | } 49 | 50 | func (ms *Models) DB() *sql.DB { return ms.db } 51 | 52 | func (ms *Models) Version() string { return ms.version } 53 | 54 | func (ms *Models) Length() (cnt int) { 55 | ms.models.Range(func(key, value any) bool { 56 | cnt++ 57 | return true 58 | }) 59 | return 60 | } 61 | -------------------------------------------------------------------------------- /internal/model/models_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package model_test 6 | 7 | import ( 8 | "database/sql" 9 | "os" 10 | "testing" 11 | 12 | "github.com/issue9/assert/v4" 13 | 14 | "github.com/issue9/orm/v6/dialect" 15 | "github.com/issue9/orm/v6/internal/model" 16 | "github.com/issue9/orm/v6/internal/model/testdata" 17 | "github.com/issue9/orm/v6/internal/test" 18 | ) 19 | 20 | func TestMain(m *testing.M) { test.Main(m) } 21 | 22 | func newModules(a *assert.Assertion) *model.Models { 23 | const testDB = "./test.db" 24 | 25 | db, err := sql.Open("sqlite3", testDB) 26 | a.NotError(err).NotNil(db) 27 | 28 | ms, e, err := model.NewModels(db, dialect.Sqlite3("sqlite"), "") 29 | a.NotError(err). 30 | NotNil(ms). 31 | NotNil(e). 32 | NotEmpty(ms.Version()). 33 | NotNil(ms.DB()) 34 | 35 | a.TB().Cleanup(func() { 36 | a.NotError(os.Remove(testDB)) 37 | }) 38 | 39 | return ms 40 | } 41 | 42 | func TestModels(t *testing.T) { 43 | a := assert.New(t, false) 44 | ms := newModules(a) 45 | 46 | m, err := ms.New(&User{}) 47 | a.NotError(err). 48 | NotNil(m). 49 | Equal(1, ms.Length()) 50 | 51 | // 相同的 model 实例,不会增加数量 52 | m, err = ms.New(&User{}) 53 | a.NotError(err). 54 | NotNil(m). 55 | Equal(1, ms.Length()) 56 | 57 | // 相同的表名,但是类型不同 58 | m, err = ms.New(&testdata.User{}) 59 | a.NotError(err). 60 | NotNil(m). 61 | Equal(2, ms.Length()) 62 | 63 | // 添加新的 model 64 | m, err = ms.New(&Admin{}) 65 | a.NotError(err). 66 | NotNil(m). 67 | Equal(3, ms.Length()) 68 | 69 | a.NotError(ms.Close()) 70 | a.Equal(0, ms.Length()) 71 | } 72 | -------------------------------------------------------------------------------- /internal/model/testdata/user.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package testdata 6 | 7 | import "github.com/issue9/orm/v6/core" 8 | 9 | // User 与 model.User 完全相同,包括表名 10 | // 11 | // 但在不同的目录下,类型不同,应该会被创建不同的 core.Model 对象。 12 | type User struct { 13 | ID int `orm:"name(id);ai;"` 14 | Username string `orm:"unique(unique_user_username);index(index_user_name);len(50)"` 15 | Password string `orm:"name(password);len(20)"` 16 | Regdate int `orm:"-"` 17 | } 18 | 19 | func (u *User) TableName() string { return "users" } 20 | 21 | func (u *User) ApplyModel(m *core.Model) error { 22 | m.Options["mysql_engine"] = []string{"innodb"} 23 | m.Options["mysql_charset"] = []string{"utf8"} 24 | return nil 25 | } 26 | -------------------------------------------------------------------------------- /internal/sqltest/sqltest.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | // Package sqltest 提供对 SQL 内容测试的工具 6 | package sqltest 7 | 8 | import ( 9 | "regexp" 10 | "strings" 11 | 12 | "github.com/issue9/assert/v4" 13 | ) 14 | 15 | var replacer = strings.NewReplacer( 16 | ")", " ) ", 17 | "(", " ( ", 18 | ",", " , ", 19 | "=", " = ", 20 | ) 21 | 22 | var spaceReplaceRegexp = regexp.MustCompile("\\s+") 23 | 24 | // Equal 检测两条 SQL 语句是否相等 25 | // 26 | // 忽略大小写与多余的空格。 27 | func Equal(a *assert.Assertion, s1, s2 string) { 28 | // 将'(', ')', ',' 等字符的前后空格标准化 29 | s1 = replacer.Replace(s1) 30 | s2 = replacer.Replace(s2) 31 | 32 | // 转换成小写,去掉首尾空格 33 | s1 = strings.TrimSpace(strings.ToLower(s1)) 34 | s2 = strings.TrimSpace(strings.ToLower(s2)) 35 | 36 | // 去掉多余的空格。 37 | s1 = spaceReplaceRegexp.ReplaceAllString(s1, " ") 38 | s2 = spaceReplaceRegexp.ReplaceAllString(s2, " ") 39 | 40 | a.TB().Helper() 41 | 42 | a.Equal(s1, s2) 43 | } 44 | -------------------------------------------------------------------------------- /internal/sqltest/sqltest_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqltest_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/internal/sqltest" 13 | "github.com/issue9/orm/v6/internal/test" 14 | ) 15 | 16 | func TestMain(m *testing.M) { 17 | test.Main(m) 18 | } 19 | 20 | func TestEqual(t *testing.T) { 21 | a := assert.New(t, false) 22 | sqltest.Equal(a, "insert INTO tb2 (c1, c2) values (?, ?) , (? ,@c2)", "insert into tb2 (c1,c2) values (?,?),(?,@c2)") 23 | } 24 | -------------------------------------------------------------------------------- /internal/tags/flag_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package tags_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/orm/v6/internal/test" 11 | ) 12 | 13 | func TestMain(m *testing.M) { 14 | test.Main(m) 15 | } 16 | -------------------------------------------------------------------------------- /internal/tags/tags.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | // Package tags 包实现对特定格式的 struct tag 字符串的分析 6 | // 7 | // 1. 以分号分隔的字符串,每个子串又以逗号分隔, 8 | // 第一个字符串为键名,之后的字符串组成的数组为键值。如: 9 | // 10 | // "id,1;unique;fun,add,1,2;" 11 | // // 以下将会被解析成: 12 | // [ 13 | // "id" :["1"], 14 | // "unique":nil, 15 | // "fun" :["add","1","2"] 16 | // ] 17 | // 18 | // 2.以分号分隔的字符串,每个子串括号前的字符串为健名, 19 | // 括号中的字符串以逗号分隔组成数组为键值。如: 20 | // 21 | // "id(1);unique;fun(add,1,2)" 22 | // // 以下将会被解析成: 23 | // [ 24 | // "id" :["1"], 25 | // "unique":nil, 26 | // "fun" :["add","1","2"] 27 | // ] 28 | package tags 29 | 30 | import "strings" 31 | 32 | // 将第二种风格的 struct tag 转换成第一种风格的。 33 | var styleReplace = strings.NewReplacer("(", ",", ")", "") 34 | 35 | // Tag 解析后的单个标签标签内容 36 | type Tag struct { 37 | Name string 38 | Args []string 39 | } 40 | 41 | // Parse 分析 tag 的内容并以 map 的形式返回 42 | func Parse(tag string) []*Tag { 43 | ret := make([]*Tag, 0, 10) 44 | 45 | if len(tag) == 0 { 46 | return nil 47 | } 48 | 49 | if strings.IndexByte(tag, '(') > -1 { 50 | tag = styleReplace.Replace(tag) 51 | } 52 | 53 | parts := strings.Split(tag, ";") 54 | for _, part := range parts { 55 | if len(part) == 0 { 56 | continue 57 | } 58 | part = strings.Trim(part, ",") 59 | items := strings.Split(part, ",") 60 | ret = append(ret, &Tag{ 61 | Name: items[0], 62 | Args: items[1:], 63 | }) 64 | } 65 | 66 | return ret 67 | } 68 | 69 | // Get 从 tag 中查找名称为 name 的内容 70 | // 71 | // 第二个参数用于判断该项是否存在。若存在多个同外的,则只返回第一个。 72 | func Get(tag, name string) ([]string, bool) { 73 | if len(tag) == 0 { 74 | return nil, false 75 | } 76 | 77 | if strings.IndexByte(tag, '(') > -1 { 78 | tag = styleReplace.Replace(tag) 79 | } 80 | 81 | parts := strings.Split(tag, ";") 82 | for _, part := range parts { 83 | if len(part) == 0 { 84 | continue 85 | } 86 | 87 | part = strings.Trim(part, ",") 88 | items := strings.Split(part, ",") 89 | if items[0] == name { 90 | return items[1:], true 91 | } 92 | } 93 | 94 | return nil, false 95 | } 96 | 97 | // MustGet 功能同 Get() 函数,但在无法找到的情况下,会返回 defVal 做为默认值。 98 | func MustGet(tag, name string, defVal ...string) []string { 99 | if ret, found := Get(tag, name); found { 100 | return ret 101 | } 102 | 103 | return defVal 104 | } 105 | 106 | // Has 查询指定名称的项是否存在 107 | // 108 | // 若只是查找是否存在该项,使用 Has() 会比 Get() 要快上许多。 109 | func Has(tag, name string) bool { 110 | if len(tag) == 0 { 111 | return false 112 | } 113 | 114 | if strings.IndexByte(tag, '(') > -1 { 115 | tag = styleReplace.Replace(tag) 116 | } 117 | 118 | parts := strings.Split(tag, ";") 119 | for _, part := range parts { 120 | if len(part) == 0 { 121 | continue 122 | } 123 | 124 | part = strings.Trim(part, ",") 125 | items := strings.SplitN(part, ",", 2) 126 | if items[0] == name { 127 | return true 128 | } 129 | } 130 | 131 | return false 132 | } 133 | -------------------------------------------------------------------------------- /internal/tags/tags_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package tags 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | ) 12 | 13 | type testData struct { // 测试数据结构 14 | tag string // 待分析字符串 15 | data []*Tag // 分析后数据 16 | } 17 | 18 | var tests = []*testData{ 19 | { 20 | tag: "name,abc;name2,;;name3,n1,n2;name3(n3,n4)", 21 | data: []*Tag{ 22 | { 23 | Name: "name", 24 | Args: []string{"abc"}, 25 | }, 26 | { 27 | Name: "name2", 28 | Args: []string{}, 29 | }, 30 | { 31 | Name: "name3", 32 | Args: []string{"n1", "n2"}, 33 | }, 34 | { 35 | Name: "name3", 36 | Args: []string{"n3", "n4"}, 37 | }, 38 | }, 39 | }, 40 | { 41 | tag: "name(abc);name2,;;name3(n1,n2);name3(n3,n4)", 42 | data: []*Tag{ 43 | { 44 | Name: "name", 45 | Args: []string{"abc"}, 46 | }, 47 | { 48 | Name: "name2", 49 | Args: []string{}, 50 | }, 51 | { 52 | Name: "name3", 53 | Args: []string{"n1", "n2"}, 54 | }, 55 | { 56 | Name: "name3", 57 | Args: []string{"n3", "n4"}, 58 | }, 59 | }, 60 | }, 61 | { 62 | tag: "", 63 | data: nil, 64 | }, 65 | { 66 | tag: "", 67 | data: []*Tag{}, 68 | }, 69 | } 70 | 71 | func TestReplace(t *testing.T) { 72 | a := assert.New(t, false) 73 | tag1 := "name,abc;name2,;;name3,n1,n2;name3,n1,n2" 74 | tag2 := "name(abc);name2,;;name3(n1,n2);name3(n1,n2)" 75 | tag := styleReplace.Replace(tag2) 76 | a.Equal(tag, tag1) 77 | } 78 | 79 | func TestParse(t *testing.T) { 80 | a := assert.New(t, false) 81 | 82 | for _, test := range tests { 83 | m := Parse(test.tag) 84 | if m != nil { 85 | for index, item := range m { 86 | a.Equal(item, test.data[index]) 87 | } 88 | } 89 | } 90 | } 91 | 92 | func TestGet(t *testing.T) { 93 | a := assert.New(t, false) 94 | 95 | for _, test := range tests { 96 | for _, items := range test.data { 97 | t.Log(test.tag) 98 | val, found := Get(test.tag, items.Name) 99 | a.True(found) 100 | if items.Name == "name3" { 101 | a.Equal(val, []string{"n1", "n2"}) // 多个重名的,只返回第一个数据 102 | } else { 103 | a.Equal(val, items.Args) 104 | } 105 | 106 | val, found = Get(test.tag, items.Name+"-temp") 107 | a.False(found).Nil(val) 108 | } 109 | } 110 | } 111 | 112 | func TestMustGet(t *testing.T) { 113 | a := assert.New(t, false) 114 | 115 | for _, test := range tests { 116 | for _, items := range test.data { 117 | val := MustGet(test.tag, items.Name, "default") 118 | if items.Name == "name3" { 119 | a.Equal(val, []string{"n1", "n2"}) // 多个重名的,只返回第一个数据 120 | } else { 121 | a.Equal(val, items.Args) 122 | } 123 | 124 | val = MustGet(test.tag, items.Name+"-temp", "def1", "def2") 125 | a.Equal(val, []string{"def1", "def2"}) 126 | } 127 | } 128 | } 129 | 130 | func TestHas(t *testing.T) { 131 | a := assert.New(t, false) 132 | 133 | for _, test := range tests { 134 | for _, item := range test.data { 135 | a.True(Has(test.tag, item.Name)) 136 | 137 | a.False(Has(test.tag, item.Name+"-temp")) 138 | } 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /internal/test/flag.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package test 6 | 7 | import ( 8 | "flag" 9 | "fmt" 10 | "os" 11 | "strings" 12 | "testing" 13 | ) 14 | 15 | var flags []*flagVar 16 | 17 | type flagVar struct { 18 | Name, DriverName string 19 | } 20 | 21 | // Main 供测试的 TestMain 调用 22 | func Main(m *testing.M) { 23 | dbString := flag.String("dbs", "sqlite3,sqlite3", "指定需要测试的数据库,格式为 name,driverName:name,driverName") 24 | 25 | flag.Parse() 26 | 27 | if *dbString == "" || flags != nil { 28 | return 29 | } 30 | 31 | flags = make([]*flagVar, 0, 10) 32 | 33 | items := strings.Split(*dbString, ":") 34 | for _, item := range items { 35 | i := strings.Split(item, ",") 36 | if len(i) != 2 { 37 | panic(fmt.Sprintf("格式错误:%v", *dbString)) 38 | } 39 | flags = append(flags, &flagVar{Name: i[0], DriverName: i[1]}) 40 | } 41 | 42 | os.Exit(m.Run()) 43 | } 44 | -------------------------------------------------------------------------------- /internal/test/test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | // Package test 提供了整个包的基本测试数据 6 | package test 7 | 8 | import ( 9 | "os" 10 | "slices" 11 | 12 | "github.com/issue9/assert/v4" 13 | 14 | "github.com/issue9/orm/v6" 15 | "github.com/issue9/orm/v6/core" 16 | "github.com/issue9/orm/v6/dialect" 17 | 18 | // 测试入口,数据库也在此初始化 19 | _ "github.com/go-sql-driver/mysql" 20 | _ "github.com/lib/pq" 21 | _ "github.com/mattn/go-sqlite3" 22 | _ "modernc.org/sqlite" 23 | ) 24 | 25 | const sqlite3DBFile = "orm_test.db" 26 | 27 | var ( 28 | // Sqlite3 Dialect 实例 29 | Sqlite3 = dialect.Sqlite3("sqlite3") 30 | 31 | // Sqlite Dialect 实例 32 | Sqlite = dialect.Sqlite3("sqlite") 33 | 34 | // Mysql Dialect 实例 35 | Mysql = dialect.Mysql("mysql") 36 | 37 | // Mariadb Dialect 实例 38 | Mariadb = dialect.Mariadb("mysql") 39 | 40 | // Postgres Dialect 实例 41 | Postgres = dialect.Postgres("postgres") 42 | ) 43 | 44 | // 以驱动为单的测试用例 45 | // 46 | // 部分设置项需要与 action 中的设置相同才能正常启动,比如端口号等。 47 | var cases = []struct { 48 | dsn string 49 | dialect orm.Dialect 50 | }{ 51 | { 52 | dsn: "./" + sqlite3DBFile + "?_fk=true&_loc=UTC", 53 | dialect: Sqlite3, 54 | }, 55 | { 56 | dsn: "./" + sqlite3DBFile + "?_fk=true&_loc=UTC", 57 | dialect: Sqlite, 58 | }, 59 | { 60 | dsn: "user=postgres host=127.0.0.1 password=postgres dbname=orm_test sslmode=disable", 61 | dialect: Postgres, 62 | }, 63 | { 64 | dsn: "root:root@/orm_test?charset=utf8&parseTime=true", 65 | dialect: Mysql, 66 | }, 67 | { 68 | dsn: "root:root@/orm_test?charset=utf8&parseTime=true", 69 | dialect: Mariadb, 70 | }, 71 | } 72 | 73 | // Driver 单个测试用例 74 | type Driver struct { 75 | *assert.Assertion 76 | DB *orm.DB 77 | DriverName string 78 | Name string 79 | dsn string 80 | } 81 | 82 | // Suite 测试用例管理 83 | type Suite struct { 84 | a *assert.Assertion 85 | drivers []*Driver 86 | } 87 | 88 | // NewSuite 初始化测试内容 89 | // 90 | // dialect 指定了当前需要测试的驱动,若未指定表示测试 flags 中的所有内容。 91 | func NewSuite(a *assert.Assertion, tablePrefix string, dialect ...core.Dialect) *Suite { 92 | s := &Suite{a: a} 93 | a.TB().Cleanup(func() { s.close() }) 94 | 95 | for _, c := range cases { 96 | name := c.dialect.Name() 97 | driver := c.dialect.DriverName() 98 | 99 | if len(dialect) > 0 && slices.IndexFunc(dialect, func(i core.Dialect) bool { return i.Name() == name && i.DriverName() == driver }) < 0 { 100 | continue 101 | } 102 | 103 | if len(flags) > 0 && slices.IndexFunc(flags, func(i *flagVar) bool { return i.Name == name && i.DriverName == driver }) < 0 { 104 | continue 105 | } 106 | 107 | a.TB().Logf("开始测试 %s:%s:%s\n", c.dialect.Name(), c.dialect.DriverName(), c.dsn) 108 | db, err := orm.NewDB(tablePrefix, c.dsn, c.dialect) 109 | a.NotError(err).NotNil(db) 110 | 111 | s.drivers = append(s.drivers, &Driver{ 112 | Assertion: a, 113 | DB: db, 114 | DriverName: driver, 115 | Name: name, 116 | dsn: c.dsn, 117 | }) 118 | } 119 | 120 | return s 121 | } 122 | 123 | func (s Suite) close() { 124 | for _, t := range s.drivers { 125 | t.NotError(t.DB.Close()) 126 | 127 | dn := t.DB.Dialect().DriverName() 128 | if dn != Sqlite3.DriverName() && dn != Sqlite.DriverName() { 129 | return 130 | } 131 | 132 | // sqlite3 删除数据库文件 133 | if _, err := os.Stat(sqlite3DBFile); err == nil || os.IsExist(err) { 134 | t.NotError(os.Remove(sqlite3DBFile)) 135 | } 136 | } 137 | } 138 | 139 | // Run 为每个数据库测试用例调用 f 进行测试 140 | func (s Suite) Run(f func(*Driver)) { 141 | for _, test := range s.drivers { 142 | f(test) 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /internal/test/test_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | "github.com/issue9/sliceutil" 12 | 13 | "github.com/issue9/orm/v6/core" 14 | ) 15 | 16 | func TestMain(m *testing.M) { 17 | Main(m) 18 | } 19 | 20 | func TestSuite_Run(t *testing.T) { 21 | a := assert.New(t, false) 22 | 23 | s := NewSuite(a, "") 24 | 25 | var size int 26 | s.Run(func(t *Driver) { 27 | a.NotNil(t). 28 | NotNil(t.DB). 29 | NotNil(t.DB.DB()). 30 | NotNil(t.DB.Dialect()). 31 | Equal(t.Assertion, a) 32 | size++ 33 | }) 34 | a.Equal(size, len(flags)) 35 | } 36 | 37 | func TestSuite_Run_withDialect(t *testing.T) { 38 | a := assert.New(t, false) 39 | 40 | // 不再限定 flags 41 | flags = []*flagVar{ 42 | {Name: "mysql", DriverName: "mysql"}, 43 | {Name: "sqlite3", DriverName: "sqlite3"}, 44 | {Name: "sqlite3", DriverName: "sqlite"}, 45 | {Name: "mariadb", DriverName: "mysql"}, 46 | {Name: "postgres", DriverName: "postgres"}, 47 | } 48 | 49 | // 通过参数限定了 dialect 50 | 51 | dialects := []core.Dialect{Sqlite3} 52 | s := NewSuite(a, "", dialects...) 53 | 54 | size := 0 55 | s.Run(func(t *Driver) { 56 | a.NotNil(t). 57 | NotNil(t.DB). 58 | NotNil(t.DB.DB()). 59 | NotNil(t.DB.Dialect()). 60 | Equal(t.Assertion, a) 61 | 62 | d := t.DB.Dialect() 63 | a.Equal(sliceutil.Count(dialects, func(i core.Dialect, _ int) bool { 64 | return i.Name() == d.Name() && i.DriverName() == d.DriverName() 65 | }), 1) 66 | 67 | size++ 68 | }) 69 | a.Equal(size, len(dialects)) 70 | } 71 | -------------------------------------------------------------------------------- /sqlbuilder/base.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | 11 | "github.com/issue9/orm/v6/core" 12 | ) 13 | 14 | type ( 15 | baseStmt struct { 16 | engine core.Engine 17 | 18 | // err 用于保存在生成语句中的错误信息 19 | // 20 | // 一旦有错误生成,那么后续的调用需要保证该 err 值不会被覆盖, 21 | // 即所有可能改变 err 的方法中,都要先判断 err 是否为空, 22 | // 如果不为空,则应该立即退出函数。 23 | err error 24 | } 25 | 26 | queryStmt struct { 27 | SQLer 28 | baseStmt 29 | } 30 | 31 | execStmt struct { 32 | SQLer 33 | baseStmt 34 | } 35 | 36 | ddlStmt struct { 37 | DDLSQLer 38 | baseStmt 39 | } 40 | 41 | multipleDDLStmt []DDLSQLer 42 | ) 43 | 44 | func newQueryStmt(e core.Engine, sql SQLer) *queryStmt { 45 | return &queryStmt{ 46 | SQLer: sql, 47 | baseStmt: baseStmt{ 48 | engine: e, 49 | }, 50 | } 51 | } 52 | 53 | func newExecStmt(e core.Engine, sql SQLer) *execStmt { 54 | return &execStmt{ 55 | SQLer: sql, 56 | baseStmt: baseStmt{ 57 | engine: e, 58 | }, 59 | } 60 | } 61 | 62 | func newDDLStmt(e core.Engine, sql DDLSQLer) *ddlStmt { 63 | return &ddlStmt{ 64 | DDLSQLer: sql, 65 | baseStmt: baseStmt{ 66 | engine: e, 67 | }, 68 | } 69 | } 70 | 71 | func (stmt *baseStmt) Dialect() core.Dialect { return stmt.engine.Dialect() } 72 | 73 | func (stmt *baseStmt) Engine() core.Engine { return stmt.engine } 74 | 75 | func (stmt *baseStmt) Err() error { return stmt.err } 76 | 77 | func (stmt *baseStmt) Reset() { stmt.err = nil } 78 | 79 | func (stmt ddlStmt) Exec() error { return stmt.ExecContext(context.Background()) } 80 | 81 | func (stmt *ddlStmt) ExecContext(ctx context.Context) error { 82 | qs, err := stmt.DDLSQL() 83 | if err != nil { 84 | return err 85 | } 86 | 87 | for _, query := range qs { 88 | if _, err = stmt.Engine().ExecContext(ctx, query); err != nil { 89 | return err 90 | } 91 | } 92 | 93 | return nil 94 | } 95 | 96 | // CombineSQL 合并 [SQLer.SQL] 返回的 query 和 args 参数 97 | func (stmt *execStmt) CombineSQL() (query string, err error) { 98 | query, args, err := stmt.SQL() 99 | if err != nil { 100 | return "", err 101 | } 102 | 103 | return fillArgs(query, args) 104 | } 105 | 106 | func (stmt *execStmt) Exec() (sql.Result, error) { return stmt.ExecContext(context.Background()) } 107 | 108 | func (stmt *execStmt) ExecContext(ctx context.Context) (sql.Result, error) { 109 | query, args, err := stmt.SQL() 110 | if err != nil { 111 | return nil, err 112 | } 113 | 114 | return stmt.Engine().ExecContext(ctx, query, args...) 115 | } 116 | 117 | // Prepare 预编译语句 118 | // 119 | // 预编译语句,参数最好采用 [sql.NamedArg] 类型。 120 | // 在生成语句时,参数顺序会发生变化,如果采用 ? 的形式, 121 | // 用户需要自己处理参数顺序问题,而 [sql.NamedArg] 没有这些问题。 122 | func (stmt *execStmt) Prepare() (*core.Stmt, error) { return stmt.PrepareContext(context.Background()) } 123 | 124 | func (stmt *execStmt) PrepareContext(ctx context.Context) (*core.Stmt, error) { 125 | query, _, err := stmt.SQL() 126 | if err != nil { 127 | return nil, err 128 | } 129 | 130 | return stmt.Engine().PrepareContext(ctx, query) 131 | } 132 | 133 | func (stmt *queryStmt) Prepare() (*core.Stmt, error) { 134 | return stmt.PrepareContext(context.Background()) 135 | } 136 | 137 | // CombineSQL 将 [SQLer.SQL] 中返回的参数替换掉 query 中的占位符, 138 | // 形成一条完整的查询语句。 139 | func (stmt *queryStmt) CombineSQL() (query string, err error) { 140 | query, args, err := stmt.SQL() 141 | if err != nil { 142 | return "", err 143 | } 144 | 145 | return fillArgs(query, args) 146 | } 147 | 148 | func (stmt *queryStmt) PrepareContext(ctx context.Context) (*core.Stmt, error) { 149 | query, _, err := stmt.SQL() 150 | if err != nil { 151 | return nil, err 152 | } 153 | 154 | return stmt.Engine().PrepareContext(ctx, query) 155 | } 156 | 157 | func (stmt queryStmt) Query() (*sql.Rows, error) { return stmt.QueryContext(context.Background()) } 158 | 159 | func (stmt *queryStmt) QueryContext(ctx context.Context) (*sql.Rows, error) { 160 | query, args, err := stmt.SQL() 161 | if err != nil { 162 | return nil, err 163 | } 164 | 165 | return stmt.Engine().QueryContext(ctx, query, args...) 166 | } 167 | 168 | // MergeDDL 合并多个 [DDLSQLer] 对象 169 | func MergeDDL(ddl ...DDLSQLer) DDLSQLer { return multipleDDLStmt(ddl) } 170 | 171 | func (stmt multipleDDLStmt) DDLSQL() ([]string, error) { 172 | queries := make([]string, 0, len(stmt)) 173 | 174 | for _, d := range stmt { 175 | q, e := d.DDLSQL() 176 | if e != nil { 177 | return nil, e 178 | } 179 | queries = append(queries, q...) 180 | } 181 | 182 | return queries, nil 183 | } 184 | -------------------------------------------------------------------------------- /sqlbuilder/base_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/core" 13 | "github.com/issue9/orm/v6/internal/test" 14 | "github.com/issue9/orm/v6/sqlbuilder" 15 | ) 16 | 17 | func TestMergeDDL(t *testing.T) { 18 | a := assert.New(t, false) 19 | 20 | suite := test.NewSuite(a, "") 21 | 22 | suite.Run(func(t *test.Driver) { 23 | initDB(t) 24 | defer clearDB(t) 25 | 26 | ddl1 := sqlbuilder.CreateIndex(t.DB). 27 | Table("users"). 28 | Name("index_key"). 29 | Columns("id", "name") 30 | 31 | ddl2 := sqlbuilder.AddColumn(t.DB). 32 | Table("users"). 33 | Column("id", core.Int, true, false, false, nil) 34 | 35 | ddl3 := sqlbuilder.AddColumn(t.DB). 36 | Table("users"). 37 | Column("name", core.String, false, true, false, nil) 38 | 39 | ddl := sqlbuilder.MergeDDL(ddl1, ddl2, ddl3) 40 | queries, err := ddl.DDLSQL() 41 | a.NotError(err).NotEmpty(queries) 42 | }) 43 | } 44 | -------------------------------------------------------------------------------- /sqlbuilder/column.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder 6 | 7 | import "github.com/issue9/orm/v6/core" 8 | 9 | // AddColumnStmt 添加列 10 | type AddColumnStmt struct { 11 | *ddlStmt 12 | 13 | table string 14 | column *core.Column 15 | } 16 | 17 | // AddColumn 声明一条添加列的语句 18 | func (sql *SQLBuilder) AddColumn() *AddColumnStmt { return AddColumn(sql.engine) } 19 | 20 | // AddColumn 声明一条添加列的语句 21 | func AddColumn(e core.Engine) *AddColumnStmt { 22 | stmt := &AddColumnStmt{} 23 | stmt.ddlStmt = newDDLStmt(e, stmt) 24 | return stmt 25 | } 26 | 27 | // Table 指定表名 28 | // 29 | // NOTE: 重复指定,会覆盖之前的。 30 | func (stmt *AddColumnStmt) Table(table string) *AddColumnStmt { 31 | stmt.table = table 32 | return stmt 33 | } 34 | 35 | // Column 添加列 36 | // 37 | // 参数信息可参考 [CreateTableStmt.Column] 38 | func (stmt *AddColumnStmt) Column(name string, p core.PrimitiveType, ai, nullable, hasDefault bool, def any, length ...int) *AddColumnStmt { 39 | if stmt.err != nil { 40 | return stmt 41 | } 42 | 43 | stmt.column, stmt.err = newColumn(name, p, ai, nullable, hasDefault, def, length...) 44 | return stmt 45 | } 46 | 47 | // DDLSQL 获取 SQL 语句以及对应的参数 48 | func (stmt *AddColumnStmt) DDLSQL() ([]string, error) { 49 | if stmt.err != nil { 50 | return nil, stmt.Err() 51 | } 52 | 53 | if stmt.table == "" { 54 | return nil, SyntaxError("ALTER TABLE ADD", "未指定表名") 55 | } 56 | 57 | if stmt.column == nil { 58 | return nil, SyntaxError("ALTER TABLE ADD", "未指定列") 59 | } 60 | 61 | if err := stmt.column.Check(); err != nil { 62 | return nil, err 63 | } 64 | 65 | typ, err := stmt.Dialect().SQLType(stmt.column) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | buf := core.NewBuilder("ALTER TABLE "). 71 | QuoteKey(stmt.table). 72 | WString(" ADD "). 73 | QuoteKey(stmt.column.Name). 74 | WBytes(' '). 75 | WString(typ) 76 | 77 | query, err := buf.String() 78 | if err != nil { 79 | return nil, err 80 | } 81 | return []string{query}, nil 82 | } 83 | 84 | // Reset 重置 85 | func (stmt *AddColumnStmt) Reset() *AddColumnStmt { 86 | stmt.baseStmt.Reset() 87 | stmt.table = "" 88 | stmt.column = nil 89 | return stmt 90 | } 91 | 92 | // DropColumnStmtHooker DropColumnStmt.DDLSQL 的钩子函数 93 | type DropColumnStmtHooker interface { 94 | DropColumnStmtHook(*DropColumnStmt) ([]string, error) 95 | } 96 | 97 | // DropColumnStmt 删除列 98 | type DropColumnStmt struct { 99 | *ddlStmt 100 | 101 | TableName string 102 | ColumnName string 103 | } 104 | 105 | // DropColumn 声明一条删除列的语句 106 | func (sql *SQLBuilder) DropColumn() *DropColumnStmt { 107 | return DropColumn(sql.engine) 108 | } 109 | 110 | // DropColumn 声明一条删除列的语句 111 | func DropColumn(e core.Engine) *DropColumnStmt { 112 | stmt := &DropColumnStmt{} 113 | stmt.ddlStmt = newDDLStmt(e, stmt) 114 | return stmt 115 | } 116 | 117 | // Table 指定表名。 118 | // 重复指定,会覆盖之前的。 119 | func (stmt *DropColumnStmt) Table(table string) *DropColumnStmt { 120 | stmt.TableName = table 121 | return stmt 122 | } 123 | 124 | // Column 指定需要删除的列 125 | // 重复指定,会覆盖之前的。 126 | func (stmt *DropColumnStmt) Column(col string) *DropColumnStmt { 127 | stmt.ColumnName = col 128 | return stmt 129 | } 130 | 131 | // DDLSQL 获取 SQL 语句以及对应的参数 132 | func (stmt *DropColumnStmt) DDLSQL() ([]string, error) { 133 | if stmt.err != nil { 134 | return nil, stmt.Err() 135 | } 136 | 137 | if stmt.TableName == "" { 138 | return nil, SyntaxError("DROP COLUMN", "未指定表名") 139 | } 140 | 141 | if hook, ok := stmt.Dialect().(DropColumnStmtHooker); ok { 142 | return hook.DropColumnStmtHook(stmt) 143 | } 144 | 145 | buf := core.NewBuilder("ALTER TABLE "). 146 | QuoteKey(stmt.TableName). 147 | WString(" DROP COLUMN "). 148 | QuoteKey(stmt.ColumnName) 149 | 150 | query, err := buf.String() 151 | if err != nil { 152 | return nil, err 153 | } 154 | return []string{query}, nil 155 | } 156 | 157 | // Reset 重置 158 | func (stmt *DropColumnStmt) Reset() *DropColumnStmt { 159 | stmt.baseStmt.Reset() 160 | stmt.TableName = "" 161 | stmt.ColumnName = "" 162 | return stmt 163 | } 164 | -------------------------------------------------------------------------------- /sqlbuilder/column_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/core" 13 | "github.com/issue9/orm/v6/internal/test" 14 | "github.com/issue9/orm/v6/sqlbuilder" 15 | ) 16 | 17 | var ( 18 | _ sqlbuilder.DDLStmt = &sqlbuilder.AddColumnStmt{} 19 | _ sqlbuilder.DDLStmt = &sqlbuilder.DropColumnStmt{} 20 | ) 21 | 22 | func TestColumn(t *testing.T) { 23 | a := assert.New(t, false) 24 | suite := test.NewSuite(a, "") 25 | 26 | suite.Run(func(t *test.Driver) { 27 | db := t.DB 28 | 29 | err := sqlbuilder.CreateTable(db). 30 | Table("users"). 31 | AutoIncrement("id", core.Int64). 32 | Exec() 33 | a.NotError(err) 34 | defer func() { 35 | err = sqlbuilder.DropTable(db).Table("users").Exec() 36 | a.NotError(err) 37 | }() 38 | 39 | addStmt := sqlbuilder.AddColumn(db) 40 | err = addStmt.Table("users"). 41 | Column("col1", core.Int, false, true, false, nil). 42 | Exec() 43 | a.NotError(err, "%s@%s", err, t.DriverName) 44 | 45 | dropStmt := sqlbuilder.DropColumn(db) 46 | err = dropStmt.Table("users"). 47 | Column("col1"). 48 | Exec() 49 | t.NotError(err, "%s@%s", err, t.DriverName) 50 | 51 | err = addStmt.Reset().Exec() 52 | a.ErrorString(err, "未指定表名") 53 | 54 | err = addStmt.Reset().Table("users").Exec() 55 | a.ErrorString(err, "未指定列") 56 | 57 | err = dropStmt.Reset().Exec() 58 | a.ErrorString(err, "未指定表名") 59 | }) 60 | 61 | // 添加主键 62 | suite.Run(func(t *test.Driver) { 63 | db := t.DB 64 | 65 | err := sqlbuilder.CreateTable(db). 66 | Table("users"). 67 | AutoIncrement("id", core.Int64). 68 | Column("name", core.String, false, false, false, nil). 69 | Exec() 70 | a.NotError(err) 71 | defer func() { 72 | err = sqlbuilder.DropTable(db).Table("users").Exec() 73 | a.NotError(err) 74 | }() 75 | 76 | // 已存在 77 | addStmt := sqlbuilder.AddColumn(db) 78 | err = addStmt.Table("users"). 79 | Column("id", core.Int, false, true, false, nil). 80 | Exec() 81 | a.Error(err, "%s@%s", err, t.DriverName) 82 | 83 | dropStmt := sqlbuilder.DropColumn(db) 84 | err = dropStmt.Table("users"). 85 | Column("id"). 86 | Exec() 87 | t.NotError(err, "%s@%s", err, t.DriverName) 88 | 89 | err = addStmt.Reset(). 90 | Table("users"). 91 | Column("id", core.Int, false, true, false, nil). 92 | Exec() 93 | a.NotError(err) 94 | }) 95 | } 96 | -------------------------------------------------------------------------------- /sqlbuilder/constraint_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/internal/test" 13 | "github.com/issue9/orm/v6/sqlbuilder" 14 | ) 15 | 16 | var ( 17 | _ sqlbuilder.DDLStmt = &sqlbuilder.DropConstraintStmt{} 18 | _ sqlbuilder.DDLStmt = &sqlbuilder.AddConstraintStmt{} 19 | ) 20 | 21 | func TestConstraint(t *testing.T) { 22 | a := assert.New(t, false) 23 | 24 | suite := test.NewSuite(a, "") 25 | 26 | suite.Run(func(t *test.Driver) { 27 | initDB(t) 28 | defer clearDB(t) 29 | 30 | addStmt := sqlbuilder.AddConstraint(t.DB) 31 | err := addStmt.Table("users"). 32 | Unique("u_user_name", "name"). 33 | Exec() 34 | t.NotError(err, "%s@%s", err, t.DriverName) 35 | 36 | // 删除约束 37 | dropStmt := sqlbuilder.DropConstraint(t.DB). 38 | Table("users"). 39 | Constraint("u_user_name") 40 | err = dropStmt.Exec() 41 | a.NotError(err, "%s@%s", err, t.DriverName) 42 | 43 | // 删除不存在的约束名 44 | err = dropStmt.Reset(). 45 | Table("users"). 46 | Constraint("u_user_name_not_exists___"). 47 | Exec() 48 | a.Error(err, "并未出错 @%s", t.DriverName) 49 | 50 | err = dropStmt.Reset().Exec() 51 | a.ErrorString(err, "未指定表名") 52 | 53 | err = dropStmt.Reset().Table("tbl").Exec() 54 | a.ErrorString(err, "未指定名称"). 55 | ErrorString(addStmt.Reset().Unique("", "name").Exec(), "未指定表名"). 56 | ErrorString(addStmt.Reset().Table("users").Unique("", "name").Exec(), "未指定名称") 57 | }) 58 | } 59 | 60 | func TestConstraint_Check(t *testing.T) { 61 | a := assert.New(t, false) 62 | suite := test.NewSuite(a, "") 63 | 64 | suite.Run(func(t *test.Driver) { 65 | initDB(t) 66 | defer clearDB(t) 67 | sb := t.DB.SQLBuilder() 68 | 69 | err := sb.AddConstraint(). 70 | Table("info"). 71 | Check("nick_not_null", "nickname IS NOT NULL"). 72 | Exec() 73 | t.NotError(err) 74 | 75 | err = sb.DropConstraint(). 76 | Table("info"). 77 | Constraint("nick_not_null"). 78 | Exec() 79 | a.NotError(err) 80 | }) 81 | } 82 | 83 | func TestConstraint_PK(t *testing.T) { 84 | a := assert.New(t, false) 85 | suite := test.NewSuite(a, "") 86 | 87 | suite.Run(func(t *test.Driver) { 88 | initDB(t) 89 | defer clearDB(t) 90 | sb := t.DB.SQLBuilder() 91 | 92 | // 已经存在主键,出错 93 | addStmt := sb.AddConstraint() 94 | err := addStmt.Table("info"). 95 | PK("info_pk", "tel"). 96 | Exec() 97 | t.Error(err) 98 | 99 | err = sb.DropConstraint(). 100 | Table("info"). 101 | PK("info_pk"). 102 | Exec() 103 | a.NotError(err) 104 | 105 | err = addStmt.Reset().Table("info"). 106 | PK("info_pk", "tel", "nickname"). 107 | Exec() 108 | t.NotError(err) 109 | }) 110 | 111 | suite.Run(func(t *test.Driver) { 112 | query := "CREATE TABLE info (uid BIGINT NOT NULL,CONSTRAINT test_pk PRIMARY KEY(uid))" 113 | _, err := t.DB.Exec(query) 114 | t.NotError(err) 115 | 116 | defer func() { 117 | err := sqlbuilder.DropTable(t.DB).Table("info").Exec() 118 | t.NotError(err) 119 | }() 120 | 121 | // 已经存在主键,出错 122 | addStmt := sqlbuilder.AddConstraint(t.DB) 123 | err = addStmt.Table("info"). 124 | PK("info_pk", "uid"). 125 | Exec() 126 | t.Error(err) 127 | 128 | err = sqlbuilder.DropConstraint(t.DB). 129 | Table("info"). 130 | PK("test_pk"). 131 | Exec() 132 | a.NotError(err) 133 | 134 | err = addStmt.Reset().Table("info"). 135 | PK("info_pk", "uid"). 136 | Exec() 137 | t.NotError(err) 138 | }) 139 | } 140 | 141 | func TestConstraint_FK(t *testing.T) { 142 | a := assert.New(t, false) 143 | suite := test.NewSuite(a, "") 144 | 145 | suite.Run(func(t *test.Driver) { 146 | initDB(t) 147 | defer clearDB(t) 148 | 149 | // 已经存在主键,出错 150 | addStmt := sqlbuilder.AddConstraint(t.DB) 151 | err := addStmt.Table("info"). 152 | FK("info_fk", "uid", "users", "id", "CASCADE", "CASCADE"). 153 | Exec() 154 | t.Error(err) 155 | 156 | err = sqlbuilder.DropConstraint(t.DB). 157 | Table("info"). 158 | Constraint("info_fk"). 159 | Exec() 160 | a.NotError(err) 161 | 162 | err = addStmt.Reset().Table("info"). 163 | FK("info_fk", "uid", "users", "id", "CASCADE", "CASCADE"). 164 | Exec() 165 | t.NotError(err) 166 | }) 167 | } 168 | -------------------------------------------------------------------------------- /sqlbuilder/delete.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder 6 | 7 | import "github.com/issue9/orm/v6/core" 8 | 9 | // DeleteStmt 表示删除操作的 SQL 语句 10 | type DeleteStmt struct { 11 | *execStmt 12 | *deleteWhere 13 | 14 | table string 15 | } 16 | 17 | type deleteWhere = WhereStmtOf[*DeleteStmt] 18 | 19 | // Delete 生成删除语句 20 | func (sql *SQLBuilder) Delete() *DeleteStmt { return Delete(sql.engine) } 21 | 22 | // Delete 声明一条删除语句 23 | func Delete(e core.Engine) *DeleteStmt { 24 | stmt := &DeleteStmt{} 25 | stmt.execStmt = newExecStmt(e, stmt) 26 | stmt.deleteWhere = NewWhereStmtOf(stmt) 27 | 28 | return stmt 29 | } 30 | 31 | // Table 指定表名 32 | func (stmt *DeleteStmt) Table(table string) *DeleteStmt { 33 | stmt.table = table 34 | return stmt 35 | } 36 | 37 | // SQL 获取 SQL 语句,以及其参数对应的具体值 38 | func (stmt *DeleteStmt) SQL() (string, []any, error) { 39 | if stmt.err != nil { 40 | return "", nil, stmt.Err() 41 | } 42 | 43 | if stmt.table == "" { 44 | return "", nil, SyntaxError("DELETE", "未指定表名") 45 | } 46 | 47 | query, args, err := stmt.WhereStmt().SQL() 48 | if err != nil { 49 | return "", nil, err 50 | } 51 | 52 | q, err := core.NewBuilder("DELETE FROM "). 53 | QuoteKey(stmt.table). 54 | WString(" WHERE "). 55 | WString(query). 56 | String() 57 | if err != nil { 58 | return "", nil, err 59 | } 60 | return q, args, nil 61 | } 62 | 63 | // Reset 重置语句 64 | func (stmt *DeleteStmt) Reset() *DeleteStmt { 65 | stmt.baseStmt.Reset() 66 | stmt.table = "" 67 | stmt.WhereStmt().Reset() 68 | return stmt 69 | } 70 | 71 | // Delete 删除指定条件的内容 72 | func (stmt *WhereStmt) Delete(e core.Engine) *DeleteStmt { 73 | del := Delete(e) 74 | del.deleteWhere.w = stmt 75 | return del 76 | } 77 | -------------------------------------------------------------------------------- /sqlbuilder/delete_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/internal/sqltest" 13 | "github.com/issue9/orm/v6/internal/test" 14 | "github.com/issue9/orm/v6/sqlbuilder" 15 | ) 16 | 17 | var _ sqlbuilder.ExecStmt = &sqlbuilder.DeleteStmt{} 18 | 19 | func TestDelete_Exec(t *testing.T) { 20 | a := assert.New(t, false) 21 | suite := test.NewSuite(a, "") 22 | 23 | suite.Run(func(t *test.Driver) { 24 | initDB(t) 25 | defer clearDB(t) 26 | 27 | sql := sqlbuilder.Delete(t.DB). 28 | Table("users"). 29 | Where("id=?", 1) 30 | _, err := sql.Exec() 31 | a.NotError(err) 32 | 33 | sql.Reset() 34 | sql.Table("users"). 35 | Where("id=?"). 36 | Or("name=?", "xx") 37 | _, err = sql.Exec() 38 | a.ErrorString(err, "列与值不匹配") 39 | 40 | sql.Reset() 41 | _, err = sql.Exec() 42 | a.ErrorString(err, "未指定表名") 43 | }) 44 | } 45 | 46 | func TestWhereStmt_Delete(t *testing.T) { 47 | a := assert.New(t, false) 48 | suite := test.NewSuite(a, "") 49 | 50 | suite.Run(func(t *test.Driver) { 51 | initDB(t) 52 | defer clearDB(t) 53 | 54 | sql := sqlbuilder.Where().And("id=?", 1). 55 | Delete(t.DB). 56 | Table("users") 57 | _, err := sql.Exec() 58 | a.NotError(err) 59 | 60 | query, args, err := sql.SQL() 61 | a.NotError(err). 62 | Equal(args, []any{1}) 63 | sqltest.Equal(a, query, "DELETE FROM {users} WHERE id=?") 64 | }) 65 | } 66 | -------------------------------------------------------------------------------- /sqlbuilder/index.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder 6 | 7 | import "github.com/issue9/orm/v6/core" 8 | 9 | // CreateIndexStmt 创建索引的语句 10 | type CreateIndexStmt struct { 11 | *ddlStmt 12 | table string 13 | name string // 索引名称 14 | cols []string // 索引列 15 | typ core.IndexType 16 | } 17 | 18 | // CreateIndex 生成创建索引的语句 19 | func (sql *SQLBuilder) CreateIndex() *CreateIndexStmt { 20 | return CreateIndex(sql.engine) 21 | } 22 | 23 | // CreateIndex 声明一条 CreateIndexStmt 语句 24 | func CreateIndex(e core.Engine) *CreateIndexStmt { 25 | stmt := &CreateIndexStmt{typ: core.IndexDefault} 26 | stmt.ddlStmt = newDDLStmt(e, stmt) 27 | 28 | return stmt 29 | } 30 | 31 | // Table 指定表名 32 | func (stmt *CreateIndexStmt) Table(tbl string) *CreateIndexStmt { 33 | stmt.table = tbl 34 | return stmt 35 | } 36 | 37 | // Name 指定索引名 38 | func (stmt *CreateIndexStmt) Name(index string) *CreateIndexStmt { 39 | stmt.name = index 40 | return stmt 41 | } 42 | 43 | // Type 指定索引类型 44 | func (stmt *CreateIndexStmt) Type(t core.IndexType) *CreateIndexStmt { 45 | stmt.typ = t 46 | return stmt 47 | } 48 | 49 | // Columns 列名 50 | func (stmt *CreateIndexStmt) Columns(col ...string) *CreateIndexStmt { 51 | if stmt.err != nil { 52 | return stmt 53 | } 54 | 55 | if stmt.cols == nil { 56 | stmt.cols = col 57 | return stmt 58 | } 59 | 60 | stmt.cols = append(stmt.cols, col...) 61 | return stmt 62 | } 63 | 64 | // DDLSQL 生成 SQL 语句 65 | func (stmt *CreateIndexStmt) DDLSQL() ([]string, error) { 66 | if stmt.err != nil { 67 | return nil, stmt.Err() 68 | } 69 | 70 | if stmt.table == "" { 71 | return nil, SyntaxError("CREATE INDEX", "未指定表名") 72 | } 73 | 74 | if len(stmt.cols) == 0 { 75 | return nil, SyntaxError("CREATE INDEX", "未指定列") 76 | } 77 | 78 | var builder *core.Builder 79 | 80 | if stmt.typ == core.IndexDefault { 81 | builder = core.NewBuilder("CREATE INDEX ") 82 | } else { 83 | builder = core.NewBuilder("CREATE UNIQUE INDEX ") 84 | } 85 | 86 | builder.WString(stmt.name). 87 | WString(" ON "). 88 | QuoteKey(stmt.table). 89 | WBytes('(') 90 | for _, col := range stmt.cols { 91 | builder.QuoteKey(col). 92 | WBytes(',') 93 | } 94 | builder.TruncateLast(1).WBytes(')') 95 | 96 | query, err := builder.String() 97 | if err != nil { 98 | return nil, err 99 | } 100 | return []string{query}, nil 101 | } 102 | 103 | // Reset 重置 104 | func (stmt *CreateIndexStmt) Reset() *CreateIndexStmt { 105 | stmt.baseStmt.Reset() 106 | stmt.table = "" 107 | stmt.cols = stmt.cols[:0] 108 | stmt.name = "" 109 | stmt.typ = core.IndexDefault 110 | 111 | return stmt 112 | } 113 | 114 | // DropIndexStmt 删除索引 115 | type DropIndexStmt struct { 116 | *ddlStmt 117 | tableName string 118 | indexName string 119 | } 120 | 121 | // DropIndex 生成删除索引的语句 122 | func (sql *SQLBuilder) DropIndex() *DropIndexStmt { return DropIndex(sql.engine) } 123 | 124 | // DropIndex 声明一条 DropIndexStmt 语句 125 | func DropIndex(e core.Engine) *DropIndexStmt { 126 | stmt := &DropIndexStmt{} 127 | stmt.ddlStmt = newDDLStmt(e, stmt) 128 | return stmt 129 | } 130 | 131 | // Table 指定表名 132 | func (stmt *DropIndexStmt) Table(tbl string) *DropIndexStmt { 133 | stmt.tableName = tbl 134 | return stmt 135 | } 136 | 137 | // Name 指定索引名 138 | func (stmt *DropIndexStmt) Name(col string) *DropIndexStmt { 139 | stmt.indexName = col 140 | return stmt 141 | } 142 | 143 | // DDLSQL 生成 SQL 语句 144 | func (stmt *DropIndexStmt) DDLSQL() ([]string, error) { 145 | q, err := stmt.Dialect().DropIndexSQL(stmt.tableName, stmt.indexName) 146 | if err != nil { 147 | return nil, err 148 | } 149 | return []string{q}, nil 150 | } 151 | 152 | // Reset 重置 153 | func (stmt *DropIndexStmt) Reset() *DropIndexStmt { 154 | stmt.baseStmt.Reset() 155 | stmt.tableName = "" 156 | stmt.indexName = "" 157 | return stmt 158 | } 159 | -------------------------------------------------------------------------------- /sqlbuilder/index_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/core" 13 | "github.com/issue9/orm/v6/internal/test" 14 | "github.com/issue9/orm/v6/sqlbuilder" 15 | ) 16 | 17 | var ( 18 | _ sqlbuilder.DDLStmt = &sqlbuilder.CreateIndexStmt{} 19 | _ sqlbuilder.DDLStmt = &sqlbuilder.DropIndexStmt{} 20 | ) 21 | 22 | func TestIndex(t *testing.T) { 23 | a := assert.New(t, false) 24 | suite := test.NewSuite(a, "") 25 | 26 | suite.Run(func(t *test.Driver) { 27 | initDB(t) 28 | defer clearDB(t) 29 | 30 | createStmt := sqlbuilder.CreateIndex(t.DB). 31 | Table("users"). 32 | Name("index_key"). 33 | Columns("id", "name") 34 | err := createStmt.Exec() 35 | t.NotError(err) 36 | 37 | // 同名约束名,应该会出错 38 | createStmt.Reset() 39 | err = createStmt.Table("users"). 40 | Name("index_key"). 41 | Columns("id", "name"). 42 | Exec() 43 | t.Error(err) 44 | 45 | // 唯一约束 46 | createStmt.Reset() 47 | err = createStmt.Table("users"). 48 | Name("index_unique_key"). 49 | Type(core.IndexUnique). 50 | Columns("id", "name"). 51 | Exec() 52 | t.NotError(err) 53 | 54 | dropStmt := sqlbuilder.DropIndex(t.DB). 55 | Table("users"). 56 | Name("index_key") 57 | err = dropStmt.Exec() 58 | t.NotError(err) 59 | 60 | // 不存在的索引 61 | dropStmt.Reset() 62 | err = dropStmt.Table("users"). 63 | Name("index_key"). 64 | Exec() 65 | a.Error(err) 66 | 67 | dropStmt.Reset() 68 | err = dropStmt.Table("users"). 69 | Name("index_unique_key"). 70 | Exec() 71 | t.NotError(err, "cc") 72 | 73 | createStmt.Reset() 74 | a.ErrorString(createStmt.Exec(), "未指定表名") 75 | 76 | createStmt.Reset() 77 | createStmt.Table("test") 78 | a.ErrorString(createStmt.Exec(), "未指定列") 79 | 80 | dropStmt.Reset() 81 | dropStmt.Table("test") 82 | a.ErrorString(dropStmt.Exec(), "未指定列") 83 | }) 84 | } 85 | -------------------------------------------------------------------------------- /sqlbuilder/insert_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder_test 6 | 7 | import ( 8 | "database/sql" 9 | "testing" 10 | 11 | "github.com/issue9/assert/v4" 12 | 13 | "github.com/issue9/orm/v6/core" 14 | "github.com/issue9/orm/v6/internal/test" 15 | "github.com/issue9/orm/v6/sqlbuilder" 16 | ) 17 | 18 | var _ sqlbuilder.SQLer = &sqlbuilder.InsertStmt{} 19 | 20 | func TestInsert(t *testing.T) { 21 | a := assert.New(t, false) 22 | s := test.NewSuite(a, "") 23 | tableName := "users" 24 | 25 | s.Run(func(t *test.Driver) { 26 | err := sqlbuilder.CreateTable(t.DB). 27 | Table(tableName). 28 | AutoIncrement("id", core.Int64). 29 | Column("name", core.String, false, false, true, "def-name", 20). 30 | Exec() 31 | a.NotError(err) 32 | defer func() { 33 | err := sqlbuilder.DropTable(t.DB). 34 | Table(tableName). 35 | Exec() 36 | a.NotError(err) 37 | }() 38 | 39 | i := sqlbuilder.Insert(t.DB).Table(tableName) 40 | a.NotNil(i) 41 | 42 | i.Columns("id", "name").Values(10, "name10").Values(11, "name11") 43 | _, err = i.Exec() 44 | a.NotError(err) 45 | 46 | i.Reset().Table("tb1"). 47 | Table(tableName). 48 | KeyValue("id", 20). 49 | KeyValue("name", "name20") 50 | _, err = i.Exec() 51 | a.NotError(err) 52 | 53 | i.Reset().Columns("id", "name") 54 | _, err = i.Exec() 55 | a.ErrorString(err, "未指定表名") 56 | 57 | i.Reset().Table(tableName).Columns("id", "name").Values("100") 58 | _, err = i.Exec() 59 | a.ErrorString(err, "列与值不匹配") 60 | 61 | // default value 62 | _, err = i.Reset().Table(tableName).Exec() 63 | a.NotError(err) 64 | }) 65 | } 66 | 67 | func TestInsert_NamedArgs(t *testing.T) { 68 | a := assert.New(t, false) 69 | s := test.NewSuite(a, "") 70 | tableName := "users" 71 | 72 | s.Run(func(t *test.Driver) { 73 | err := sqlbuilder.CreateTable(t.DB). 74 | Table(tableName). 75 | AutoIncrement("id", core.Int64). 76 | Column("name", core.String, false, false, false, nil, 20). 77 | Exec() 78 | a.NotError(err) 79 | defer func() { 80 | err := sqlbuilder.DropTable(t.DB). 81 | Table(tableName). 82 | Exec() 83 | a.NotError(err) 84 | }() 85 | 86 | i := sqlbuilder.Insert(t.DB).Table(tableName) 87 | i.Reset().Table(tableName). 88 | Columns("id", "name"). 89 | Values(sql.Named("id", 1), sql.Named("name", "name1")) 90 | _, err = i.Exec() 91 | t.NotError(err) 92 | 93 | // 预编译 94 | stmt, err := i.Prepare() 95 | a.NotError(err).NotNil(stmt) 96 | _, err = stmt.Exec(sql.Named("id", 2), sql.Named("name", "name2")) 97 | a.NotError(err) 98 | _, err = stmt.Exec(sql.Named("id", 3), sql.Named("name", "name3")) 99 | a.NotError(err) 100 | 101 | // 部分参数类型不正确 102 | _, err = stmt.Exec(sql.Named("id", 4), "name4") 103 | a.Error(err) 104 | 105 | // 参数类型不正确 106 | _, err = stmt.Exec(5, "name5") 107 | a.Error(err) 108 | }) 109 | } 110 | -------------------------------------------------------------------------------- /sqlbuilder/parser.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder 6 | 7 | import ( 8 | "bufio" 9 | "database/sql" 10 | "fmt" 11 | "strings" 12 | "unicode" 13 | 14 | "github.com/issue9/orm/v6/core" 15 | ) 16 | 17 | var quoteReplacer = strings.NewReplacer("{", "", "}", "") 18 | 19 | // 将参数替换成实际的值 20 | func fillArgs(query string, args []any) (string, error) { 21 | // 获取所有命名参数列表 22 | named := make(map[string]any, len(args)) 23 | for _, arg := range args { 24 | if n, ok := arg.(sql.NamedArg); ok { 25 | named[n.Name] = n.Value 26 | continue 27 | } 28 | } 29 | 30 | w := func(builder *core.Builder, name string) error { 31 | v, found := named[name] 32 | if !found { 33 | return fmt.Errorf("不存在该名称的参数:%s", name) 34 | } 35 | builder.Quote(fmt.Sprint(v), '\'', '\'') 36 | return nil 37 | } 38 | 39 | builder := core.NewBuilder("") 40 | var index int 41 | start := -1 42 | for i, c := range query { 43 | switch { 44 | case c == '@': 45 | start = i + 1 46 | case start != -1 && !unicode.IsLetter(c): 47 | if err := w(builder, query[start:i]); err != nil { 48 | return "", err 49 | } 50 | builder.WRunes(c) // 当前的字符不能丢 51 | start = -1 52 | index++ 53 | case start == -1: 54 | if c == '?' { 55 | builder.Quote(fmt.Sprint(args[index]), '\'', '\'') 56 | index++ 57 | } else { 58 | builder.WRunes(c) 59 | } 60 | } 61 | } 62 | 63 | if start > -1 { 64 | if err := w(builder, query[start:]); err != nil { 65 | return "", err 66 | } 67 | } 68 | 69 | return builder.String() 70 | } 71 | 72 | // 从表达式中获取列的名称 73 | // 74 | // 如果不存在别名,则取其列名或是整个表达式作为别名。 75 | // - => * 76 | // table.* => * 77 | // table.col => {col} 78 | // table.col as col => {col} 79 | // sum(table.count) as cnt ==> {cnt} 80 | // func1(func2(table.col1),table.col2) as fn1 ==> {fn1} 81 | // count({table.*}) => {count(table.*)} 82 | func getColumnName(expr string) string { 83 | if expr == "*" { 84 | return expr 85 | } 86 | 87 | s := bufio.NewScanner(strings.NewReader(expr)) 88 | s.Split(splitWithAS) 89 | 90 | var name string 91 | for s.Scan() { 92 | name = s.Text() 93 | } 94 | 95 | if len(name) == 0 || name == "*" { 96 | return name 97 | } 98 | 99 | // 尽量取列名部分作为别名,如果包含了函数信息, 100 | // 则将整个表达式作为别名。 101 | var deep, start int 102 | for i, b := range name { 103 | switch { 104 | case b == '{': 105 | deep++ 106 | case b == '}': 107 | deep-- 108 | case b == '.' && deep == 0: 109 | start = i 110 | case b == '(': // 包含函数信息,则将整个表达式作为别名 111 | return "{" + quoteReplacer.Replace(name) + "}" 112 | } 113 | } 114 | 115 | if start > 0 { 116 | name = name[start+1:] 117 | } 118 | 119 | if name == "*" || name[0] == '{' { 120 | return name 121 | } 122 | 123 | return "{" + name + "}" 124 | } 125 | 126 | func splitWithAS(data []byte, atEOF bool) (advance int, token []byte, err error) { 127 | var start, deep int 128 | var b byte 129 | 130 | // 去掉行首的空格 131 | for start, b = range data { 132 | if !unicode.IsSpace(rune(b)) { 133 | break 134 | } 135 | } 136 | 137 | // 找到第一个 AS 字符串 138 | for i, b := range data { 139 | if b == '{' { 140 | deep++ 141 | continue 142 | } 143 | 144 | if b == '}' { 145 | deep-- 146 | continue 147 | } 148 | 149 | if deep != 0 { 150 | continue 151 | } 152 | 153 | if !unicode.IsSpace(rune(b)) { 154 | continue 155 | } 156 | 157 | if len(data) <= i+3 { 158 | break 159 | } 160 | 161 | b1 := data[i+1] 162 | b2 := data[i+2] 163 | b3 := data[i+3] 164 | if (b1 == 'a' || b1 == 'A') && 165 | (b2 == 's' || b2 == 'S') && 166 | unicode.IsSpace(rune(b3)) { 167 | return i + 4, data[start:i], nil 168 | } 169 | } 170 | 171 | if atEOF && len(data) > start { 172 | return len(data), data[start:], nil 173 | } 174 | 175 | return start, nil, nil 176 | } 177 | -------------------------------------------------------------------------------- /sqlbuilder/parser_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder 6 | 7 | import ( 8 | "bufio" 9 | "database/sql" 10 | "testing" 11 | 12 | "github.com/issue9/assert/v4" 13 | 14 | "github.com/issue9/orm/v6/internal/sqltest" 15 | ) 16 | 17 | var _ bufio.SplitFunc = splitWithAS 18 | 19 | func TestFillArgs(t *testing.T) { 20 | a := assert.New(t, false) 21 | 22 | var data = []*struct { 23 | query string 24 | args []any 25 | output string 26 | err bool 27 | }{ 28 | { 29 | query: "select * from tbl", 30 | args: []any{}, 31 | output: "select * from tbl", 32 | }, 33 | { 34 | query: "select * from tbl where id=?", 35 | args: []any{1}, 36 | output: "select * from tbl where id='1'", 37 | }, 38 | { 39 | query: "select * from tbl where id=? and name=?", 40 | args: []any{1, "n"}, 41 | output: "select * from tbl where id='1' and name='n'", 42 | }, 43 | { 44 | query: "select * from tbl where id=? and name=@name", 45 | args: []any{1, sql.Named("name", "n")}, 46 | output: "select * from tbl where id='1' and name='n'", 47 | }, 48 | { 49 | query: "select * from tbl where id=? and name=@name and age>?", 50 | args: []any{1, sql.Named("name", "n"), 18}, 51 | output: "select * from tbl where id='1' and name='n' and age>'18'", 52 | }, 53 | { // 类型不匹配 54 | query: "select * from tbl where id=? and name=@name", 55 | args: []any{1, "n"}, 56 | err: true, 57 | }, 58 | { // 类型不匹配 59 | query: "select * from tbl where id=? and name=@name and age>@age", 60 | args: []any{1, "n", sql.Named("age", 18)}, 61 | err: true, 62 | }, 63 | { // 名称不存在 64 | query: "select * from tbl where id=? and name=@name", 65 | args: []any{1, sql.Named("not-exists", "n")}, 66 | err: true, 67 | }, 68 | } 69 | 70 | for index, item := range data { 71 | output, err := fillArgs(item.query, item.args) 72 | if item.err { 73 | a.Error(err, "%s@%d", err, index). 74 | Empty(output) 75 | continue 76 | } 77 | 78 | a.NotError(err, "%s@%d", err, index) 79 | sqltest.Equal(a, output, item.output) 80 | } 81 | } 82 | 83 | func TestGetColumnName(t *testing.T) { 84 | a := assert.New(t, false) 85 | 86 | var data = []*struct { 87 | input string 88 | output string 89 | }{ 90 | { 91 | input: "", 92 | output: "", 93 | }, 94 | { 95 | input: "table.*", 96 | output: "*", 97 | }, 98 | { 99 | input: "{table}.*", 100 | output: "*", 101 | }, 102 | { 103 | input: "{table}.{as}", 104 | output: "{as}", 105 | }, 106 | { // 多个 as 107 | input: "table.{as} as {as}", 108 | output: "{as}", 109 | }, 110 | { 111 | input: "count({table}.*) as cnt", 112 | output: "{cnt}", 113 | }, 114 | { // 别名中包含 AS 115 | input: "count({table}.*) as {col as name}", 116 | output: "{col as name}", 117 | }, 118 | { 119 | input: "count({table}.*) as {count\t name}", 120 | output: "{count\t name}", 121 | }, 122 | { // 采用 \t 分隔 123 | input: "count({table}.*)\tas\tcnt", 124 | output: "{cnt}", 125 | }, 126 | { // 采用 \t、\n 混合 127 | input: "count({table}.*)\tas\ncnt", 128 | output: "{cnt}", 129 | }, 130 | { // 采用 \t 与空格混合 131 | input: "count({table}.*) \tas\t cnt", 132 | output: "{cnt}", 133 | }, 134 | { 135 | input: "sum(count({table}.*)) as cnt", 136 | output: "{cnt}", 137 | }, 138 | { // 整个内容作为列名 139 | input: "count({table}.*)", 140 | output: "{count(table.*)}", 141 | }, 142 | { 143 | input: "sum(count({table}.*)) as cnt", 144 | output: "{cnt}", 145 | }, 146 | { 147 | input: "sum(count({table}.as)) as {as}", 148 | output: "{as}", 149 | }, 150 | { 151 | input: "{table}.{as} as {as}", 152 | output: "{as}", 153 | }, 154 | { 155 | input: "{table}.{as} as 列名1", 156 | output: "{列名1}", 157 | }, 158 | { 159 | input: "{table}.{as} as {列名1}", 160 | output: "{列名1}", 161 | }, 162 | } 163 | 164 | for index, item := range data { 165 | col := getColumnName(item.input) 166 | a.Equal(col, item.output, "not equal @%d v1:%v,v2:%v", index, col, item.output) 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /sqlbuilder/sqlbuilder.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | // Package sqlbuilder 提供一套通过字符串拼接来构成 SQL 语句的工具 6 | // 7 | // sqlbuilder 提供了部分 *Hooker 的接口, 8 | // 用于处理大部分数据都有标准实现而只有某个数据库采用了非标准模式的。 9 | package sqlbuilder 10 | 11 | import ( 12 | "context" 13 | "database/sql" 14 | "fmt" 15 | 16 | "github.com/issue9/orm/v6/core" 17 | ) 18 | 19 | type ( 20 | // SQLBuilder 提供了 sqlbuilder 下的各类语句的创建方法 21 | SQLBuilder struct { 22 | engine core.Engine 23 | } 24 | 25 | // SQLer 定义 SQL 语句的基本接口 26 | SQLer interface { 27 | // SQL 将当前实例转换成 SQL 语句返回 28 | // 29 | // query 表示 SQL 语句,而 args 表示语句各个参数占位符对应的参数值。 30 | SQL() (query string, args []any, err error) 31 | } 32 | 33 | // DDLSQLer SQL 中 DDL 语句的基本接口 34 | // 35 | // 大部分数据的 DDL 操作是有多条语句组成,比如 CREATE TABLE 36 | // 可能包含了额外的定义信息。 37 | DDLSQLer interface { 38 | DDLSQL() ([]string, error) 39 | } 40 | 41 | ExecStmt interface { 42 | SQLer 43 | Prepare() (*core.Stmt, error) 44 | PrepareContext(ctx context.Context) (*core.Stmt, error) 45 | Exec() (sql.Result, error) 46 | ExecContext(ctx context.Context) (sql.Result, error) 47 | } 48 | 49 | DDLStmt interface { 50 | DDLSQLer 51 | Exec() error 52 | ExecContext(ctx context.Context) error 53 | } 54 | ) 55 | 56 | // New 声明 SQLBuilder 实例 57 | // 58 | // tablePrefix 表名前缀; 59 | func New(e core.Engine) *SQLBuilder { 60 | return &SQLBuilder{engine: e} 61 | } 62 | 63 | // SyntaxError 返回语法错误的信息 64 | // 65 | // typ 表示语句的类型,比如 SELECT、UPDATE 等; 66 | // msg 为具体的错误信息; 67 | func SyntaxError(typ string, msg any) error { 68 | return fmt.Errorf("在 %s 语句中存在语法错误 %s", typ, msg) 69 | } 70 | -------------------------------------------------------------------------------- /sqlbuilder/sqlbuilder_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder_test 6 | 7 | import ( 8 | "testing" 9 | "time" 10 | 11 | "github.com/issue9/orm/v6/core" 12 | "github.com/issue9/orm/v6/internal/test" 13 | "github.com/issue9/orm/v6/sqlbuilder" 14 | ) 15 | 16 | func TestMain(m *testing.M) { 17 | test.Main(m) 18 | } 19 | 20 | // user 需要与 initDB 中的 users 表中的字段相同 21 | type user struct { 22 | ID int64 `orm:"name(id);ai"` 23 | Name string `orm:"name(name);len(20)"` 24 | Age int `orm:"name(age)"` 25 | Version int64 `orm:"name(version);default(0)"` 26 | } 27 | 28 | func (u *user) ApplyModel(m *core.Model) error { 29 | m.Name = "users" 30 | return nil 31 | } 32 | 33 | func initDB(t *test.Driver) { 34 | t.Assertion.TB().Helper() 35 | 36 | creator := sqlbuilder.CreateTable(t.DB). 37 | Table("users"). 38 | AutoIncrement("id", core.Int64). 39 | Column("name", core.String, false, false, false, nil, 20). 40 | Column("age", core.Int, false, true, false, nil). 41 | Column("version", core.Int64, false, false, true, 0). 42 | Unique("unique_users_id", "id") 43 | err := creator.Exec() 44 | t.NotError(err, "%s@%s", err, t.DriverName) 45 | 46 | creator.Reset().Table("info"). 47 | Column("uid", core.Int64, false, false, false, nil). 48 | Column("tel", core.String, false, false, false, nil, 11). 49 | Column("nickname", core.String, false, false, false, nil, 20). 50 | Column("address", core.String, false, false, false, nil, 1024). 51 | Column("birthday", core.Time, false, false, true, time.Time{}). 52 | PK("info_pk", "tel", "nickname"). 53 | ForeignKey("info_fk", "uid", "users", "id", "CASCADE", "CASCADE") 54 | err = creator.Exec() 55 | t.NotError(err) 56 | 57 | sql := sqlbuilder.Insert(t.DB). 58 | Columns("name", "age"). 59 | Table("users"). 60 | Values("1", 1). 61 | Values("2", 2) 62 | _, err = sql.Exec() 63 | t.NotError(err, "%s@%s", err, t.DriverName) 64 | 65 | stmt, err := sql.Prepare() 66 | t.NotError(err, "%s@%s", err, t.DriverName). 67 | NotNil(stmt, "not nil @%s", t.DriverName) 68 | 69 | _, err = stmt.Exec("3", 3, "4", 4) 70 | t.NotError(err, "%s@%s", err, t.DriverName) 71 | _, err = stmt.Exec("5", 6, "6", 6) 72 | t.NotError(err, "%s@%s", err, t.DriverName) 73 | 74 | sql.Reset() 75 | sql.Table("users"). 76 | Columns("name"). 77 | Values("7") 78 | id, err := sql.LastInsertID("id") 79 | t.NotError(err, "%s@%s", err, t.DriverName). 80 | Equal(id, 7, "%d != %d @ %s", id, 7, t.DriverName) 81 | 82 | // 多行插入,不能拿到 lastInsertID 83 | sql.Table("users"). 84 | Columns("name"). 85 | Values("8"). 86 | Values("9") 87 | id, err = sql.LastInsertID("id") 88 | t.Error(err, "%s@%s", err, t.DriverName). 89 | Empty(id, "not empty @%s", t.DriverName) 90 | } 91 | 92 | func clearDB(t *test.Driver) { 93 | err := sqlbuilder.DropTable(t.DB). 94 | Table("info"). // 需要先删除 info,info 的外键依赖 users 95 | Table("users"). 96 | Exec() 97 | t.NotError(err) 98 | } 99 | -------------------------------------------------------------------------------- /sqlbuilder/table_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/core" 13 | "github.com/issue9/orm/v6/internal/test" 14 | "github.com/issue9/orm/v6/sqlbuilder" 15 | ) 16 | 17 | var ( 18 | _ sqlbuilder.DDLStmt = &sqlbuilder.CreateTableStmt{} 19 | _ sqlbuilder.DDLStmt = &sqlbuilder.DropTableStmt{} 20 | _ sqlbuilder.DDLStmt = &sqlbuilder.TruncateTableStmt{} 21 | ) 22 | 23 | func TestCreateTableStmt(t *testing.T) { 24 | a := assert.New(t, false) 25 | table := "create_table_test" 26 | suite := test.NewSuite(a, "") 27 | 28 | suite.Run(func(t *test.Driver) { 29 | stmt := sqlbuilder.CreateTable(t.DB). 30 | Table(table). 31 | AutoIncrement("id", core.Int). 32 | Column("age", core.Int, false, false, false, nil). 33 | Column("name", core.String, false, true, true, "", 100). 34 | Column("address", core.String, false, false, false, nil, 100). 35 | Index(core.IndexDefault, "index_index", "name", "address"). 36 | Unique("u_age", "name", "address"). 37 | Check("age_gt_0", "age>0") 38 | err := stmt.Exec() 39 | a.NotError(err) 40 | 41 | exists, err := sqlbuilder.TableExists(t.DB).Table(table).Exists() 42 | a.NotError(err).True(exists) 43 | exists, err = sqlbuilder.TableExists(t.DB).Table("not-exists").Exists() 44 | a.NotError(err).False(exists) 45 | 46 | defer func() { 47 | err = sqlbuilder.DropTable(t.DB). 48 | Table(table). 49 | Exec() 50 | a.NotError(err) 51 | }() 52 | 53 | // AI 和 PK 同时指定为 ID 54 | err = stmt.Reset(). 55 | Table("users"). 56 | AutoIncrement("id", core.Int). 57 | PK("users_pk", "id"). 58 | Err() 59 | t.Error(err) 60 | 61 | // 约束名重和昨 62 | err = stmt.Reset().Table("users"). 63 | Column("name", core.String, false, false, false, nil). 64 | Unique("c1", "name"). 65 | Check("c1", "name IS NOT NULL"). 66 | Exec() 67 | a.Error(err) 68 | 69 | a.ErrorString(stmt.Reset().Exec(), "缺少模型名称") 70 | 71 | a.ErrorString(stmt.Reset().Table("users").Exec(), "未指定列") 72 | 73 | insert := sqlbuilder.Insert(t.DB). 74 | Table(table). 75 | KeyValue("age", 1). 76 | KeyValue("name", "name1"). 77 | KeyValue("address", "address1") 78 | rslt, err := insert.Exec() 79 | a.NotError(err).NotNil(rslt) 80 | 81 | prepare, err := insert.Prepare() 82 | a.NotError(err).NotNil(prepare) 83 | rslt, err = prepare.Exec(2, "name2", "address2") 84 | a.NotError(err).NotNil(rslt) 85 | rslt, err = prepare.Exec(3, "name3", "address3") 86 | a.NotError(err).NotNil(rslt) 87 | 88 | cnt, err := sqlbuilder.Select(t.DB). 89 | Count("count(*) as cnt"). 90 | From(table). 91 | QueryInt("cnt") 92 | a.NotError(err).Equal(cnt, 3) 93 | }) 94 | } 95 | 96 | func TestTruncateTable(t *testing.T) { 97 | a := assert.New(t, false) 98 | suite := test.NewSuite(a, "") 99 | 100 | suite.Run(func(t *test.Driver) { 101 | initDB(t) 102 | defer clearDB(t) 103 | 104 | _, err := sqlbuilder.Insert(t.DB). 105 | Table("info"). 106 | KeyValue("uid", 1). 107 | KeyValue("tel", "18011112222"). 108 | KeyValue("nickname", "nickname1"). 109 | KeyValue("address", "address1"). 110 | Exec() 111 | a.NotError(err) 112 | 113 | truncate := sqlbuilder.TruncateTable(t.DB) 114 | err = truncate.Table("info", "").Exec() 115 | t.NotError(err) 116 | 117 | // 可重复调用 118 | err = truncate.Reset().Table("info", "").Exec() 119 | t.NotError(err) 120 | 121 | sel := sqlbuilder.Select(t.DB). 122 | Count("count(*) as cnt"). 123 | From("info") 124 | rows, err := sel.Query() 125 | t.NotError(err).NotNil(rows) 126 | t.True(rows.Next()) 127 | var val int 128 | t.NotError(rows.Scan(&val)) 129 | t.NotError(rows.Close()) 130 | t.Equal(val, 0) 131 | }) 132 | } 133 | 134 | func TestDropTable(t *testing.T) { 135 | a := assert.New(t, false) 136 | suite := test.NewSuite(a, "") 137 | 138 | suite.Run(func(t *test.Driver) { 139 | initDB(t) 140 | defer clearDB(t) 141 | 142 | drop := sqlbuilder.DropTable(t.DB) 143 | a.Error(drop.Exec()) 144 | 145 | a.NotError(drop.Reset().Table("info").Exec()) 146 | 147 | // 删除不存在的表 148 | a.NotError(drop.Reset().Table("not-exists").Exec()) 149 | }) 150 | } 151 | -------------------------------------------------------------------------------- /sqlbuilder/update.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder 6 | 7 | import ( 8 | "database/sql" 9 | 10 | "github.com/issue9/sliceutil" 11 | 12 | "github.com/issue9/orm/v6/core" 13 | ) 14 | 15 | // UpdateStmt 更新语句 16 | type UpdateStmt struct { 17 | *execStmt 18 | *updateWhere 19 | 20 | table string 21 | values []*updateSet 22 | 23 | occColumn string // 乐观锁的列名 24 | occValue any // 乐观锁的当前值 25 | } 26 | 27 | type updateWhere = WhereStmtOf[*UpdateStmt] 28 | 29 | // 表示一条 SET 语句。比如 set key=val 30 | type updateSet struct { 31 | column string 32 | value any 33 | typ byte // 类型,可以是 + 自增类型,- 自减类型,或是空值表示正常表达式 34 | } 35 | 36 | // Update 生成更新语句 37 | func (sql *SQLBuilder) Update() *UpdateStmt { return Update(sql.engine) } 38 | 39 | // Update 声明一条 UPDATE 的 SQL 语句 40 | func Update(e core.Engine) *UpdateStmt { 41 | stmt := &UpdateStmt{values: []*updateSet{}} 42 | stmt.execStmt = newExecStmt(e, stmt) 43 | stmt.updateWhere = NewWhereStmtOf(stmt) 44 | 45 | return stmt 46 | } 47 | 48 | // Table 指定表名 49 | func (stmt *UpdateStmt) Table(table string) *UpdateStmt { 50 | stmt.table = table 51 | return stmt 52 | } 53 | 54 | // Set 设置值,若 col 相同,则会覆盖 55 | // 56 | // val 可以是 sql.NamedArg 类型 57 | func (stmt *UpdateStmt) Set(col string, val any) *UpdateStmt { 58 | stmt.values = append(stmt.values, &updateSet{ 59 | column: col, 60 | value: val, 61 | typ: 0, 62 | }) 63 | return stmt 64 | } 65 | 66 | // Increase 给列增加值 67 | func (stmt *UpdateStmt) Increase(col string, val any) *UpdateStmt { 68 | stmt.values = append(stmt.values, &updateSet{ 69 | column: col, 70 | value: val, 71 | typ: '+', 72 | }) 73 | return stmt 74 | } 75 | 76 | // Decrease 给列减少值 77 | func (stmt *UpdateStmt) Decrease(col string, val any) *UpdateStmt { 78 | stmt.values = append(stmt.values, &updateSet{ 79 | column: col, 80 | value: val, 81 | typ: '-', 82 | }) 83 | return stmt 84 | } 85 | 86 | // OCC 指定一个用于乐观锁的字段 87 | // 88 | // val 表示乐观锁原始的值,更新时如果值不等于 val,将更新失败。 89 | func (stmt *UpdateStmt) OCC(col string, val any) *UpdateStmt { 90 | stmt.occColumn = col 91 | stmt.occValue = val 92 | stmt.Increase(col, 1) 93 | return stmt 94 | } 95 | 96 | // Reset 重置语句 97 | func (stmt *UpdateStmt) Reset() *UpdateStmt { 98 | stmt.baseStmt.Reset() 99 | 100 | stmt.table = "" 101 | stmt.WhereStmt().Reset() 102 | stmt.values = stmt.values[:0] 103 | 104 | stmt.occColumn = "" 105 | stmt.occValue = nil 106 | 107 | return stmt 108 | } 109 | 110 | // SQL 获取 SQL 语句以及对应的参数 111 | func (stmt *UpdateStmt) SQL() (string, []any, error) { 112 | if stmt.err != nil { 113 | return "", nil, stmt.Err() 114 | } 115 | 116 | if err := stmt.checkErrors(); err != nil { 117 | return "", nil, err 118 | } 119 | 120 | buf := core.NewBuilder("UPDATE "). 121 | QuoteKey(stmt.table). 122 | WString(" SET ") 123 | 124 | args := make([]any, 0, len(stmt.values)) 125 | 126 | for _, val := range stmt.values { 127 | buf.QuoteKey(val.column).WBytes('=') 128 | 129 | if val.typ != 0 { 130 | buf.QuoteKey(val.column).WBytes(val.typ) 131 | } 132 | 133 | if named, ok := val.value.(sql.NamedArg); ok && named.Name != "" { 134 | buf.WBytes('@').WString(named.Name) 135 | } else { 136 | buf.WBytes('?') 137 | } 138 | buf.WBytes(',') 139 | args = append(args, val.value) 140 | } 141 | buf.TruncateLast(1) 142 | 143 | wq, wa, err := stmt.getWhereSQL() 144 | if err != nil { 145 | return "", nil, err 146 | } 147 | 148 | if wq != "" { 149 | buf.WString(" WHERE ").WString(wq) 150 | args = append(args, wa...) 151 | } 152 | 153 | query, err := buf.String() 154 | if err != nil { 155 | return "", nil, err 156 | } 157 | return query, args, nil 158 | } 159 | 160 | func (stmt *UpdateStmt) getWhereSQL() (string, []any, error) { 161 | if stmt.occColumn == "" { 162 | return stmt.WhereStmt().SQL() 163 | } 164 | 165 | w := Where() 166 | w.appendGroup(true, stmt.WhereStmt()) 167 | 168 | if named, ok := stmt.occValue.(sql.NamedArg); ok && named.Name != "" { 169 | w.AndGroup(func(occ *WhereStmt) { 170 | occ.And(stmt.occColumn+"=@"+named.Name, stmt.occValue) 171 | }) 172 | } else { 173 | w.AndGroup(func(occ *WhereStmt) { 174 | occ.And(stmt.occColumn+"=?", stmt.occValue) 175 | }) 176 | } 177 | 178 | return w.SQL() 179 | } 180 | 181 | // 检测列名是否存在重复,先排序,再与后一元素比较。 182 | func (stmt *UpdateStmt) checkErrors() error { 183 | if stmt.table == "" { 184 | return SyntaxError("UPDATE", "未指定表名") 185 | } 186 | 187 | if len(stmt.values) == 0 { 188 | return SyntaxError("UPDATE", "未指定任何更新的值") 189 | } 190 | 191 | if len(sliceutil.Dup(stmt.values, func(i, j *updateSet) bool { return i.column == j.column })) > 0 { 192 | return SyntaxError("UPDATE", "存在重复的列名") 193 | } 194 | 195 | return nil 196 | } 197 | 198 | // Update 更新指定条件内容 199 | func (stmt *WhereStmt) Update(e core.Engine) *UpdateStmt { 200 | upd := Update(e) 201 | upd.updateWhere.w = stmt 202 | return upd 203 | } 204 | -------------------------------------------------------------------------------- /sqlbuilder/update_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/internal/test" 13 | "github.com/issue9/orm/v6/sqlbuilder" 14 | ) 15 | 16 | var _ sqlbuilder.ExecStmt = &sqlbuilder.UpdateStmt{} 17 | 18 | func TestUpdate_columnsHasDup(t *testing.T) { 19 | a := assert.New(t, false) 20 | suite := test.NewSuite(a, "") 21 | 22 | suite.Run(func(t *test.Driver) { 23 | u := sqlbuilder.Update(t.DB). 24 | Table("users"). 25 | Set("c1", "v1"). 26 | Set("c1", "v2") 27 | _, err := u.Exec() 28 | a.ErrorString(err, "存在重复的列名") 29 | }) 30 | } 31 | 32 | func TestUpdate(t *testing.T) { 33 | a := assert.New(t, false) 34 | suite := test.NewSuite(a, "") 35 | 36 | suite.Run(func(t *test.Driver) { 37 | initDB(t) 38 | defer clearDB(t) 39 | 40 | u := sqlbuilder.Update(t.DB).Table("users") 41 | t.NotNil(u) 42 | 43 | u.Set("name", "name222").Where("id=?", 2) 44 | _, err := u.Exec() 45 | t.NotError(err) 46 | 47 | sel := sqlbuilder.Select(t.DB). 48 | Column("name"). 49 | From("users"). 50 | Where("id=?", 2) 51 | rows, err := sel.Query() 52 | t.NotError(err).NotNil(rows) 53 | t.True(rows.Next()) 54 | var name string 55 | t.NotError(rows.Scan(&name)) 56 | t.NotError(rows.Close()) 57 | t.Equal(name, "name222") 58 | }) 59 | } 60 | 61 | func TestUpdateStmt_Increase(t *testing.T) { 62 | a := assert.New(t, false) 63 | suite := test.NewSuite(a, "") 64 | 65 | suite.Run(func(t *test.Driver) { 66 | initDB(t) 67 | defer clearDB(t) 68 | 69 | u := sqlbuilder.Update(t.DB). 70 | Table("users"). 71 | Increase("age", 5). 72 | Where("id=?", 1) 73 | t.NotNil(u) 74 | _, err := u.Exec() 75 | t.NotError(err) 76 | 77 | sel := sqlbuilder.Select(t.DB). 78 | Column("age"). 79 | From("users"). 80 | Where("id=?", 1) 81 | rows, err := sel.Query() 82 | t.NotError(err).NotNil(rows) 83 | t.True(rows.Next()) 84 | var val int 85 | t.NotError(rows.Scan(&val)) 86 | t.NotError(rows.Close()) 87 | t.Equal(val, 6) 88 | 89 | // decrease 90 | u.Reset() 91 | u.Table("users"). 92 | Decrease("age", 3). 93 | Where("id=?", 1) 94 | t.NotNil(u) 95 | _, err = u.Exec() 96 | t.NotError(err) 97 | sel.Reset(). 98 | Column("age"). 99 | From("users"). 100 | Where("id=?", 1) 101 | rows, err = sel.Query() 102 | t.NotError(err).NotNil(rows) 103 | t.True(rows.Next()) 104 | t.NotError(rows.Scan(&val)) 105 | t.NotError(rows.Close()) 106 | t.Equal(val, 3) 107 | }) 108 | } 109 | 110 | func TestUpdateStmt_OCC(t *testing.T) { 111 | a := assert.New(t, false) 112 | suite := test.NewSuite(a, "") 113 | 114 | suite.Run(func(t *test.Driver) { 115 | initDB(t) 116 | defer clearDB(t) 117 | 118 | u := sqlbuilder.Update(t.DB). 119 | Table("users"). 120 | Set("age", 100). 121 | Where("id=?", 1). 122 | OCC("version", 0) 123 | r, err := u.Exec() 124 | a.NotError(err).NotNil(r) 125 | 126 | sel := sqlbuilder.Select(t.DB). 127 | Column("age"). 128 | From("users"). 129 | Where("id=?", 1) 130 | rows, err := sel.Query() 131 | t.NotError(err).NotNil(rows) 132 | t.True(rows.Next()) 133 | var val int 134 | t.NotError(rows.Scan(&val)) 135 | t.NotError(rows.Close()) 136 | t.Equal(val, 100) 137 | 138 | // 乐观锁判断失败主 139 | u.Reset() 140 | u.Table("users"). 141 | Set("age", 111). 142 | Where("id=?", 1). 143 | OCC("version", 0) 144 | r, err = u.Exec() 145 | a.NotError(err).NotNil(r) 146 | 147 | sel.Reset() 148 | sel.Column("age"). 149 | From("users"). 150 | Where("id=?", 1) 151 | rows, err = sel.Query() 152 | t.NotError(err).NotNil(rows) 153 | t.True(rows.Next()) 154 | t.NotError(rows.Scan(&val)) 155 | t.NotError(rows.Close()) 156 | t.Equal(val, 100) 157 | }) 158 | } 159 | -------------------------------------------------------------------------------- /sqlbuilder/version.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder 6 | 7 | import "github.com/issue9/orm/v6/core" 8 | 9 | // Version 查询数据库服务器的版本信息 10 | func Version(e core.Engine) (version string, err error) { 11 | err = e.QueryRow(e.Dialect().VersionSQL()).Scan(&version) 12 | return 13 | } 14 | -------------------------------------------------------------------------------- /sqlbuilder/version_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/internal/test" 13 | "github.com/issue9/orm/v6/sqlbuilder" 14 | ) 15 | 16 | func TestVersion(t *testing.T) { 17 | a := assert.New(t, false) 18 | s := test.NewSuite(a, "") 19 | 20 | s.Run(func(t *test.Driver) { 21 | ver, err := sqlbuilder.Version(t.DB) 22 | t.NotError(err). 23 | NotEmpty(ver) 24 | }) 25 | } 26 | -------------------------------------------------------------------------------- /sqlbuilder/view.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder 6 | 7 | import ( 8 | "errors" 9 | 10 | "github.com/issue9/orm/v6/core" 11 | ) 12 | 13 | // CreateViewStmt 创建视图的语句 14 | type CreateViewStmt struct { 15 | *ddlStmt 16 | 17 | selectQuery string 18 | name string 19 | columns []string 20 | temporary bool 21 | replace bool 22 | } 23 | 24 | // CreateView 创建视图 25 | func (sql *SQLBuilder) CreateView() *CreateViewStmt { return CreateView(sql.engine) } 26 | 27 | // CreateView 创建视图 28 | func CreateView(e core.Engine) *CreateViewStmt { 29 | stmt := &CreateViewStmt{} 30 | stmt.ddlStmt = newDDLStmt(e, stmt) 31 | 32 | return stmt 33 | } 34 | 35 | // View 将当前查询语句转换为视图 36 | // 37 | // name 为视图名称。 38 | func (stmt *SelectStmt) View(name string) *CreateViewStmt { 39 | return CreateView(stmt.Engine()).From(stmt).Name(name) 40 | } 41 | 42 | // Reset 重置对象 43 | func (stmt *CreateViewStmt) Reset() *CreateViewStmt { 44 | stmt.baseStmt.Reset() 45 | stmt.name = "" 46 | stmt.selectQuery = "" 47 | stmt.columns = stmt.columns[:0] 48 | stmt.temporary = false 49 | stmt.replace = false 50 | 51 | return stmt 52 | } 53 | 54 | // Column 指定视图的列,如果未指定,则会直接采用 Select 中的列信息 55 | func (stmt *CreateViewStmt) Column(col ...string) *CreateViewStmt { 56 | if stmt.columns == nil { 57 | stmt.columns = col 58 | return stmt 59 | } 60 | 61 | stmt.columns = append(stmt.columns, col...) 62 | return stmt 63 | } 64 | 65 | // Name 指定视图名称 66 | func (stmt *CreateViewStmt) Name(name string) *CreateViewStmt { 67 | stmt.name = name 68 | return stmt 69 | } 70 | 71 | // Temporary 临时视图 72 | func (stmt *CreateViewStmt) Temporary() *CreateViewStmt { 73 | stmt.temporary = true 74 | return stmt 75 | } 76 | 77 | // Replace 如果已经存在,则更新视图内容 78 | func (stmt *CreateViewStmt) Replace() *CreateViewStmt { 79 | stmt.replace = true 80 | return stmt 81 | } 82 | 83 | // From 指定 Select 语句 84 | func (stmt *CreateViewStmt) From(sel *SelectStmt) *CreateViewStmt { 85 | if stmt.err != nil { 86 | return stmt 87 | } 88 | 89 | query, err := sel.CombineSQL() 90 | if err != nil { 91 | stmt.err = err 92 | return stmt 93 | } 94 | 95 | stmt.selectQuery = query 96 | return stmt 97 | } 98 | 99 | // FromQuery 指定查询语句 100 | // 101 | // FromQuery 和 From 会相互覆盖。 102 | func (stmt *CreateViewStmt) FromQuery(query string) *CreateViewStmt { 103 | stmt.selectQuery = query 104 | return stmt 105 | } 106 | 107 | // DDLSQL 返回创建视图的 SQL 语句 108 | func (stmt *CreateViewStmt) DDLSQL() ([]string, error) { 109 | return stmt.Dialect().CreateViewSQL(stmt.replace, stmt.temporary, stmt.name, stmt.selectQuery, stmt.columns) 110 | } 111 | 112 | // DropViewStmt 删除视图 113 | type DropViewStmt struct { 114 | *ddlStmt 115 | name string 116 | } 117 | 118 | func (sql *SQLBuilder) DropView() *DropViewStmt { return DropView(sql.engine) } 119 | 120 | // DropView 创建视图 121 | func DropView(e core.Engine) *DropViewStmt { 122 | stmt := &DropViewStmt{} 123 | stmt.ddlStmt = newDDLStmt(e, stmt) 124 | 125 | return stmt 126 | } 127 | 128 | // Name 指定需要删除的视图名称 129 | func (stmt *DropViewStmt) Name(name string) *DropViewStmt { 130 | stmt.name = name 131 | return stmt 132 | } 133 | 134 | // DDLSQL 返回删除视图的 SQL 语句 135 | func (stmt *DropViewStmt) DDLSQL() ([]string, error) { 136 | if len(stmt.name) == 0 { 137 | return nil, SyntaxError("DROP VIEW", "未指定表名") 138 | } 139 | 140 | query, err := core.NewBuilder("DROP VIEW IF EXISTS "). 141 | QuoteKey(stmt.name). 142 | String() 143 | if err != nil { 144 | return nil, err 145 | } 146 | 147 | return []string{query}, nil 148 | } 149 | 150 | // Reset 重置对象 151 | func (stmt *DropViewStmt) Reset() *DropViewStmt { 152 | stmt.baseStmt.Reset() 153 | stmt.name = "" 154 | 155 | return stmt 156 | } 157 | 158 | type ViewExistsStmt struct { 159 | *queryStmt 160 | name string 161 | } 162 | 163 | func ViewExists(e core.Engine) *ViewExistsStmt { 164 | stmt := &ViewExistsStmt{} 165 | stmt.queryStmt = newQueryStmt(e, stmt) 166 | return stmt 167 | } 168 | 169 | func (sql *SQLBuilder) ViewExists() *ViewExistsStmt { 170 | return ViewExists(sql.engine) 171 | } 172 | 173 | func (stmt *ViewExistsStmt) View(table string) *ViewExistsStmt { 174 | stmt.name = table 175 | return stmt 176 | } 177 | 178 | func (stmt *ViewExistsStmt) Reset() *ViewExistsStmt { 179 | stmt.name = "" 180 | return stmt 181 | } 182 | 183 | func (stmt *ViewExistsStmt) SQL() (string, []any, error) { 184 | if stmt.name == "" { 185 | return "", nil, SyntaxError("VIEW EXISTS", "未指定表名") 186 | } 187 | 188 | sql, args := stmt.Dialect().ExistsSQL(stmt.name, true) 189 | return sql, args, nil 190 | } 191 | 192 | func (stmt *ViewExistsStmt) Exists() (bool, error) { 193 | rows, err := stmt.Query() 194 | if err != nil { 195 | return false, err 196 | } 197 | 198 | name, err := fetchColumn[string](rows, "name") 199 | switch { 200 | case errors.Is(err, ErrNoData): 201 | return false, nil 202 | case err != nil: 203 | return false, err 204 | default: 205 | return name == stmt.name, nil 206 | } 207 | } 208 | -------------------------------------------------------------------------------- /sqlbuilder/view_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package sqlbuilder_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/internal/test" 13 | "github.com/issue9/orm/v6/sqlbuilder" 14 | ) 15 | 16 | var _ sqlbuilder.DDLStmt = &sqlbuilder.CreateTableStmt{} 17 | 18 | func TestCreateView(t *testing.T) { 19 | a := assert.New(t, false) 20 | suite := test.NewSuite(a, "") 21 | 22 | suite.Run(func(t *test.Driver) { 23 | testCreateView(t) 24 | }) 25 | } 26 | 27 | func testCreateView(d *test.Driver) { 28 | initDB(d) 29 | defer clearDB(d) 30 | sb := d.DB.SQLBuilder() 31 | 32 | viewName := "user_view" 33 | 34 | sel := sb.Select(). 35 | Column("u.id as uid"). 36 | Column("u.name"). 37 | Column("i.address"). 38 | Join("LEFT", "info", "i", "u.id=i.uid"). 39 | From("users", "u") 40 | view := sel.View(viewName). 41 | Column("uid"). 42 | Column("name", "address") 43 | d.NotError(view.Exec()) 44 | 45 | exists, err := sqlbuilder.ViewExists(d.DB).View(viewName).Exists() 46 | d.NotError(err).True(exists) 47 | exists, err = sqlbuilder.TableExists(d.DB).Table("not-exists").Exists() 48 | d.NotError(err).False(exists) 49 | 50 | // 删除 51 | defer func() { 52 | dropView := sb.DropView().Name(viewName) 53 | d.NotError(dropView.Exec()) 54 | }() 55 | 56 | // 创建同名视图 57 | view.Reset().Name(viewName).From(sel) 58 | d.Error(view.Exec(), "not err @%s", d.DriverName) 59 | 60 | // 以 replace 的方式创建 61 | view.Reset().Name(viewName).From(sel).Replace() 62 | d.NotError(view.Exec()) 63 | } 64 | -------------------------------------------------------------------------------- /types.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package orm 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | "time" 11 | 12 | "github.com/issue9/orm/v6/core" 13 | "github.com/issue9/orm/v6/fetch" 14 | "github.com/issue9/orm/v6/sqlbuilder" 15 | "github.com/issue9/orm/v6/types" 16 | ) 17 | 18 | type ( 19 | ApplyModeler = core.ApplyModeler 20 | 21 | // Column 列结构 22 | Column = core.Column 23 | 24 | TableNamer = core.TableNamer 25 | 26 | // Unix 表示 Unix 时间戳的数据样式 27 | // 28 | // 表现为 [time.Time],但是保存数据库时,以 unix 时间戳的形式保存。 29 | Unix = types.Unix 30 | 31 | Rat = types.Rat 32 | 33 | Decimal = types.Decimal 34 | 35 | // AfterFetcher 从数据库查询到数据之后需要执行的操作 36 | AfterFetcher = fetch.AfterFetcher 37 | 38 | // Dialect 数据库驱动特有的语言特性实现 39 | Dialect = core.Dialect 40 | 41 | // BeforeUpdater 在更新之前调用的函数 42 | BeforeUpdater interface { 43 | BeforeUpdate() error 44 | } 45 | 46 | // BeforeInserter 在插入之前调用的函数 47 | BeforeInserter interface { 48 | BeforeInsert() error 49 | } 50 | 51 | // Engine 数据操作引擎 52 | // 53 | // 相对于 [core.Engine],添加了针对 [TableNamer] 的操作。 54 | // 所有针对 [TableNamer] 的操作与 sqlbuilder 的拼接方式有以下区别: 55 | // - 针对 [TableNamer] 的操作会自动为表名加上 # 表名前缀; 56 | // - 针对 [TableNamer] 的操作会为约束名加上表名,以确保约束名的唯一性; 57 | // 58 | // [DB] 和 [Tx] 均实现了此接口。 59 | Engine interface { 60 | core.Engine 61 | 62 | // LastInsertIDContext 插入一条数据并返回其自增 ID 63 | // 64 | // 理论上功能等同于以下两步操作: 65 | // rslt, err := engine.Insert(obj) 66 | // id, err := rslt.LastInsertId() 67 | // 但是实际上部分数据库不支持直接在 [sql.Result] 中获取 LastInsertId, 68 | // 比如 postgresql,所以使用此方法比 [sql.Result] 更有效。 69 | // 70 | // NOTE: 要求 v 有定义自增列。 71 | LastInsertIDContext(ctx context.Context, v TableNamer) (int64, error) 72 | LastInsertID(v TableNamer) (int64, error) 73 | 74 | // InsertContext 插入数据 75 | // 76 | // NOTE: 若需一次性插入多条数据,请使用 [Engine.InsertMany] 。 77 | InsertContext(ctx context.Context, v TableNamer) (sql.Result, error) 78 | Insert(v TableNamer) (sql.Result, error) 79 | 80 | // Delete 删除符合条件的数据 81 | // 82 | // 查找条件以结构体定义的主键或是唯一约束(在没有主键的情况下)来查找, 83 | // 若两者都不存在,则将返回 error 84 | DeleteContext(ctx context.Context, v TableNamer) (sql.Result, error) 85 | Delete(v TableNamer) (sql.Result, error) 86 | 87 | // UpdateContext 更新数据 88 | // 89 | // 零值不会被提交,cols 指定的列,即使是零值也会被更新。 90 | // 91 | // 查找条件以结构体定义的主键或是唯一约束(在没有主键的情况下)来查找, 92 | // 若两者都不存在,则将返回 error 93 | UpdateContext(ctx context.Context, v TableNamer, cols ...string) (sql.Result, error) 94 | Update(v TableNamer, cols ...string) (sql.Result, error) 95 | 96 | // SaveContext 更新或是插入数据 97 | // 98 | // 根据 v 中的唯一约束或是自增列是否要在表中找到值来确定是采用 [Engine.UpdateContext] 还是 [Engine.InsertContext]。 99 | // 100 | // isnew 表示是否是 insert 的数据,如果是 insert 模式,那么 lastid 表示插入项的 id,否则 lastid 无意义。 101 | SaveContext(ctx context.Context, v TableNamer, cols ...string) (lastid int64, isnew bool, err error) 102 | Save(v TableNamer, cols ...string) (lastid int64, isnew bool, err error) 103 | 104 | // SelectContext 查询一个符合条件的数据 105 | // 106 | // 查找条件以结构体定义的主键或是唯一约束(在没有主键的情况下 ) 来查找, 107 | // 若两者都不存在,则将返回 error 108 | // 若没有符合条件的数据,将不会对参数 v 做任何变动。 109 | // 110 | // 查找条件的查找顺序是为 自增 > 主键 > 唯一约束, 111 | // 如果同时存在多个唯一约束满足条件(可能每个唯一约束查询至的结果是不一样的),则返回错误信息。 112 | SelectContext(context.Context, TableNamer) (found bool, err error) 113 | Select(TableNamer) (found bool, err error) 114 | 115 | CreateContext(context.Context, ...TableNamer) error 116 | Create(...TableNamer) error 117 | 118 | DropContext(context.Context, ...TableNamer) error 119 | Drop(...TableNamer) error 120 | 121 | // TruncateContext 清空表并重置 ai 但保留表结构 122 | TruncateContext(context.Context, ...TableNamer) error 123 | Truncate(...TableNamer) error 124 | 125 | // InsertManyContext 插入多条相同的数据 126 | // 127 | // 若需要向某张表中插入多条记录,此方法会比 [Engine.Insert] 性能上好很多。 128 | // 129 | // max 表示一次最多插入的数量,如果超过此值,会分批执行,但是依然在一个事务中完成。 130 | InsertManyContext(ctx context.Context, max int, v ...TableNamer) error 131 | InsertMany(max int, v ...TableNamer) error 132 | 133 | // Where 生成 [WhereStmt] 语句 134 | Where(cond string, args ...any) *WhereStmt 135 | 136 | SQLBuilder() *sqlbuilder.SQLBuilder 137 | 138 | // newModel 获取 v 的 [core.Model] 实例 139 | // 140 | // 内部使用不公开,[Engine] 也不会有外部的实现。 141 | newModel(v TableNamer) (*core.Model, error) 142 | } 143 | ) 144 | 145 | // NowUnix 返回当前时间 146 | func NowUnix() Unix { return Unix{Time: time.Now()} } 147 | 148 | // NowNullTime 返回当前时间 149 | func NowNullTime() sql.NullTime { return sql.NullTime{Time: time.Now(), Valid: true} } 150 | 151 | func TableName(t TableNamer) string { return "#" + t.TableName() } 152 | -------------------------------------------------------------------------------- /types/decimal.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package types 6 | 7 | import ( 8 | "bytes" 9 | "database/sql/driver" 10 | "encoding/json" 11 | "strings" 12 | 13 | "github.com/shopspring/decimal" 14 | 15 | "github.com/issue9/orm/v6/core" 16 | ) 17 | 18 | type Decimal struct { 19 | Decimal decimal.Decimal 20 | Precision int32 21 | Valid bool 22 | } 23 | 24 | // FloatDecimal 从浮点数还原 [Decimal] 对象 25 | // 26 | // precision 表示输出的精度。 27 | func FloatDecimal(f float64, precision int32) Decimal { 28 | return Decimal{Decimal: decimal.NewFromFloat(f), Precision: precision, Valid: true} 29 | } 30 | 31 | // StringDecimal 从字符串还原 [Decimal] 对象 32 | // 33 | // precision 表示输出的精度。 34 | func StringDecimal(s string, precision int32) (Decimal, error) { 35 | d, err := decimal.NewFromString(s) 36 | if err != nil { 37 | return Decimal{}, err 38 | } 39 | return Decimal{Decimal: d, Precision: precision, Valid: true}, nil 40 | } 41 | 42 | // StringDecimalWithPrecision 从字符串还原 [Decimal] 对象 43 | // 44 | // 输出精度从 s 获取,如果 s 不包含小数位,则小数长度为 0 45 | func StringDecimalWithPrecision(s string) (Decimal, error) { 46 | var p int32 47 | index := strings.IndexByte(s, '.') 48 | if index >= 0 { 49 | p = int32(len(s) - index - 1) 50 | } 51 | 52 | return StringDecimal(s, p) 53 | } 54 | 55 | // Scan implements the Scanner.Scan 56 | func (n *Decimal) Scan(src any) (err error) { 57 | if src == nil { 58 | n.Valid = false 59 | return nil 60 | } 61 | 62 | switch v := src.(type) { 63 | case []byte: 64 | if bytes.Equal(v, nullBytes) { 65 | n.Valid = false 66 | return nil 67 | } 68 | case string: 69 | if v == null { 70 | n.Valid = false 71 | return nil 72 | } 73 | } 74 | return n.Decimal.Scan(src) 75 | } 76 | 77 | func (n Decimal) Value() (driver.Value, error) { 78 | if !n.Valid { 79 | return nil, nil 80 | } 81 | return n.Decimal.StringFixed(n.Precision), nil 82 | } 83 | 84 | func (n Decimal) PrimitiveType() core.PrimitiveType { return core.String } 85 | 86 | func (n Decimal) MarshalText() ([]byte, error) { 87 | if !n.Valid { 88 | return nil, nil 89 | } 90 | return []byte(n.Decimal.StringFixed(n.Precision)), nil 91 | } 92 | 93 | func (n Decimal) MarshalJSON() ([]byte, error) { 94 | if !n.Valid { 95 | return json.Marshal(nil) 96 | } 97 | return json.Marshal(n.Decimal.StringFixed(n.Precision)) 98 | } 99 | 100 | func (n *Decimal) UnmarshalText(data []byte) error { 101 | if n.Valid = len(data) > 0; !n.Valid { 102 | return nil 103 | } 104 | return n.Decimal.UnmarshalText(data) 105 | } 106 | 107 | func (n *Decimal) UnmarshalJSON(data []byte) error { 108 | if n.Valid = len(data) > 0; !n.Valid { 109 | return nil 110 | } 111 | return n.Decimal.UnmarshalJSON(data) 112 | } 113 | -------------------------------------------------------------------------------- /types/decimal_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package types 6 | 7 | import ( 8 | "database/sql" 9 | "database/sql/driver" 10 | "encoding" 11 | "encoding/json" 12 | "testing" 13 | 14 | "github.com/issue9/assert/v4" 15 | 16 | "github.com/issue9/orm/v6/core" 17 | ) 18 | 19 | var ( 20 | _ sql.Scanner = &Decimal{} 21 | _ driver.Valuer = Decimal{} 22 | _ core.PrimitiveTyper = &Decimal{} 23 | 24 | _ encoding.TextMarshaler = Decimal{} 25 | _ json.Marshaler = Decimal{} 26 | _ encoding.TextUnmarshaler = &Decimal{} 27 | _ json.Unmarshaler = &Decimal{} 28 | ) 29 | 30 | func TestStringDecimalWithPrecision(t *testing.T) { 31 | a := assert.New(t, false) 32 | 33 | d, err := StringDecimalWithPrecision("3.222") 34 | a.NotError(err).Equal(d.Precision, 3).True(d.Valid) 35 | 36 | d, err = StringDecimalWithPrecision(".222") 37 | a.NotError(err).Equal(d.Precision, 3).True(d.Valid) 38 | 39 | d, err = StringDecimalWithPrecision("222") 40 | a.NotError(err).Equal(d.Precision, 0).True(d.Valid) 41 | 42 | d, err = StringDecimalWithPrecision("") 43 | a.Error(err).False(d.Valid) 44 | } 45 | 46 | func TestSQL(t *testing.T) { 47 | a := assert.New(t, false) 48 | 49 | d := FloatDecimal(2.22, 3) 50 | a.NotError(d.Scan([]byte("3.3333"))) 51 | v, err := d.Value() 52 | a.NotError(err).Equal(v, "3.333") 53 | 54 | d = FloatDecimal(2.22, 3) 55 | a.NotError(d.Scan("3")) 56 | v, err = d.Value() 57 | a.NotError(err).Equal(v, "3.000") 58 | 59 | d = FloatDecimal(2.22, 3) 60 | a.Error(d.Scan("")) 61 | } 62 | -------------------------------------------------------------------------------- /types/rat.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package types 6 | 7 | import ( 8 | "bytes" 9 | "database/sql/driver" 10 | "math/big" 11 | 12 | "github.com/issue9/orm/v6/core" 13 | ) 14 | 15 | // Rat 有理数 16 | // 17 | // 这是对 [big.Rat] 的扩展,提供了 orm 需要的接口支持。 18 | // 19 | // 在数据库中以分数的形式保存至字符串类型的列,所以需要指定长度。 20 | type Rat struct { 21 | rat *big.Rat 22 | } 23 | 24 | func Rational(a, b int64) Rat { return Rat{rat: big.NewRat(a, b)} } 25 | 26 | // Scan implements the sql.Scanner 27 | func (n *Rat) Scan(src any) (err error) { 28 | // The src value will be of one of the following types: 29 | // 30 | // int64 31 | // float64 32 | // bool 33 | // []byte 34 | // string 35 | // time.Time 36 | // nil - for NULL values 37 | if src == nil { 38 | n.rat = nil 39 | return nil 40 | } 41 | 42 | switch v := src.(type) { 43 | case []byte: 44 | if bytes.Equal(v, nullBytes) { 45 | n.rat = nil 46 | return nil 47 | } 48 | return n.UnmarshalText(v) 49 | case string: 50 | if v == null { 51 | n.rat = nil 52 | return nil 53 | } 54 | return n.UnmarshalText([]byte(v)) 55 | default: 56 | return core.ErrInvalidColumnType() 57 | } 58 | } 59 | 60 | func (n Rat) Value() (driver.Value, error) { 61 | if n.IsNull() { 62 | return nil, nil 63 | } 64 | return n.Rat().String(), nil 65 | } 66 | 67 | // Rat 返回标准库中 [big.Rat] 的实例 68 | func (n Rat) Rat() *big.Rat { return n.rat } 69 | 70 | func (n Rat) PrimitiveType() core.PrimitiveType { return core.String } 71 | 72 | func (n Rat) MarshalText() ([]byte, error) { 73 | if n.IsNull() { 74 | return []byte{}, nil 75 | } 76 | return n.Rat().MarshalText() 77 | } 78 | 79 | func (n *Rat) UnmarshalText(data []byte) error { 80 | if len(data) == 0 { 81 | n.rat = nil 82 | return nil 83 | } 84 | 85 | if n.IsNull() { 86 | n.rat = new(big.Rat) 87 | } 88 | return n.Rat().UnmarshalText(data) 89 | } 90 | 91 | func (n Rat) IsNull() bool { return n.Rat() == nil } 92 | -------------------------------------------------------------------------------- /types/rat_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package types 6 | 7 | import ( 8 | "database/sql" 9 | "database/sql/driver" 10 | "encoding" 11 | "testing" 12 | 13 | "github.com/issue9/assert/v4" 14 | 15 | "github.com/issue9/orm/v6/core" 16 | ) 17 | 18 | var ( 19 | _ core.PrimitiveTyper = &Rat{} 20 | _ driver.Valuer = &Rat{} 21 | _ sql.Scanner = &Rat{} 22 | 23 | _ encoding.TextMarshaler = &Rat{} 24 | _ encoding.TextUnmarshaler = &Rat{} 25 | ) 26 | 27 | func TestRational(t *testing.T) { 28 | a := assert.New(t, false) 29 | 30 | r := Rational(3, 4) 31 | a.False(r.IsNull()) 32 | val, err := r.Value() 33 | a.NotError(err).Equal(val, "3/4") 34 | 35 | r = Rat{} 36 | a.True(r.IsNull()) 37 | } 38 | 39 | func TestRat_SQL(t *testing.T) { 40 | a := assert.New(t, false) 41 | 42 | r := &Rat{} 43 | a.NotError(r.Scan("1/3")) 44 | a.Equal(r.Rat().String(), "1/3") 45 | val, err := r.Value() 46 | a.Equal(val, "1/3").NotError(err) 47 | 48 | r = &Rat{} 49 | a.NotError(r.Scan(nil)) 50 | a.Nil(r.Rat()) 51 | val, err = r.Value() 52 | a.Nil(val).NotError(err) 53 | 54 | r = &Rat{} 55 | a.ErrorIs(r.Scan(1), core.ErrInvalidColumnType()) 56 | val, err = r.Value() 57 | a.Nil(val).NotError(err) 58 | 59 | r2 := Rational(3, 4) 60 | a.NotError(r2.Scan("1/3")) 61 | a.Equal(r2.Rat().String(), "1/3") 62 | val, err = r2.Value() 63 | a.Equal(val, "1/3").NotError(err) 64 | } 65 | -------------------------------------------------------------------------------- /types/slices.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package types 6 | 7 | import ( 8 | "database/sql/driver" 9 | "encoding/json" 10 | 11 | "github.com/issue9/orm/v6/core" 12 | ) 13 | 14 | // SliceOf 针对数组存组方式 15 | // 16 | // 最终是以 json 的方式保存在数据库。 17 | type SliceOf[T any] []T 18 | 19 | func (n *SliceOf[T]) Scan(value any) (err error) { 20 | if value == nil { 21 | return nil 22 | } 23 | 24 | var j []byte 25 | switch v := value.(type) { 26 | case string: 27 | j = []byte(v) 28 | case []byte: 29 | j = v 30 | default: 31 | return core.ErrInvalidColumnType() 32 | } 33 | 34 | return json.Unmarshal(j, n) 35 | } 36 | 37 | func (n SliceOf[T]) Value() (driver.Value, error) { return json.Marshal(n) } 38 | 39 | func (n SliceOf[T]) PrimitiveType() core.PrimitiveType { return core.String } 40 | -------------------------------------------------------------------------------- /types/slices_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package types 6 | 7 | import ( 8 | "database/sql" 9 | "database/sql/driver" 10 | "testing" 11 | 12 | "github.com/issue9/assert/v4" 13 | 14 | "github.com/issue9/orm/v6/core" 15 | ) 16 | 17 | type ( 18 | ints = SliceOf[int64] 19 | strs = SliceOf[string] 20 | ) 21 | 22 | var ( 23 | _ sql.Scanner = &ints{} 24 | _ driver.Valuer = ints{} 25 | _ core.PrimitiveTyper = &ints{} 26 | 27 | _ sql.Scanner = &strs{} 28 | _ driver.Valuer = strs{} 29 | _ core.PrimitiveTyper = &strs{} 30 | ) 31 | 32 | func TestSlices_Scan(t *testing.T) { 33 | a := assert.New(t, false) 34 | 35 | // ints 36 | 37 | u := &ints{} 38 | a.NotError(u.Scan("[1,2,3]")). 39 | Equal(u, &ints{1, 2, 3}) 40 | 41 | // 无效的类型 42 | u = &ints{} 43 | a.Error(u.Scan(1)) 44 | 45 | u = &ints{} 46 | a.Error(u.Scan("2020")) 47 | 48 | u = &ints{} 49 | a.Error(u.Scan(map[string]string{})) 50 | 51 | u = &ints{} 52 | a.NotError(u.Scan(nil)) 53 | 54 | // strs 55 | 56 | s := &strs{} 57 | a.NotError(s.Scan(`["1","2","3\""]`)). 58 | Equal(s, &strs{"1", "2", "3\""}) 59 | 60 | // 无效的类型 61 | s = &strs{} 62 | a.Error(s.Scan(1)) 63 | 64 | s = &strs{} 65 | a.Error(s.Scan("2020")) 66 | 67 | s = &strs{} 68 | a.Error(s.Scan(map[string]string{})) 69 | 70 | s = &strs{} 71 | a.NotError(s.Scan(nil)) 72 | } 73 | -------------------------------------------------------------------------------- /types/types.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | // Package types 提供部分存取数据库的类型 6 | package types 7 | 8 | const null = "NULL" 9 | 10 | var nullBytes = []byte(null) 11 | -------------------------------------------------------------------------------- /types/types_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package types_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/orm/v6/internal/test" 11 | ) 12 | 13 | func TestMain(m *testing.M) { 14 | test.Main(m) 15 | } 16 | -------------------------------------------------------------------------------- /types/unix.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package types 6 | 7 | import ( 8 | "bytes" 9 | "database/sql/driver" 10 | "encoding/json" 11 | "strconv" 12 | "time" 13 | 14 | "github.com/issue9/orm/v6/core" 15 | ) 16 | 17 | // Unix 以 unix 时间戳保存的 [time.Time] 数据格式 18 | type Unix struct { 19 | time.Time 20 | Valid bool 21 | } 22 | 23 | func (n *Unix) Scan(src any) (err error) { 24 | // The src value will be of one of the following types: 25 | // 26 | // int64 27 | // float64 28 | // bool 29 | // []byte 30 | // string 31 | // time.Time 32 | // nil - for NULL values 33 | if src == nil { 34 | n.Valid = false 35 | return nil 36 | } 37 | 38 | unix := int64(0) 39 | switch v := src.(type) { 40 | case int64: 41 | unix = v 42 | case []byte: 43 | if bytes.Equal(v, nullBytes) { 44 | n.Valid = false 45 | return nil 46 | } 47 | if unix, err = strconv.ParseInt(string(v), 10, 64); err != nil { 48 | return err 49 | } 50 | case string: 51 | if v == null { 52 | n.Valid = false 53 | return nil 54 | } 55 | if unix, err = strconv.ParseInt(v, 10, 64); err != nil { 56 | return err 57 | } 58 | default: 59 | return core.ErrInvalidColumnType() 60 | } 61 | 62 | n.Time = time.Unix(unix, 0) 63 | n.Valid = true 64 | return nil 65 | } 66 | 67 | func (n Unix) Value() (driver.Value, error) { 68 | if !n.Valid { 69 | return nil, nil 70 | } 71 | 72 | return n.Time.Unix(), nil 73 | } 74 | 75 | // FromTime 从 time.Time 转换而来 76 | func (n *Unix) FromTime(t time.Time) { 77 | n.Valid = true 78 | n.Time = t 79 | } 80 | 81 | func (n Unix) PrimitiveType() core.PrimitiveType { return core.Int64 } 82 | 83 | func (n Unix) MarshalText() ([]byte, error) { 84 | if !n.Valid { 85 | return nil, nil 86 | } 87 | return n.Time.MarshalText() 88 | } 89 | 90 | func (n Unix) MarshalJSON() ([]byte, error) { 91 | if !n.Valid { 92 | return json.Marshal(nil) 93 | } 94 | return n.Time.MarshalJSON() 95 | } 96 | 97 | func (n *Unix) UnmarshalText(data []byte) error { 98 | if n.Valid = len(data) > 0; !n.Valid { 99 | return nil 100 | } 101 | return n.Time.UnmarshalText(data) 102 | } 103 | 104 | func (n *Unix) UnmarshalJSON(data []byte) error { 105 | if n.Valid = len(data) > 0; !n.Valid { 106 | return nil 107 | } 108 | return n.Time.UnmarshalJSON(data) 109 | } 110 | -------------------------------------------------------------------------------- /types/unix_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package types 6 | 7 | import ( 8 | "database/sql" 9 | "database/sql/driver" 10 | "encoding" 11 | "encoding/json" 12 | "testing" 13 | "time" 14 | 15 | "github.com/issue9/assert/v4" 16 | 17 | "github.com/issue9/orm/v6/core" 18 | ) 19 | 20 | var ( 21 | _ sql.Scanner = &Unix{} 22 | _ driver.Valuer = Unix{} 23 | _ core.PrimitiveTyper = &Unix{} 24 | 25 | _ encoding.BinaryMarshaler = Unix{} 26 | _ encoding.TextMarshaler = Unix{} 27 | _ json.Marshaler = Unix{} 28 | _ encoding.BinaryUnmarshaler = &Unix{} 29 | _ encoding.TextUnmarshaler = &Unix{} 30 | _ json.Unmarshaler = &Unix{} 31 | ) 32 | 33 | func TestUnix_Scan(t *testing.T) { 34 | a := assert.New(t, false) 35 | 36 | u := &Unix{} 37 | a.NotError(u.Scan(int64(1))). 38 | Equal(1, u.Time.Unix()) 39 | 40 | u = &Unix{} 41 | a.NotError(u.Scan("123")). 42 | Equal(123, u.Time.Unix()) 43 | 44 | u = &Unix{} 45 | a.NotError(u.Scan([]byte("123"))). 46 | Equal(123, u.Time.Unix()). 47 | True(u.Valid) 48 | 49 | u = &Unix{} 50 | a.NotError(u.Scan(nil)). 51 | False(u.Valid) 52 | 53 | // 无法解析的值 54 | u = &Unix{} 55 | a.Error(u.Scan(int32(1))) 56 | u = &Unix{} 57 | a.Error(u.Scan("123x")) 58 | 59 | // 无效的类型 60 | u = &Unix{} 61 | a.Error(u.Scan(int32(1))) 62 | u = &Unix{} 63 | a.Error(u.Scan(&struct{ X int }{X: 5})) 64 | 65 | u = &Unix{} 66 | a.NotError(u.Scan(nil)) 67 | } 68 | 69 | func TestUnix_Unmarshal(t *testing.T) { 70 | a := assert.New(t, false) 71 | 72 | now := time.Now() 73 | format := now.Format(time.RFC3339) 74 | j := `{"u":"` + format + `"}` 75 | 76 | obj := struct { 77 | U Unix `json:"u"` 78 | }{} 79 | a.NotError(json.Unmarshal([]byte(j), &obj)) 80 | a.Equal(now.Unix(), obj.U.Unix()) 81 | 82 | jj, err := json.Marshal(obj) 83 | a.NotError(err).Equal(string(jj), j) 84 | } 85 | -------------------------------------------------------------------------------- /types_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package orm_test 6 | 7 | import ( 8 | "testing" 9 | "time" 10 | 11 | "github.com/issue9/assert/v4" 12 | 13 | "github.com/issue9/orm/v6" 14 | "github.com/issue9/orm/v6/internal/test" 15 | ) 16 | 17 | var ( 18 | _ orm.Engine = &orm.DB{} 19 | _ orm.Engine = &orm.Tx{} 20 | ) 21 | 22 | type beforeObject1 struct { 23 | ID int64 `orm:"name(id);ai"` 24 | Name string `orm:"name(name);len(24)"` 25 | } 26 | 27 | type beforeObject2 struct { 28 | ID int64 `orm:"name(id);ai"` 29 | Name string `orm:"name(name);len(24)"` 30 | } 31 | 32 | var ( 33 | _ orm.BeforeInserter = &beforeObject1{} 34 | _ orm.BeforeUpdater = &beforeObject1{} 35 | ) 36 | 37 | func (o *beforeObject1) TableName() string { return "objects1" } 38 | 39 | func (o *beforeObject1) BeforeInsert() error { 40 | o.Name = "insert-" + o.Name 41 | return nil 42 | } 43 | 44 | func (o *beforeObject1) BeforeUpdate() error { 45 | o.Name = "update-" + o.Name 46 | return nil 47 | } 48 | 49 | var ( 50 | _ orm.BeforeInserter = &beforeObject1{} 51 | _ orm.BeforeUpdater = &beforeObject1{} 52 | ) 53 | 54 | func (o *beforeObject2) TableName() string { return "objects2" } 55 | 56 | func (o *beforeObject2) BeforeInsert() error { 57 | o.Name = "insert-" + o.Name 58 | return nil 59 | } 60 | 61 | func (o *beforeObject2) BeforeUpdate() error { 62 | o.Name = "update-" + o.Name 63 | return nil 64 | } 65 | 66 | func TestBeforeCreateUpdate(t *testing.T) { 67 | a := assert.New(t, false) 68 | suite := test.NewSuite(a, "") 69 | 70 | suite.Run(func(t *test.Driver) { 71 | // create 72 | t.NotError(t.DB.Create(&beforeObject1{})) 73 | defer func() { 74 | t.NotError(t.DB.Drop(&beforeObject1{})) 75 | }() 76 | 77 | // insert 78 | o := &beforeObject1{Name: "name1"} 79 | _, err := t.DB.Insert(o) 80 | t.NotError(err) 81 | o = &beforeObject1{ID: 1} 82 | found, err := t.DB.Select(o) 83 | t.NotError(err).True(found) 84 | t.Equal(o.Name, "insert-name1") 85 | 86 | // update 87 | o = &beforeObject1{ID: 1, Name: "name11"} 88 | _, err = t.DB.Update(o) 89 | t.NotError(err) 90 | o = &beforeObject1{ID: 1} 91 | found, err = t.DB.Select(o) 92 | t.NotError(err).True(found) 93 | t.Equal(o.Name, "update-name11") 94 | }) 95 | } 96 | 97 | func TestNow(t *testing.T) { 98 | a := assert.New(t, false) 99 | now := time.Now() 100 | 101 | n1 := orm.NowUnix() 102 | a.True(n1.Time.After(now)). 103 | False(n1.Valid) 104 | 105 | n2 := orm.NowNullTime() 106 | a.True(n2.Time.After(now)). 107 | True(n2.Valid) 108 | } 109 | -------------------------------------------------------------------------------- /upgrade_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package orm_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/core" 13 | "github.com/issue9/orm/v6/internal/test" 14 | ) 15 | 16 | type u2 struct { 17 | ID int64 `orm:"name(id);unique(u_id)"` 18 | Name string `orm:"name(name);len(50);index(index_name)"` 19 | UserName string `orm:"name(username);len(50);pk"` 20 | Modified int64 `orm:"name(modified);default(0)"` 21 | Created string `orm:"name(created);nullable"` 22 | } 23 | 24 | func (u *u2) TableName() string { return "upgrades" } 25 | 26 | func (u *u2) ApplyModel(m *core.Model) error { 27 | return m.NewCheck("chk_username", "{username} IS NOT NULL") 28 | } 29 | 30 | func TestUpgrader(t *testing.T) { 31 | a := assert.New(t, false) 32 | suite := test.NewSuite(a, "") 33 | 34 | suite.Run(func(t *test.Driver) { 35 | sql := t.DB.SQLBuilder().CreateTable(). 36 | Column("id", core.Int64, false, false, false, nil). 37 | Column("name", core.String, false, false, false, nil, 50). 38 | Column("username", core.String, false, false, false, nil, 50). 39 | Column("created", core.Int64, false, false, false, nil). 40 | Index(core.IndexDefault, "i_name", "name"). 41 | Unique("u_username", "username"). 42 | Check("chk_id", "id>0"). 43 | Table((&u2{}).TableName()) 44 | t.NotError(sql.Exec()) 45 | 46 | defer func(n string) { 47 | t.NotError(t.DB.Drop(&u2{})) 48 | }(t.DriverName) 49 | 50 | u, err := t.DB.Upgrade(&u2{}) 51 | t.NotError(err, "%s@%s", err, t.DriverName). 52 | NotNil(u) 53 | 54 | err = u.DropConstraint("u_username", "chk_id"). 55 | DropIndex("i_name"). 56 | DropColumn("created"). 57 | AddColumn("modified"). 58 | AddConstraint("u_id", "chk_username", "u2_pk"). 59 | AddIndex("index_name"). 60 | AddColumn("created"). // 同名不同类型 61 | Do() 62 | t.NotError(err, "%s@%s", err, t.DriverName) 63 | }) 64 | } 65 | -------------------------------------------------------------------------------- /where.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package orm 6 | 7 | import ( 8 | "database/sql" 9 | "fmt" 10 | "reflect" 11 | 12 | "github.com/issue9/orm/v6/core" 13 | "github.com/issue9/orm/v6/sqlbuilder" 14 | ) 15 | 16 | type whereWhere = sqlbuilder.WhereStmtOf[*WhereStmt] 17 | 18 | type WhereStmt struct { 19 | *whereWhere 20 | engine Engine 21 | } 22 | 23 | func (db *DB) Where(cond string, args ...any) *WhereStmt { 24 | w := &WhereStmt{engine: db} 25 | w.whereWhere = sqlbuilder.NewWhereStmtOf(w) 26 | return w.Where(cond, args...) 27 | } 28 | 29 | func (tx *Tx) Where(cond string, args ...any) *WhereStmt { 30 | w := &WhereStmt{engine: tx} 31 | w.whereWhere = sqlbuilder.NewWhereStmtOf(w) 32 | return w.Where(cond, args...) 33 | } 34 | 35 | func (e *txEngine) Where(cond string, args ...any) *WhereStmt { 36 | w := &WhereStmt{engine: e} 37 | w.whereWhere = sqlbuilder.NewWhereStmtOf(w) 38 | return w.Where(cond, args...) 39 | } 40 | 41 | // Delete 从 v 表中删除符合条件的内容 42 | func (stmt *WhereStmt) Delete(v TableNamer) (sql.Result, error) { 43 | m, err := stmt.engine.newModel(v) 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | if m.Type == core.View { 49 | return nil, fmt.Errorf("模型 %s 的类型是视图,无法从其中删除数据", m.Name) 50 | } 51 | 52 | return stmt.WhereStmt().Delete(stmt.engine).Table(m.Name).Exec() 53 | } 54 | 55 | // Update 将 v 中内容更新到符合条件的行中 56 | // 57 | // 不会更新零值,除非通过 cols 指定了该列。 58 | // 表名来自 v,列名为 v 的所有列或是 cols 指定的列。 59 | func (stmt *WhereStmt) Update(v TableNamer, cols ...string) (sql.Result, error) { 60 | upd := stmt.WhereStmt().Update(stmt.engine) 61 | 62 | if _, _, err := getUpdateColumns(stmt.engine, v, upd, cols...); err != nil { 63 | return nil, err 64 | } 65 | 66 | return upd.Exec() 67 | } 68 | 69 | // Select 获取所有符合条件的数据 70 | // 71 | // v 可能是某个对象的指针,或是一组相同对象指针数组。表名来自 v,列名为 v 的所有列。 72 | func (stmt *WhereStmt) Select(strict bool, v any) (int, error) { 73 | t := reflect.TypeOf(v) 74 | for t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice || t.Kind() == reflect.Array { 75 | t = t.Elem() 76 | } 77 | 78 | tn, ok := reflect.New(t).Interface().(TableNamer) 79 | if !ok { 80 | return 0, fmt.Errorf("v 不是 TableNamer 类型") 81 | } 82 | m, err := stmt.engine.newModel(tn) 83 | if err != nil { 84 | return 0, err 85 | } 86 | 87 | return stmt.WhereStmt().Select(stmt.engine). 88 | Column("*"). 89 | From(m.Name). 90 | QueryObject(strict, v) 91 | } 92 | 93 | // Count 返回符合条件数量 94 | // 95 | // 表名来自 v。 96 | func (stmt *WhereStmt) Count(v TableNamer) (int64, error) { 97 | m, _, err := getModel(stmt.engine, v) 98 | if err != nil { 99 | return 0, err 100 | } 101 | 102 | return stmt.WhereStmt().Select(stmt.engine). 103 | Count("count(*) as cnt"). 104 | From(m.Name). 105 | QueryInt("cnt") 106 | } 107 | -------------------------------------------------------------------------------- /where_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package orm_test 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/issue9/assert/v4" 11 | 12 | "github.com/issue9/orm/v6/internal/test" 13 | ) 14 | 15 | func TestWhereStmt_Delete(t *testing.T) { 16 | a := assert.New(t, false) 17 | suite := test.NewSuite(a, "") 18 | 19 | suite.Run(func(t *test.Driver) { 20 | initData(t) 21 | defer clearData(t) 22 | 23 | // delete 24 | r, err := t.DB.Where("uid=?", 1).Delete(&UserInfo{}) 25 | t.NotError(err) 26 | cnt, err := r.RowsAffected() 27 | t.NotError(err). 28 | Equal(cnt, 1) 29 | 30 | r, err = t.DB.Where("last_name=?", "l2").And("first_name=?", "f2").Delete(&UserInfo{}) 31 | t.NotError(err) 32 | cnt, err = r.RowsAffected() 33 | t.NotError(err). 34 | Equal(cnt, 1) 35 | 36 | r, err = t.DB.Where("email=?", "email1").Delete(&Admin{Email: "e"}) 37 | t.NotError(err) 38 | cnt, err = r.RowsAffected() 39 | t.NotError(err). 40 | Equal(cnt, 1) 41 | 42 | hasCount(t.DB, t.Assertion, "user_info", 0) 43 | hasCount(t.DB, t.Assertion, "administrators", 0) 44 | }) 45 | } 46 | 47 | func TestWhereStmt_Update(t *testing.T) { 48 | a := assert.New(t, false) 49 | suite := test.NewSuite(a, "") 50 | 51 | suite.Run(func(t *test.Driver) { 52 | initData(t) 53 | defer clearData(t) 54 | 55 | r, err := t.DB.Where("last_name=?", "l2"). 56 | And("first_name=?", "f2"). 57 | Update(&UserInfo{ 58 | FirstName: "firstName2", 59 | LastName: "lastName2", 60 | Sex: "sex2", 61 | }) 62 | t.NotError(err) 63 | cnt, err := r.RowsAffected() 64 | t.NotError(err). 65 | Equal(cnt, 1) 66 | 67 | r, err = t.DB.Where("email=?", "email1").Update(&Admin{Email: "email1111"}) 68 | t.NotError(err) 69 | cnt, err = r.RowsAffected() 70 | t.NotError(err). 71 | Equal(cnt, 1) 72 | 73 | u2 := &UserInfo{LastName: "lastName2", FirstName: "firstName2"} 74 | found, err := t.DB.Select(u2) 75 | t.NotError(err).True(found) 76 | t.Equal(u2, &UserInfo{UID: 2, FirstName: "firstName2", LastName: "lastName2", Sex: "sex2"}) 77 | 78 | admin := &Admin{Email: "email1111"} 79 | found, err = t.DB.Select(admin) 80 | t.NotError(err).True(found) 81 | t.Equal(admin, &Admin{User: User{ID: 1, Username: "username1", Password: "password1"}, Email: "email1111", Group: 1}) 82 | }) 83 | } 84 | 85 | func TestWhereStmt_Select(t *testing.T) { 86 | a := assert.New(t, false) 87 | suite := test.NewSuite(a, "") 88 | 89 | suite.Run(func(t *test.Driver) { 90 | initData(t) 91 | defer clearData(t) 92 | 93 | u := &UserInfo{} 94 | cnt, err := t.DB.Where("uid>=?", 1).Select(true, u) 95 | a.NotError(err). 96 | Equal(cnt, 1). 97 | Equal(u.UID, 1). 98 | Equal(u.FirstName, "f1") 99 | 100 | us := make([]*UserInfo, 0) 101 | cnt, err = t.DB.Where("uid>=?", 1).Select(true, &us) 102 | a.NotError(err). 103 | Equal(cnt, 2). 104 | Equal(us[0].UID, 1). 105 | Equal(us[0].FirstName, "f1"). 106 | Equal(us[1].UID, 2). 107 | Equal(us[1].FirstName, "f2") 108 | 109 | type ui struct { 110 | UserInfo 111 | } 112 | 113 | items := []any{ 114 | &UserInfo{}, 115 | &ui{}, 116 | } 117 | cnt, err = t.DB.Where("uid>=?", 1).Select(true, &items) 118 | a.Error(err).Empty(cnt) 119 | }) 120 | } 121 | 122 | func TestWhereStmt_Count(t *testing.T) { 123 | a := assert.New(t, false) 124 | suite := test.NewSuite(a, "") 125 | 126 | suite.Run(func(t *test.Driver) { 127 | initData(t) 128 | defer clearData(t) 129 | 130 | // 单条件 131 | count, err := t.DB.Where("uid=?", 1).Count(&UserInfo{}) 132 | t.NotError(err). 133 | Equal(1, count) 134 | 135 | // 无条件 136 | count, err = t.DB.Where("1=1").Count(&UserInfo{}) 137 | t.NotError(err). 138 | Equal(2, count) 139 | 140 | // 条件不存在 141 | count, err = t.DB.Where("email=?", "email1-not-exists").Count(&Admin{}) 142 | t.NotError(err). 143 | Equal(0, count) 144 | }) 145 | } 146 | --------------------------------------------------------------------------------