├── .github └── dependabot.yml ├── .gitignore ├── License ├── README.md ├── callbacks.go ├── clauses.go ├── database.go ├── dbresolver.go ├── dbresolver_test.go ├── docker-compose.yml ├── go.mod ├── go.sum ├── logger.go ├── policy.go ├── policy_test.go ├── resolver.go ├── utils.go └── utils_test.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "github-actions" 9 | directory: "/" 10 | schedule: 11 | interval: "daily" 12 | - package-ecosystem: "gomod" 13 | directory: "/" 14 | schedule: 15 | interval: "daily" 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2013-NOW Jinzhu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DBResolver 2 | 3 | DBResolver adds multiple databases support to GORM, the following features are supported: 4 | 5 | * Multiple sources, replicas 6 | * Read/Write Splitting 7 | * Automatic connection switching based on the working table/struct 8 | * Manual connection switching 9 | * Sources/Replicas load balancing 10 | * Works for RAW SQL 11 | * Transaction 12 | 13 | ## Quick Start 14 | 15 | ```go 16 | import ( 17 | "gorm.io/gorm" 18 | "gorm.io/plugin/dbresolver" 19 | "gorm.io/driver/mysql" 20 | ) 21 | 22 | DB, err := gorm.Open(mysql.Open("db1_dsn"), &gorm.Config{}) 23 | 24 | DB.Use(dbresolver.Register(dbresolver.Config{ 25 | // use `db2` as sources, `db3`, `db4` as replicas 26 | Sources: []gorm.Dialector{mysql.Open("db2_dsn")}, 27 | Replicas: []gorm.Dialector{mysql.Open("db3_dsn"), mysql.Open("db4_dsn")}, 28 | // sources/replicas load balancing policy 29 | Policy: dbresolver.RandomPolicy{}, 30 | // print sources/replicas mode in logger 31 | ResolverModeReplica: true, 32 | }).Register(dbresolver.Config{ 33 | // use `db1` as sources (DB's default connection), `db5` as replicas for `User`, `Address` 34 | Replicas: []gorm.Dialector{mysql.Open("db5_dsn")}, 35 | }, &User{}, &Address{}).Register(dbresolver.Config{ 36 | // use `db6`, `db7` as sources, `db8` as replicas for `orders`, `Product` 37 | Sources: []gorm.Dialector{mysql.Open("db6_dsn"), mysql.Open("db7_dsn")}, 38 | Replicas: []gorm.Dialector{mysql.Open("db8_dsn")}, 39 | }, "orders", &Product{}, "secondary")) 40 | ``` 41 | 42 | ### Automatic connection switching 43 | 44 | DBResolver will automatically switch connections based on the working table/struct 45 | 46 | For RAW SQL, DBResolver will extract the table name from the SQL to match the resolver, and will use `sources` unless the SQL begins with `SELECT`, for example: 47 | 48 | ```go 49 | // `User` Resolver Examples 50 | DB.Table("users").Rows() // replicas `db5` 51 | DB.Model(&User{}).Find(&AdvancedUser{}) // replicas `db5` 52 | DB.Exec("update users set name = ?", "jinzhu") // sources `db1` 53 | DB.Raw("select name from users").Row().Scan(&name) // replicas `db5` 54 | DB.Create(&user) // sources `db1` 55 | DB.Delete(&User{}, "name = ?", "jinzhu") // sources `db1` 56 | DB.Table("users").Update("name", "jinzhu") // sources `db1` 57 | 58 | // Global Resolver Examples 59 | DB.Find(&Pet{}) // replicas `db3`/`db4` 60 | DB.Save(&Pet{}) // sources `db2` 61 | 62 | // Orders Resolver Examples 63 | DB.Find(&Order{}) // replicas `db8` 64 | DB.Table("orders").Find(&Report{}) // replicas `db8` 65 | ``` 66 | 67 | ### Read/Write Splitting 68 | 69 | Read/Write splitting with DBResolver based on the current using [GORM callback](https://gorm.io/docs/write_plugins.html). 70 | 71 | For `Query`, `Row` callback, will use `replicas` unless `Write` mode specified 72 | For `Raw` callback, statements are considered read-only and will use `replicas` if the SQL starts with `SELECT` 73 | 74 | ### Manual connection switching 75 | 76 | ```go 77 | // Use Write Mode: read user from sources `db1` 78 | DB.Clauses(dbresolver.Write).First(&user) 79 | 80 | // Specify Resolver: read user from `secondary`'s replicas: db8 81 | DB.Clauses(dbresolver.Use("secondary")).First(&user) 82 | 83 | // Specify Resolver and Write Mode: read user from `secondary`'s sources: db6 or db7 84 | DB.Clauses(dbresolver.Use("secondary"), dbresolver.Write).First(&user) 85 | ``` 86 | 87 | ### Transaction 88 | 89 | When using transaction, DBResolver will keep using the transaction and won't switch to sources/replicas based on configuration 90 | 91 | But you can specifies which DB to use before starting a transaction, for example: 92 | 93 | ```go 94 | // Start transaction based on default replicas db 95 | tx := DB.Clauses(dbresolver.Read).Begin() 96 | 97 | // Start transaction based on default sources db 98 | tx := DB.Clauses(dbresolver.Write).Begin() 99 | 100 | // Start transaction based on `secondary`'s sources 101 | tx := DB.Clauses(dbresolver.Use("secondary"), dbresolver.Write).Begin() 102 | ``` 103 | 104 | ### Load Balancing 105 | 106 | GORM supports load balancing sources/replicas based on policy, the policy is an interface implements following interface: 107 | 108 | ```go 109 | type Policy interface { 110 | Resolve([]gorm.ConnPool) gorm.ConnPool 111 | } 112 | ``` 113 | 114 | Currently only the `RandomPolicy` implemented and it is the default option if no policy specified. 115 | 116 | ### Connection Pool 117 | 118 | ```go 119 | DB.Use( 120 | dbresolver.Register(dbresolver.Config{ /* xxx */ }). 121 | SetConnMaxIdleTime(time.Hour). 122 | SetConnMaxLifetime(24 * time.Hour). 123 | SetMaxIdleConns(100). 124 | SetMaxOpenConns(200) 125 | ) 126 | ``` 127 | -------------------------------------------------------------------------------- /callbacks.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "strings" 5 | 6 | "gorm.io/gorm" 7 | ) 8 | 9 | func (dr *DBResolver) registerCallbacks(db *gorm.DB) { 10 | dr.Callback().Create().Before("*").Register("gorm:db_resolver", dr.switchSource) 11 | dr.Callback().Query().Before("*").Register("gorm:db_resolver", dr.switchReplica) 12 | dr.Callback().Update().Before("*").Register("gorm:db_resolver", dr.switchSource) 13 | dr.Callback().Delete().Before("*").Register("gorm:db_resolver", dr.switchSource) 14 | dr.Callback().Row().Before("*").Register("gorm:db_resolver", dr.switchReplica) 15 | dr.Callback().Raw().Before("*").Register("gorm:db_resolver", dr.switchGuess) 16 | } 17 | 18 | func (dr *DBResolver) switchSource(db *gorm.DB) { 19 | if !isTransaction(db.Statement.ConnPool) { 20 | db.Statement.ConnPool = dr.resolve(db.Statement, Write) 21 | } 22 | } 23 | 24 | func (dr *DBResolver) switchReplica(db *gorm.DB) { 25 | if !isTransaction(db.Statement.ConnPool) { 26 | if rawSQL := db.Statement.SQL.String(); len(rawSQL) > 0 { 27 | dr.switchGuess(db) 28 | } else { 29 | _, locking := db.Statement.Clauses["FOR"] 30 | if _, ok := db.Statement.Settings.Load(writeName); ok || locking { 31 | db.Statement.ConnPool = dr.resolve(db.Statement, Write) 32 | } else { 33 | db.Statement.ConnPool = dr.resolve(db.Statement, Read) 34 | } 35 | } 36 | } 37 | } 38 | 39 | func (dr *DBResolver) switchGuess(db *gorm.DB) { 40 | if !isTransaction(db.Statement.ConnPool) { 41 | if _, ok := db.Statement.Settings.Load(writeName); ok { 42 | db.Statement.ConnPool = dr.resolve(db.Statement, Write) 43 | } else if _, ok := db.Statement.Settings.Load(readName); ok { 44 | db.Statement.ConnPool = dr.resolve(db.Statement, Read) 45 | } else if rawSQL := strings.TrimSpace(db.Statement.SQL.String()); len(rawSQL) > 10 && strings.EqualFold(rawSQL[:6], "select") && !strings.EqualFold(rawSQL[len(rawSQL)-10:], "for update") { 46 | db.Statement.ConnPool = dr.resolve(db.Statement, Read) 47 | } else { 48 | db.Statement.ConnPool = dr.resolve(db.Statement, Write) 49 | } 50 | } 51 | } 52 | 53 | func isTransaction(connPool gorm.ConnPool) bool { 54 | _, ok := connPool.(gorm.TxCommitter) 55 | return ok 56 | } 57 | -------------------------------------------------------------------------------- /clauses.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | "gorm.io/gorm/clause" 6 | ) 7 | 8 | // Operation specifies dbresolver mode 9 | type Operation string 10 | 11 | const ( 12 | writeName = "gorm:db_resolver:write" 13 | readName = "gorm:db_resolver:read" 14 | ) 15 | 16 | // ModifyStatement modify operation mode 17 | func (op Operation) ModifyStatement(stmt *gorm.Statement) { 18 | var optName string 19 | if op == Write { 20 | optName = writeName 21 | stmt.Settings.Delete(readName) 22 | } else if op == Read { 23 | optName = readName 24 | stmt.Settings.Delete(writeName) 25 | } 26 | 27 | if optName != "" { 28 | stmt.Settings.Store(optName, struct{}{}) 29 | if fc := stmt.DB.Callback().Query().Get("gorm:db_resolver"); fc != nil { 30 | fc(stmt.DB) 31 | } 32 | } 33 | } 34 | 35 | // Build implements clause.Expression interface 36 | func (op Operation) Build(clause.Builder) { 37 | } 38 | 39 | // Use specifies configuration 40 | func Use(str string) clause.Expression { 41 | return using{Use: str} 42 | } 43 | 44 | type using struct { 45 | Use string 46 | } 47 | 48 | const usingName = "gorm:db_resolver:using" 49 | 50 | // ModifyStatement modify operation mode 51 | func (u using) ModifyStatement(stmt *gorm.Statement) { 52 | stmt.Clauses[usingName] = clause.Clause{Expression: u} 53 | if fc := stmt.DB.Callback().Query().Get("gorm:db_resolver"); fc != nil { 54 | fc(stmt.DB) 55 | } 56 | } 57 | 58 | // Build implements clause.Expression interface 59 | func (u using) Build(clause.Builder) { 60 | } 61 | -------------------------------------------------------------------------------- /database.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "gorm.io/gorm" 8 | ) 9 | 10 | func (dr *DBResolver) SetConnMaxIdleTime(d time.Duration) *DBResolver { 11 | dr.Call(func(connPool gorm.ConnPool) error { 12 | if conn, ok := connPool.(interface{ SetConnMaxIdleTime(time.Duration) }); ok { 13 | conn.SetConnMaxIdleTime(d) 14 | } else { 15 | dr.DB.Logger.Error(context.Background(), "SetConnMaxIdleTime not implemented for %#v, please use golang v1.15+", conn) 16 | } 17 | return nil 18 | }) 19 | 20 | return dr 21 | } 22 | 23 | func (dr *DBResolver) SetConnMaxLifetime(d time.Duration) *DBResolver { 24 | dr.Call(func(connPool gorm.ConnPool) error { 25 | if conn, ok := connPool.(interface{ SetConnMaxLifetime(time.Duration) }); ok { 26 | conn.SetConnMaxLifetime(d) 27 | } else { 28 | dr.DB.Logger.Error(context.Background(), "SetConnMaxLifetime not implemented for %#v", conn) 29 | } 30 | return nil 31 | }) 32 | 33 | return dr 34 | } 35 | 36 | func (dr *DBResolver) SetMaxIdleConns(n int) *DBResolver { 37 | dr.Call(func(connPool gorm.ConnPool) error { 38 | if conn, ok := connPool.(interface{ SetMaxIdleConns(int) }); ok { 39 | conn.SetMaxIdleConns(n) 40 | } else { 41 | dr.DB.Logger.Error(context.Background(), "SetMaxIdleConns not implemented for %#v", conn) 42 | } 43 | return nil 44 | }) 45 | 46 | return dr 47 | } 48 | 49 | func (dr *DBResolver) SetMaxOpenConns(n int) *DBResolver { 50 | dr.Call(func(connPool gorm.ConnPool) error { 51 | 52 | if conn, ok := connPool.(interface{ SetMaxOpenConns(int) }); ok { 53 | conn.SetMaxOpenConns(n) 54 | } else { 55 | dr.DB.Logger.Error(context.Background(), "SetMaxOpenConns not implemented for %#v", conn) 56 | } 57 | return nil 58 | }) 59 | 60 | return dr 61 | } 62 | 63 | func (dr *DBResolver) Call(fc func(connPool gorm.ConnPool) error) error { 64 | if dr.DB != nil { 65 | for _, r := range dr.resolvers { 66 | if err := r.call(fc); err != nil { 67 | return err 68 | } 69 | } 70 | 71 | if dr.global != nil { 72 | if err := dr.global.call(fc); err != nil { 73 | return err 74 | } 75 | } 76 | } else { 77 | dr.compileCallbacks = append(dr.compileCallbacks, fc) 78 | } 79 | 80 | return nil 81 | } 82 | -------------------------------------------------------------------------------- /dbresolver.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "errors" 5 | 6 | "gorm.io/gorm" 7 | ) 8 | 9 | const ( 10 | Write Operation = "write" 11 | Read Operation = "read" 12 | ) 13 | 14 | type DBResolver struct { 15 | *gorm.DB 16 | configs []Config 17 | resolvers map[string]*resolver 18 | global *resolver 19 | prepareStmtStore map[gorm.ConnPool]*gorm.PreparedStmtDB 20 | compileCallbacks []func(gorm.ConnPool) error 21 | } 22 | 23 | type Config struct { 24 | Sources []gorm.Dialector 25 | Replicas []gorm.Dialector 26 | Policy Policy 27 | datas []interface{} 28 | TraceResolverMode bool 29 | } 30 | 31 | func Register(config Config, datas ...interface{}) *DBResolver { 32 | return (&DBResolver{}).Register(config, datas...) 33 | } 34 | 35 | func (dr *DBResolver) Register(config Config, datas ...interface{}) *DBResolver { 36 | if dr.prepareStmtStore == nil { 37 | dr.prepareStmtStore = map[gorm.ConnPool]*gorm.PreparedStmtDB{} 38 | } 39 | 40 | if dr.resolvers == nil { 41 | dr.resolvers = map[string]*resolver{} 42 | } 43 | 44 | if config.Policy == nil { 45 | config.Policy = RandomPolicy{} 46 | } 47 | 48 | config.datas = datas 49 | 50 | dr.configs = append(dr.configs, config) 51 | if dr.DB != nil { 52 | dr.compileConfig(config) 53 | } 54 | return dr 55 | } 56 | 57 | func (dr *DBResolver) Name() string { 58 | return "gorm:db_resolver" 59 | } 60 | 61 | func (dr *DBResolver) Initialize(db *gorm.DB) error { 62 | dr.DB = db 63 | dr.registerCallbacks(db) 64 | return dr.compile() 65 | } 66 | 67 | func (dr *DBResolver) compile() error { 68 | for _, config := range dr.configs { 69 | if err := dr.compileConfig(config); err != nil { 70 | return err 71 | } 72 | } 73 | return nil 74 | } 75 | 76 | func (dr *DBResolver) compileConfig(config Config) (err error) { 77 | var ( 78 | connPool = dr.DB.Config.ConnPool 79 | r = resolver{ 80 | dbResolver: dr, 81 | policy: config.Policy, 82 | traceResolverMode: config.TraceResolverMode, 83 | } 84 | ) 85 | 86 | if preparedStmtDB, ok := connPool.(*gorm.PreparedStmtDB); ok { 87 | connPool = preparedStmtDB.ConnPool 88 | } 89 | 90 | if len(config.Sources) == 0 { 91 | r.sources = []gorm.ConnPool{connPool} 92 | } else if r.sources, err = dr.convertToConnPool(config.Sources); err != nil { 93 | return err 94 | } 95 | 96 | if len(config.Replicas) == 0 { 97 | r.replicas = r.sources 98 | } else if r.replicas, err = dr.convertToConnPool(config.Replicas); err != nil { 99 | return err 100 | } 101 | 102 | if len(config.datas) > 0 { 103 | for _, data := range config.datas { 104 | if t, ok := data.(string); ok { 105 | dr.resolvers[t] = &r 106 | } else { 107 | stmt := &gorm.Statement{DB: dr.DB} 108 | if err := stmt.Parse(data); err == nil { 109 | dr.resolvers[stmt.Table] = &r 110 | } else { 111 | return err 112 | } 113 | } 114 | } 115 | } else if dr.global == nil { 116 | dr.global = &r 117 | } else { 118 | return errors.New("conflicted global resolver") 119 | } 120 | 121 | for _, fc := range dr.compileCallbacks { 122 | if err = r.call(fc); err != nil { 123 | return err 124 | } 125 | } 126 | 127 | if config.TraceResolverMode { 128 | dr.Logger = NewResolverModeLogger(dr.Logger) 129 | } 130 | 131 | return nil 132 | } 133 | 134 | func (dr *DBResolver) convertToConnPool(dialectors []gorm.Dialector) (connPools []gorm.ConnPool, err error) { 135 | config := *dr.DB.Config 136 | for _, dialector := range dialectors { 137 | if db, err := gorm.Open(dialector, &config); err == nil { 138 | connPool := db.ConnPool 139 | if preparedStmtDB, ok := connPool.(*gorm.PreparedStmtDB); ok { 140 | connPool = preparedStmtDB.ConnPool 141 | } 142 | 143 | dr.prepareStmtStore[connPool] = gorm.NewPreparedStmtDB(db.ConnPool, dr.PrepareStmtMaxSize, dr.PrepareStmtTTL) 144 | 145 | connPools = append(connPools, connPool) 146 | } else { 147 | return nil, err 148 | } 149 | } 150 | 151 | return connPools, err 152 | } 153 | 154 | func (dr *DBResolver) resolve(stmt *gorm.Statement, op Operation) gorm.ConnPool { 155 | if r := dr.getResolver(stmt); r != nil { 156 | return r.resolve(stmt, op) 157 | } 158 | return stmt.ConnPool 159 | } 160 | 161 | func (dr *DBResolver) getResolver(stmt *gorm.Statement) *resolver { 162 | if len(dr.resolvers) > 0 { 163 | if u, ok := stmt.Clauses[usingName].Expression.(using); ok && u.Use != "" { 164 | if r, ok := dr.resolvers[u.Use]; ok { 165 | return r 166 | } 167 | } 168 | 169 | if stmt.Table != "" { 170 | if r, ok := dr.resolvers[stmt.Table]; ok { 171 | return r 172 | } 173 | } 174 | 175 | if stmt.Model != nil { 176 | if err := stmt.Parse(stmt.Model); err == nil { 177 | if r, ok := dr.resolvers[stmt.Table]; ok { 178 | return r 179 | } 180 | } 181 | } 182 | 183 | if stmt.Schema != nil { 184 | if r, ok := dr.resolvers[stmt.Schema.Table]; ok { 185 | return r 186 | } 187 | } 188 | 189 | if rawSQL := stmt.SQL.String(); rawSQL != "" { 190 | if r, ok := dr.resolvers[getTableFromRawSQL(rawSQL)]; ok { 191 | return r 192 | } 193 | } 194 | } 195 | 196 | return dr.global 197 | } 198 | -------------------------------------------------------------------------------- /dbresolver_test.go: -------------------------------------------------------------------------------- 1 | package dbresolver_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | 8 | "gorm.io/driver/mysql" 9 | "gorm.io/gorm" 10 | "gorm.io/gorm/logger" 11 | "gorm.io/plugin/dbresolver" 12 | ) 13 | 14 | type User struct { 15 | ID uint 16 | Name string 17 | Orders []Order 18 | } 19 | 20 | type Product struct { 21 | ID uint 22 | Name string 23 | } 24 | 25 | type Order struct { 26 | ID uint 27 | OrderNo string 28 | UserID uint 29 | } 30 | 31 | func GetDB(port int) *gorm.DB { 32 | DB, err := gorm.Open(mysql.Open(fmt.Sprintf("gorm:gorm@tcp(localhost:%v)/gorm?charset=utf8&parseTime=True&loc=Local", port)), &gorm.Config{}) 33 | if err != nil { 34 | panic(fmt.Sprintf("failed to connect db, got error: %v, port: %v", err, port)) 35 | } 36 | return DB 37 | } 38 | 39 | func init() { 40 | for _, port := range []int{9911, 9912, 9913, 9914} { 41 | DB := GetDB(port) 42 | DB.Migrator().DropTable(&User{}, &Product{}, &Order{}) 43 | DB.AutoMigrate(&User{}, &Product{}, &Order{}) 44 | 45 | user := User{Name: fmt.Sprintf("%v", port)} 46 | DB.Create(&user) 47 | DB.Create(&Product{Name: fmt.Sprintf("%v", port)}) 48 | DB.Create(&Order{OrderNo: fmt.Sprintf("%v", port), UserID: user.ID}) 49 | } 50 | } 51 | 52 | func TestDBResolver(t *testing.T) { 53 | for i := 0; i < 2; i++ { 54 | DB, err := gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9911)/gorm?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{PrepareStmt: i%2 == 0}) 55 | if err != nil { 56 | t.Fatalf("failed to connect db, got error: %v", err) 57 | } 58 | if debug := os.Getenv("DEBUG"); debug == "true" { 59 | DB.Logger = DB.Logger.LogMode(logger.Info) 60 | } else if debug == "false" { 61 | DB.Logger = DB.Logger.LogMode(logger.Silent) 62 | } 63 | 64 | if err := DB.Use(dbresolver.Register(dbresolver.Config{ 65 | Sources: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9911)/gorm?charset=utf8&parseTime=True&loc=Local")}, 66 | Replicas: []gorm.Dialector{ 67 | mysql.Open("gorm:gorm@tcp(localhost:9912)/gorm?charset=utf8&parseTime=True&loc=Local"), 68 | mysql.Open("gorm:gorm@tcp(localhost:9913)/gorm?charset=utf8&parseTime=True&loc=Local"), 69 | }, 70 | TraceResolverMode: true, 71 | }).Register(dbresolver.Config{ 72 | Sources: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9914)/gorm?charset=utf8&parseTime=True&loc=Local")}, 73 | Replicas: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9913)/gorm?charset=utf8&parseTime=True&loc=Local")}, 74 | TraceResolverMode: true, 75 | }, "users", &Product{}).SetMaxOpenConns(5)); err != nil { 76 | t.Fatalf("failed to use plugin, got error: %v", err) 77 | } 78 | 79 | for j := 0; j < 20; j++ { 80 | var order Order 81 | // test transaction 82 | tx := DB.Begin() 83 | tx.Find(&order) 84 | if order.OrderNo != "9911" { 85 | t.Fatalf("idx: %v: order should comes from default db, but got order %v", j, order.OrderNo) 86 | } 87 | tx.Rollback() 88 | 89 | tx = DB.Clauses(dbresolver.Read).Begin() 90 | tx.Find(&order) 91 | if order.OrderNo != "9912" && order.OrderNo != "9913" { 92 | t.Fatalf("idx: %v: order should comes from read db, but got order %v", j, order.OrderNo) 93 | } 94 | tx.Rollback() 95 | 96 | tx = DB.Clauses(dbresolver.Write).Begin() 97 | tx.Find(&order) 98 | if order.OrderNo != "9911" { 99 | t.Fatalf("idx: %v: order should comes from write db, but got order %v", j, order.OrderNo) 100 | } 101 | tx.Rollback() 102 | 103 | tx = DB.Clauses(dbresolver.Use("users"), dbresolver.Write).Begin() 104 | tx.Find(&order) 105 | if order.OrderNo != "9914" { 106 | t.Fatalf("idx: %v: order should comes from users, write db, but got order %v", j, order.OrderNo) 107 | } 108 | tx.Rollback() 109 | 110 | tx = DB.Clauses(dbresolver.Write, dbresolver.Use("users")).Begin() 111 | tx.Find(&order) 112 | if order.OrderNo != "9914" { 113 | t.Fatalf("idx: %v: order should comes from users, write db, but got order %v", j, order.OrderNo) 114 | } 115 | tx.Rollback() 116 | 117 | // test query 118 | DB.First(&order) 119 | if order.OrderNo != "9912" && order.OrderNo != "9913" { 120 | t.Fatalf("idx: %v: order should comes from read db, but got order %v", j, order.OrderNo) 121 | } 122 | 123 | DB.Clauses(dbresolver.Write).First(&order) 124 | if order.OrderNo != "9911" { 125 | t.Fatalf("idx: %v: order should comes from write db, but got order %v", j, order.OrderNo) 126 | } 127 | 128 | DB.Clauses(dbresolver.Use("users")).First(&order) 129 | if order.OrderNo != "9913" { 130 | t.Fatalf("idx: %v: order should comes from write db @ users, but got order %v", j, order.OrderNo) 131 | } 132 | 133 | DB.Clauses(dbresolver.Use("users"), dbresolver.Write).First(&order) 134 | if order.OrderNo != "9914" { 135 | t.Fatalf("idx: %v: order should comes from write db @ users, but got order %v", j, order.OrderNo) 136 | } 137 | 138 | var user User 139 | DB.First(&user) 140 | if user.Name != "9913" { 141 | t.Fatalf("idx: %v: user should comes from read db, but got %v", j, user.Name) 142 | } 143 | 144 | DB.Clauses(dbresolver.Write).First(&user) 145 | if user.Name != "9914" { 146 | t.Fatalf("idx: %v: user should comes from read db, but got %v", j, user.Name) 147 | } 148 | 149 | var product Product 150 | DB.First(&product) 151 | if product.Name != "9913" { 152 | t.Fatalf("idx: %v: product should comes from read db, but got %v", j, product.Name) 153 | } 154 | 155 | DB.Clauses(dbresolver.Write).First(&product) 156 | if product.Name != "9914" { 157 | t.Fatalf("idx: %v: product should comes from write db, but got %v", j, product.Name) 158 | } 159 | 160 | // test preload 161 | if err := DB.Clauses(dbresolver.Write).Preload("Orders").First(&user).Error; err != nil || len(user.Orders) != 1 { 162 | t.Fatalf("failed to preload orders, count: %v, got error: %v", len(user.Orders), err) 163 | } 164 | 165 | // order source 9911, user source: 9914 166 | if user.Orders[0].OrderNo != "9911" || user.Name != "9914" { 167 | t.Fatalf("incorrect order info: userName: %v, orderNo: %v", user.Name, user.Orders[0].OrderNo) 168 | } 169 | 170 | if err := DB.Preload("Orders", func(tx *gorm.DB) *gorm.DB { 171 | return tx.Clauses(dbresolver.Write) 172 | }).First(&user).Error; err != nil || len(user.Orders) != 1 { 173 | t.Fatalf("failed to preload orders, count: %v, got error: %v", len(user.Orders), err) 174 | } 175 | 176 | // order source 9911, user replica: 9913 177 | if user.Orders[0].OrderNo != "9911" || user.Name != "9913" { 178 | t.Fatalf("incorrect order info: userName: %v, orderNo: %v", user.Name, user.Orders[0].OrderNo) 179 | } 180 | 181 | // test create 182 | DB.Create(&User{Name: "create"}) 183 | if err := DB.First(&User{}, "name = ?", "create").Error; err == nil { 184 | t.Fatalf("can't read user from read db, got no error happened") 185 | } 186 | 187 | if err := DB.Clauses(dbresolver.Write).First(&User{}, "name = ?", "create").Error; err != nil { 188 | t.Fatalf("read user from write db, got error: %v", err) 189 | } 190 | 191 | DB9914 := GetDB(9914) 192 | 193 | if err := DB9914.First(&User{}, "name = ?", "create").Error; err != nil { 194 | t.Fatalf("read user from write db, got error: %v", err) 195 | } 196 | 197 | var name string 198 | if err := DB.Raw("select name from users").Row().Scan(&name); err != nil || name != "9913" { 199 | t.Fatalf("read users from read db, name %v", name) 200 | } 201 | 202 | if err := DB.Debug().Raw("select name from users where name = ? for update", "create").Row().Scan(&name); err != nil || name != "create" { 203 | t.Fatalf("read users from write db, name %v, err %v", name, err) 204 | } 205 | 206 | // test update 207 | if err := DB.Model(&User{}).Where("name = ?", "create").Update("name", "update").Error; err != nil { 208 | t.Fatalf("failed to update users, got error: %v", err) 209 | } 210 | 211 | if err := DB9914.First(&User{}, "name = ?", "update").Error; err != nil { 212 | t.Fatalf("read user from write db, got error: %v", err) 213 | } 214 | 215 | // test raw sql 216 | name = "" 217 | if err := DB.Raw("select name from users where name = ?", "update").Row().Scan(&name); err == nil || name != "" { 218 | t.Fatalf("can't read users from read db, name %v", name) 219 | } 220 | 221 | if err := DB.Raw(" select name from users where name = ?", "9913").Row().Scan(&name); err != nil { 222 | t.Fatalf("(raw sql has leading space) should go to read db, got error: %v", err) 223 | } 224 | 225 | if err := DB.Raw(` 226 | select name 227 | from users where name = ?`, "9913").Row().Scan(&name); err != nil { 228 | t.Fatalf("(raw sql has leading newline) should go to read db, got error: %v", err) 229 | } 230 | 231 | if err := DB.Clauses(dbresolver.Write).Raw("select name from users where name = ?", "update").Row().Scan(&name); err != nil || name != "update" { 232 | t.Fatalf("read users from write db, error %v, name %v", err, name) 233 | } 234 | 235 | // test delete 236 | if err := DB.Where("name = ?", "update").Delete(&User{}).Error; err != nil { 237 | t.Fatalf("failed to delete users, got error: %v", err) 238 | } 239 | 240 | if err := DB9914.First(&User{}, "name = ?", "update").Error; err != gorm.ErrRecordNotFound { 241 | t.Fatalf("read user from write db after delete, got error: %v", err) 242 | } 243 | } 244 | } 245 | } 246 | 247 | func TestConnPool(t *testing.T) { 248 | for i := 0; i < 2; i++ { 249 | DB, err := gorm.Open(mysql.Open("gorm:gorm@tcp(localhost:9911)/gorm?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{PrepareStmt: i%2 == 0}) 250 | if err != nil { 251 | t.Fatalf("failed to connect db, got error: %v", err) 252 | } 253 | if debug := os.Getenv("DEBUG"); debug == "true" { 254 | DB.Logger = DB.Logger.LogMode(logger.Info) 255 | } else if debug == "false" { 256 | DB.Logger = DB.Logger.LogMode(logger.Silent) 257 | } 258 | 259 | if err := DB.Use(dbresolver.Register(dbresolver.Config{ 260 | Sources: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9912)/gorm?charset=utf8&parseTime=True&loc=Local")}, 261 | Replicas: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9913)/gorm?charset=utf8&parseTime=True&loc=Local")}, 262 | TraceResolverMode: true, 263 | }).Register(dbresolver.Config{ 264 | Sources: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9914)/gorm?charset=utf8&parseTime=True&loc=Local")}, 265 | Replicas: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9913)/gorm?charset=utf8&parseTime=True&loc=Local")}, 266 | TraceResolverMode: true, 267 | }, "users", &Product{}).SetMaxOpenConns(5)); err != nil { 268 | t.Fatalf("failed to use plugin, got error: %v", err) 269 | } 270 | 271 | tests := []struct { 272 | name string 273 | db *gorm.DB 274 | want int 275 | }{ 276 | {"global", DB, 9911}, 277 | {"source", DB.Clauses(dbresolver.Write), 9912}, 278 | {"replica", DB.Clauses(dbresolver.Read), 9913}, 279 | {"table global", DB.Table("users"), 9911}, 280 | {"table source", DB.Table("users").Clauses(dbresolver.Write), 9914}, 281 | {"table replica", DB.Table("users").Clauses(dbresolver.Read), 9913}, 282 | {"model global", DB.Model(&Product{}), 9911}, 283 | {"model source", DB.Model(&Product{}).Clauses(dbresolver.Write), 9914}, 284 | {"model replica", DB.Model(&Product{}).Clauses(dbresolver.Read), 9913}, 285 | } 286 | 287 | for _, tt := range tests { 288 | t.Run(tt.name, func(t *testing.T) { 289 | db, err := tt.db.DB() 290 | if err != nil { 291 | t.Fatalf("failed to get *sql.DB, got error: %v", err) 292 | } 293 | 294 | var got int 295 | if err := db.QueryRow("SELECT order_no FROM orders LIMIT 1").Scan(&got); err != nil { 296 | t.Fatalf("failed to get order_no, got error: %v", err) 297 | } 298 | 299 | if got != tt.want { 300 | t.Errorf("got %v, want %v", got, tt.want) 301 | } 302 | }) 303 | } 304 | } 305 | } 306 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | mysql1: 5 | image: 'mysql/mysql-server:latest' 6 | ports: 7 | - 9911:3306 8 | environment: 9 | - MYSQL_DATABASE=gorm 10 | - MYSQL_USER=gorm 11 | - MYSQL_PASSWORD=gorm 12 | - MYSQL_RANDOM_ROOT_PASSWORD="yes" 13 | mysql2: 14 | image: 'mysql/mysql-server:latest' 15 | ports: 16 | - 9912:3306 17 | environment: 18 | - MYSQL_DATABASE=gorm 19 | - MYSQL_USER=gorm 20 | - MYSQL_PASSWORD=gorm 21 | - MYSQL_RANDOM_ROOT_PASSWORD="yes" 22 | mysql3: 23 | image: 'mysql/mysql-server:latest' 24 | ports: 25 | - 9913:3306 26 | environment: 27 | - MYSQL_DATABASE=gorm 28 | - MYSQL_USER=gorm 29 | - MYSQL_PASSWORD=gorm 30 | - MYSQL_RANDOM_ROOT_PASSWORD="yes" 31 | mysql4: 32 | image: 'mysql/mysql-server:latest' 33 | ports: 34 | - 9914:3306 35 | environment: 36 | - MYSQL_DATABASE=gorm 37 | - MYSQL_USER=gorm 38 | - MYSQL_PASSWORD=gorm 39 | - MYSQL_RANDOM_ROOT_PASSWORD="yes" 40 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module gorm.io/plugin/dbresolver 2 | 3 | go 1.18 4 | 5 | require ( 6 | gorm.io/driver/mysql v1.5.7 7 | gorm.io/gorm v1.26.0 8 | ) 9 | 10 | require ( 11 | github.com/go-sql-driver/mysql v1.7.0 // indirect 12 | github.com/jinzhu/inflection v1.0.0 // indirect 13 | github.com/jinzhu/now v1.1.5 // indirect 14 | golang.org/x/text v0.20.0 // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= 2 | github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= 3 | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= 4 | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= 5 | github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= 6 | github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= 7 | golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= 8 | golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= 9 | gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= 10 | gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= 11 | gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= 12 | gorm.io/gorm v1.25.13-0.20250415034200-3ae5fdee0ce4 h1:RToN9Sg/d32U0vU2Vqhagb5MD9k85+uf9z9eG3cwOuw= 13 | gorm.io/gorm v1.25.13-0.20250415034200-3ae5fdee0ce4/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= 14 | gorm.io/gorm v1.26.0 h1:9lqQVPG5aNNS6AyHdRiwScAVnXHg/L/Srzx55G5fOgs= 15 | gorm.io/gorm v1.26.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= 16 | -------------------------------------------------------------------------------- /logger.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "gorm.io/gorm" 9 | "gorm.io/gorm/logger" 10 | ) 11 | 12 | type ResolverModeKey string 13 | type ResolverMode string 14 | 15 | const resolverModeKey ResolverModeKey = "dbresolver:resolver_mode_key" 16 | const ( 17 | ResolverModeSource ResolverMode = "source" 18 | ResolverModeReplica ResolverMode = "replica" 19 | ) 20 | 21 | type resolverModeLogger struct { 22 | logger.Interface 23 | } 24 | 25 | func (l resolverModeLogger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { 26 | if filter, ok := l.Interface.(gorm.ParamsFilter); ok { 27 | sql, params = filter.ParamsFilter(ctx, sql, params...) 28 | } 29 | return sql, params 30 | } 31 | 32 | func (l resolverModeLogger) LogMode(level logger.LogLevel) logger.Interface { 33 | l.Interface = l.Interface.LogMode(level) 34 | return l 35 | } 36 | 37 | func (l resolverModeLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { 38 | var splitFn = func() (sql string, rowsAffected int64) { 39 | sql, rowsAffected = fc() 40 | op := ctx.Value(resolverModeKey) 41 | if op != nil { 42 | sql = fmt.Sprintf("[%s] %s", op, sql) 43 | return 44 | } 45 | 46 | // the situation that dbresolver does not handle 47 | // such as transactions, or some resolvers do not enable MarkResolverMode. 48 | return 49 | } 50 | l.Interface.Trace(ctx, begin, splitFn, err) 51 | } 52 | 53 | func NewResolverModeLogger(l logger.Interface) logger.Interface { 54 | if _, ok := l.(resolverModeLogger); ok { 55 | return l 56 | } 57 | return resolverModeLogger{ 58 | Interface: l, 59 | } 60 | } 61 | 62 | func markStmtResolverMode(stmt *gorm.Statement, mode ResolverMode) { 63 | if _, ok := stmt.Logger.(resolverModeLogger); ok { 64 | stmt.Context = context.WithValue(stmt.Context, resolverModeKey, mode) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /policy.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "math/rand" 5 | "sync/atomic" 6 | 7 | "gorm.io/gorm" 8 | ) 9 | 10 | type Policy interface { 11 | Resolve([]gorm.ConnPool) gorm.ConnPool 12 | } 13 | 14 | type PolicyFunc func([]gorm.ConnPool) gorm.ConnPool 15 | 16 | func (f PolicyFunc) Resolve(connPools []gorm.ConnPool) gorm.ConnPool { 17 | return f(connPools) 18 | } 19 | 20 | type RandomPolicy struct { 21 | } 22 | 23 | func (RandomPolicy) Resolve(connPools []gorm.ConnPool) gorm.ConnPool { 24 | return connPools[rand.Intn(len(connPools))] 25 | } 26 | 27 | func RoundRobinPolicy() Policy { 28 | var i int 29 | return PolicyFunc(func(connPools []gorm.ConnPool) gorm.ConnPool { 30 | i = (i + 1) % len(connPools) 31 | return connPools[i] 32 | }) 33 | } 34 | 35 | func StrictRoundRobinPolicy() Policy { 36 | var i int64 37 | return PolicyFunc(func(connPools []gorm.ConnPool) gorm.ConnPool { 38 | return connPools[int(atomic.AddInt64(&i, 1))%len(connPools)] 39 | }) 40 | } 41 | -------------------------------------------------------------------------------- /policy_test.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "sync/atomic" 5 | "testing" 6 | 7 | "gorm.io/gorm" 8 | ) 9 | 10 | func TestPolicy_RoundRobinPolicy(t *testing.T) { 11 | var p1, p2, p3 gorm.ConnPool 12 | var pools = []gorm.ConnPool{ 13 | p1, p2, p3, 14 | } 15 | 16 | for i := 0; i < 10; i++ { 17 | if pools[i%3] != RoundRobinPolicy().Resolve(pools) { 18 | t.Errorf("RoundRobinPolicy failed") 19 | } 20 | if pools[i%3] != StrictRoundRobinPolicy().Resolve(pools) { 21 | t.Errorf("StrictRoundRobinPolicy failed") 22 | } 23 | } 24 | } 25 | 26 | func BenchmarkPolicy_StrictRoundRobinPolicy(b *testing.B) { 27 | var p1, p2, p3 gorm.ConnPool 28 | var pools = []gorm.ConnPool{ 29 | p1, p2, p3, 30 | } 31 | 32 | var i int64 33 | b.RunParallel(func(pb *testing.PB) { 34 | for pb.Next() { 35 | if pools[int(atomic.AddInt64(&i, 1))%3] != StrictRoundRobinPolicy().Resolve(pools) { 36 | b.Errorf("RoundRobinPolicy failed") 37 | } 38 | } 39 | }) 40 | } 41 | -------------------------------------------------------------------------------- /resolver.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | ) 6 | 7 | type resolver struct { 8 | sources []gorm.ConnPool 9 | replicas []gorm.ConnPool 10 | policy Policy 11 | dbResolver *DBResolver 12 | traceResolverMode bool 13 | } 14 | 15 | func (r *resolver) resolve(stmt *gorm.Statement, op Operation) (connPool gorm.ConnPool) { 16 | if op == Read { 17 | if len(r.replicas) == 1 { 18 | connPool = r.replicas[0] 19 | } else { 20 | connPool = r.policy.Resolve(r.replicas) 21 | } 22 | if r.traceResolverMode { 23 | markStmtResolverMode(stmt, ResolverModeReplica) 24 | } 25 | } else if len(r.sources) == 1 { 26 | connPool = r.sources[0] 27 | if r.traceResolverMode { 28 | markStmtResolverMode(stmt, ResolverModeSource) 29 | } 30 | } else { 31 | connPool = r.policy.Resolve(r.sources) 32 | if r.traceResolverMode { 33 | markStmtResolverMode(stmt, ResolverModeSource) 34 | } 35 | } 36 | 37 | if stmt.DB.PrepareStmt { 38 | if preparedStmt, ok := r.dbResolver.prepareStmtStore[connPool]; ok { 39 | return &gorm.PreparedStmtDB{ 40 | ConnPool: connPool, 41 | Mux: preparedStmt.Mux, 42 | Stmts: preparedStmt.Stmts, 43 | } 44 | } 45 | } 46 | 47 | return 48 | } 49 | 50 | func (r *resolver) call(fc func(connPool gorm.ConnPool) error) error { 51 | for _, s := range r.sources { 52 | if err := fc(s); err != nil { 53 | return err 54 | } 55 | } 56 | 57 | for _, re := range r.replicas { 58 | if err := fc(re); err != nil { 59 | return err 60 | } 61 | } 62 | return nil 63 | } 64 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import ( 4 | "regexp" 5 | ) 6 | 7 | var fromTableRegexp = regexp.MustCompile("(?i)(?:FROM|UPDATE|MERGE INTO|INSERT [a-z ]*INTO) ['`\"]?([a-zA-Z0-9_]+)([ '`\",)]|$)") 8 | 9 | func getTableFromRawSQL(sql string) string { 10 | if matches := fromTableRegexp.FindAllStringSubmatch(sql, -1); len(matches) > 0 { 11 | return matches[0][1] 12 | } 13 | 14 | return "" 15 | } 16 | -------------------------------------------------------------------------------- /utils_test.go: -------------------------------------------------------------------------------- 1 | package dbresolver 2 | 3 | import "testing" 4 | 5 | func TestGetTableFromRawSQL(t *testing.T) { 6 | datas := [][2]string{ 7 | {"select * from users as u", "users"}, 8 | {"select name from users", "users"}, 9 | {"select * from (select * from users) as u", "users"}, 10 | {"select * from (select * from users)", "users"}, 11 | {"select * from (select * from users), (select * from products)", "users"}, 12 | {"select * from users, products", "users"}, 13 | {"select * from users as u, products as p", "users"}, 14 | {"UPDATE users SET column1 = value1, column2 = value2", "users"}, 15 | {"DELETE FROM users WHERE condition;", "users"}, 16 | {"INSERT INTO users (column1, column2) VALUES (v1, v2)", "users"}, 17 | {"insert ignore into users (name,age) VALUES ('jinzhu',18);", "users"}, 18 | {"MERGE INTO users USING ", "users"}, 19 | } 20 | 21 | for _, data := range datas { 22 | if getTableFromRawSQL(data[0]) != data[1] { 23 | t.Errorf("Failed to get table name from %v, expect: %v, got: %v", data[0], data[1], getTableFromRawSQL(data[0])) 24 | } 25 | } 26 | } 27 | --------------------------------------------------------------------------------