├── .github ├── dependabot.yml └── workflows │ └── tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── conn_pool.go ├── dialector.go ├── examples └── order.go ├── go.mod ├── go.sum ├── primary_key.go ├── primary_key_test.go ├── sharding.go └── sharding_test.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 2 3 | updates: 4 | - package-ecosystem: gomod 5 | directory: / 6 | schedule: 7 | interval: weekly 8 | - package-ecosystem: github-actions 9 | directory: / 10 | schedule: 11 | interval: weekly 12 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | on: 3 | push: 4 | branches: 5 | - "main" 6 | tags: 7 | - "v*" 8 | pull_request: 9 | jobs: 10 | postgres: 11 | strategy: 12 | matrix: 13 | dbversion: 14 | [ 15 | "postgres:latest", 16 | "postgres:13", 17 | "postgres:12", 18 | "postgres:11", 19 | "postgres:10", 20 | ] 21 | platform: [ubuntu-latest] # can not run in macOS and Windows 22 | runs-on: ${{ matrix.platform }} 23 | 24 | services: 25 | postgres: 26 | image: ${{ matrix.dbversion }} 27 | env: 28 | POSTGRES_DB: sharding-test 29 | POSTGRES_USER: gorm 30 | POSTGRES_PASSWORD: gorm 31 | TZ: Asia/Shanghai 32 | ports: 33 | - 5432:5432 34 | # Set health checks to wait until postgres has started 35 | options: >- 36 | --health-cmd pg_isready 37 | --health-interval 10s 38 | --health-timeout 5s 39 | --health-retries 5 40 | 41 | env: 42 | DIALECTOR: postgres 43 | DB_URL: postgres://gorm:gorm@localhost:5432/sharding-test 44 | DB_NOID_URL: postgres://gorm:gorm@localhost:5432/sharding-noid-test 45 | DB_READ_URL: postgres://gorm:gorm@localhost:5432/sharding-read-test 46 | DB_WRITE_URL: postgres://gorm:gorm@localhost:5432/sharding-write-test 47 | steps: 48 | - name: Set up Go 49 | uses: actions/setup-go@v5 50 | with: 51 | go-version: "1.20" 52 | id: go 53 | 54 | - name: Create No ID Database 55 | run: PGPASSWORD=gorm psql -h localhost -U gorm -d sharding-test -c 'CREATE DATABASE "sharding-noid-test";' 56 | 57 | - name: Create Read Database 58 | run: PGPASSWORD=gorm psql -h localhost -U gorm -d sharding-test -c 'CREATE DATABASE "sharding-read-test";' 59 | 60 | - name: Create Write Databases 61 | run: PGPASSWORD=gorm psql -h localhost -U gorm -d sharding-test -c 'CREATE DATABASE "sharding-write-test";' 62 | 63 | - name: Check out code into the Go module directory 64 | uses: actions/checkout@v4 65 | 66 | - name: Get dependencies 67 | run: | 68 | go get -v -t -d ./... 69 | 70 | - name: Test 71 | run: go test 72 | mysql: 73 | name: MySQL 74 | 75 | strategy: 76 | matrix: 77 | dbversion: ["mysql:latest", "mysql:5.7"] 78 | platform: [ubuntu-latest] 79 | runs-on: ${{ matrix.platform }} 80 | 81 | services: 82 | mysql: 83 | image: ${{ matrix.dbversion }} 84 | env: 85 | MYSQL_DATABASE: sharding-test 86 | MYSQL_USER: gorm 87 | MYSQL_PASSWORD: gorm 88 | MYSQL_ROOT_PASSWORD: gorm 89 | ports: 90 | - 3306:3306 91 | options: >- 92 | --health-cmd "mysqladmin ping -ugorm -pgorm" 93 | --health-interval 10s 94 | --health-start-period 10s 95 | --health-timeout 5s 96 | --health-retries 10 97 | 98 | env: 99 | DIALECTOR: mysql 100 | DB_URL: gorm:gorm@tcp(127.0.0.1:3306)/sharding-test?charset=utf8mb4&parseTime=True&loc=Local 101 | DB_NOID_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-noid-test?charset=utf8mb4&parseTime=True&loc=Local 102 | DB_READ_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4&parseTime=True&loc=Local 103 | DB_WRITE_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4&parseTime=True&loc=Local 104 | steps: 105 | - name: Set up Go 106 | uses: actions/setup-go@v5 107 | with: 108 | go-version: "1.20" 109 | id: go 110 | 111 | - name: Create No ID Database 112 | run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-noid-test 113 | 114 | - name: Create Read Database 115 | run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-read-test 116 | 117 | - name: Create Write Database 118 | run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-write-test 119 | 120 | - name: Check out code into the Go module directory 121 | uses: actions/checkout@v4 122 | 123 | - name: Get dependencies 124 | run: | 125 | go get -v -t -d ./... 126 | 127 | - name: Test 128 | run: go test 129 | mariadb: 130 | name: MariaDB 131 | 132 | strategy: 133 | matrix: 134 | dbversion: ["mariadb:10.11"] 135 | platform: [ubuntu-latest] 136 | runs-on: ${{ matrix.platform }} 137 | 138 | services: 139 | mariadb: 140 | image: ${{ matrix.dbversion }} 141 | env: 142 | MYSQL_DATABASE: sharding-test 143 | MYSQL_USER: gorm 144 | MYSQL_PASSWORD: gorm 145 | MYSQL_ROOT_PASSWORD: gorm 146 | ports: 147 | - 3306:3306 148 | options: >- 149 | --health-cmd "mysqladmin ping -ugorm -pgorm" 150 | --health-interval 10s 151 | --health-start-period 10s 152 | --health-timeout 5s 153 | --health-retries 10 154 | 155 | env: 156 | DIALECTOR: mariadb 157 | DB_URL: gorm:gorm@tcp(127.0.0.1:3306)/sharding-test?charset=utf8mb4&parseTime=True&loc=Local 158 | DB_NOID_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-noid-test?charset=utf8mb4&parseTime=True&loc=Local 159 | DB_READ_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4&parseTime=True&loc=Local 160 | DB_WRITE_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4&parseTime=True&loc=Local 161 | steps: 162 | - name: Set up Go 163 | uses: actions/setup-go@v5 164 | with: 165 | go-version: "1.20" 166 | id: go 167 | 168 | - name: Create No ID Database 169 | run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-noid-test 170 | 171 | - name: Create Read Database 172 | run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-read-test 173 | 174 | - name: Create Write Database 175 | run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-write-test 176 | 177 | - name: Check out code into the Go module directory 178 | uses: actions/checkout@v4 179 | 180 | - name: Get dependencies 181 | run: | 182 | go get -v -t -d ./... 183 | 184 | - name: Test 185 | run: go test 186 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Longbridge 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gorm Sharding 2 | 3 | [![Go](https://github.com/go-gorm/sharding/actions/workflows/tests.yml/badge.svg)](https://github.com/go-gorm/sharding/actions/workflows/tests.yml) 4 | 5 | Gorm Sharding plugin using SQL parser and replace for splits large tables into smaller ones, redirects Query into sharding tables. Give you a high performance database access. 6 | 7 | Gorm Sharding 是一个高性能的数据库分表中间件。 8 | 9 | 它基于 Conn 层做 SQL 拦截、AST 解析、分表路由、自增主键填充,带来的额外开销极小。对开发者友好、透明,使用上与普通 SQL、Gorm 查询无差别,只需要额外注意一下分表键条件。 10 | 11 | ## Features 12 | 13 | - Non-intrusive design. Load the plugin, specify the config, and all done. 14 | - Lighting-fast. No network based middlewares, as fast as Go. 15 | - Multiple database (PostgreSQL, MySQL) support. 16 | - Integrated primary key generator (Snowflake, PostgreSQL Sequence, Custom, ...). 17 | 18 | ## Install 19 | 20 | ```bash 21 | go get -u gorm.io/sharding 22 | ``` 23 | 24 | ## Usage 25 | 26 | Config the sharding middleware, register the tables which you want to shard. 27 | 28 | ```go 29 | import ( 30 | "fmt" 31 | 32 | "gorm.io/driver/postgres" 33 | "gorm.io/gorm" 34 | "gorm.io/sharding" 35 | ) 36 | 37 | db, err := gorm.Open(postgres.New(postgres.Config{DSN: "postgres://localhost:5432/sharding-db?sslmode=disable")) 38 | 39 | db.Use(sharding.Register(sharding.Config{ 40 | ShardingKey: "user_id", 41 | NumberOfShards: 64, 42 | PrimaryKeyGenerator: sharding.PKSnowflake, 43 | }, "orders", Notification{}, AuditLog{})) 44 | // This case for show up give notifications, audit_logs table use same sharding rule. 45 | ``` 46 | 47 | Use the db session as usual. Just note that the query should have the `Sharding Key` when operate sharding tables. 48 | 49 | ```go 50 | // Gorm create example, this will insert to orders_02 51 | db.Create(&Order{UserID: 2}) 52 | // sql: INSERT INTO orders_2 ... 53 | 54 | // Show have use Raw SQL to insert, this will insert into orders_03 55 | db.Exec("INSERT INTO orders(user_id) VALUES(?)", int64(3)) 56 | 57 | // This will throw ErrMissingShardingKey error, because there not have sharding key presented. 58 | db.Create(&Order{Amount: 10, ProductID: 100}) 59 | fmt.Println(err) 60 | 61 | // Find, this will redirect query to orders_02 62 | var orders []Order 63 | db.Model(&Order{}).Where("user_id", int64(2)).Find(&orders) 64 | fmt.Printf("%#v\n", orders) 65 | 66 | // Raw SQL also supported 67 | db.Raw("SELECT * FROM orders WHERE user_id = ?", int64(3)).Scan(&orders) 68 | fmt.Printf("%#v\n", orders) 69 | 70 | // This will throw ErrMissingShardingKey error, because WHERE conditions not included sharding key 71 | err = db.Model(&Order{}).Where("product_id", "1").Find(&orders).Error 72 | fmt.Println(err) 73 | 74 | // Update and Delete are similar to create and query 75 | db.Exec("UPDATE orders SET product_id = ? WHERE user_id = ?", 2, int64(3)) 76 | err = db.Exec("DELETE FROM orders WHERE product_id = 3").Error 77 | fmt.Println(err) // ErrMissingShardingKey 78 | ``` 79 | 80 | The full example is [here](./examples/order.go). 81 | 82 | > 🚨 NOTE: Gorm config `PrepareStmt: true` is not supported for now. 83 | > 84 | > 🚨 NOTE: Default snowflake generator in multiple nodes may result conflicted primary key, use your custom primary key generator, or regenerate a primary key when conflict occurs. 85 | 86 | ## Primary Key 87 | 88 | When you sharding tables, you need consider how the primary key generate. 89 | 90 | Recommend options: 91 | 92 | - [Snowflake](https://github.com/bwmarrin/snowflake) 93 | - [Database sequence by manully](https://www.postgresql.org/docs/current/sql-createsequence.html) 94 | 95 | ### Use Snowflake 96 | 97 | Built-in Snowflake primary key generator. 98 | 99 | ```go 100 | db.Use(sharding.Register(sharding.Config{ 101 | ShardingKey: "user_id", 102 | NumberOfShards: 64, 103 | PrimaryKeyGenerator: sharding.PKSnowflake, 104 | }, "orders") 105 | ``` 106 | 107 | ### Use PostgreSQL Sequence 108 | 109 | There has built-in PostgreSQL sequence primary key implementation in Gorm Sharding, you just configure `PrimaryKeyGenerator: sharding.PKPGSequence` to use. 110 | 111 | You don't need create sequence manually, Gorm Sharding check and create when the PostgreSQL sequence does not exists. 112 | 113 | This sequence name followed `gorm_sharding_${table_name}_id_seq`, for example `orders` table, the sequence name is `gorm_sharding_orders_id_seq`. 114 | 115 | ```go 116 | db.Use(sharding.Register(sharding.Config{ 117 | ShardingKey: "user_id", 118 | NumberOfShards: 64, 119 | PrimaryKeyGenerator: sharding.PKPGSequence, 120 | }, "orders") 121 | ``` 122 | 123 | ### Use MySQL Sequence 124 | 125 | There has built-in MySQL sequence primary key implementation in Gorm Sharding, you just configure `PrimaryKeyGenerator: sharding.PKMySQLSequence` to use. 126 | 127 | You don't need create sequence manually, Gorm Sharding check and create when the MySQL sequence does not exists. 128 | 129 | This sequence name followed `gorm_sharding_${table_name}_id_seq`, for example `orders` table, the sequence name is `gorm_sharding_orders_id_seq`. 130 | 131 | ```go 132 | db.Use(sharding.Register(sharding.Config{ 133 | ShardingKey: "user_id", 134 | NumberOfShards: 64, 135 | PrimaryKeyGenerator: sharding.PKMySQLSequence, 136 | }, "orders") 137 | ``` 138 | 139 | ### No primary key 140 | 141 | If your table doesn't have a primary key, or has a primary key that isn't called `id`, anyway, you don't want to auto-fill the `id` field, then you can set `PrimaryKeyGenerator` to `PKCustom` and have `PrimaryKeyGeneratorFn` return `0`. 142 | 143 | ## Combining with dbresolver 144 | 145 | > 🚨 NOTE: Use dbresolver first. 146 | 147 | ```go 148 | dsn := "host=localhost user=gorm password=gorm dbname=gorm port=5432 sslmode=disable" 149 | dsnRead := "host=localhost user=gorm password=gorm dbname=gorm-slave port=5432 sslmode=disable" 150 | 151 | conn := postgres.Open(dsn) 152 | connRead := postgres.Open(dsnRead) 153 | 154 | db, err := gorm.Open(conn, &gorm.Config{}) 155 | dbRead, err := gorm.Open(conn, &gorm.Config{}) 156 | 157 | db.Use(dbresolver.Register(dbresolver.Config{ 158 | Replicas: []gorm.Dialector{dbRead.Dialector}, 159 | })) 160 | 161 | db.Use(sharding.Register(sharding.Config{ 162 | ShardingKey: "user_id", 163 | NumberOfShards: 64, 164 | PrimaryKeyGenerator: sharding.PKSnowflake, 165 | })) 166 | ``` 167 | 168 | ## Sharding process 169 | 170 | This graph show up how Gorm Sharding works. 171 | 172 | ```mermaid 173 | graph TD 174 | first("SELECT * FROM orders WHERE user_id = ? AND status = ? 175 | args = [100, 1]") 176 | 177 | first--->gorm(["Gorm Query"]) 178 | 179 | subgraph "Gorm" 180 | gorm--->gorm_query 181 | gorm--->gorm_exec 182 | gorm--->gorm_queryrow 183 | gorm_query["connPool.QueryContext(sql, args)"] 184 | gorm_exec[/"connPool.ExecContext"/] 185 | gorm_queryrow[/"connPool.QueryRowContext"/] 186 | end 187 | 188 | subgraph "database/sql" 189 | gorm_query-->conn(["Conn"]) 190 | gorm_exec-->conn(["Conn"]) 191 | gorm_queryrow-->conn(["Conn"]) 192 | ExecContext[/"ExecContext"/] 193 | QueryContext[/"QueryContext"/] 194 | QueryRowContext[/"QueryRowContext"/] 195 | 196 | 197 | conn-->ExecContext 198 | conn-->QueryRowContext 199 | conn-->QueryContext 200 | end 201 | 202 | subgraph sharding ["Sharding"] 203 | QueryContext-->router-->| Format to get full SQL string |format_sql-->| Parser to AST |parse-->check_table 204 | router[["router(sql, args)
"]] 205 | format_sql>"sql = SELECT * FROM orders WHERE user_id = 100 AND status = 1"] 206 | 207 | check_table{"Check sharding rules
by table name"} 208 | check_table-->| Exist |process_ast 209 | check_table_1{{"Return Raw SQL"}} 210 | not_match_error[/"Return Error
SQL query must has sharding key"\] 211 | 212 | parse[["ast = sqlparser.Parse(sql)"]] 213 | 214 | check_table-.->| Not exist |check_table_1 215 | process_ast(("Sharding rules")) 216 | get_new_table_name[["Use value in WhereValue (100) for get sharding table index
orders + (100 % 16)
Sharding Table = orders_4"]] 217 | new_sql{{"SELECT * FROM orders_4 WHERE user_id = 100 AND status = 1"}} 218 | 219 | process_ast-.->| Not match ShardingKey |not_match_error 220 | process_ast-->| Match ShardingKey |match_sharding_key-->| Get table name |get_new_table_name-->| Replace TableName to get new SQL |new_sql 221 | end 222 | 223 | 224 | subgraph database [Database] 225 | orders_other[("orders_0, orders_1 ... orders_3")] 226 | orders_4[(orders_4)] 227 | orders_last[("orders_5 ... orders_15")] 228 | other_tables[(Other non-sharding tables
users, stocks, topics ...)] 229 | 230 | new_sql-->| Sharding Query | orders_4 231 | check_table_1-.->| None sharding Query |other_tables 232 | end 233 | 234 | orders_4-->result 235 | other_tables-.->result 236 | result[/Query results\] 237 | ``` 238 | 239 | ## License 240 | 241 | MIT license. 242 | 243 | Original fork from [Longbridge](https://github.com/longbridgeapp/gorm-sharding). 244 | -------------------------------------------------------------------------------- /conn_pool.go: -------------------------------------------------------------------------------- 1 | package sharding 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "time" 7 | 8 | "gorm.io/gorm" 9 | ) 10 | 11 | // ConnPool Implement a ConnPool for replace db.Statement.ConnPool in Gorm 12 | type ConnPool struct { 13 | // db, This is global db instance 14 | sharding *Sharding 15 | gorm.ConnPool 16 | } 17 | 18 | func (pool *ConnPool) String() string { 19 | return "gorm:sharding:conn_pool" 20 | } 21 | 22 | func (pool ConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { 23 | return pool.ConnPool.PrepareContext(ctx, query) 24 | } 25 | 26 | func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { 27 | var ( 28 | curTime = time.Now() 29 | ) 30 | 31 | ftQuery, stQuery, table, err := pool.sharding.resolve(query, args...) 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | pool.sharding.querys.Store("last_query", stQuery) 37 | 38 | if table != "" { 39 | if r, ok := pool.sharding.configs[table]; ok { 40 | if r.DoubleWrite { 41 | pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) { 42 | result, _ := pool.ConnPool.ExecContext(ctx, ftQuery, args...) 43 | rowsAffected, _ = result.RowsAffected() 44 | return pool.sharding.Explain(ftQuery, args...), rowsAffected 45 | }, pool.sharding.Error) 46 | } 47 | } 48 | } 49 | 50 | var result sql.Result 51 | result, err = pool.ConnPool.ExecContext(ctx, stQuery, args...) 52 | pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) { 53 | rowsAffected, _ = result.RowsAffected() 54 | return pool.sharding.Explain(stQuery, args...), rowsAffected 55 | }, pool.sharding.Error) 56 | 57 | return result, err 58 | } 59 | 60 | // https://github.com/go-gorm/gorm/blob/v1.21.11/callbacks/query.go#L18 61 | func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { 62 | var ( 63 | curTime = time.Now() 64 | ) 65 | 66 | _, stQuery, _, err := pool.sharding.resolve(query, args...) 67 | if err != nil { 68 | return nil, err 69 | } 70 | 71 | pool.sharding.querys.Store("last_query", stQuery) 72 | 73 | var rows *sql.Rows 74 | rows, err = pool.ConnPool.QueryContext(ctx, stQuery, args...) 75 | pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) { 76 | return pool.sharding.Explain(stQuery, args...), 0 77 | }, pool.sharding.Error) 78 | 79 | return rows, err 80 | } 81 | 82 | func (pool ConnPool) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { 83 | _, query, _, _ = pool.sharding.resolve(query, args...) 84 | pool.sharding.querys.Store("last_query", query) 85 | 86 | return pool.ConnPool.QueryRowContext(ctx, query, args...) 87 | } 88 | 89 | // BeginTx Implement ConnPoolBeginner.BeginTx 90 | func (pool *ConnPool) BeginTx(ctx context.Context, opt *sql.TxOptions) (gorm.ConnPool, error) { 91 | if basePool, ok := pool.ConnPool.(gorm.ConnPoolBeginner); ok { 92 | return basePool.BeginTx(ctx, opt) 93 | } 94 | 95 | return pool, nil 96 | } 97 | 98 | // Implement TxCommitter.Commit 99 | func (pool *ConnPool) Commit() error { 100 | if _, ok := pool.ConnPool.(*sql.Tx); ok { 101 | return nil 102 | } 103 | 104 | if basePool, ok := pool.ConnPool.(gorm.TxCommitter); ok { 105 | return basePool.Commit() 106 | } 107 | 108 | return nil 109 | } 110 | 111 | // Implement TxCommitter.Rollback 112 | func (pool *ConnPool) Rollback() error { 113 | if _, ok := pool.ConnPool.(*sql.Tx); ok { 114 | return nil 115 | } 116 | 117 | if basePool, ok := pool.ConnPool.(gorm.TxCommitter); ok { 118 | return basePool.Rollback() 119 | } 120 | 121 | return nil 122 | } 123 | 124 | func (pool *ConnPool) Ping() error { 125 | return nil 126 | } 127 | -------------------------------------------------------------------------------- /dialector.go: -------------------------------------------------------------------------------- 1 | package sharding 2 | 3 | import ( 4 | "fmt" 5 | "gorm.io/gorm/migrator" 6 | "gorm.io/gorm/schema" 7 | 8 | "gorm.io/gorm" 9 | ) 10 | 11 | type ShardingDialector struct { 12 | gorm.Dialector 13 | sharding *Sharding 14 | } 15 | 16 | type ShardingMigrator struct { 17 | gorm.Migrator 18 | sharding *Sharding 19 | dialector gorm.Dialector 20 | } 21 | 22 | func NewShardingDialector(d gorm.Dialector, s *Sharding) ShardingDialector { 23 | return ShardingDialector{ 24 | Dialector: d, 25 | sharding: s, 26 | } 27 | } 28 | 29 | func (d ShardingDialector) Migrator(db *gorm.DB) gorm.Migrator { 30 | m := d.Dialector.Migrator(db) 31 | return ShardingMigrator{ 32 | Migrator: m, 33 | sharding: d.sharding, 34 | dialector: d.Dialector, 35 | } 36 | } 37 | 38 | func (m ShardingMigrator) AutoMigrate(dst ...any) error { 39 | shardingDsts, noShardingDsts, err := m.splitShardingDsts(dst...) 40 | if err != nil { 41 | return err 42 | } 43 | 44 | stmt := &gorm.Statement{DB: m.sharding.DB} 45 | for _, sd := range shardingDsts { 46 | tx := stmt.DB.Session(&gorm.Session{}).Table(sd.table) 47 | if err := m.dialector.Migrator(tx).AutoMigrate(sd.dst); err != nil { 48 | return err 49 | } 50 | 51 | } 52 | 53 | if len(noShardingDsts) > 0 { 54 | tx := stmt.DB.Session(&gorm.Session{}) 55 | tx.Statement.Settings.Store(ShardingIgnoreStoreKey, nil) 56 | defer tx.Statement.Settings.Delete(ShardingIgnoreStoreKey) 57 | 58 | if err := m.dialector.Migrator(tx).AutoMigrate(noShardingDsts...); err != nil { 59 | return err 60 | } 61 | } 62 | 63 | return nil 64 | } 65 | 66 | // BuildIndexOptions build index options 67 | func (m ShardingMigrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { 68 | return m.Migrator.(migrator.BuildIndexOptionsInterface).BuildIndexOptions(opts, stmt) 69 | } 70 | 71 | func (m ShardingMigrator) DropTable(dst ...any) error { 72 | shardingDsts, noShardingDsts, err := m.splitShardingDsts(dst...) 73 | if err != nil { 74 | return err 75 | } 76 | 77 | for _, sd := range shardingDsts { 78 | if err := m.Migrator.DropTable(sd.table); err != nil { 79 | return err 80 | } 81 | } 82 | 83 | if len(noShardingDsts) > 0 { 84 | if err := m.Migrator.DropTable(noShardingDsts...); err != nil { 85 | return err 86 | } 87 | } 88 | 89 | return nil 90 | } 91 | 92 | type shardingDst struct { 93 | table string 94 | dst any 95 | } 96 | 97 | // splite sharding or normal dsts 98 | func (m ShardingMigrator) splitShardingDsts(dsts ...any) (shardingDsts []shardingDst, 99 | noShardingDsts []any, err error) { 100 | 101 | shardingDsts = make([]shardingDst, 0) 102 | noShardingDsts = make([]any, 0) 103 | for _, model := range dsts { 104 | stmt := &gorm.Statement{DB: m.sharding.DB} 105 | err = stmt.Parse(model) 106 | if err != nil { 107 | return 108 | } 109 | 110 | if cfg, ok := m.sharding.configs[stmt.Table]; ok { 111 | // support sharding table 112 | suffixs := cfg.ShardingSuffixs() 113 | if len(suffixs) == 0 { 114 | err = fmt.Errorf("sharding table:%s suffixs is empty", stmt.Table) 115 | return 116 | } 117 | 118 | for _, suffix := range suffixs { 119 | shardingTable := stmt.Table + suffix 120 | shardingDsts = append(shardingDsts, shardingDst{ 121 | table: shardingTable, 122 | dst: model, 123 | }) 124 | } 125 | 126 | if cfg.DoubleWrite { 127 | noShardingDsts = append(noShardingDsts, model) 128 | } 129 | } else { 130 | noShardingDsts = append(noShardingDsts, model) 131 | } 132 | } 133 | return 134 | } 135 | -------------------------------------------------------------------------------- /examples/order.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "gorm.io/driver/postgres" 7 | "gorm.io/gorm" 8 | "gorm.io/sharding" 9 | ) 10 | 11 | type Order struct { 12 | ID int64 `gorm:"primarykey"` 13 | UserID int64 14 | ProductID int64 15 | } 16 | 17 | func main() { 18 | dsn := "postgres://localhost:5432/sharding-db?sslmode=disable" 19 | db, err := gorm.Open(postgres.New(postgres.Config{DSN: dsn})) 20 | if err != nil { 21 | panic(err) 22 | } 23 | 24 | for i := 0; i < 64; i += 1 { 25 | table := fmt.Sprintf("orders_%02d", i) 26 | db.Exec(`DROP TABLE IF EXISTS ` + table) 27 | db.Exec(`CREATE TABLE ` + table + ` ( 28 | id BIGSERIAL PRIMARY KEY, 29 | user_id bigint, 30 | product_id bigint 31 | )`) 32 | } 33 | 34 | middleware := sharding.Register(sharding.Config{ 35 | ShardingKey: "user_id", 36 | NumberOfShards: 64, 37 | PrimaryKeyGenerator: sharding.PKSnowflake, 38 | }, "orders") 39 | db.Use(middleware) 40 | 41 | // this record will insert to orders_02 42 | err = db.Create(&Order{UserID: 2}).Error 43 | if err != nil { 44 | fmt.Println(err) 45 | } 46 | 47 | // this record will insert to orders_03 48 | err = db.Exec("INSERT INTO orders(user_id) VALUES(?)", int64(3)).Error 49 | if err != nil { 50 | fmt.Println(err) 51 | } 52 | 53 | // this will throw ErrMissingShardingKey error 54 | err = db.Exec("INSERT INTO orders(product_id) VALUES(1)").Error 55 | fmt.Println(err) 56 | 57 | // this will redirect query to orders_02 58 | var orders []Order 59 | err = db.Model(&Order{}).Where("user_id", int64(2)).Find(&orders).Error 60 | if err != nil { 61 | fmt.Println(err) 62 | } 63 | fmt.Printf("%#v\n", orders) 64 | 65 | // Raw SQL also supported 66 | db.Raw("SELECT * FROM orders WHERE user_id = ?", int64(3)).Scan(&orders) 67 | fmt.Printf("%#v\n", orders) 68 | 69 | // this will throw ErrMissingShardingKey error 70 | err = db.Model(&Order{}).Where("product_id", "1").Find(&orders).Error 71 | fmt.Println(err) 72 | 73 | // Update and Delete are similar to create and query 74 | err = db.Exec("UPDATE orders SET product_id = ? WHERE user_id = ?", 2, int64(3)).Error 75 | fmt.Println(err) // nil 76 | err = db.Exec("DELETE FROM orders WHERE product_id = 3").Error 77 | fmt.Println(err) // ErrMissingShardingKey 78 | } 79 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module gorm.io/sharding 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/bwmarrin/snowflake v0.3.0 7 | github.com/longbridgeapp/assert v1.1.0 8 | github.com/longbridgeapp/sqlparser v0.3.1 9 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 10 | gorm.io/driver/mysql v1.5.1 11 | gorm.io/driver/postgres v1.5.2 12 | gorm.io/gorm v1.25.4 13 | gorm.io/hints v1.1.2 14 | gorm.io/plugin/dbresolver v1.5.1 15 | ) 16 | 17 | require ( 18 | github.com/davecgh/go-spew v1.1.1 // indirect 19 | github.com/go-sql-driver/mysql v1.7.0 // indirect 20 | github.com/jackc/pgpassfile v1.0.0 // indirect 21 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 22 | github.com/jackc/pgx/v5 v5.3.1 // indirect 23 | github.com/jinzhu/inflection v1.0.0 // indirect 24 | github.com/jinzhu/now v1.1.5 // indirect 25 | github.com/kr/text v0.2.0 // indirect 26 | github.com/pmezard/go-difflib v1.0.0 // indirect 27 | github.com/rogpeppe/go-internal v1.12.0 // indirect 28 | github.com/stretchr/testify v1.8.1 // indirect 29 | golang.org/x/crypto v0.8.0 // indirect 30 | golang.org/x/text v0.9.0 // indirect 31 | gopkg.in/yaml.v3 v3.0.1 // indirect 32 | ) 33 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/bwmarrin/snowflake v0.3.0 h1:xm67bEhkKh6ij1790JB83OujPR5CzNe8QuQqAgISZN0= 2 | github.com/bwmarrin/snowflake v0.3.0/go.mod h1:NdZxfVWX+oR6y2K0o6qAYv6gIOP9rjG0/E9WsDpxqwE= 3 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 6 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= 8 | github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= 9 | github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= 10 | github.com/go-test/deep v1.0.7 h1:/VSMRlnY/JSyqxQUzQLKVMAskpY/NZKFA5j2P+0pP2M= 11 | github.com/go-test/deep v1.0.7/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8= 12 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 13 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 14 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= 15 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= 16 | github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= 17 | github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= 18 | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= 19 | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= 20 | github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= 21 | github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= 22 | github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= 23 | github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= 24 | github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= 25 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 26 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 27 | github.com/longbridgeapp/assert v1.1.0 h1:L+/HISOhuGbNAAmJNXgk3+Tm5QmSB70kwdktJXgjL+I= 28 | github.com/longbridgeapp/assert v1.1.0/go.mod h1:UOI7O3rzlzlz715lQm0atWs6JbrYGuIJUEeOekutL6o= 29 | github.com/longbridgeapp/sqlparser v0.3.1 h1:iWOZWGIFgQrJRgobLXUNJdvqGRpbVXkyKUKUA5CNJBE= 30 | github.com/longbridgeapp/sqlparser v0.3.1/go.mod h1:GIHaUq8zvYyHLCLMJJykx1CdM6LHtkUih/QaJXySSx4= 31 | github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= 32 | github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= 33 | github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= 34 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 35 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 36 | github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= 37 | github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= 38 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 39 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 40 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 41 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 42 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 43 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 44 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 45 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 46 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 47 | golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= 48 | golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= 49 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= 50 | golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= 51 | golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= 52 | golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 53 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 54 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 55 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 56 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 57 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 58 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 59 | gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= 60 | gorm.io/driver/mysql v1.5.1 h1:WUEH5VF9obL/lTtzjmML/5e6VfFR/788coz2uaVCAZw= 61 | gorm.io/driver/mysql v1.5.1/go.mod h1:Jo3Xu7mMhCyj8dlrb3WoCaRd1FhsVh+yMXb1jUInf5o= 62 | gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= 63 | gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= 64 | gorm.io/driver/sqlite v1.5.0 h1:zKYbzRCpBrT1bNijRnxLDJWPjVfImGEn0lSnUY5gZ+c= 65 | gorm.io/driver/sqlite v1.5.0/go.mod h1:kDMDfntV9u/vuMmz8APHtHF0b4nyBB7sfCieC6G8k8I= 66 | gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= 67 | gorm.io/gorm v1.24.7-0.20230306060331-85eaf9eeda11/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= 68 | gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= 69 | gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= 70 | gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= 71 | gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw= 72 | gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= 73 | gorm.io/hints v1.1.2 h1:b5j0kwk5p4+3BtDtYqqfY+ATSxjj+6ptPgVveuynn9o= 74 | gorm.io/hints v1.1.2/go.mod h1:/ARdpUHAtyEMCh5NNi3tI7FsGh+Cj/MIUlvNxCNCFWg= 75 | gorm.io/plugin/dbresolver v1.5.1 h1:s9Dj9f7r+1rE3nx/Ywzc85nXptUEaeOO0pt27xdopM8= 76 | gorm.io/plugin/dbresolver v1.5.1/go.mod h1:l4Cn87EHLEYuqUncpEeTC2tTJQkjngPSD+lo8hIvcT0= 77 | -------------------------------------------------------------------------------- /primary_key.go: -------------------------------------------------------------------------------- 1 | package sharding 2 | 3 | import "fmt" 4 | 5 | const ( 6 | // Use Snowflake primary key generator 7 | PKSnowflake = iota 8 | // Use PostgreSQL sequence primary key generator 9 | PKPGSequence 10 | // Use MySQL sequence primary key generator 11 | PKMySQLSequence 12 | // Use custom primary key generator 13 | PKCustom 14 | ) 15 | 16 | func (s *Sharding) genSnowflakeKey(index int64) int64 { 17 | return s.snowflakeNodes[index].Generate().Int64() 18 | } 19 | 20 | // PostgreSQL sequence 21 | 22 | func (s *Sharding) genPostgreSQLSequenceKey(tableName string, index int64) int64 { 23 | var id int64 24 | err := s.DB.Raw("SELECT nextval('" + pgSeqName(tableName) + "')").Scan(&id).Error 25 | if err != nil { 26 | panic(err) 27 | } 28 | return id 29 | } 30 | 31 | func (s *Sharding) createPostgreSQLSequenceKeyIfNotExist(tableName string) error { 32 | return s.DB.Exec(`CREATE SEQUENCE IF NOT EXISTS "` + pgSeqName(tableName) + `" START 1`).Error 33 | } 34 | 35 | func pgSeqName(table string) string { 36 | return fmt.Sprintf("gorm_sharding_%s_id_seq", table) 37 | } 38 | 39 | // MySQL Sequence 40 | 41 | func (s *Sharding) genMySQLSequenceKey(tableName string, index int64) int64 { 42 | var id int64 43 | err := s.DB.Exec("UPDATE `" + mySQLSeqName(tableName) + "` SET id = LAST_INSERT_ID(id + 1)").Error 44 | if err != nil { 45 | panic(err) 46 | } 47 | err = s.DB.Raw("SELECT LAST_INSERT_ID()").Scan(&id).Error 48 | if err != nil { 49 | panic(err) 50 | } 51 | return id 52 | } 53 | 54 | func (s *Sharding) createMySQLSequenceKeyIfNotExist(tableName string) error { 55 | stmt := s.DB.Exec("CREATE TABLE IF NOT EXISTS `" + mySQLSeqName(tableName) + "` (id INT NOT NULL)") 56 | if stmt.Error != nil { 57 | return fmt.Errorf("failed to create sequence table: %w", stmt.Error) 58 | } 59 | stmt = s.DB.Exec("INSERT INTO `" + mySQLSeqName(tableName) + "` VALUES (0)") 60 | if stmt.Error != nil { 61 | return fmt.Errorf("failed to insert into sequence table: %w", stmt.Error) 62 | } 63 | return nil 64 | } 65 | 66 | func mySQLSeqName(table string) string { 67 | return fmt.Sprintf("gorm_sharding_%s_id_seq", table) 68 | } 69 | -------------------------------------------------------------------------------- /primary_key_test.go: -------------------------------------------------------------------------------- 1 | package sharding 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/longbridgeapp/assert" 7 | ) 8 | 9 | func Test_pgSeqName(t *testing.T) { 10 | assert.Equal(t, "gorm_sharding_users_id_seq", pgSeqName("users")) 11 | } 12 | -------------------------------------------------------------------------------- /sharding.go: -------------------------------------------------------------------------------- 1 | package sharding 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "hash/crc32" 7 | "strconv" 8 | "strings" 9 | "sync" 10 | 11 | "github.com/bwmarrin/snowflake" 12 | "github.com/longbridgeapp/sqlparser" 13 | "golang.org/x/exp/slices" 14 | "gorm.io/gorm" 15 | ) 16 | 17 | var ( 18 | ErrMissingShardingKey = errors.New("sharding key or id required, and use operator =") 19 | ErrInvalidID = errors.New("invalid id format") 20 | ErrInsertDiffSuffix = errors.New("can not insert different suffix table in one query ") 21 | ) 22 | 23 | var ( 24 | ShardingIgnoreStoreKey = "sharding_ignore" 25 | ) 26 | 27 | type Sharding struct { 28 | *gorm.DB 29 | ConnPool *ConnPool 30 | configs map[string]Config 31 | querys sync.Map 32 | snowflakeNodes []*snowflake.Node 33 | 34 | _config Config 35 | _tables []any 36 | 37 | mutex sync.RWMutex 38 | } 39 | 40 | // Config specifies the configuration for sharding. 41 | type Config struct { 42 | // When DoubleWrite enabled, data will double write to both main table and sharding table. 43 | DoubleWrite bool 44 | 45 | // ShardingKey specifies the table column you want to used for sharding the table rows. 46 | // For example, for a product order table, you may want to split the rows by `user_id`. 47 | ShardingKey string 48 | 49 | // NumberOfShards specifies how many tables you want to sharding. 50 | NumberOfShards uint 51 | 52 | // tableFormat specifies the sharding table suffix format. 53 | tableFormat string 54 | 55 | // ShardingAlgorithm specifies a function to generate the sharding 56 | // table's suffix by the column value. 57 | // For example, this function implements a mod sharding algorithm. 58 | // 59 | // func(value any) (suffix string, err error) { 60 | // if uid, ok := value.(int64);ok { 61 | // return fmt.Sprintf("_%02d", user_id % 64), nil 62 | // } 63 | // return "", errors.New("invalid user_id") 64 | // } 65 | ShardingAlgorithm func(columnValue any) (suffix string, err error) 66 | 67 | // ShardingSuffixs specifies a function to generate all table's suffix. 68 | // Used to support Migrator and generate PrimaryKey. 69 | // For example, this function get a mod all sharding suffixs. 70 | // 71 | // func () (suffixs []string) { 72 | // numberOfShards := 5 73 | // for i := 0; i < numberOfShards; i++ { 74 | // suffixs = append(suffixs, fmt.Sprintf("_%02d", i%numberOfShards)) 75 | // } 76 | // return 77 | // } 78 | ShardingSuffixs func() (suffixs []string) 79 | 80 | // ShardingAlgorithmByPrimaryKey specifies a function to generate the sharding 81 | // table's suffix by the primary key. Used when no sharding key specified. 82 | // For example, this function use the Snowflake library to generate the suffix. 83 | // 84 | // func(id int64) (suffix string) { 85 | // return fmt.Sprintf("_%02d", snowflake.ParseInt64(id).Node()) 86 | // } 87 | ShardingAlgorithmByPrimaryKey func(id int64) (suffix string) 88 | 89 | // PrimaryKeyGenerator specifies the primary key generate algorithm. 90 | // Used only when insert and the record does not contains an id field. 91 | // Options are PKSnowflake, PKPGSequence and PKCustom. 92 | // When use PKCustom, you should also specify PrimaryKeyGeneratorFn. 93 | PrimaryKeyGenerator int 94 | 95 | // PrimaryKeyGeneratorFn specifies a function to generate the primary key. 96 | // When use auto-increment like generator, the tableIdx argument could ignored. 97 | // For example, this function use the Snowflake library to generate the primary key. 98 | // If you don't want to auto-fill the `id` or use a primary key that isn't called `id`, just return 0. 99 | // 100 | // func(tableIdx int64) int64 { 101 | // return nodes[tableIdx].Generate().Int64() 102 | // } 103 | PrimaryKeyGeneratorFn func(tableIdx int64) int64 104 | } 105 | 106 | func Register(config Config, tables ...any) *Sharding { 107 | return &Sharding{ 108 | _config: config, 109 | _tables: tables, 110 | } 111 | } 112 | 113 | func (s *Sharding) compile() error { 114 | if s.configs == nil { 115 | s.configs = make(map[string]Config) 116 | } 117 | for _, table := range s._tables { 118 | if t, ok := table.(string); ok { 119 | s.configs[t] = s._config 120 | } else { 121 | stmt := &gorm.Statement{DB: s.DB} 122 | if err := stmt.Parse(table); err == nil { 123 | s.configs[stmt.Table] = s._config 124 | } else { 125 | return err 126 | } 127 | } 128 | } 129 | 130 | for t, c := range s.configs { 131 | if c.NumberOfShards > 1024 && c.PrimaryKeyGenerator == PKSnowflake { 132 | panic("Snowflake NumberOfShards should less than 1024") 133 | } 134 | 135 | if c.PrimaryKeyGenerator == PKSnowflake { 136 | c.PrimaryKeyGeneratorFn = s.genSnowflakeKey 137 | } else if c.PrimaryKeyGenerator == PKPGSequence { 138 | 139 | // Execute SQL to CREATE SEQUENCE for this table if not exist 140 | err := s.createPostgreSQLSequenceKeyIfNotExist(t) 141 | if err != nil { 142 | return err 143 | } 144 | 145 | c.PrimaryKeyGeneratorFn = func(index int64) int64 { 146 | return s.genPostgreSQLSequenceKey(t, index) 147 | } 148 | } else if c.PrimaryKeyGenerator == PKMySQLSequence { 149 | err := s.createMySQLSequenceKeyIfNotExist(t) 150 | if err != nil { 151 | return err 152 | } 153 | 154 | c.PrimaryKeyGeneratorFn = func(index int64) int64 { 155 | return s.genMySQLSequenceKey(t, index) 156 | } 157 | } else if c.PrimaryKeyGenerator == PKCustom { 158 | if c.PrimaryKeyGeneratorFn == nil { 159 | return errors.New("PrimaryKeyGeneratorFn is required when use PKCustom") 160 | } 161 | } else { 162 | return errors.New("PrimaryKeyGenerator can only be one of PKSnowflake, PKPGSequence, PKMySQLSequence and PKCustom") 163 | } 164 | 165 | if c.ShardingAlgorithm == nil { 166 | if c.NumberOfShards == 0 { 167 | return errors.New("specify NumberOfShards or ShardingAlgorithm") 168 | } 169 | if c.NumberOfShards < 10 { 170 | c.tableFormat = "_%01d" 171 | } else if c.NumberOfShards < 100 { 172 | c.tableFormat = "_%02d" 173 | } else if c.NumberOfShards < 1000 { 174 | c.tableFormat = "_%03d" 175 | } else if c.NumberOfShards < 10000 { 176 | c.tableFormat = "_%04d" 177 | } 178 | c.ShardingAlgorithm = func(value any) (suffix string, err error) { 179 | id := 0 180 | switch value := value.(type) { 181 | case int: 182 | id = value 183 | case int64: 184 | id = int(value) 185 | case string: 186 | id, err = strconv.Atoi(value) 187 | if err != nil { 188 | id = int(crc32.ChecksumIEEE([]byte(value))) 189 | } 190 | default: 191 | return "", fmt.Errorf("default algorithm only support integer and string column," + 192 | "if you use other type, specify you own ShardingAlgorithm") 193 | } 194 | 195 | return fmt.Sprintf(c.tableFormat, id%int(c.NumberOfShards)), nil 196 | } 197 | } 198 | 199 | if c.ShardingSuffixs == nil { 200 | c.ShardingSuffixs = func() (suffixs []string) { 201 | for i := 0; i < int(c.NumberOfShards); i++ { 202 | suffix, err := c.ShardingAlgorithm(i) 203 | if err != nil { 204 | return nil 205 | } 206 | suffixs = append(suffixs, suffix) 207 | } 208 | return 209 | } 210 | } 211 | 212 | if c.ShardingAlgorithmByPrimaryKey == nil { 213 | if c.PrimaryKeyGenerator == PKSnowflake { 214 | c.ShardingAlgorithmByPrimaryKey = func(id int64) (suffix string) { 215 | return fmt.Sprintf(c.tableFormat, snowflake.ParseInt64(id).Node()) 216 | } 217 | } 218 | } 219 | s.configs[t] = c 220 | } 221 | 222 | return nil 223 | } 224 | 225 | // Name plugin name for Gorm plugin interface 226 | func (s *Sharding) Name() string { 227 | return "gorm:sharding" 228 | } 229 | 230 | // LastQuery get last SQL query 231 | func (s *Sharding) LastQuery() string { 232 | if query, ok := s.querys.Load("last_query"); ok { 233 | return query.(string) 234 | } 235 | 236 | return "" 237 | } 238 | 239 | // Initialize implement for Gorm plugin interface 240 | func (s *Sharding) Initialize(db *gorm.DB) error { 241 | db.Dialector = NewShardingDialector(db.Dialector, s) 242 | s.DB = db 243 | s.registerCallbacks(db) 244 | 245 | for t, c := range s.configs { 246 | if c.PrimaryKeyGenerator == PKPGSequence { 247 | err := s.DB.Exec("CREATE SEQUENCE IF NOT EXISTS " + pgSeqName(t)).Error 248 | if err != nil { 249 | return fmt.Errorf("init postgresql sequence error, %w", err) 250 | } 251 | } 252 | if c.PrimaryKeyGenerator == PKMySQLSequence { 253 | err := s.DB.Exec("CREATE TABLE IF NOT EXISTS " + mySQLSeqName(t) + " (id INT NOT NULL)").Error 254 | if err != nil { 255 | return fmt.Errorf("init mysql create sequence error, %w", err) 256 | } 257 | err = s.DB.Exec("INSERT INTO " + mySQLSeqName(t) + " VALUES (0)").Error 258 | if err != nil { 259 | return fmt.Errorf("init mysql insert sequence error, %w", err) 260 | } 261 | } 262 | } 263 | 264 | s.snowflakeNodes = make([]*snowflake.Node, 1024) 265 | for i := int64(0); i < 1024; i++ { 266 | n, err := snowflake.NewNode(i) 267 | if err != nil { 268 | return fmt.Errorf("init snowflake node error, %w", err) 269 | } 270 | s.snowflakeNodes[i] = n 271 | } 272 | 273 | return s.compile() 274 | } 275 | 276 | func (s *Sharding) registerCallbacks(db *gorm.DB) { 277 | s.Callback().Create().Before("*").Register("gorm:sharding", s.switchConn) 278 | s.Callback().Query().Before("*").Register("gorm:sharding", s.switchConn) 279 | s.Callback().Update().Before("*").Register("gorm:sharding", s.switchConn) 280 | s.Callback().Delete().Before("*").Register("gorm:sharding", s.switchConn) 281 | s.Callback().Row().Before("*").Register("gorm:sharding", s.switchConn) 282 | s.Callback().Raw().Before("*").Register("gorm:sharding", s.switchConn) 283 | } 284 | 285 | func (s *Sharding) switchConn(db *gorm.DB) { 286 | // Support ignore sharding in some case, like: 287 | // When DoubleWrite is enabled, we need to query database schema 288 | // information by table name during the migration. 289 | if _, ok := db.Get(ShardingIgnoreStoreKey); !ok { 290 | s.mutex.Lock() 291 | if db.Statement.ConnPool != nil { 292 | s.ConnPool = &ConnPool{ConnPool: db.Statement.ConnPool, sharding: s} 293 | db.Statement.ConnPool = s.ConnPool 294 | } 295 | s.mutex.Unlock() 296 | } 297 | } 298 | 299 | // resolve split the old query to full table query and sharding table query 300 | func (s *Sharding) resolve(query string, args ...any) (ftQuery, stQuery, tableName string, err error) { 301 | ftQuery = query 302 | stQuery = query 303 | if len(s.configs) == 0 { 304 | return 305 | } 306 | 307 | expr, err := sqlparser.NewParser(strings.NewReader(query)).ParseStatement() 308 | if err != nil { 309 | return ftQuery, stQuery, tableName, nil 310 | } 311 | 312 | var table *sqlparser.TableName 313 | var condition sqlparser.Expr 314 | var isInsert bool 315 | var insertNames []*sqlparser.Ident 316 | var insertExpressions []*sqlparser.Exprs 317 | var insertStmt *sqlparser.InsertStatement 318 | 319 | switch stmt := expr.(type) { 320 | case *sqlparser.SelectStatement: 321 | tbl, ok := stmt.FromItems.(*sqlparser.TableName) 322 | if !ok { 323 | return 324 | } 325 | if stmt.Hint != nil && stmt.Hint.Value == "nosharding" { 326 | return 327 | } 328 | table = tbl 329 | condition = stmt.Condition 330 | case *sqlparser.InsertStatement: 331 | table = stmt.TableName 332 | isInsert = true 333 | insertNames = stmt.ColumnNames 334 | insertExpressions = stmt.Expressions 335 | insertStmt = stmt 336 | case *sqlparser.UpdateStatement: 337 | condition = stmt.Condition 338 | table = stmt.TableName 339 | case *sqlparser.DeleteStatement: 340 | condition = stmt.Condition 341 | table = stmt.TableName 342 | default: 343 | return ftQuery, stQuery, "", sqlparser.ErrNotImplemented 344 | } 345 | 346 | tableName = table.Name.Name 347 | r, ok := s.configs[tableName] 348 | if !ok { 349 | return 350 | } 351 | 352 | var suffix string 353 | if isInsert { 354 | var newTable *sqlparser.TableName 355 | for _, insertExpression := range insertExpressions { 356 | var value any 357 | var id int64 358 | var keyFind bool 359 | columnNames := insertNames 360 | insertValues := insertExpression.Exprs 361 | value, id, keyFind, err = s.insertValue(r.ShardingKey, insertNames, insertValues, args...) 362 | if err != nil { 363 | return 364 | } 365 | 366 | var subSuffix string 367 | subSuffix, err = getSuffix(value, id, keyFind, r) 368 | if err != nil { 369 | return 370 | } 371 | 372 | if suffix != "" && suffix != subSuffix { 373 | err = ErrInsertDiffSuffix 374 | return 375 | } 376 | 377 | suffix = subSuffix 378 | 379 | newTable = &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}} 380 | 381 | fillID := true 382 | if isInsert { 383 | for _, name := range insertNames { 384 | if name.Name == "id" { 385 | fillID = false 386 | break 387 | } 388 | } 389 | suffixWord := strings.Replace(suffix, "_", "", 1) 390 | tblIdx, err := strconv.Atoi(suffixWord) 391 | if err != nil { 392 | tblIdx = slices.Index(r.ShardingSuffixs(), suffix) 393 | if tblIdx == -1 { 394 | return ftQuery, stQuery, tableName, errors.New("table suffix '" + suffix + "' is not in ShardingSuffixs. In order to generate the primary key, ShardingSuffixs should include all table suffixes") 395 | } 396 | // return ftQuery, stQuery, tableName, err 397 | } 398 | 399 | id := r.PrimaryKeyGeneratorFn(int64(tblIdx)) 400 | if id == 0 { 401 | fillID = false 402 | } 403 | 404 | if fillID { 405 | columnNames = append(insertNames, &sqlparser.Ident{Name: "id"}) 406 | insertValues = append(insertValues, &sqlparser.NumberLit{Value: strconv.FormatInt(id, 10)}) 407 | } 408 | } 409 | 410 | if fillID { 411 | insertStmt.ColumnNames = columnNames 412 | insertExpression.Exprs = insertValues 413 | } 414 | } 415 | 416 | ftQuery = insertStmt.String() 417 | insertStmt.TableName = newTable 418 | stQuery = insertStmt.String() 419 | 420 | } else { 421 | var value any 422 | var id int64 423 | var keyFind bool 424 | value, id, keyFind, err = s.nonInsertValue(r.ShardingKey, condition, args...) 425 | if err != nil { 426 | return 427 | } 428 | 429 | suffix, err = getSuffix(value, id, keyFind, r) 430 | if err != nil { 431 | return 432 | } 433 | 434 | newTable := &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}} 435 | 436 | switch stmt := expr.(type) { 437 | case *sqlparser.SelectStatement: 438 | ftQuery = stmt.String() 439 | stmt.FromItems = newTable 440 | stmt.OrderBy = replaceOrderByTableName(stmt.OrderBy, tableName, newTable.Name.Name) 441 | replaceTableNameInCondition(stmt.Condition, tableName, newTable.Name.Name) 442 | stQuery = stmt.String() 443 | case *sqlparser.UpdateStatement: 444 | ftQuery = stmt.String() 445 | stmt.TableName = newTable 446 | replaceTableNameInCondition(stmt.Condition, tableName, newTable.Name.Name) 447 | stQuery = stmt.String() 448 | case *sqlparser.DeleteStatement: 449 | ftQuery = stmt.String() 450 | stmt.TableName = newTable 451 | replaceTableNameInCondition(stmt.Condition, tableName, newTable.Name.Name) 452 | stQuery = stmt.String() 453 | } 454 | } 455 | 456 | return 457 | } 458 | 459 | func getSuffix(value any, id int64, keyFind bool, r Config) (suffix string, err error) { 460 | if keyFind { 461 | suffix, err = r.ShardingAlgorithm(value) 462 | if err != nil { 463 | return 464 | } 465 | } else { 466 | if r.ShardingAlgorithmByPrimaryKey == nil { 467 | err = fmt.Errorf("there is not sharding key and ShardingAlgorithmByPrimaryKey is not configured") 468 | return 469 | } 470 | suffix = r.ShardingAlgorithmByPrimaryKey(id) 471 | } 472 | 473 | return 474 | } 475 | 476 | func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sqlparser.Expr, args ...any) (value any, id int64, keyFind bool, err error) { 477 | if len(names) != len(exprs) { 478 | return nil, 0, keyFind, errors.New("column names and expressions mismatch") 479 | } 480 | 481 | for i, name := range names { 482 | if name.Name == key { 483 | switch expr := exprs[i].(type) { 484 | case *sqlparser.BindExpr: 485 | value = args[expr.Pos] 486 | case *sqlparser.StringLit: 487 | value = expr.Value 488 | case *sqlparser.NumberLit: 489 | value = expr.Value 490 | default: 491 | return nil, 0, keyFind, sqlparser.ErrNotImplemented 492 | } 493 | keyFind = true 494 | break 495 | } 496 | } 497 | if !keyFind { 498 | return nil, 0, keyFind, ErrMissingShardingKey 499 | } 500 | 501 | return 502 | } 503 | 504 | func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...any) (value any, id int64, keyFind bool, err error) { 505 | err = sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error { 506 | if n, ok := node.(*sqlparser.BinaryExpr); ok { 507 | x, ok := n.X.(*sqlparser.Ident) 508 | if !ok { 509 | if q, ok2 := n.X.(*sqlparser.QualifiedRef); ok2 { 510 | x = q.Column 511 | ok = true 512 | } 513 | } 514 | if ok { 515 | if x.Name == key && n.Op == sqlparser.EQ { 516 | keyFind = true 517 | switch expr := n.Y.(type) { 518 | case *sqlparser.BindExpr: 519 | value = args[expr.Pos] 520 | case *sqlparser.StringLit: 521 | value = expr.Value 522 | case *sqlparser.NumberLit: 523 | value = expr.Value 524 | default: 525 | return sqlparser.ErrNotImplemented 526 | } 527 | return nil 528 | } else if x.Name == "id" && n.Op == sqlparser.EQ { 529 | switch expr := n.Y.(type) { 530 | case *sqlparser.BindExpr: 531 | v := args[expr.Pos] 532 | var ok bool 533 | if id, ok = v.(int64); !ok { 534 | return fmt.Errorf("ID should be int64 type") 535 | } 536 | case *sqlparser.NumberLit: 537 | id, err = strconv.ParseInt(expr.Value, 10, 64) 538 | if err != nil { 539 | return err 540 | } 541 | default: 542 | return ErrInvalidID 543 | } 544 | return nil 545 | } 546 | } 547 | } 548 | return nil 549 | }), condition) 550 | if err != nil { 551 | return 552 | } 553 | 554 | if !keyFind && id == 0 { 555 | return nil, 0, keyFind, ErrMissingShardingKey 556 | } 557 | 558 | return 559 | } 560 | 561 | func replaceOrderByTableName(orderBy []*sqlparser.OrderingTerm, oldName, newName string) []*sqlparser.OrderingTerm { 562 | for i, term := range orderBy { 563 | if x, ok := term.X.(*sqlparser.QualifiedRef); ok { 564 | if x.Table.Name == oldName { 565 | x.Table.Name = newName 566 | orderBy[i].X = x 567 | } 568 | } 569 | } 570 | 571 | return orderBy 572 | } 573 | 574 | // replaceTableNameInCondition walks the WHERE expression tree 575 | // and renames any qualified column references matching oldName → newName. 576 | func replaceTableNameInCondition(expr sqlparser.Expr, oldName, newName string) { 577 | if expr == nil { 578 | return 579 | } 580 | 581 | _ = sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error { 582 | if qr, ok := node.(*sqlparser.QualifiedRef); ok && qr.Table.Name == oldName { 583 | qr.Table.Name = newName 584 | } 585 | 586 | return nil 587 | }), expr) 588 | } 589 | -------------------------------------------------------------------------------- /sharding_test.go: -------------------------------------------------------------------------------- 1 | package sharding 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "regexp" 8 | "sort" 9 | "strconv" 10 | "strings" 11 | "testing" 12 | "time" 13 | 14 | "github.com/bwmarrin/snowflake" 15 | "github.com/longbridgeapp/assert" 16 | "github.com/longbridgeapp/sqlparser" 17 | "gorm.io/driver/mysql" 18 | "gorm.io/driver/postgres" 19 | "gorm.io/gorm" 20 | "gorm.io/hints" 21 | "gorm.io/plugin/dbresolver" 22 | ) 23 | 24 | type Order struct { 25 | ID int64 `gorm:"primarykey"` 26 | UserID int64 27 | Product string 28 | Deleted gorm.DeletedAt 29 | } 30 | 31 | type Category struct { 32 | ID int64 `gorm:"primarykey"` 33 | Name string 34 | } 35 | 36 | func dbURL() string { 37 | dbURL := os.Getenv("DB_URL") 38 | if len(dbURL) == 0 { 39 | dbURL = "postgres://localhost:5432/sharding-test?sslmode=disable" 40 | if mysqlDialector() { 41 | dbURL = "root@tcp(127.0.0.1:3306)/sharding-test?charset=utf8mb4" 42 | } 43 | } 44 | return dbURL 45 | } 46 | 47 | func dbNoIDURL() string { 48 | dbURL := os.Getenv("DB_NOID_URL") 49 | if len(dbURL) == 0 { 50 | dbURL = "postgres://localhost:5432/sharding-noid-test?sslmode=disable" 51 | if mysqlDialector() { 52 | dbURL = "root@tcp(127.0.0.1:3306)/sharding-noid-test?charset=utf8mb4" 53 | } 54 | } 55 | return dbURL 56 | } 57 | 58 | func dbReadURL() string { 59 | dbURL := os.Getenv("DB_READ_URL") 60 | if len(dbURL) == 0 { 61 | dbURL = "postgres://localhost:5432/sharding-read-test?sslmode=disable" 62 | if mysqlDialector() { 63 | dbURL = "root@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4" 64 | } 65 | } 66 | return dbURL 67 | } 68 | 69 | func dbWriteURL() string { 70 | dbURL := os.Getenv("DB_WRITE_URL") 71 | if len(dbURL) == 0 { 72 | dbURL = "postgres://localhost:5432/sharding-write-test?sslmode=disable" 73 | if mysqlDialector() { 74 | dbURL = "root@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4" 75 | } 76 | } 77 | return dbURL 78 | } 79 | 80 | var ( 81 | dbConfig = postgres.Config{ 82 | DSN: dbURL(), 83 | PreferSimpleProtocol: true, 84 | } 85 | dbNoIDConfig = postgres.Config{ 86 | DSN: dbNoIDURL(), 87 | PreferSimpleProtocol: true, 88 | } 89 | dbReadConfig = postgres.Config{ 90 | DSN: dbReadURL(), 91 | PreferSimpleProtocol: true, 92 | } 93 | dbWriteConfig = postgres.Config{ 94 | DSN: dbWriteURL(), 95 | PreferSimpleProtocol: true, 96 | } 97 | db, dbNoID, dbRead, dbWrite *gorm.DB 98 | 99 | shardingConfig, shardingConfigNoID Config 100 | middleware, middlewareNoID *Sharding 101 | node, _ = snowflake.NewNode(1) 102 | ) 103 | 104 | func init() { 105 | if mysqlDialector() { 106 | db, _ = gorm.Open(mysql.Open(dbURL()), &gorm.Config{ 107 | DisableForeignKeyConstraintWhenMigrating: true, 108 | }) 109 | dbNoID, _ = gorm.Open(mysql.Open(dbNoIDURL()), &gorm.Config{ 110 | DisableForeignKeyConstraintWhenMigrating: true, 111 | }) 112 | dbRead, _ = gorm.Open(mysql.Open(dbReadURL()), &gorm.Config{ 113 | DisableForeignKeyConstraintWhenMigrating: true, 114 | }) 115 | dbWrite, _ = gorm.Open(mysql.Open(dbWriteURL()), &gorm.Config{ 116 | DisableForeignKeyConstraintWhenMigrating: true, 117 | }) 118 | } else { 119 | db, _ = gorm.Open(postgres.New(dbConfig), &gorm.Config{ 120 | DisableForeignKeyConstraintWhenMigrating: true, 121 | }) 122 | dbNoID, _ = gorm.Open(postgres.New(dbNoIDConfig), &gorm.Config{ 123 | DisableForeignKeyConstraintWhenMigrating: true, 124 | }) 125 | dbRead, _ = gorm.Open(postgres.New(dbReadConfig), &gorm.Config{ 126 | DisableForeignKeyConstraintWhenMigrating: true, 127 | }) 128 | dbWrite, _ = gorm.Open(postgres.New(dbWriteConfig), &gorm.Config{ 129 | DisableForeignKeyConstraintWhenMigrating: true, 130 | }) 131 | } 132 | 133 | shardingConfig = Config{ 134 | DoubleWrite: true, 135 | ShardingKey: "user_id", 136 | NumberOfShards: 4, 137 | PrimaryKeyGenerator: PKSnowflake, 138 | } 139 | 140 | shardingConfigNoID = Config{ 141 | DoubleWrite: true, 142 | ShardingKey: "user_id", 143 | NumberOfShards: 4, 144 | PrimaryKeyGenerator: PKCustom, 145 | PrimaryKeyGeneratorFn: func(_ int64) int64 { 146 | return 0 147 | }, 148 | } 149 | 150 | middleware = Register(shardingConfig, &Order{}) 151 | middlewareNoID = Register(shardingConfigNoID, &Order{}) 152 | 153 | fmt.Println("Clean only tables ...") 154 | dropTables() 155 | fmt.Println("AutoMigrate tables ...") 156 | err := db.AutoMigrate(&Order{}, &Category{}) 157 | if err != nil { 158 | panic(err) 159 | } 160 | stables := []string{"orders_0", "orders_1", "orders_2", "orders_3"} 161 | for _, table := range stables { 162 | db.Exec(`CREATE TABLE ` + table + ` ( 163 | id bigint PRIMARY KEY, 164 | user_id bigint, 165 | product text, 166 | deleted timestamp NULL 167 | )`) 168 | dbNoID.Exec(`CREATE TABLE ` + table + ` ( 169 | user_id bigint, 170 | product text, 171 | deleted timestamp NULL 172 | )`) 173 | dbRead.Exec(`CREATE TABLE ` + table + ` ( 174 | id bigint PRIMARY KEY, 175 | user_id bigint, 176 | product text, 177 | deleted timestamp NULL 178 | )`) 179 | dbWrite.Exec(`CREATE TABLE ` + table + ` ( 180 | id bigint PRIMARY KEY, 181 | user_id bigint, 182 | product text, 183 | deleted timestamp NULL 184 | )`) 185 | } 186 | 187 | db.Use(middleware) 188 | dbNoID.Use(middlewareNoID) 189 | } 190 | 191 | func dropTables() { 192 | tables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"} 193 | for _, table := range tables { 194 | db.Exec("DROP TABLE IF EXISTS " + table) 195 | dbNoID.Exec("DROP TABLE IF EXISTS " + table) 196 | dbRead.Exec("DROP TABLE IF EXISTS " + table) 197 | dbWrite.Exec("DROP TABLE IF EXISTS " + table) 198 | if mysqlDialector() { 199 | db.Exec(("DROP TABLE IF EXISTS gorm_sharding_" + table + "_id_seq")) 200 | } else { 201 | db.Exec(("DROP SEQUENCE IF EXISTS gorm_sharding_" + table + "_id_seq")) 202 | } 203 | } 204 | } 205 | 206 | func TestMigrate(t *testing.T) { 207 | targetTables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"} 208 | sort.Strings(targetTables) 209 | 210 | // origin tables 211 | tables, _ := db.Migrator().GetTables() 212 | sort.Strings(tables) 213 | assert.Equal(t, tables, targetTables) 214 | 215 | // drop table 216 | db.Migrator().DropTable(Order{}, &Category{}) 217 | tables, _ = db.Migrator().GetTables() 218 | assert.Equal(t, len(tables), 0) 219 | 220 | // auto migrate 221 | db.AutoMigrate(&Order{}, &Category{}) 222 | tables, _ = db.Migrator().GetTables() 223 | sort.Strings(tables) 224 | assert.Equal(t, tables, targetTables) 225 | 226 | // auto migrate again 227 | err := db.AutoMigrate(&Order{}, &Category{}) 228 | assert.Equal[error, error](t, err, nil) 229 | } 230 | 231 | func TestInsert(t *testing.T) { 232 | tx := db.Create(&Order{ID: 100, UserID: 100, Product: "iPhone"}) 233 | assertQueryResult(t, `INSERT INTO orders_0 ("user_id", "product", "deleted", "id") VALUES ($1, $2, $3, $4) RETURNING "id"`, tx) 234 | } 235 | 236 | func TestInsertNoID(t *testing.T) { 237 | dbNoID.Create(&Order{UserID: 100, Product: "iPhone"}) 238 | expected := `INSERT INTO orders_0 ("user_id", "product", "deleted") VALUES ($1, $2, $3) RETURNING "id"` 239 | assert.Equal(t, toDialect(expected), middlewareNoID.LastQuery()) 240 | } 241 | 242 | func TestFillID(t *testing.T) { 243 | db.Create(&Order{UserID: 100, Product: "iPhone"}) 244 | expected := `INSERT INTO orders_0 ("user_id", "product", "deleted", id) VALUES` 245 | lastQuery := middleware.LastQuery() 246 | assert.Equal(t, toDialect(expected), lastQuery[0:len(expected)]) 247 | } 248 | 249 | func TestInsertManyWithFillID(t *testing.T) { 250 | err := db.Create([]Order{{UserID: 100, Product: "Mac"}, {UserID: 100, Product: "Mac Pro"}}).Error 251 | assert.Equal[error, error](t, err, nil) 252 | 253 | expected := `INSERT INTO orders_0 ("user_id", "product", "deleted", id) VALUES ($1, $2, $3, $sfid), ($4, $5, $6, $sfid) RETURNING "id"` 254 | lastQuery := middleware.LastQuery() 255 | assertSfidQueryResult(t, toDialect(expected), lastQuery) 256 | } 257 | 258 | func TestInsertDiffSuffix(t *testing.T) { 259 | err := db.Create([]Order{{UserID: 100, Product: "Mac"}, {UserID: 101, Product: "Mac Pro"}}).Error 260 | assert.Equal(t, ErrInsertDiffSuffix, err) 261 | } 262 | 263 | func TestSelect1(t *testing.T) { 264 | tx := db.Model(&Order{}).Where("user_id", 101).Where("id", node.Generate().Int64()).Find(&[]Order{}) 265 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE "user_id" = $1 AND "id" = $2 AND "orders_1"."deleted" IS NULL`, tx) 266 | } 267 | 268 | func TestSelect2(t *testing.T) { 269 | tx := db.Model(&Order{}).Where("id", node.Generate().Int64()).Where("user_id", 101).Find(&[]Order{}) 270 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE "id" = $1 AND "user_id" = $2 AND "orders_1"."deleted" IS NULL`, tx) 271 | } 272 | 273 | func TestSelect3(t *testing.T) { 274 | tx := db.Model(&Order{}).Where("id", node.Generate().Int64()).Where("user_id = 101").Find(&[]Order{}) 275 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE "id" = $1 AND user_id = 101 AND "orders_1"."deleted" IS NULL`, tx) 276 | } 277 | 278 | func TestSelect4(t *testing.T) { 279 | tx := db.Model(&Order{}).Where("product", "iPad").Where("user_id", 100).Find(&[]Order{}) 280 | assertQueryResult(t, `SELECT * FROM orders_0 WHERE "product" = $1 AND "user_id" = $2 AND "orders_0"."deleted" IS NULL`, tx) 281 | } 282 | 283 | func TestSelect5(t *testing.T) { 284 | tx := db.Model(&Order{}).Where("user_id = 101").Find(&[]Order{}) 285 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE user_id = 101 AND "orders_1"."deleted" IS NULL`, tx) 286 | } 287 | 288 | func TestSelect6(t *testing.T) { 289 | tx := db.Model(&Order{}).Where("id", node.Generate().Int64()).Find(&[]Order{}) 290 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE "id" = $1 AND "orders_1"."deleted" IS NULL`, tx) 291 | } 292 | 293 | func TestSelect7(t *testing.T) { 294 | tx := db.Model(&Order{}).Where("user_id", 101).Where("id > ?", node.Generate().Int64()).Find(&[]Order{}) 295 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE "user_id" = $1 AND id > $2 AND "orders_1"."deleted" IS NULL`, tx) 296 | } 297 | 298 | func TestSelect8(t *testing.T) { 299 | tx := db.Model(&Order{}).Where("id > ?", node.Generate().Int64()).Where("user_id", 101).Find(&[]Order{}) 300 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE id > $1 AND "user_id" = $2 AND "orders_1"."deleted" IS NULL`, tx) 301 | } 302 | 303 | func TestSelect9(t *testing.T) { 304 | tx := db.Model(&Order{}).Where("user_id = 101").First(&[]Order{}) 305 | assertQueryResult(t, `SELECT * FROM orders_1 WHERE user_id = 101 AND "orders_1"."deleted" IS NULL ORDER BY "orders_1"."id" LIMIT 1`, tx) 306 | } 307 | 308 | func TestSelect10(t *testing.T) { 309 | tx := db.Clauses(hints.Comment("select", "nosharding")).Model(&Order{}).Find(&[]Order{}) 310 | assertQueryResult(t, `SELECT /* nosharding */ * FROM "orders" WHERE "orders"."deleted" IS NULL`, tx) 311 | } 312 | 313 | func TestSelect11(t *testing.T) { 314 | tx := db.Clauses(hints.Comment("select", "nosharding")).Model(&Order{}).Where("user_id", 101).Find(&[]Order{}) 315 | assertQueryResult(t, `SELECT /* nosharding */ * FROM "orders" WHERE "user_id" = $1 AND "orders"."deleted" IS NULL`, tx) 316 | } 317 | 318 | func TestSelect12(t *testing.T) { 319 | sql := toDialect(`SELECT * FROM "public"."orders" WHERE user_id = 101`) 320 | tx := db.Raw(sql).Find(&[]Order{}) 321 | assertQueryResult(t, sql, tx) 322 | } 323 | 324 | func TestSelect13(t *testing.T) { 325 | var n int 326 | tx := db.Raw("SELECT 1").Find(&n) 327 | assertQueryResult(t, `SELECT 1`, tx) 328 | } 329 | 330 | func TestSelect14(t *testing.T) { 331 | dbNoID.Model(&Order{}).Where("user_id = 101").Find(&[]Order{}) 332 | expected := `SELECT * FROM orders_1 WHERE user_id = 101 AND "orders_1"."deleted" IS NULL` 333 | assert.Equal(t, toDialect(expected), middlewareNoID.LastQuery()) 334 | } 335 | 336 | func TestUpdate(t *testing.T) { 337 | tx := db.Model(&Order{}).Where("user_id = ?", 100).Update("product", "new title") 338 | assertQueryResult(t, `UPDATE orders_0 SET "product" = $1 WHERE user_id = $2 AND "orders_0"."deleted" IS NULL`, tx) 339 | } 340 | 341 | func TestDelete(t *testing.T) { 342 | tx := db.Unscoped().Where("user_id = ?", 100).Delete(&Order{ID: 1}) 343 | assertQueryResult(t, `DELETE FROM orders_0 WHERE user_id = $1 AND "orders_0"."id" = $2`, tx) 344 | } 345 | 346 | func TestInsertMissingShardingKey(t *testing.T) { 347 | err := db.Exec(`INSERT INTO "orders" ("id", "product") VALUES(1, 'iPad')`).Error 348 | assert.Equal(t, ErrMissingShardingKey, err) 349 | } 350 | 351 | func TestSelectMissingShardingKey(t *testing.T) { 352 | err := db.Exec(`SELECT * FROM "orders" WHERE "product" = 'iPad'`).Error 353 | assert.Equal(t, ErrMissingShardingKey, err) 354 | } 355 | 356 | func TestSelectNoSharding(t *testing.T) { 357 | sql := toDialect(`SELECT /* nosharding */ * FROM "orders" WHERE "product" = 'iPad'`) 358 | err := db.Exec(sql).Error 359 | assert.Equal[error](t, nil, err) 360 | } 361 | 362 | func TestNoEq(t *testing.T) { 363 | err := db.Model(&Order{}).Where("user_id <> ?", 101).Find([]Order{}).Error 364 | assert.Equal(t, ErrMissingShardingKey, err) 365 | } 366 | 367 | func TestShardingKeyOK(t *testing.T) { 368 | err := db.Model(&Order{}).Where("user_id = ? and id > ?", 101, int64(100)).Find(&[]Order{}).Error 369 | assert.Equal[error](t, nil, err) 370 | } 371 | 372 | func TestShardingKeyNotOK(t *testing.T) { 373 | err := db.Model(&Order{}).Where("user_id > ? and id > ?", 101, int64(100)).Find(&[]Order{}).Error 374 | assert.Equal(t, ErrMissingShardingKey, err) 375 | } 376 | 377 | func TestShardingIdOK(t *testing.T) { 378 | err := db.Model(&Order{}).Where("id = ? and user_id > ?", int64(101), 100).Find(&[]Order{}).Error 379 | assert.Equal[error](t, nil, err) 380 | } 381 | 382 | func TestNoSharding(t *testing.T) { 383 | categories := []Category{} 384 | tx := db.Model(&Category{}).Where("id = ?", 1).Find(&categories) 385 | assertQueryResult(t, `SELECT * FROM "categories" WHERE id = $1`, tx) 386 | } 387 | 388 | func TestPKSnowflake(t *testing.T) { 389 | var db *gorm.DB 390 | if mysqlDialector() { 391 | db, _ = gorm.Open(mysql.Open(dbURL()), &gorm.Config{ 392 | DisableForeignKeyConstraintWhenMigrating: true, 393 | }) 394 | } else { 395 | db, _ = gorm.Open(postgres.New(dbConfig), &gorm.Config{ 396 | DisableForeignKeyConstraintWhenMigrating: true, 397 | }) 398 | } 399 | shardingConfig.PrimaryKeyGenerator = PKSnowflake 400 | middleware := Register(shardingConfig, &Order{}) 401 | db.Use(middleware) 402 | 403 | node, _ := snowflake.NewNode(0) 404 | sfid := node.Generate().Int64() 405 | expected := fmt.Sprintf(`INSERT INTO orders_0 ("user_id", "product", "deleted", id) VALUES ($1, $2, $3, %d`, sfid)[0:83] 406 | expected = toDialect(expected) 407 | 408 | db.Create(&Order{UserID: 100, Product: "iPhone"}) 409 | assert.Equal(t, expected, middleware.LastQuery()[0:len(expected)]) 410 | } 411 | 412 | func TestPKPGSequence(t *testing.T) { 413 | if mysqlDialector() { 414 | return 415 | } 416 | 417 | db, _ := gorm.Open(postgres.New(dbConfig), &gorm.Config{ 418 | DisableForeignKeyConstraintWhenMigrating: true, 419 | }) 420 | shardingConfig.PrimaryKeyGenerator = PKPGSequence 421 | middleware := Register(shardingConfig, &Order{}) 422 | db.Use(middleware) 423 | 424 | db.Exec("SELECT setval('" + pgSeqName("orders") + "', 42)") 425 | db.Create(&Order{UserID: 100, Product: "iPhone"}) 426 | expected := `INSERT INTO orders_0 ("user_id", "product", "deleted", id) VALUES ($1, $2, $3, 43) RETURNING "id"` 427 | assert.Equal(t, expected, middleware.LastQuery()) 428 | } 429 | 430 | func TestPKMySQLSequence(t *testing.T) { 431 | if !mysqlDialector() { 432 | return 433 | } 434 | 435 | db, _ := gorm.Open(mysql.Open(dbURL()), &gorm.Config{ 436 | DisableForeignKeyConstraintWhenMigrating: true, 437 | }) 438 | shardingConfig.PrimaryKeyGenerator = PKMySQLSequence 439 | middleware := Register(shardingConfig, &Order{}) 440 | db.Use(middleware) 441 | 442 | db.Exec("UPDATE `" + mySQLSeqName("orders") + "` SET id = 42") 443 | db.Create(&Order{UserID: 100, Product: "iPhone"}) 444 | expected := "INSERT INTO orders_0 (`user_id`, `product`, `deleted`, id) VALUES (?, ?, ?, 43)" 445 | if mariadbDialector() { 446 | expected = expected + " RETURNING `id`" 447 | } 448 | assert.Equal(t, expected, middleware.LastQuery()) 449 | } 450 | 451 | func TestReadWriteSplitting(t *testing.T) { 452 | dbRead.Exec("INSERT INTO orders_0 (id, product, user_id) VALUES(1, 'iPad', 100)") 453 | dbWrite.Exec("INSERT INTO orders_0 (id, product, user_id) VALUES(1, 'iPad', 100)") 454 | 455 | var db *gorm.DB 456 | if mysqlDialector() { 457 | db, _ = gorm.Open(mysql.Open(dbWriteURL()), &gorm.Config{ 458 | DisableForeignKeyConstraintWhenMigrating: true, 459 | }) 460 | } else { 461 | db, _ = gorm.Open(postgres.New(dbWriteConfig), &gorm.Config{ 462 | DisableForeignKeyConstraintWhenMigrating: true, 463 | }) 464 | } 465 | 466 | db.Use(dbresolver.Register(dbresolver.Config{ 467 | Sources: []gorm.Dialector{dbWrite.Dialector}, 468 | Replicas: []gorm.Dialector{dbRead.Dialector}, 469 | })) 470 | db.Use(middleware) 471 | 472 | var order Order 473 | db.Model(&Order{}).Where("user_id", 100).Find(&order) 474 | assert.Equal(t, "iPad", order.Product) 475 | 476 | db.Model(&Order{}).Where("user_id", 100).Update("product", "iPhone") 477 | db.Clauses(dbresolver.Read).Table("orders_0").Where("user_id", 100).Find(&order) 478 | assert.Equal(t, "iPad", order.Product) 479 | 480 | dbWrite.Table("orders_0").Where("user_id", 100).Find(&order) 481 | assert.Equal(t, "iPhone", order.Product) 482 | } 483 | 484 | func TestDataRace(t *testing.T) { 485 | ctx, cancel := context.WithCancel(context.Background()) 486 | ch := make(chan error) 487 | 488 | for i := 0; i < 2; i++ { 489 | go func() { 490 | for { 491 | select { 492 | case <-ctx.Done(): 493 | return 494 | default: 495 | err := db.Model(&Order{}).Where("user_id", 100).Find(&[]Order{}).Error 496 | if err != nil { 497 | ch <- err 498 | return 499 | } 500 | } 501 | } 502 | }() 503 | } 504 | 505 | select { 506 | case <-time.After(time.Millisecond * 50): 507 | cancel() 508 | case err := <-ch: 509 | cancel() 510 | t.Fatal(err) 511 | } 512 | } 513 | 514 | func assertQueryResult(t *testing.T, expected string, tx *gorm.DB) { 515 | t.Helper() 516 | assert.Equal(t, toDialect(expected), middleware.LastQuery()) 517 | } 518 | 519 | func toDialect(sql string) string { 520 | if os.Getenv("DIALECTOR") == "mysql" { 521 | sql = strings.ReplaceAll(sql, `"`, "`") 522 | r := regexp.MustCompile(`\$([0-9]+)`) 523 | sql = r.ReplaceAllString(sql, "?") 524 | sql = strings.ReplaceAll(sql, " RETURNING `id`", "") 525 | } else if os.Getenv("DIALECTOR") == "mariadb" { 526 | sql = strings.ReplaceAll(sql, `"`, "`") 527 | r := regexp.MustCompile(`\$([0-9]+)`) 528 | sql = r.ReplaceAllString(sql, "?") 529 | } 530 | return sql 531 | } 532 | 533 | // skip $sfid compare 534 | func assertSfidQueryResult(t *testing.T, expected, lastQuery string) { 535 | t.Helper() 536 | 537 | node, _ := snowflake.NewNode(0) 538 | sfid := node.Generate().Int64() 539 | sfidLen := len(strconv.Itoa(int(sfid))) 540 | re := regexp.MustCompile(`\$sfid`) 541 | 542 | for { 543 | match := re.FindStringIndex(expected) 544 | if len(match) == 0 { 545 | break 546 | } 547 | 548 | start := match[0] 549 | end := match[1] 550 | 551 | if len(lastQuery) < start+sfidLen { 552 | break 553 | } 554 | 555 | sfid := lastQuery[start : start+sfidLen] 556 | expected = expected[:start] + sfid + expected[end:] 557 | } 558 | 559 | assert.Equal(t, toDialect(expected), lastQuery) 560 | } 561 | 562 | func mysqlDialector() bool { 563 | return os.Getenv("DIALECTOR") == "mysql" || os.Getenv("DIALECTOR") == "mariadb" 564 | } 565 | 566 | func mariadbDialector() bool { 567 | return os.Getenv("DIALECTOR") == "mariadb" 568 | } 569 | 570 | // parseExpr parses a SQL boolean expression into an AST Expr node. 571 | func parseExpr(t *testing.T, src string) sqlparser.Expr { 572 | p := sqlparser.NewParser(strings.NewReader(src)) 573 | expr, err := p.ParseExpr() 574 | if err != nil { 575 | t.Fatalf("failed to parse expr %q: %v", src, err) 576 | } 577 | return expr 578 | } 579 | 580 | func TestReplaceTableNameInCondition(t *testing.T) { 581 | oldName := "orders" 582 | newName := "orders_1" 583 | 584 | tests := []struct { 585 | name string 586 | input string 587 | expected string 588 | }{ 589 | { 590 | name: "simple equality", 591 | input: "orders.user_id = 5", 592 | expected: "orders_1.user_id = 5", 593 | }, 594 | { 595 | name: "LIKE operator", 596 | input: "orders.product LIKE '%foo%'", 597 | expected: "orders_1.product LIKE '%foo%'", 598 | }, 599 | { 600 | name: "IN list", 601 | input: "orders.id IN (1, 2, 3)", 602 | expected: "orders_1.id IN (1, 2, 3)", 603 | }, 604 | { 605 | name: "NOT IN list", 606 | input: "orders.id NOT IN (4,5)", 607 | expected: "orders_1.id NOT IN (4, 5)", 608 | }, 609 | { 610 | name: "BETWEEN", 611 | input: "orders.age BETWEEN 18 AND 30", 612 | expected: "orders_1.age BETWEEN 18 AND 30", 613 | }, 614 | { 615 | name: "nested AND/OR", 616 | input: "(orders.age > 18 AND orders.age < 30) OR orders.id = 7", 617 | expected: "(orders_1.age > 18 AND orders_1.age < 30) OR orders_1.id = 7", 618 | }, 619 | { 620 | name: "function call", 621 | input: "UPPER(orders.product) = 'FOO'", 622 | expected: "UPPER(orders_1.product) = 'FOO'", 623 | }, 624 | } 625 | 626 | for _, tc := range tests { 627 | t.Run(tc.name, func(t *testing.T) { 628 | expr := parseExpr(t, tc.input) 629 | replaceTableNameInCondition(expr, oldName, newName) 630 | assert.Equal(t, tc.expected, expr.String()) 631 | }) 632 | } 633 | } 634 | --------------------------------------------------------------------------------