├── .github
├── dependabot.yml
└── workflows
│ └── tests.yml
├── .gitignore
├── LICENSE
├── README.md
├── conn_pool.go
├── dialector.go
├── examples
└── order.go
├── go.mod
├── go.sum
├── primary_key.go
├── primary_key_test.go
├── sharding.go
└── sharding_test.go
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | ---
2 | version: 2
3 | updates:
4 | - package-ecosystem: gomod
5 | directory: /
6 | schedule:
7 | interval: weekly
8 | - package-ecosystem: github-actions
9 | directory: /
10 | schedule:
11 | interval: weekly
12 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 | on:
3 | push:
4 | branches:
5 | - "main"
6 | tags:
7 | - "v*"
8 | pull_request:
9 | jobs:
10 | postgres:
11 | strategy:
12 | matrix:
13 | dbversion:
14 | [
15 | "postgres:latest",
16 | "postgres:13",
17 | "postgres:12",
18 | "postgres:11",
19 | "postgres:10",
20 | ]
21 | platform: [ubuntu-latest] # can not run in macOS and Windows
22 | runs-on: ${{ matrix.platform }}
23 |
24 | services:
25 | postgres:
26 | image: ${{ matrix.dbversion }}
27 | env:
28 | POSTGRES_DB: sharding-test
29 | POSTGRES_USER: gorm
30 | POSTGRES_PASSWORD: gorm
31 | TZ: Asia/Shanghai
32 | ports:
33 | - 5432:5432
34 | # Set health checks to wait until postgres has started
35 | options: >-
36 | --health-cmd pg_isready
37 | --health-interval 10s
38 | --health-timeout 5s
39 | --health-retries 5
40 |
41 | env:
42 | DIALECTOR: postgres
43 | DB_URL: postgres://gorm:gorm@localhost:5432/sharding-test
44 | DB_NOID_URL: postgres://gorm:gorm@localhost:5432/sharding-noid-test
45 | DB_READ_URL: postgres://gorm:gorm@localhost:5432/sharding-read-test
46 | DB_WRITE_URL: postgres://gorm:gorm@localhost:5432/sharding-write-test
47 | steps:
48 | - name: Set up Go
49 | uses: actions/setup-go@v5
50 | with:
51 | go-version: "1.20"
52 | id: go
53 |
54 | - name: Create No ID Database
55 | run: PGPASSWORD=gorm psql -h localhost -U gorm -d sharding-test -c 'CREATE DATABASE "sharding-noid-test";'
56 |
57 | - name: Create Read Database
58 | run: PGPASSWORD=gorm psql -h localhost -U gorm -d sharding-test -c 'CREATE DATABASE "sharding-read-test";'
59 |
60 | - name: Create Write Databases
61 | run: PGPASSWORD=gorm psql -h localhost -U gorm -d sharding-test -c 'CREATE DATABASE "sharding-write-test";'
62 |
63 | - name: Check out code into the Go module directory
64 | uses: actions/checkout@v4
65 |
66 | - name: Get dependencies
67 | run: |
68 | go get -v -t -d ./...
69 |
70 | - name: Test
71 | run: go test
72 | mysql:
73 | name: MySQL
74 |
75 | strategy:
76 | matrix:
77 | dbversion: ["mysql:latest", "mysql:5.7"]
78 | platform: [ubuntu-latest]
79 | runs-on: ${{ matrix.platform }}
80 |
81 | services:
82 | mysql:
83 | image: ${{ matrix.dbversion }}
84 | env:
85 | MYSQL_DATABASE: sharding-test
86 | MYSQL_USER: gorm
87 | MYSQL_PASSWORD: gorm
88 | MYSQL_ROOT_PASSWORD: gorm
89 | ports:
90 | - 3306:3306
91 | options: >-
92 | --health-cmd "mysqladmin ping -ugorm -pgorm"
93 | --health-interval 10s
94 | --health-start-period 10s
95 | --health-timeout 5s
96 | --health-retries 10
97 |
98 | env:
99 | DIALECTOR: mysql
100 | DB_URL: gorm:gorm@tcp(127.0.0.1:3306)/sharding-test?charset=utf8mb4&parseTime=True&loc=Local
101 | DB_NOID_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-noid-test?charset=utf8mb4&parseTime=True&loc=Local
102 | DB_READ_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4&parseTime=True&loc=Local
103 | DB_WRITE_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4&parseTime=True&loc=Local
104 | steps:
105 | - name: Set up Go
106 | uses: actions/setup-go@v5
107 | with:
108 | go-version: "1.20"
109 | id: go
110 |
111 | - name: Create No ID Database
112 | run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-noid-test
113 |
114 | - name: Create Read Database
115 | run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-read-test
116 |
117 | - name: Create Write Database
118 | run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-write-test
119 |
120 | - name: Check out code into the Go module directory
121 | uses: actions/checkout@v4
122 |
123 | - name: Get dependencies
124 | run: |
125 | go get -v -t -d ./...
126 |
127 | - name: Test
128 | run: go test
129 | mariadb:
130 | name: MariaDB
131 |
132 | strategy:
133 | matrix:
134 | dbversion: ["mariadb:10.11"]
135 | platform: [ubuntu-latest]
136 | runs-on: ${{ matrix.platform }}
137 |
138 | services:
139 | mariadb:
140 | image: ${{ matrix.dbversion }}
141 | env:
142 | MYSQL_DATABASE: sharding-test
143 | MYSQL_USER: gorm
144 | MYSQL_PASSWORD: gorm
145 | MYSQL_ROOT_PASSWORD: gorm
146 | ports:
147 | - 3306:3306
148 | options: >-
149 | --health-cmd "mysqladmin ping -ugorm -pgorm"
150 | --health-interval 10s
151 | --health-start-period 10s
152 | --health-timeout 5s
153 | --health-retries 10
154 |
155 | env:
156 | DIALECTOR: mariadb
157 | DB_URL: gorm:gorm@tcp(127.0.0.1:3306)/sharding-test?charset=utf8mb4&parseTime=True&loc=Local
158 | DB_NOID_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-noid-test?charset=utf8mb4&parseTime=True&loc=Local
159 | DB_READ_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4&parseTime=True&loc=Local
160 | DB_WRITE_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4&parseTime=True&loc=Local
161 | steps:
162 | - name: Set up Go
163 | uses: actions/setup-go@v5
164 | with:
165 | go-version: "1.20"
166 | id: go
167 |
168 | - name: Create No ID Database
169 | run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-noid-test
170 |
171 | - name: Create Read Database
172 | run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-read-test
173 |
174 | - name: Create Write Database
175 | run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-write-test
176 |
177 | - name: Check out code into the Go module directory
178 | uses: actions/checkout@v4
179 |
180 | - name: Get dependencies
181 | run: |
182 | go get -v -t -d ./...
183 |
184 | - name: Test
185 | run: go test
186 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Longbridge
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Gorm Sharding
2 |
3 | [](https://github.com/go-gorm/sharding/actions/workflows/tests.yml)
4 |
5 | Gorm Sharding plugin using SQL parser and replace for splits large tables into smaller ones, redirects Query into sharding tables. Give you a high performance database access.
6 |
7 | Gorm Sharding 是一个高性能的数据库分表中间件。
8 |
9 | 它基于 Conn 层做 SQL 拦截、AST 解析、分表路由、自增主键填充,带来的额外开销极小。对开发者友好、透明,使用上与普通 SQL、Gorm 查询无差别,只需要额外注意一下分表键条件。
10 |
11 | ## Features
12 |
13 | - Non-intrusive design. Load the plugin, specify the config, and all done.
14 | - Lighting-fast. No network based middlewares, as fast as Go.
15 | - Multiple database (PostgreSQL, MySQL) support.
16 | - Integrated primary key generator (Snowflake, PostgreSQL Sequence, Custom, ...).
17 |
18 | ## Install
19 |
20 | ```bash
21 | go get -u gorm.io/sharding
22 | ```
23 |
24 | ## Usage
25 |
26 | Config the sharding middleware, register the tables which you want to shard.
27 |
28 | ```go
29 | import (
30 | "fmt"
31 |
32 | "gorm.io/driver/postgres"
33 | "gorm.io/gorm"
34 | "gorm.io/sharding"
35 | )
36 |
37 | db, err := gorm.Open(postgres.New(postgres.Config{DSN: "postgres://localhost:5432/sharding-db?sslmode=disable"))
38 |
39 | db.Use(sharding.Register(sharding.Config{
40 | ShardingKey: "user_id",
41 | NumberOfShards: 64,
42 | PrimaryKeyGenerator: sharding.PKSnowflake,
43 | }, "orders", Notification{}, AuditLog{}))
44 | // This case for show up give notifications, audit_logs table use same sharding rule.
45 | ```
46 |
47 | Use the db session as usual. Just note that the query should have the `Sharding Key` when operate sharding tables.
48 |
49 | ```go
50 | // Gorm create example, this will insert to orders_02
51 | db.Create(&Order{UserID: 2})
52 | // sql: INSERT INTO orders_2 ...
53 |
54 | // Show have use Raw SQL to insert, this will insert into orders_03
55 | db.Exec("INSERT INTO orders(user_id) VALUES(?)", int64(3))
56 |
57 | // This will throw ErrMissingShardingKey error, because there not have sharding key presented.
58 | db.Create(&Order{Amount: 10, ProductID: 100})
59 | fmt.Println(err)
60 |
61 | // Find, this will redirect query to orders_02
62 | var orders []Order
63 | db.Model(&Order{}).Where("user_id", int64(2)).Find(&orders)
64 | fmt.Printf("%#v\n", orders)
65 |
66 | // Raw SQL also supported
67 | db.Raw("SELECT * FROM orders WHERE user_id = ?", int64(3)).Scan(&orders)
68 | fmt.Printf("%#v\n", orders)
69 |
70 | // This will throw ErrMissingShardingKey error, because WHERE conditions not included sharding key
71 | err = db.Model(&Order{}).Where("product_id", "1").Find(&orders).Error
72 | fmt.Println(err)
73 |
74 | // Update and Delete are similar to create and query
75 | db.Exec("UPDATE orders SET product_id = ? WHERE user_id = ?", 2, int64(3))
76 | err = db.Exec("DELETE FROM orders WHERE product_id = 3").Error
77 | fmt.Println(err) // ErrMissingShardingKey
78 | ```
79 |
80 | The full example is [here](./examples/order.go).
81 |
82 | > 🚨 NOTE: Gorm config `PrepareStmt: true` is not supported for now.
83 | >
84 | > 🚨 NOTE: Default snowflake generator in multiple nodes may result conflicted primary key, use your custom primary key generator, or regenerate a primary key when conflict occurs.
85 |
86 | ## Primary Key
87 |
88 | When you sharding tables, you need consider how the primary key generate.
89 |
90 | Recommend options:
91 |
92 | - [Snowflake](https://github.com/bwmarrin/snowflake)
93 | - [Database sequence by manully](https://www.postgresql.org/docs/current/sql-createsequence.html)
94 |
95 | ### Use Snowflake
96 |
97 | Built-in Snowflake primary key generator.
98 |
99 | ```go
100 | db.Use(sharding.Register(sharding.Config{
101 | ShardingKey: "user_id",
102 | NumberOfShards: 64,
103 | PrimaryKeyGenerator: sharding.PKSnowflake,
104 | }, "orders")
105 | ```
106 |
107 | ### Use PostgreSQL Sequence
108 |
109 | There has built-in PostgreSQL sequence primary key implementation in Gorm Sharding, you just configure `PrimaryKeyGenerator: sharding.PKPGSequence` to use.
110 |
111 | You don't need create sequence manually, Gorm Sharding check and create when the PostgreSQL sequence does not exists.
112 |
113 | This sequence name followed `gorm_sharding_${table_name}_id_seq`, for example `orders` table, the sequence name is `gorm_sharding_orders_id_seq`.
114 |
115 | ```go
116 | db.Use(sharding.Register(sharding.Config{
117 | ShardingKey: "user_id",
118 | NumberOfShards: 64,
119 | PrimaryKeyGenerator: sharding.PKPGSequence,
120 | }, "orders")
121 | ```
122 |
123 | ### Use MySQL Sequence
124 |
125 | There has built-in MySQL sequence primary key implementation in Gorm Sharding, you just configure `PrimaryKeyGenerator: sharding.PKMySQLSequence` to use.
126 |
127 | You don't need create sequence manually, Gorm Sharding check and create when the MySQL sequence does not exists.
128 |
129 | This sequence name followed `gorm_sharding_${table_name}_id_seq`, for example `orders` table, the sequence name is `gorm_sharding_orders_id_seq`.
130 |
131 | ```go
132 | db.Use(sharding.Register(sharding.Config{
133 | ShardingKey: "user_id",
134 | NumberOfShards: 64,
135 | PrimaryKeyGenerator: sharding.PKMySQLSequence,
136 | }, "orders")
137 | ```
138 |
139 | ### No primary key
140 |
141 | If your table doesn't have a primary key, or has a primary key that isn't called `id`, anyway, you don't want to auto-fill the `id` field, then you can set `PrimaryKeyGenerator` to `PKCustom` and have `PrimaryKeyGeneratorFn` return `0`.
142 |
143 | ## Combining with dbresolver
144 |
145 | > 🚨 NOTE: Use dbresolver first.
146 |
147 | ```go
148 | dsn := "host=localhost user=gorm password=gorm dbname=gorm port=5432 sslmode=disable"
149 | dsnRead := "host=localhost user=gorm password=gorm dbname=gorm-slave port=5432 sslmode=disable"
150 |
151 | conn := postgres.Open(dsn)
152 | connRead := postgres.Open(dsnRead)
153 |
154 | db, err := gorm.Open(conn, &gorm.Config{})
155 | dbRead, err := gorm.Open(conn, &gorm.Config{})
156 |
157 | db.Use(dbresolver.Register(dbresolver.Config{
158 | Replicas: []gorm.Dialector{dbRead.Dialector},
159 | }))
160 |
161 | db.Use(sharding.Register(sharding.Config{
162 | ShardingKey: "user_id",
163 | NumberOfShards: 64,
164 | PrimaryKeyGenerator: sharding.PKSnowflake,
165 | }))
166 | ```
167 |
168 | ## Sharding process
169 |
170 | This graph show up how Gorm Sharding works.
171 |
172 | ```mermaid
173 | graph TD
174 | first("SELECT * FROM orders WHERE user_id = ? AND status = ?
175 | args = [100, 1]")
176 |
177 | first--->gorm(["Gorm Query"])
178 |
179 | subgraph "Gorm"
180 | gorm--->gorm_query
181 | gorm--->gorm_exec
182 | gorm--->gorm_queryrow
183 | gorm_query["connPool.QueryContext(sql, args)"]
184 | gorm_exec[/"connPool.ExecContext"/]
185 | gorm_queryrow[/"connPool.QueryRowContext"/]
186 | end
187 |
188 | subgraph "database/sql"
189 | gorm_query-->conn(["Conn"])
190 | gorm_exec-->conn(["Conn"])
191 | gorm_queryrow-->conn(["Conn"])
192 | ExecContext[/"ExecContext"/]
193 | QueryContext[/"QueryContext"/]
194 | QueryRowContext[/"QueryRowContext"/]
195 |
196 |
197 | conn-->ExecContext
198 | conn-->QueryRowContext
199 | conn-->QueryContext
200 | end
201 |
202 | subgraph sharding ["Sharding"]
203 | QueryContext-->router-->| Format to get full SQL string |format_sql-->| Parser to AST |parse-->check_table
204 | router[["router(sql, args)
"]]
205 | format_sql>"sql = SELECT * FROM orders WHERE user_id = 100 AND status = 1"]
206 |
207 | check_table{"Check sharding rules
by table name"}
208 | check_table-->| Exist |process_ast
209 | check_table_1{{"Return Raw SQL"}}
210 | not_match_error[/"Return Error
SQL query must has sharding key"\]
211 |
212 | parse[["ast = sqlparser.Parse(sql)"]]
213 |
214 | check_table-.->| Not exist |check_table_1
215 | process_ast(("Sharding rules"))
216 | get_new_table_name[["Use value in WhereValue (100) for get sharding table index
orders + (100 % 16)
Sharding Table = orders_4"]]
217 | new_sql{{"SELECT * FROM orders_4 WHERE user_id = 100 AND status = 1"}}
218 |
219 | process_ast-.->| Not match ShardingKey |not_match_error
220 | process_ast-->| Match ShardingKey |match_sharding_key-->| Get table name |get_new_table_name-->| Replace TableName to get new SQL |new_sql
221 | end
222 |
223 |
224 | subgraph database [Database]
225 | orders_other[("orders_0, orders_1 ... orders_3")]
226 | orders_4[(orders_4)]
227 | orders_last[("orders_5 ... orders_15")]
228 | other_tables[(Other non-sharding tables
users, stocks, topics ...)]
229 |
230 | new_sql-->| Sharding Query | orders_4
231 | check_table_1-.->| None sharding Query |other_tables
232 | end
233 |
234 | orders_4-->result
235 | other_tables-.->result
236 | result[/Query results\]
237 | ```
238 |
239 | ## License
240 |
241 | MIT license.
242 |
243 | Original fork from [Longbridge](https://github.com/longbridgeapp/gorm-sharding).
244 |
--------------------------------------------------------------------------------
/conn_pool.go:
--------------------------------------------------------------------------------
1 | package sharding
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "time"
7 |
8 | "gorm.io/gorm"
9 | )
10 |
11 | // ConnPool Implement a ConnPool for replace db.Statement.ConnPool in Gorm
12 | type ConnPool struct {
13 | // db, This is global db instance
14 | sharding *Sharding
15 | gorm.ConnPool
16 | }
17 |
18 | func (pool *ConnPool) String() string {
19 | return "gorm:sharding:conn_pool"
20 | }
21 |
22 | func (pool ConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
23 | return pool.ConnPool.PrepareContext(ctx, query)
24 | }
25 |
26 | func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
27 | var (
28 | curTime = time.Now()
29 | )
30 |
31 | ftQuery, stQuery, table, err := pool.sharding.resolve(query, args...)
32 | if err != nil {
33 | return nil, err
34 | }
35 |
36 | pool.sharding.querys.Store("last_query", stQuery)
37 |
38 | if table != "" {
39 | if r, ok := pool.sharding.configs[table]; ok {
40 | if r.DoubleWrite {
41 | pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) {
42 | result, _ := pool.ConnPool.ExecContext(ctx, ftQuery, args...)
43 | rowsAffected, _ = result.RowsAffected()
44 | return pool.sharding.Explain(ftQuery, args...), rowsAffected
45 | }, pool.sharding.Error)
46 | }
47 | }
48 | }
49 |
50 | var result sql.Result
51 | result, err = pool.ConnPool.ExecContext(ctx, stQuery, args...)
52 | pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) {
53 | rowsAffected, _ = result.RowsAffected()
54 | return pool.sharding.Explain(stQuery, args...), rowsAffected
55 | }, pool.sharding.Error)
56 |
57 | return result, err
58 | }
59 |
60 | // https://github.com/go-gorm/gorm/blob/v1.21.11/callbacks/query.go#L18
61 | func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
62 | var (
63 | curTime = time.Now()
64 | )
65 |
66 | _, stQuery, _, err := pool.sharding.resolve(query, args...)
67 | if err != nil {
68 | return nil, err
69 | }
70 |
71 | pool.sharding.querys.Store("last_query", stQuery)
72 |
73 | var rows *sql.Rows
74 | rows, err = pool.ConnPool.QueryContext(ctx, stQuery, args...)
75 | pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) {
76 | return pool.sharding.Explain(stQuery, args...), 0
77 | }, pool.sharding.Error)
78 |
79 | return rows, err
80 | }
81 |
82 | func (pool ConnPool) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
83 | _, query, _, _ = pool.sharding.resolve(query, args...)
84 | pool.sharding.querys.Store("last_query", query)
85 |
86 | return pool.ConnPool.QueryRowContext(ctx, query, args...)
87 | }
88 |
89 | // BeginTx Implement ConnPoolBeginner.BeginTx
90 | func (pool *ConnPool) BeginTx(ctx context.Context, opt *sql.TxOptions) (gorm.ConnPool, error) {
91 | if basePool, ok := pool.ConnPool.(gorm.ConnPoolBeginner); ok {
92 | return basePool.BeginTx(ctx, opt)
93 | }
94 |
95 | return pool, nil
96 | }
97 |
98 | // Implement TxCommitter.Commit
99 | func (pool *ConnPool) Commit() error {
100 | if _, ok := pool.ConnPool.(*sql.Tx); ok {
101 | return nil
102 | }
103 |
104 | if basePool, ok := pool.ConnPool.(gorm.TxCommitter); ok {
105 | return basePool.Commit()
106 | }
107 |
108 | return nil
109 | }
110 |
111 | // Implement TxCommitter.Rollback
112 | func (pool *ConnPool) Rollback() error {
113 | if _, ok := pool.ConnPool.(*sql.Tx); ok {
114 | return nil
115 | }
116 |
117 | if basePool, ok := pool.ConnPool.(gorm.TxCommitter); ok {
118 | return basePool.Rollback()
119 | }
120 |
121 | return nil
122 | }
123 |
124 | func (pool *ConnPool) Ping() error {
125 | return nil
126 | }
127 |
--------------------------------------------------------------------------------
/dialector.go:
--------------------------------------------------------------------------------
1 | package sharding
2 |
3 | import (
4 | "fmt"
5 | "gorm.io/gorm/migrator"
6 | "gorm.io/gorm/schema"
7 |
8 | "gorm.io/gorm"
9 | )
10 |
11 | type ShardingDialector struct {
12 | gorm.Dialector
13 | sharding *Sharding
14 | }
15 |
16 | type ShardingMigrator struct {
17 | gorm.Migrator
18 | sharding *Sharding
19 | dialector gorm.Dialector
20 | }
21 |
22 | func NewShardingDialector(d gorm.Dialector, s *Sharding) ShardingDialector {
23 | return ShardingDialector{
24 | Dialector: d,
25 | sharding: s,
26 | }
27 | }
28 |
29 | func (d ShardingDialector) Migrator(db *gorm.DB) gorm.Migrator {
30 | m := d.Dialector.Migrator(db)
31 | return ShardingMigrator{
32 | Migrator: m,
33 | sharding: d.sharding,
34 | dialector: d.Dialector,
35 | }
36 | }
37 |
38 | func (m ShardingMigrator) AutoMigrate(dst ...any) error {
39 | shardingDsts, noShardingDsts, err := m.splitShardingDsts(dst...)
40 | if err != nil {
41 | return err
42 | }
43 |
44 | stmt := &gorm.Statement{DB: m.sharding.DB}
45 | for _, sd := range shardingDsts {
46 | tx := stmt.DB.Session(&gorm.Session{}).Table(sd.table)
47 | if err := m.dialector.Migrator(tx).AutoMigrate(sd.dst); err != nil {
48 | return err
49 | }
50 |
51 | }
52 |
53 | if len(noShardingDsts) > 0 {
54 | tx := stmt.DB.Session(&gorm.Session{})
55 | tx.Statement.Settings.Store(ShardingIgnoreStoreKey, nil)
56 | defer tx.Statement.Settings.Delete(ShardingIgnoreStoreKey)
57 |
58 | if err := m.dialector.Migrator(tx).AutoMigrate(noShardingDsts...); err != nil {
59 | return err
60 | }
61 | }
62 |
63 | return nil
64 | }
65 |
66 | // BuildIndexOptions build index options
67 | func (m ShardingMigrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
68 | return m.Migrator.(migrator.BuildIndexOptionsInterface).BuildIndexOptions(opts, stmt)
69 | }
70 |
71 | func (m ShardingMigrator) DropTable(dst ...any) error {
72 | shardingDsts, noShardingDsts, err := m.splitShardingDsts(dst...)
73 | if err != nil {
74 | return err
75 | }
76 |
77 | for _, sd := range shardingDsts {
78 | if err := m.Migrator.DropTable(sd.table); err != nil {
79 | return err
80 | }
81 | }
82 |
83 | if len(noShardingDsts) > 0 {
84 | if err := m.Migrator.DropTable(noShardingDsts...); err != nil {
85 | return err
86 | }
87 | }
88 |
89 | return nil
90 | }
91 |
92 | type shardingDst struct {
93 | table string
94 | dst any
95 | }
96 |
97 | // splite sharding or normal dsts
98 | func (m ShardingMigrator) splitShardingDsts(dsts ...any) (shardingDsts []shardingDst,
99 | noShardingDsts []any, err error) {
100 |
101 | shardingDsts = make([]shardingDst, 0)
102 | noShardingDsts = make([]any, 0)
103 | for _, model := range dsts {
104 | stmt := &gorm.Statement{DB: m.sharding.DB}
105 | err = stmt.Parse(model)
106 | if err != nil {
107 | return
108 | }
109 |
110 | if cfg, ok := m.sharding.configs[stmt.Table]; ok {
111 | // support sharding table
112 | suffixs := cfg.ShardingSuffixs()
113 | if len(suffixs) == 0 {
114 | err = fmt.Errorf("sharding table:%s suffixs is empty", stmt.Table)
115 | return
116 | }
117 |
118 | for _, suffix := range suffixs {
119 | shardingTable := stmt.Table + suffix
120 | shardingDsts = append(shardingDsts, shardingDst{
121 | table: shardingTable,
122 | dst: model,
123 | })
124 | }
125 |
126 | if cfg.DoubleWrite {
127 | noShardingDsts = append(noShardingDsts, model)
128 | }
129 | } else {
130 | noShardingDsts = append(noShardingDsts, model)
131 | }
132 | }
133 | return
134 | }
135 |
--------------------------------------------------------------------------------
/examples/order.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 |
6 | "gorm.io/driver/postgres"
7 | "gorm.io/gorm"
8 | "gorm.io/sharding"
9 | )
10 |
11 | type Order struct {
12 | ID int64 `gorm:"primarykey"`
13 | UserID int64
14 | ProductID int64
15 | }
16 |
17 | func main() {
18 | dsn := "postgres://localhost:5432/sharding-db?sslmode=disable"
19 | db, err := gorm.Open(postgres.New(postgres.Config{DSN: dsn}))
20 | if err != nil {
21 | panic(err)
22 | }
23 |
24 | for i := 0; i < 64; i += 1 {
25 | table := fmt.Sprintf("orders_%02d", i)
26 | db.Exec(`DROP TABLE IF EXISTS ` + table)
27 | db.Exec(`CREATE TABLE ` + table + ` (
28 | id BIGSERIAL PRIMARY KEY,
29 | user_id bigint,
30 | product_id bigint
31 | )`)
32 | }
33 |
34 | middleware := sharding.Register(sharding.Config{
35 | ShardingKey: "user_id",
36 | NumberOfShards: 64,
37 | PrimaryKeyGenerator: sharding.PKSnowflake,
38 | }, "orders")
39 | db.Use(middleware)
40 |
41 | // this record will insert to orders_02
42 | err = db.Create(&Order{UserID: 2}).Error
43 | if err != nil {
44 | fmt.Println(err)
45 | }
46 |
47 | // this record will insert to orders_03
48 | err = db.Exec("INSERT INTO orders(user_id) VALUES(?)", int64(3)).Error
49 | if err != nil {
50 | fmt.Println(err)
51 | }
52 |
53 | // this will throw ErrMissingShardingKey error
54 | err = db.Exec("INSERT INTO orders(product_id) VALUES(1)").Error
55 | fmt.Println(err)
56 |
57 | // this will redirect query to orders_02
58 | var orders []Order
59 | err = db.Model(&Order{}).Where("user_id", int64(2)).Find(&orders).Error
60 | if err != nil {
61 | fmt.Println(err)
62 | }
63 | fmt.Printf("%#v\n", orders)
64 |
65 | // Raw SQL also supported
66 | db.Raw("SELECT * FROM orders WHERE user_id = ?", int64(3)).Scan(&orders)
67 | fmt.Printf("%#v\n", orders)
68 |
69 | // this will throw ErrMissingShardingKey error
70 | err = db.Model(&Order{}).Where("product_id", "1").Find(&orders).Error
71 | fmt.Println(err)
72 |
73 | // Update and Delete are similar to create and query
74 | err = db.Exec("UPDATE orders SET product_id = ? WHERE user_id = ?", 2, int64(3)).Error
75 | fmt.Println(err) // nil
76 | err = db.Exec("DELETE FROM orders WHERE product_id = 3").Error
77 | fmt.Println(err) // ErrMissingShardingKey
78 | }
79 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module gorm.io/sharding
2 |
3 | go 1.21
4 |
5 | require (
6 | github.com/bwmarrin/snowflake v0.3.0
7 | github.com/longbridgeapp/assert v1.1.0
8 | github.com/longbridgeapp/sqlparser v0.3.1
9 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
10 | gorm.io/driver/mysql v1.5.1
11 | gorm.io/driver/postgres v1.5.2
12 | gorm.io/gorm v1.25.4
13 | gorm.io/hints v1.1.2
14 | gorm.io/plugin/dbresolver v1.5.1
15 | )
16 |
17 | require (
18 | github.com/davecgh/go-spew v1.1.1 // indirect
19 | github.com/go-sql-driver/mysql v1.7.0 // indirect
20 | github.com/jackc/pgpassfile v1.0.0 // indirect
21 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
22 | github.com/jackc/pgx/v5 v5.3.1 // indirect
23 | github.com/jinzhu/inflection v1.0.0 // indirect
24 | github.com/jinzhu/now v1.1.5 // indirect
25 | github.com/kr/text v0.2.0 // indirect
26 | github.com/pmezard/go-difflib v1.0.0 // indirect
27 | github.com/rogpeppe/go-internal v1.12.0 // indirect
28 | github.com/stretchr/testify v1.8.1 // indirect
29 | golang.org/x/crypto v0.8.0 // indirect
30 | golang.org/x/text v0.9.0 // indirect
31 | gopkg.in/yaml.v3 v3.0.1 // indirect
32 | )
33 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/bwmarrin/snowflake v0.3.0 h1:xm67bEhkKh6ij1790JB83OujPR5CzNe8QuQqAgISZN0=
2 | github.com/bwmarrin/snowflake v0.3.0/go.mod h1:NdZxfVWX+oR6y2K0o6qAYv6gIOP9rjG0/E9WsDpxqwE=
3 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
5 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
6 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
7 | github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
8 | github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
9 | github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
10 | github.com/go-test/deep v1.0.7 h1:/VSMRlnY/JSyqxQUzQLKVMAskpY/NZKFA5j2P+0pP2M=
11 | github.com/go-test/deep v1.0.7/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8=
12 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
13 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
14 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
15 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
16 | github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
17 | github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
18 | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
19 | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
20 | github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
21 | github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
22 | github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
23 | github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
24 | github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
25 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
26 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
27 | github.com/longbridgeapp/assert v1.1.0 h1:L+/HISOhuGbNAAmJNXgk3+Tm5QmSB70kwdktJXgjL+I=
28 | github.com/longbridgeapp/assert v1.1.0/go.mod h1:UOI7O3rzlzlz715lQm0atWs6JbrYGuIJUEeOekutL6o=
29 | github.com/longbridgeapp/sqlparser v0.3.1 h1:iWOZWGIFgQrJRgobLXUNJdvqGRpbVXkyKUKUA5CNJBE=
30 | github.com/longbridgeapp/sqlparser v0.3.1/go.mod h1:GIHaUq8zvYyHLCLMJJykx1CdM6LHtkUih/QaJXySSx4=
31 | github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
32 | github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
33 | github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
34 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
35 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
36 | github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
37 | github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
38 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
39 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
40 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
41 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
42 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
43 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
44 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
45 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
46 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
47 | golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ=
48 | golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
49 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
50 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
51 | golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
52 | golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
53 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
54 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
55 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
56 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
57 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
58 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
59 | gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
60 | gorm.io/driver/mysql v1.5.1 h1:WUEH5VF9obL/lTtzjmML/5e6VfFR/788coz2uaVCAZw=
61 | gorm.io/driver/mysql v1.5.1/go.mod h1:Jo3Xu7mMhCyj8dlrb3WoCaRd1FhsVh+yMXb1jUInf5o=
62 | gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
63 | gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
64 | gorm.io/driver/sqlite v1.5.0 h1:zKYbzRCpBrT1bNijRnxLDJWPjVfImGEn0lSnUY5gZ+c=
65 | gorm.io/driver/sqlite v1.5.0/go.mod h1:kDMDfntV9u/vuMmz8APHtHF0b4nyBB7sfCieC6G8k8I=
66 | gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
67 | gorm.io/gorm v1.24.7-0.20230306060331-85eaf9eeda11/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
68 | gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
69 | gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
70 | gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
71 | gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw=
72 | gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
73 | gorm.io/hints v1.1.2 h1:b5j0kwk5p4+3BtDtYqqfY+ATSxjj+6ptPgVveuynn9o=
74 | gorm.io/hints v1.1.2/go.mod h1:/ARdpUHAtyEMCh5NNi3tI7FsGh+Cj/MIUlvNxCNCFWg=
75 | gorm.io/plugin/dbresolver v1.5.1 h1:s9Dj9f7r+1rE3nx/Ywzc85nXptUEaeOO0pt27xdopM8=
76 | gorm.io/plugin/dbresolver v1.5.1/go.mod h1:l4Cn87EHLEYuqUncpEeTC2tTJQkjngPSD+lo8hIvcT0=
77 |
--------------------------------------------------------------------------------
/primary_key.go:
--------------------------------------------------------------------------------
1 | package sharding
2 |
3 | import "fmt"
4 |
5 | const (
6 | // Use Snowflake primary key generator
7 | PKSnowflake = iota
8 | // Use PostgreSQL sequence primary key generator
9 | PKPGSequence
10 | // Use MySQL sequence primary key generator
11 | PKMySQLSequence
12 | // Use custom primary key generator
13 | PKCustom
14 | )
15 |
16 | func (s *Sharding) genSnowflakeKey(index int64) int64 {
17 | return s.snowflakeNodes[index].Generate().Int64()
18 | }
19 |
20 | // PostgreSQL sequence
21 |
22 | func (s *Sharding) genPostgreSQLSequenceKey(tableName string, index int64) int64 {
23 | var id int64
24 | err := s.DB.Raw("SELECT nextval('" + pgSeqName(tableName) + "')").Scan(&id).Error
25 | if err != nil {
26 | panic(err)
27 | }
28 | return id
29 | }
30 |
31 | func (s *Sharding) createPostgreSQLSequenceKeyIfNotExist(tableName string) error {
32 | return s.DB.Exec(`CREATE SEQUENCE IF NOT EXISTS "` + pgSeqName(tableName) + `" START 1`).Error
33 | }
34 |
35 | func pgSeqName(table string) string {
36 | return fmt.Sprintf("gorm_sharding_%s_id_seq", table)
37 | }
38 |
39 | // MySQL Sequence
40 |
41 | func (s *Sharding) genMySQLSequenceKey(tableName string, index int64) int64 {
42 | var id int64
43 | err := s.DB.Exec("UPDATE `" + mySQLSeqName(tableName) + "` SET id = LAST_INSERT_ID(id + 1)").Error
44 | if err != nil {
45 | panic(err)
46 | }
47 | err = s.DB.Raw("SELECT LAST_INSERT_ID()").Scan(&id).Error
48 | if err != nil {
49 | panic(err)
50 | }
51 | return id
52 | }
53 |
54 | func (s *Sharding) createMySQLSequenceKeyIfNotExist(tableName string) error {
55 | stmt := s.DB.Exec("CREATE TABLE IF NOT EXISTS `" + mySQLSeqName(tableName) + "` (id INT NOT NULL)")
56 | if stmt.Error != nil {
57 | return fmt.Errorf("failed to create sequence table: %w", stmt.Error)
58 | }
59 | stmt = s.DB.Exec("INSERT INTO `" + mySQLSeqName(tableName) + "` VALUES (0)")
60 | if stmt.Error != nil {
61 | return fmt.Errorf("failed to insert into sequence table: %w", stmt.Error)
62 | }
63 | return nil
64 | }
65 |
66 | func mySQLSeqName(table string) string {
67 | return fmt.Sprintf("gorm_sharding_%s_id_seq", table)
68 | }
69 |
--------------------------------------------------------------------------------
/primary_key_test.go:
--------------------------------------------------------------------------------
1 | package sharding
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/longbridgeapp/assert"
7 | )
8 |
9 | func Test_pgSeqName(t *testing.T) {
10 | assert.Equal(t, "gorm_sharding_users_id_seq", pgSeqName("users"))
11 | }
12 |
--------------------------------------------------------------------------------
/sharding.go:
--------------------------------------------------------------------------------
1 | package sharding
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "hash/crc32"
7 | "strconv"
8 | "strings"
9 | "sync"
10 |
11 | "github.com/bwmarrin/snowflake"
12 | "github.com/longbridgeapp/sqlparser"
13 | "golang.org/x/exp/slices"
14 | "gorm.io/gorm"
15 | )
16 |
17 | var (
18 | ErrMissingShardingKey = errors.New("sharding key or id required, and use operator =")
19 | ErrInvalidID = errors.New("invalid id format")
20 | ErrInsertDiffSuffix = errors.New("can not insert different suffix table in one query ")
21 | )
22 |
23 | var (
24 | ShardingIgnoreStoreKey = "sharding_ignore"
25 | )
26 |
27 | type Sharding struct {
28 | *gorm.DB
29 | ConnPool *ConnPool
30 | configs map[string]Config
31 | querys sync.Map
32 | snowflakeNodes []*snowflake.Node
33 |
34 | _config Config
35 | _tables []any
36 |
37 | mutex sync.RWMutex
38 | }
39 |
40 | // Config specifies the configuration for sharding.
41 | type Config struct {
42 | // When DoubleWrite enabled, data will double write to both main table and sharding table.
43 | DoubleWrite bool
44 |
45 | // ShardingKey specifies the table column you want to used for sharding the table rows.
46 | // For example, for a product order table, you may want to split the rows by `user_id`.
47 | ShardingKey string
48 |
49 | // NumberOfShards specifies how many tables you want to sharding.
50 | NumberOfShards uint
51 |
52 | // tableFormat specifies the sharding table suffix format.
53 | tableFormat string
54 |
55 | // ShardingAlgorithm specifies a function to generate the sharding
56 | // table's suffix by the column value.
57 | // For example, this function implements a mod sharding algorithm.
58 | //
59 | // func(value any) (suffix string, err error) {
60 | // if uid, ok := value.(int64);ok {
61 | // return fmt.Sprintf("_%02d", user_id % 64), nil
62 | // }
63 | // return "", errors.New("invalid user_id")
64 | // }
65 | ShardingAlgorithm func(columnValue any) (suffix string, err error)
66 |
67 | // ShardingSuffixs specifies a function to generate all table's suffix.
68 | // Used to support Migrator and generate PrimaryKey.
69 | // For example, this function get a mod all sharding suffixs.
70 | //
71 | // func () (suffixs []string) {
72 | // numberOfShards := 5
73 | // for i := 0; i < numberOfShards; i++ {
74 | // suffixs = append(suffixs, fmt.Sprintf("_%02d", i%numberOfShards))
75 | // }
76 | // return
77 | // }
78 | ShardingSuffixs func() (suffixs []string)
79 |
80 | // ShardingAlgorithmByPrimaryKey specifies a function to generate the sharding
81 | // table's suffix by the primary key. Used when no sharding key specified.
82 | // For example, this function use the Snowflake library to generate the suffix.
83 | //
84 | // func(id int64) (suffix string) {
85 | // return fmt.Sprintf("_%02d", snowflake.ParseInt64(id).Node())
86 | // }
87 | ShardingAlgorithmByPrimaryKey func(id int64) (suffix string)
88 |
89 | // PrimaryKeyGenerator specifies the primary key generate algorithm.
90 | // Used only when insert and the record does not contains an id field.
91 | // Options are PKSnowflake, PKPGSequence and PKCustom.
92 | // When use PKCustom, you should also specify PrimaryKeyGeneratorFn.
93 | PrimaryKeyGenerator int
94 |
95 | // PrimaryKeyGeneratorFn specifies a function to generate the primary key.
96 | // When use auto-increment like generator, the tableIdx argument could ignored.
97 | // For example, this function use the Snowflake library to generate the primary key.
98 | // If you don't want to auto-fill the `id` or use a primary key that isn't called `id`, just return 0.
99 | //
100 | // func(tableIdx int64) int64 {
101 | // return nodes[tableIdx].Generate().Int64()
102 | // }
103 | PrimaryKeyGeneratorFn func(tableIdx int64) int64
104 | }
105 |
106 | func Register(config Config, tables ...any) *Sharding {
107 | return &Sharding{
108 | _config: config,
109 | _tables: tables,
110 | }
111 | }
112 |
113 | func (s *Sharding) compile() error {
114 | if s.configs == nil {
115 | s.configs = make(map[string]Config)
116 | }
117 | for _, table := range s._tables {
118 | if t, ok := table.(string); ok {
119 | s.configs[t] = s._config
120 | } else {
121 | stmt := &gorm.Statement{DB: s.DB}
122 | if err := stmt.Parse(table); err == nil {
123 | s.configs[stmt.Table] = s._config
124 | } else {
125 | return err
126 | }
127 | }
128 | }
129 |
130 | for t, c := range s.configs {
131 | if c.NumberOfShards > 1024 && c.PrimaryKeyGenerator == PKSnowflake {
132 | panic("Snowflake NumberOfShards should less than 1024")
133 | }
134 |
135 | if c.PrimaryKeyGenerator == PKSnowflake {
136 | c.PrimaryKeyGeneratorFn = s.genSnowflakeKey
137 | } else if c.PrimaryKeyGenerator == PKPGSequence {
138 |
139 | // Execute SQL to CREATE SEQUENCE for this table if not exist
140 | err := s.createPostgreSQLSequenceKeyIfNotExist(t)
141 | if err != nil {
142 | return err
143 | }
144 |
145 | c.PrimaryKeyGeneratorFn = func(index int64) int64 {
146 | return s.genPostgreSQLSequenceKey(t, index)
147 | }
148 | } else if c.PrimaryKeyGenerator == PKMySQLSequence {
149 | err := s.createMySQLSequenceKeyIfNotExist(t)
150 | if err != nil {
151 | return err
152 | }
153 |
154 | c.PrimaryKeyGeneratorFn = func(index int64) int64 {
155 | return s.genMySQLSequenceKey(t, index)
156 | }
157 | } else if c.PrimaryKeyGenerator == PKCustom {
158 | if c.PrimaryKeyGeneratorFn == nil {
159 | return errors.New("PrimaryKeyGeneratorFn is required when use PKCustom")
160 | }
161 | } else {
162 | return errors.New("PrimaryKeyGenerator can only be one of PKSnowflake, PKPGSequence, PKMySQLSequence and PKCustom")
163 | }
164 |
165 | if c.ShardingAlgorithm == nil {
166 | if c.NumberOfShards == 0 {
167 | return errors.New("specify NumberOfShards or ShardingAlgorithm")
168 | }
169 | if c.NumberOfShards < 10 {
170 | c.tableFormat = "_%01d"
171 | } else if c.NumberOfShards < 100 {
172 | c.tableFormat = "_%02d"
173 | } else if c.NumberOfShards < 1000 {
174 | c.tableFormat = "_%03d"
175 | } else if c.NumberOfShards < 10000 {
176 | c.tableFormat = "_%04d"
177 | }
178 | c.ShardingAlgorithm = func(value any) (suffix string, err error) {
179 | id := 0
180 | switch value := value.(type) {
181 | case int:
182 | id = value
183 | case int64:
184 | id = int(value)
185 | case string:
186 | id, err = strconv.Atoi(value)
187 | if err != nil {
188 | id = int(crc32.ChecksumIEEE([]byte(value)))
189 | }
190 | default:
191 | return "", fmt.Errorf("default algorithm only support integer and string column," +
192 | "if you use other type, specify you own ShardingAlgorithm")
193 | }
194 |
195 | return fmt.Sprintf(c.tableFormat, id%int(c.NumberOfShards)), nil
196 | }
197 | }
198 |
199 | if c.ShardingSuffixs == nil {
200 | c.ShardingSuffixs = func() (suffixs []string) {
201 | for i := 0; i < int(c.NumberOfShards); i++ {
202 | suffix, err := c.ShardingAlgorithm(i)
203 | if err != nil {
204 | return nil
205 | }
206 | suffixs = append(suffixs, suffix)
207 | }
208 | return
209 | }
210 | }
211 |
212 | if c.ShardingAlgorithmByPrimaryKey == nil {
213 | if c.PrimaryKeyGenerator == PKSnowflake {
214 | c.ShardingAlgorithmByPrimaryKey = func(id int64) (suffix string) {
215 | return fmt.Sprintf(c.tableFormat, snowflake.ParseInt64(id).Node())
216 | }
217 | }
218 | }
219 | s.configs[t] = c
220 | }
221 |
222 | return nil
223 | }
224 |
225 | // Name plugin name for Gorm plugin interface
226 | func (s *Sharding) Name() string {
227 | return "gorm:sharding"
228 | }
229 |
230 | // LastQuery get last SQL query
231 | func (s *Sharding) LastQuery() string {
232 | if query, ok := s.querys.Load("last_query"); ok {
233 | return query.(string)
234 | }
235 |
236 | return ""
237 | }
238 |
239 | // Initialize implement for Gorm plugin interface
240 | func (s *Sharding) Initialize(db *gorm.DB) error {
241 | db.Dialector = NewShardingDialector(db.Dialector, s)
242 | s.DB = db
243 | s.registerCallbacks(db)
244 |
245 | for t, c := range s.configs {
246 | if c.PrimaryKeyGenerator == PKPGSequence {
247 | err := s.DB.Exec("CREATE SEQUENCE IF NOT EXISTS " + pgSeqName(t)).Error
248 | if err != nil {
249 | return fmt.Errorf("init postgresql sequence error, %w", err)
250 | }
251 | }
252 | if c.PrimaryKeyGenerator == PKMySQLSequence {
253 | err := s.DB.Exec("CREATE TABLE IF NOT EXISTS " + mySQLSeqName(t) + " (id INT NOT NULL)").Error
254 | if err != nil {
255 | return fmt.Errorf("init mysql create sequence error, %w", err)
256 | }
257 | err = s.DB.Exec("INSERT INTO " + mySQLSeqName(t) + " VALUES (0)").Error
258 | if err != nil {
259 | return fmt.Errorf("init mysql insert sequence error, %w", err)
260 | }
261 | }
262 | }
263 |
264 | s.snowflakeNodes = make([]*snowflake.Node, 1024)
265 | for i := int64(0); i < 1024; i++ {
266 | n, err := snowflake.NewNode(i)
267 | if err != nil {
268 | return fmt.Errorf("init snowflake node error, %w", err)
269 | }
270 | s.snowflakeNodes[i] = n
271 | }
272 |
273 | return s.compile()
274 | }
275 |
276 | func (s *Sharding) registerCallbacks(db *gorm.DB) {
277 | s.Callback().Create().Before("*").Register("gorm:sharding", s.switchConn)
278 | s.Callback().Query().Before("*").Register("gorm:sharding", s.switchConn)
279 | s.Callback().Update().Before("*").Register("gorm:sharding", s.switchConn)
280 | s.Callback().Delete().Before("*").Register("gorm:sharding", s.switchConn)
281 | s.Callback().Row().Before("*").Register("gorm:sharding", s.switchConn)
282 | s.Callback().Raw().Before("*").Register("gorm:sharding", s.switchConn)
283 | }
284 |
285 | func (s *Sharding) switchConn(db *gorm.DB) {
286 | // Support ignore sharding in some case, like:
287 | // When DoubleWrite is enabled, we need to query database schema
288 | // information by table name during the migration.
289 | if _, ok := db.Get(ShardingIgnoreStoreKey); !ok {
290 | s.mutex.Lock()
291 | if db.Statement.ConnPool != nil {
292 | s.ConnPool = &ConnPool{ConnPool: db.Statement.ConnPool, sharding: s}
293 | db.Statement.ConnPool = s.ConnPool
294 | }
295 | s.mutex.Unlock()
296 | }
297 | }
298 |
299 | // resolve split the old query to full table query and sharding table query
300 | func (s *Sharding) resolve(query string, args ...any) (ftQuery, stQuery, tableName string, err error) {
301 | ftQuery = query
302 | stQuery = query
303 | if len(s.configs) == 0 {
304 | return
305 | }
306 |
307 | expr, err := sqlparser.NewParser(strings.NewReader(query)).ParseStatement()
308 | if err != nil {
309 | return ftQuery, stQuery, tableName, nil
310 | }
311 |
312 | var table *sqlparser.TableName
313 | var condition sqlparser.Expr
314 | var isInsert bool
315 | var insertNames []*sqlparser.Ident
316 | var insertExpressions []*sqlparser.Exprs
317 | var insertStmt *sqlparser.InsertStatement
318 |
319 | switch stmt := expr.(type) {
320 | case *sqlparser.SelectStatement:
321 | tbl, ok := stmt.FromItems.(*sqlparser.TableName)
322 | if !ok {
323 | return
324 | }
325 | if stmt.Hint != nil && stmt.Hint.Value == "nosharding" {
326 | return
327 | }
328 | table = tbl
329 | condition = stmt.Condition
330 | case *sqlparser.InsertStatement:
331 | table = stmt.TableName
332 | isInsert = true
333 | insertNames = stmt.ColumnNames
334 | insertExpressions = stmt.Expressions
335 | insertStmt = stmt
336 | case *sqlparser.UpdateStatement:
337 | condition = stmt.Condition
338 | table = stmt.TableName
339 | case *sqlparser.DeleteStatement:
340 | condition = stmt.Condition
341 | table = stmt.TableName
342 | default:
343 | return ftQuery, stQuery, "", sqlparser.ErrNotImplemented
344 | }
345 |
346 | tableName = table.Name.Name
347 | r, ok := s.configs[tableName]
348 | if !ok {
349 | return
350 | }
351 |
352 | var suffix string
353 | if isInsert {
354 | var newTable *sqlparser.TableName
355 | for _, insertExpression := range insertExpressions {
356 | var value any
357 | var id int64
358 | var keyFind bool
359 | columnNames := insertNames
360 | insertValues := insertExpression.Exprs
361 | value, id, keyFind, err = s.insertValue(r.ShardingKey, insertNames, insertValues, args...)
362 | if err != nil {
363 | return
364 | }
365 |
366 | var subSuffix string
367 | subSuffix, err = getSuffix(value, id, keyFind, r)
368 | if err != nil {
369 | return
370 | }
371 |
372 | if suffix != "" && suffix != subSuffix {
373 | err = ErrInsertDiffSuffix
374 | return
375 | }
376 |
377 | suffix = subSuffix
378 |
379 | newTable = &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}}
380 |
381 | fillID := true
382 | if isInsert {
383 | for _, name := range insertNames {
384 | if name.Name == "id" {
385 | fillID = false
386 | break
387 | }
388 | }
389 | suffixWord := strings.Replace(suffix, "_", "", 1)
390 | tblIdx, err := strconv.Atoi(suffixWord)
391 | if err != nil {
392 | tblIdx = slices.Index(r.ShardingSuffixs(), suffix)
393 | if tblIdx == -1 {
394 | return ftQuery, stQuery, tableName, errors.New("table suffix '" + suffix + "' is not in ShardingSuffixs. In order to generate the primary key, ShardingSuffixs should include all table suffixes")
395 | }
396 | // return ftQuery, stQuery, tableName, err
397 | }
398 |
399 | id := r.PrimaryKeyGeneratorFn(int64(tblIdx))
400 | if id == 0 {
401 | fillID = false
402 | }
403 |
404 | if fillID {
405 | columnNames = append(insertNames, &sqlparser.Ident{Name: "id"})
406 | insertValues = append(insertValues, &sqlparser.NumberLit{Value: strconv.FormatInt(id, 10)})
407 | }
408 | }
409 |
410 | if fillID {
411 | insertStmt.ColumnNames = columnNames
412 | insertExpression.Exprs = insertValues
413 | }
414 | }
415 |
416 | ftQuery = insertStmt.String()
417 | insertStmt.TableName = newTable
418 | stQuery = insertStmt.String()
419 |
420 | } else {
421 | var value any
422 | var id int64
423 | var keyFind bool
424 | value, id, keyFind, err = s.nonInsertValue(r.ShardingKey, condition, args...)
425 | if err != nil {
426 | return
427 | }
428 |
429 | suffix, err = getSuffix(value, id, keyFind, r)
430 | if err != nil {
431 | return
432 | }
433 |
434 | newTable := &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}}
435 |
436 | switch stmt := expr.(type) {
437 | case *sqlparser.SelectStatement:
438 | ftQuery = stmt.String()
439 | stmt.FromItems = newTable
440 | stmt.OrderBy = replaceOrderByTableName(stmt.OrderBy, tableName, newTable.Name.Name)
441 | replaceTableNameInCondition(stmt.Condition, tableName, newTable.Name.Name)
442 | stQuery = stmt.String()
443 | case *sqlparser.UpdateStatement:
444 | ftQuery = stmt.String()
445 | stmt.TableName = newTable
446 | replaceTableNameInCondition(stmt.Condition, tableName, newTable.Name.Name)
447 | stQuery = stmt.String()
448 | case *sqlparser.DeleteStatement:
449 | ftQuery = stmt.String()
450 | stmt.TableName = newTable
451 | replaceTableNameInCondition(stmt.Condition, tableName, newTable.Name.Name)
452 | stQuery = stmt.String()
453 | }
454 | }
455 |
456 | return
457 | }
458 |
459 | func getSuffix(value any, id int64, keyFind bool, r Config) (suffix string, err error) {
460 | if keyFind {
461 | suffix, err = r.ShardingAlgorithm(value)
462 | if err != nil {
463 | return
464 | }
465 | } else {
466 | if r.ShardingAlgorithmByPrimaryKey == nil {
467 | err = fmt.Errorf("there is not sharding key and ShardingAlgorithmByPrimaryKey is not configured")
468 | return
469 | }
470 | suffix = r.ShardingAlgorithmByPrimaryKey(id)
471 | }
472 |
473 | return
474 | }
475 |
476 | func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sqlparser.Expr, args ...any) (value any, id int64, keyFind bool, err error) {
477 | if len(names) != len(exprs) {
478 | return nil, 0, keyFind, errors.New("column names and expressions mismatch")
479 | }
480 |
481 | for i, name := range names {
482 | if name.Name == key {
483 | switch expr := exprs[i].(type) {
484 | case *sqlparser.BindExpr:
485 | value = args[expr.Pos]
486 | case *sqlparser.StringLit:
487 | value = expr.Value
488 | case *sqlparser.NumberLit:
489 | value = expr.Value
490 | default:
491 | return nil, 0, keyFind, sqlparser.ErrNotImplemented
492 | }
493 | keyFind = true
494 | break
495 | }
496 | }
497 | if !keyFind {
498 | return nil, 0, keyFind, ErrMissingShardingKey
499 | }
500 |
501 | return
502 | }
503 |
504 | func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...any) (value any, id int64, keyFind bool, err error) {
505 | err = sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error {
506 | if n, ok := node.(*sqlparser.BinaryExpr); ok {
507 | x, ok := n.X.(*sqlparser.Ident)
508 | if !ok {
509 | if q, ok2 := n.X.(*sqlparser.QualifiedRef); ok2 {
510 | x = q.Column
511 | ok = true
512 | }
513 | }
514 | if ok {
515 | if x.Name == key && n.Op == sqlparser.EQ {
516 | keyFind = true
517 | switch expr := n.Y.(type) {
518 | case *sqlparser.BindExpr:
519 | value = args[expr.Pos]
520 | case *sqlparser.StringLit:
521 | value = expr.Value
522 | case *sqlparser.NumberLit:
523 | value = expr.Value
524 | default:
525 | return sqlparser.ErrNotImplemented
526 | }
527 | return nil
528 | } else if x.Name == "id" && n.Op == sqlparser.EQ {
529 | switch expr := n.Y.(type) {
530 | case *sqlparser.BindExpr:
531 | v := args[expr.Pos]
532 | var ok bool
533 | if id, ok = v.(int64); !ok {
534 | return fmt.Errorf("ID should be int64 type")
535 | }
536 | case *sqlparser.NumberLit:
537 | id, err = strconv.ParseInt(expr.Value, 10, 64)
538 | if err != nil {
539 | return err
540 | }
541 | default:
542 | return ErrInvalidID
543 | }
544 | return nil
545 | }
546 | }
547 | }
548 | return nil
549 | }), condition)
550 | if err != nil {
551 | return
552 | }
553 |
554 | if !keyFind && id == 0 {
555 | return nil, 0, keyFind, ErrMissingShardingKey
556 | }
557 |
558 | return
559 | }
560 |
561 | func replaceOrderByTableName(orderBy []*sqlparser.OrderingTerm, oldName, newName string) []*sqlparser.OrderingTerm {
562 | for i, term := range orderBy {
563 | if x, ok := term.X.(*sqlparser.QualifiedRef); ok {
564 | if x.Table.Name == oldName {
565 | x.Table.Name = newName
566 | orderBy[i].X = x
567 | }
568 | }
569 | }
570 |
571 | return orderBy
572 | }
573 |
574 | // replaceTableNameInCondition walks the WHERE expression tree
575 | // and renames any qualified column references matching oldName → newName.
576 | func replaceTableNameInCondition(expr sqlparser.Expr, oldName, newName string) {
577 | if expr == nil {
578 | return
579 | }
580 |
581 | _ = sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error {
582 | if qr, ok := node.(*sqlparser.QualifiedRef); ok && qr.Table.Name == oldName {
583 | qr.Table.Name = newName
584 | }
585 |
586 | return nil
587 | }), expr)
588 | }
589 |
--------------------------------------------------------------------------------
/sharding_test.go:
--------------------------------------------------------------------------------
1 | package sharding
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "os"
7 | "regexp"
8 | "sort"
9 | "strconv"
10 | "strings"
11 | "testing"
12 | "time"
13 |
14 | "github.com/bwmarrin/snowflake"
15 | "github.com/longbridgeapp/assert"
16 | "github.com/longbridgeapp/sqlparser"
17 | "gorm.io/driver/mysql"
18 | "gorm.io/driver/postgres"
19 | "gorm.io/gorm"
20 | "gorm.io/hints"
21 | "gorm.io/plugin/dbresolver"
22 | )
23 |
24 | type Order struct {
25 | ID int64 `gorm:"primarykey"`
26 | UserID int64
27 | Product string
28 | Deleted gorm.DeletedAt
29 | }
30 |
31 | type Category struct {
32 | ID int64 `gorm:"primarykey"`
33 | Name string
34 | }
35 |
36 | func dbURL() string {
37 | dbURL := os.Getenv("DB_URL")
38 | if len(dbURL) == 0 {
39 | dbURL = "postgres://localhost:5432/sharding-test?sslmode=disable"
40 | if mysqlDialector() {
41 | dbURL = "root@tcp(127.0.0.1:3306)/sharding-test?charset=utf8mb4"
42 | }
43 | }
44 | return dbURL
45 | }
46 |
47 | func dbNoIDURL() string {
48 | dbURL := os.Getenv("DB_NOID_URL")
49 | if len(dbURL) == 0 {
50 | dbURL = "postgres://localhost:5432/sharding-noid-test?sslmode=disable"
51 | if mysqlDialector() {
52 | dbURL = "root@tcp(127.0.0.1:3306)/sharding-noid-test?charset=utf8mb4"
53 | }
54 | }
55 | return dbURL
56 | }
57 |
58 | func dbReadURL() string {
59 | dbURL := os.Getenv("DB_READ_URL")
60 | if len(dbURL) == 0 {
61 | dbURL = "postgres://localhost:5432/sharding-read-test?sslmode=disable"
62 | if mysqlDialector() {
63 | dbURL = "root@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4"
64 | }
65 | }
66 | return dbURL
67 | }
68 |
69 | func dbWriteURL() string {
70 | dbURL := os.Getenv("DB_WRITE_URL")
71 | if len(dbURL) == 0 {
72 | dbURL = "postgres://localhost:5432/sharding-write-test?sslmode=disable"
73 | if mysqlDialector() {
74 | dbURL = "root@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4"
75 | }
76 | }
77 | return dbURL
78 | }
79 |
80 | var (
81 | dbConfig = postgres.Config{
82 | DSN: dbURL(),
83 | PreferSimpleProtocol: true,
84 | }
85 | dbNoIDConfig = postgres.Config{
86 | DSN: dbNoIDURL(),
87 | PreferSimpleProtocol: true,
88 | }
89 | dbReadConfig = postgres.Config{
90 | DSN: dbReadURL(),
91 | PreferSimpleProtocol: true,
92 | }
93 | dbWriteConfig = postgres.Config{
94 | DSN: dbWriteURL(),
95 | PreferSimpleProtocol: true,
96 | }
97 | db, dbNoID, dbRead, dbWrite *gorm.DB
98 |
99 | shardingConfig, shardingConfigNoID Config
100 | middleware, middlewareNoID *Sharding
101 | node, _ = snowflake.NewNode(1)
102 | )
103 |
104 | func init() {
105 | if mysqlDialector() {
106 | db, _ = gorm.Open(mysql.Open(dbURL()), &gorm.Config{
107 | DisableForeignKeyConstraintWhenMigrating: true,
108 | })
109 | dbNoID, _ = gorm.Open(mysql.Open(dbNoIDURL()), &gorm.Config{
110 | DisableForeignKeyConstraintWhenMigrating: true,
111 | })
112 | dbRead, _ = gorm.Open(mysql.Open(dbReadURL()), &gorm.Config{
113 | DisableForeignKeyConstraintWhenMigrating: true,
114 | })
115 | dbWrite, _ = gorm.Open(mysql.Open(dbWriteURL()), &gorm.Config{
116 | DisableForeignKeyConstraintWhenMigrating: true,
117 | })
118 | } else {
119 | db, _ = gorm.Open(postgres.New(dbConfig), &gorm.Config{
120 | DisableForeignKeyConstraintWhenMigrating: true,
121 | })
122 | dbNoID, _ = gorm.Open(postgres.New(dbNoIDConfig), &gorm.Config{
123 | DisableForeignKeyConstraintWhenMigrating: true,
124 | })
125 | dbRead, _ = gorm.Open(postgres.New(dbReadConfig), &gorm.Config{
126 | DisableForeignKeyConstraintWhenMigrating: true,
127 | })
128 | dbWrite, _ = gorm.Open(postgres.New(dbWriteConfig), &gorm.Config{
129 | DisableForeignKeyConstraintWhenMigrating: true,
130 | })
131 | }
132 |
133 | shardingConfig = Config{
134 | DoubleWrite: true,
135 | ShardingKey: "user_id",
136 | NumberOfShards: 4,
137 | PrimaryKeyGenerator: PKSnowflake,
138 | }
139 |
140 | shardingConfigNoID = Config{
141 | DoubleWrite: true,
142 | ShardingKey: "user_id",
143 | NumberOfShards: 4,
144 | PrimaryKeyGenerator: PKCustom,
145 | PrimaryKeyGeneratorFn: func(_ int64) int64 {
146 | return 0
147 | },
148 | }
149 |
150 | middleware = Register(shardingConfig, &Order{})
151 | middlewareNoID = Register(shardingConfigNoID, &Order{})
152 |
153 | fmt.Println("Clean only tables ...")
154 | dropTables()
155 | fmt.Println("AutoMigrate tables ...")
156 | err := db.AutoMigrate(&Order{}, &Category{})
157 | if err != nil {
158 | panic(err)
159 | }
160 | stables := []string{"orders_0", "orders_1", "orders_2", "orders_3"}
161 | for _, table := range stables {
162 | db.Exec(`CREATE TABLE ` + table + ` (
163 | id bigint PRIMARY KEY,
164 | user_id bigint,
165 | product text,
166 | deleted timestamp NULL
167 | )`)
168 | dbNoID.Exec(`CREATE TABLE ` + table + ` (
169 | user_id bigint,
170 | product text,
171 | deleted timestamp NULL
172 | )`)
173 | dbRead.Exec(`CREATE TABLE ` + table + ` (
174 | id bigint PRIMARY KEY,
175 | user_id bigint,
176 | product text,
177 | deleted timestamp NULL
178 | )`)
179 | dbWrite.Exec(`CREATE TABLE ` + table + ` (
180 | id bigint PRIMARY KEY,
181 | user_id bigint,
182 | product text,
183 | deleted timestamp NULL
184 | )`)
185 | }
186 |
187 | db.Use(middleware)
188 | dbNoID.Use(middlewareNoID)
189 | }
190 |
191 | func dropTables() {
192 | tables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"}
193 | for _, table := range tables {
194 | db.Exec("DROP TABLE IF EXISTS " + table)
195 | dbNoID.Exec("DROP TABLE IF EXISTS " + table)
196 | dbRead.Exec("DROP TABLE IF EXISTS " + table)
197 | dbWrite.Exec("DROP TABLE IF EXISTS " + table)
198 | if mysqlDialector() {
199 | db.Exec(("DROP TABLE IF EXISTS gorm_sharding_" + table + "_id_seq"))
200 | } else {
201 | db.Exec(("DROP SEQUENCE IF EXISTS gorm_sharding_" + table + "_id_seq"))
202 | }
203 | }
204 | }
205 |
206 | func TestMigrate(t *testing.T) {
207 | targetTables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"}
208 | sort.Strings(targetTables)
209 |
210 | // origin tables
211 | tables, _ := db.Migrator().GetTables()
212 | sort.Strings(tables)
213 | assert.Equal(t, tables, targetTables)
214 |
215 | // drop table
216 | db.Migrator().DropTable(Order{}, &Category{})
217 | tables, _ = db.Migrator().GetTables()
218 | assert.Equal(t, len(tables), 0)
219 |
220 | // auto migrate
221 | db.AutoMigrate(&Order{}, &Category{})
222 | tables, _ = db.Migrator().GetTables()
223 | sort.Strings(tables)
224 | assert.Equal(t, tables, targetTables)
225 |
226 | // auto migrate again
227 | err := db.AutoMigrate(&Order{}, &Category{})
228 | assert.Equal[error, error](t, err, nil)
229 | }
230 |
231 | func TestInsert(t *testing.T) {
232 | tx := db.Create(&Order{ID: 100, UserID: 100, Product: "iPhone"})
233 | assertQueryResult(t, `INSERT INTO orders_0 ("user_id", "product", "deleted", "id") VALUES ($1, $2, $3, $4) RETURNING "id"`, tx)
234 | }
235 |
236 | func TestInsertNoID(t *testing.T) {
237 | dbNoID.Create(&Order{UserID: 100, Product: "iPhone"})
238 | expected := `INSERT INTO orders_0 ("user_id", "product", "deleted") VALUES ($1, $2, $3) RETURNING "id"`
239 | assert.Equal(t, toDialect(expected), middlewareNoID.LastQuery())
240 | }
241 |
242 | func TestFillID(t *testing.T) {
243 | db.Create(&Order{UserID: 100, Product: "iPhone"})
244 | expected := `INSERT INTO orders_0 ("user_id", "product", "deleted", id) VALUES`
245 | lastQuery := middleware.LastQuery()
246 | assert.Equal(t, toDialect(expected), lastQuery[0:len(expected)])
247 | }
248 |
249 | func TestInsertManyWithFillID(t *testing.T) {
250 | err := db.Create([]Order{{UserID: 100, Product: "Mac"}, {UserID: 100, Product: "Mac Pro"}}).Error
251 | assert.Equal[error, error](t, err, nil)
252 |
253 | expected := `INSERT INTO orders_0 ("user_id", "product", "deleted", id) VALUES ($1, $2, $3, $sfid), ($4, $5, $6, $sfid) RETURNING "id"`
254 | lastQuery := middleware.LastQuery()
255 | assertSfidQueryResult(t, toDialect(expected), lastQuery)
256 | }
257 |
258 | func TestInsertDiffSuffix(t *testing.T) {
259 | err := db.Create([]Order{{UserID: 100, Product: "Mac"}, {UserID: 101, Product: "Mac Pro"}}).Error
260 | assert.Equal(t, ErrInsertDiffSuffix, err)
261 | }
262 |
263 | func TestSelect1(t *testing.T) {
264 | tx := db.Model(&Order{}).Where("user_id", 101).Where("id", node.Generate().Int64()).Find(&[]Order{})
265 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE "user_id" = $1 AND "id" = $2 AND "orders_1"."deleted" IS NULL`, tx)
266 | }
267 |
268 | func TestSelect2(t *testing.T) {
269 | tx := db.Model(&Order{}).Where("id", node.Generate().Int64()).Where("user_id", 101).Find(&[]Order{})
270 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE "id" = $1 AND "user_id" = $2 AND "orders_1"."deleted" IS NULL`, tx)
271 | }
272 |
273 | func TestSelect3(t *testing.T) {
274 | tx := db.Model(&Order{}).Where("id", node.Generate().Int64()).Where("user_id = 101").Find(&[]Order{})
275 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE "id" = $1 AND user_id = 101 AND "orders_1"."deleted" IS NULL`, tx)
276 | }
277 |
278 | func TestSelect4(t *testing.T) {
279 | tx := db.Model(&Order{}).Where("product", "iPad").Where("user_id", 100).Find(&[]Order{})
280 | assertQueryResult(t, `SELECT * FROM orders_0 WHERE "product" = $1 AND "user_id" = $2 AND "orders_0"."deleted" IS NULL`, tx)
281 | }
282 |
283 | func TestSelect5(t *testing.T) {
284 | tx := db.Model(&Order{}).Where("user_id = 101").Find(&[]Order{})
285 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE user_id = 101 AND "orders_1"."deleted" IS NULL`, tx)
286 | }
287 |
288 | func TestSelect6(t *testing.T) {
289 | tx := db.Model(&Order{}).Where("id", node.Generate().Int64()).Find(&[]Order{})
290 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE "id" = $1 AND "orders_1"."deleted" IS NULL`, tx)
291 | }
292 |
293 | func TestSelect7(t *testing.T) {
294 | tx := db.Model(&Order{}).Where("user_id", 101).Where("id > ?", node.Generate().Int64()).Find(&[]Order{})
295 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE "user_id" = $1 AND id > $2 AND "orders_1"."deleted" IS NULL`, tx)
296 | }
297 |
298 | func TestSelect8(t *testing.T) {
299 | tx := db.Model(&Order{}).Where("id > ?", node.Generate().Int64()).Where("user_id", 101).Find(&[]Order{})
300 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE id > $1 AND "user_id" = $2 AND "orders_1"."deleted" IS NULL`, tx)
301 | }
302 |
303 | func TestSelect9(t *testing.T) {
304 | tx := db.Model(&Order{}).Where("user_id = 101").First(&[]Order{})
305 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE user_id = 101 AND "orders_1"."deleted" IS NULL ORDER BY "orders_1"."id" LIMIT 1`, tx)
306 | }
307 |
308 | func TestSelect10(t *testing.T) {
309 | tx := db.Clauses(hints.Comment("select", "nosharding")).Model(&Order{}).Find(&[]Order{})
310 | assertQueryResult(t, `SELECT /* nosharding */ * FROM "orders" WHERE "orders"."deleted" IS NULL`, tx)
311 | }
312 |
313 | func TestSelect11(t *testing.T) {
314 | tx := db.Clauses(hints.Comment("select", "nosharding")).Model(&Order{}).Where("user_id", 101).Find(&[]Order{})
315 | assertQueryResult(t, `SELECT /* nosharding */ * FROM "orders" WHERE "user_id" = $1 AND "orders"."deleted" IS NULL`, tx)
316 | }
317 |
318 | func TestSelect12(t *testing.T) {
319 | sql := toDialect(`SELECT * FROM "public"."orders" WHERE user_id = 101`)
320 | tx := db.Raw(sql).Find(&[]Order{})
321 | assertQueryResult(t, sql, tx)
322 | }
323 |
324 | func TestSelect13(t *testing.T) {
325 | var n int
326 | tx := db.Raw("SELECT 1").Find(&n)
327 | assertQueryResult(t, `SELECT 1`, tx)
328 | }
329 |
330 | func TestSelect14(t *testing.T) {
331 | dbNoID.Model(&Order{}).Where("user_id = 101").Find(&[]Order{})
332 | expected := `SELECT * FROM orders_1 WHERE user_id = 101 AND "orders_1"."deleted" IS NULL`
333 | assert.Equal(t, toDialect(expected), middlewareNoID.LastQuery())
334 | }
335 |
336 | func TestUpdate(t *testing.T) {
337 | tx := db.Model(&Order{}).Where("user_id = ?", 100).Update("product", "new title")
338 | assertQueryResult(t, `UPDATE orders_0 SET "product" = $1 WHERE user_id = $2 AND "orders_0"."deleted" IS NULL`, tx)
339 | }
340 |
341 | func TestDelete(t *testing.T) {
342 | tx := db.Unscoped().Where("user_id = ?", 100).Delete(&Order{ID: 1})
343 | assertQueryResult(t, `DELETE FROM orders_0 WHERE user_id = $1 AND "orders_0"."id" = $2`, tx)
344 | }
345 |
346 | func TestInsertMissingShardingKey(t *testing.T) {
347 | err := db.Exec(`INSERT INTO "orders" ("id", "product") VALUES(1, 'iPad')`).Error
348 | assert.Equal(t, ErrMissingShardingKey, err)
349 | }
350 |
351 | func TestSelectMissingShardingKey(t *testing.T) {
352 | err := db.Exec(`SELECT * FROM "orders" WHERE "product" = 'iPad'`).Error
353 | assert.Equal(t, ErrMissingShardingKey, err)
354 | }
355 |
356 | func TestSelectNoSharding(t *testing.T) {
357 | sql := toDialect(`SELECT /* nosharding */ * FROM "orders" WHERE "product" = 'iPad'`)
358 | err := db.Exec(sql).Error
359 | assert.Equal[error](t, nil, err)
360 | }
361 |
362 | func TestNoEq(t *testing.T) {
363 | err := db.Model(&Order{}).Where("user_id <> ?", 101).Find([]Order{}).Error
364 | assert.Equal(t, ErrMissingShardingKey, err)
365 | }
366 |
367 | func TestShardingKeyOK(t *testing.T) {
368 | err := db.Model(&Order{}).Where("user_id = ? and id > ?", 101, int64(100)).Find(&[]Order{}).Error
369 | assert.Equal[error](t, nil, err)
370 | }
371 |
372 | func TestShardingKeyNotOK(t *testing.T) {
373 | err := db.Model(&Order{}).Where("user_id > ? and id > ?", 101, int64(100)).Find(&[]Order{}).Error
374 | assert.Equal(t, ErrMissingShardingKey, err)
375 | }
376 |
377 | func TestShardingIdOK(t *testing.T) {
378 | err := db.Model(&Order{}).Where("id = ? and user_id > ?", int64(101), 100).Find(&[]Order{}).Error
379 | assert.Equal[error](t, nil, err)
380 | }
381 |
382 | func TestNoSharding(t *testing.T) {
383 | categories := []Category{}
384 | tx := db.Model(&Category{}).Where("id = ?", 1).Find(&categories)
385 | assertQueryResult(t, `SELECT * FROM "categories" WHERE id = $1`, tx)
386 | }
387 |
388 | func TestPKSnowflake(t *testing.T) {
389 | var db *gorm.DB
390 | if mysqlDialector() {
391 | db, _ = gorm.Open(mysql.Open(dbURL()), &gorm.Config{
392 | DisableForeignKeyConstraintWhenMigrating: true,
393 | })
394 | } else {
395 | db, _ = gorm.Open(postgres.New(dbConfig), &gorm.Config{
396 | DisableForeignKeyConstraintWhenMigrating: true,
397 | })
398 | }
399 | shardingConfig.PrimaryKeyGenerator = PKSnowflake
400 | middleware := Register(shardingConfig, &Order{})
401 | db.Use(middleware)
402 |
403 | node, _ := snowflake.NewNode(0)
404 | sfid := node.Generate().Int64()
405 | expected := fmt.Sprintf(`INSERT INTO orders_0 ("user_id", "product", "deleted", id) VALUES ($1, $2, $3, %d`, sfid)[0:83]
406 | expected = toDialect(expected)
407 |
408 | db.Create(&Order{UserID: 100, Product: "iPhone"})
409 | assert.Equal(t, expected, middleware.LastQuery()[0:len(expected)])
410 | }
411 |
412 | func TestPKPGSequence(t *testing.T) {
413 | if mysqlDialector() {
414 | return
415 | }
416 |
417 | db, _ := gorm.Open(postgres.New(dbConfig), &gorm.Config{
418 | DisableForeignKeyConstraintWhenMigrating: true,
419 | })
420 | shardingConfig.PrimaryKeyGenerator = PKPGSequence
421 | middleware := Register(shardingConfig, &Order{})
422 | db.Use(middleware)
423 |
424 | db.Exec("SELECT setval('" + pgSeqName("orders") + "', 42)")
425 | db.Create(&Order{UserID: 100, Product: "iPhone"})
426 | expected := `INSERT INTO orders_0 ("user_id", "product", "deleted", id) VALUES ($1, $2, $3, 43) RETURNING "id"`
427 | assert.Equal(t, expected, middleware.LastQuery())
428 | }
429 |
430 | func TestPKMySQLSequence(t *testing.T) {
431 | if !mysqlDialector() {
432 | return
433 | }
434 |
435 | db, _ := gorm.Open(mysql.Open(dbURL()), &gorm.Config{
436 | DisableForeignKeyConstraintWhenMigrating: true,
437 | })
438 | shardingConfig.PrimaryKeyGenerator = PKMySQLSequence
439 | middleware := Register(shardingConfig, &Order{})
440 | db.Use(middleware)
441 |
442 | db.Exec("UPDATE `" + mySQLSeqName("orders") + "` SET id = 42")
443 | db.Create(&Order{UserID: 100, Product: "iPhone"})
444 | expected := "INSERT INTO orders_0 (`user_id`, `product`, `deleted`, id) VALUES (?, ?, ?, 43)"
445 | if mariadbDialector() {
446 | expected = expected + " RETURNING `id`"
447 | }
448 | assert.Equal(t, expected, middleware.LastQuery())
449 | }
450 |
451 | func TestReadWriteSplitting(t *testing.T) {
452 | dbRead.Exec("INSERT INTO orders_0 (id, product, user_id) VALUES(1, 'iPad', 100)")
453 | dbWrite.Exec("INSERT INTO orders_0 (id, product, user_id) VALUES(1, 'iPad', 100)")
454 |
455 | var db *gorm.DB
456 | if mysqlDialector() {
457 | db, _ = gorm.Open(mysql.Open(dbWriteURL()), &gorm.Config{
458 | DisableForeignKeyConstraintWhenMigrating: true,
459 | })
460 | } else {
461 | db, _ = gorm.Open(postgres.New(dbWriteConfig), &gorm.Config{
462 | DisableForeignKeyConstraintWhenMigrating: true,
463 | })
464 | }
465 |
466 | db.Use(dbresolver.Register(dbresolver.Config{
467 | Sources: []gorm.Dialector{dbWrite.Dialector},
468 | Replicas: []gorm.Dialector{dbRead.Dialector},
469 | }))
470 | db.Use(middleware)
471 |
472 | var order Order
473 | db.Model(&Order{}).Where("user_id", 100).Find(&order)
474 | assert.Equal(t, "iPad", order.Product)
475 |
476 | db.Model(&Order{}).Where("user_id", 100).Update("product", "iPhone")
477 | db.Clauses(dbresolver.Read).Table("orders_0").Where("user_id", 100).Find(&order)
478 | assert.Equal(t, "iPad", order.Product)
479 |
480 | dbWrite.Table("orders_0").Where("user_id", 100).Find(&order)
481 | assert.Equal(t, "iPhone", order.Product)
482 | }
483 |
484 | func TestDataRace(t *testing.T) {
485 | ctx, cancel := context.WithCancel(context.Background())
486 | ch := make(chan error)
487 |
488 | for i := 0; i < 2; i++ {
489 | go func() {
490 | for {
491 | select {
492 | case <-ctx.Done():
493 | return
494 | default:
495 | err := db.Model(&Order{}).Where("user_id", 100).Find(&[]Order{}).Error
496 | if err != nil {
497 | ch <- err
498 | return
499 | }
500 | }
501 | }
502 | }()
503 | }
504 |
505 | select {
506 | case <-time.After(time.Millisecond * 50):
507 | cancel()
508 | case err := <-ch:
509 | cancel()
510 | t.Fatal(err)
511 | }
512 | }
513 |
514 | func assertQueryResult(t *testing.T, expected string, tx *gorm.DB) {
515 | t.Helper()
516 | assert.Equal(t, toDialect(expected), middleware.LastQuery())
517 | }
518 |
519 | func toDialect(sql string) string {
520 | if os.Getenv("DIALECTOR") == "mysql" {
521 | sql = strings.ReplaceAll(sql, `"`, "`")
522 | r := regexp.MustCompile(`\$([0-9]+)`)
523 | sql = r.ReplaceAllString(sql, "?")
524 | sql = strings.ReplaceAll(sql, " RETURNING `id`", "")
525 | } else if os.Getenv("DIALECTOR") == "mariadb" {
526 | sql = strings.ReplaceAll(sql, `"`, "`")
527 | r := regexp.MustCompile(`\$([0-9]+)`)
528 | sql = r.ReplaceAllString(sql, "?")
529 | }
530 | return sql
531 | }
532 |
533 | // skip $sfid compare
534 | func assertSfidQueryResult(t *testing.T, expected, lastQuery string) {
535 | t.Helper()
536 |
537 | node, _ := snowflake.NewNode(0)
538 | sfid := node.Generate().Int64()
539 | sfidLen := len(strconv.Itoa(int(sfid)))
540 | re := regexp.MustCompile(`\$sfid`)
541 |
542 | for {
543 | match := re.FindStringIndex(expected)
544 | if len(match) == 0 {
545 | break
546 | }
547 |
548 | start := match[0]
549 | end := match[1]
550 |
551 | if len(lastQuery) < start+sfidLen {
552 | break
553 | }
554 |
555 | sfid := lastQuery[start : start+sfidLen]
556 | expected = expected[:start] + sfid + expected[end:]
557 | }
558 |
559 | assert.Equal(t, toDialect(expected), lastQuery)
560 | }
561 |
562 | func mysqlDialector() bool {
563 | return os.Getenv("DIALECTOR") == "mysql" || os.Getenv("DIALECTOR") == "mariadb"
564 | }
565 |
566 | func mariadbDialector() bool {
567 | return os.Getenv("DIALECTOR") == "mariadb"
568 | }
569 |
570 | // parseExpr parses a SQL boolean expression into an AST Expr node.
571 | func parseExpr(t *testing.T, src string) sqlparser.Expr {
572 | p := sqlparser.NewParser(strings.NewReader(src))
573 | expr, err := p.ParseExpr()
574 | if err != nil {
575 | t.Fatalf("failed to parse expr %q: %v", src, err)
576 | }
577 | return expr
578 | }
579 |
580 | func TestReplaceTableNameInCondition(t *testing.T) {
581 | oldName := "orders"
582 | newName := "orders_1"
583 |
584 | tests := []struct {
585 | name string
586 | input string
587 | expected string
588 | }{
589 | {
590 | name: "simple equality",
591 | input: "orders.user_id = 5",
592 | expected: "orders_1.user_id = 5",
593 | },
594 | {
595 | name: "LIKE operator",
596 | input: "orders.product LIKE '%foo%'",
597 | expected: "orders_1.product LIKE '%foo%'",
598 | },
599 | {
600 | name: "IN list",
601 | input: "orders.id IN (1, 2, 3)",
602 | expected: "orders_1.id IN (1, 2, 3)",
603 | },
604 | {
605 | name: "NOT IN list",
606 | input: "orders.id NOT IN (4,5)",
607 | expected: "orders_1.id NOT IN (4, 5)",
608 | },
609 | {
610 | name: "BETWEEN",
611 | input: "orders.age BETWEEN 18 AND 30",
612 | expected: "orders_1.age BETWEEN 18 AND 30",
613 | },
614 | {
615 | name: "nested AND/OR",
616 | input: "(orders.age > 18 AND orders.age < 30) OR orders.id = 7",
617 | expected: "(orders_1.age > 18 AND orders_1.age < 30) OR orders_1.id = 7",
618 | },
619 | {
620 | name: "function call",
621 | input: "UPPER(orders.product) = 'FOO'",
622 | expected: "UPPER(orders_1.product) = 'FOO'",
623 | },
624 | }
625 |
626 | for _, tc := range tests {
627 | t.Run(tc.name, func(t *testing.T) {
628 | expr := parseExpr(t, tc.input)
629 | replaceTableNameInCondition(expr, oldName, newName)
630 | assert.Equal(t, tc.expected, expr.String())
631 | })
632 | }
633 | }
634 |
--------------------------------------------------------------------------------