├── .gitignore ├── .travis.yml ├── README.md ├── errors └── errors.go ├── hash.go ├── hash_test.go ├── helper.go ├── helper_test.go ├── node.go ├── node_test.go ├── orm └── xorm │ ├── .gitignore │ ├── by_session.go │ ├── by_session_test.go │ ├── condition.go │ ├── examples_test.go │ ├── interface.go │ ├── interface_test.go │ ├── xorm.go │ ├── xorm_function.go │ ├── xorm_function_test.go │ ├── xorm_parallel.go │ ├── xorm_parallel_test.go │ ├── xorm_session_list.go │ ├── xorm_session_manager.go │ ├── xorm_session_manager_test.go │ ├── xorm_session_manager_tx.go │ ├── xorm_session_manager_tx_test.go │ ├── xorm_test.go │ ├── xorm_wizard.go │ └── xorm_wizard_test.go ├── reflect.go ├── reflect_test.go ├── shard_cluster.go ├── shard_cluster_test.go ├── standard_cluster.go ├── standard_cluster_test.go ├── wizard.go └── wizard_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: go 3 | go: 4 | - 1.8 5 | - 1.9 6 | - tip 7 | matrix: 8 | allow_failures: 9 | - go: tip 10 | before_install: 11 | - go get golang.org/x/tools/cmd/cover 12 | - go get github.com/golang/lint/golint 13 | - go get github.com/modocache/gover 14 | - go get -d github.com/stretchr/testify/assert github.com/go-sql-driver/mysql github.com/mattn/go-sqlite3 15 | before_script: 16 | - go vet ./... 17 | - gofmt -s -l . 18 | script: 19 | - go list -f '{{if len .TestGoFiles}}"go test -coverprofile={{.Dir}}/.coverprofile {{.ImportPath}}"{{end}}' ./... | xargs -I{} sh -c '{}' 20 | - gover . coverprofile.txt 21 | after_success: 22 | - bash <(curl -s https://codecov.io/bash) -f coverprofile.txt 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Wizard 2 | ==== 3 | [![Build Status](https://travis-ci.org/evalphobia/wizard.svg?branch=master)](https://travis-ci.org/evalphobia/wizard) [![codecov.io](https://codecov.io/github/evalphobia/wizard/coverage.svg?branch=master)](https://codecov.io/github/evalphobia/wizard?branch=master) 4 | [![GoDoc](https://godoc.org/github.com/evalphobia/wizard?status.svg)](https://godoc.org/github.com/evalphobia/wizard) [![Join the chat at https://gitter.im/evalphobia/wizard](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/evalphobia/wizard?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 5 | 6 | Wizard is database/sql management library for multi instance and sharding in golang. 7 | Inspired by [MixedGauge](https://github.com/taiki45/mixed_gauge) 8 | 9 | ## Supported orm list 10 | 11 | - [xorm](https://github.com/go-xorm/xorm) 12 | 13 | ## Quick Usage 14 | 15 | ### Register database clusters 16 | 17 | ```go 18 | import ( 19 | _ "github.com/go-sql-driver/mysql" 20 | "github.com/go-xorm/xorm" 21 | 22 | "github.com/evalphobia/wizard" 23 | ) 24 | 25 | type Blog struct { 26 | ArticleID string `xorm:"article_id pk VARCHAR(255) not null"` 27 | Content string `xorm:"content text"` 28 | } 29 | 30 | type User struct { 31 | ID int64 `xorm:"id pk BIGINT(20) not null" shard_key:"true"` 32 | Name string `xorm:"name VARCHAR(100) not null"` 33 | } 34 | 35 | func main() { 36 | wiz = wizard.NewWizard() 37 | 38 | /** 39 | register normal cluster 40 | */ 41 | 42 | // create engines 43 | blogMaster, _ = xorm.NewEngine("mysql", "root:@tcp(db-master:3306)/blog?charset=utf8") 44 | blogSlave01, _ = xorm.NewEngine("mysql", "root:@tcp(db-slave01:3306)/blog?charset=utf8") 45 | blogSlave02, _ = xorm.NewEngine("mysql", "root:@tcp(db-slave01:3306)/blog?charset=utf8") 46 | 47 | // create cluster with master nodel; CreateCluster(name, master-instance) 48 | blogCluster := wiz.CreateCluster(Blog{}, blogMaster) 49 | blogCluster.RegisterSlave(blogSlave01) // add slaves 50 | blogCluster.RegisterSlave(blogSlave02) 51 | 52 | 53 | /** 54 | register shard clusters 55 | */ 56 | 57 | // shard one 58 | user01Master, _ = xorm.NewEngine("mysql", "root:@/tcp(shard01-master:3306)/users?charset=utf8") 59 | user01Slave01, _ = xorm.NewEngine("mysql", "root:@/tcp(shard01-slave01:3306)/users?charset=utf8") 60 | user01Slave02, _ = xorm.NewEngine("mysql", "root:@/tcp(shard01-slave02:3306)/users?charset=utf8") 61 | 62 | // shard two 63 | user02Master, _ = xorm.NewEngine("mysql", "root:@/tcp(shard02-master:3306)/users?charset=utf8") 64 | user02Slave01, _ = xorm.NewEngine("mysql", "root:@/tcp(shard02-slave01:3306)/users?charset=utf8") 65 | user02Slave02, _ = xorm.NewEngine("mysql", "root:@/tcp(shard02-slave02:3306)/users?charset=utf8") 66 | 67 | // create shard clusters; CreateShardCluster(name, slot-size) 68 | shardClusters := wiz.CreateShardCluster(User{}, 1023) 69 | 70 | // create single shard set #1 71 | shardCluster01 := wizard.NewCluster(user01Master) 72 | shardCluster01.RegisterSlave(user01Slave01) 73 | shardCluster01.RegisterSlave(user01Slave02) 74 | 75 | // create single shard set #2 76 | shardCluster02 := wizard.NewCluster(user02Master) 77 | shardCluster02.RegisterSlave(user02Slave01) 78 | shardCluster02.RegisterSlave(user02Slave02) 79 | 80 | // register shards with slot; RegisterShard(min, max, cluster) 81 | shardClusters.RegisterShard(0, 500, shardCluster01) 82 | shardClusters.RegisterShard(501, 1022, shardCluster02) 83 | } 84 | ``` 85 | 86 | ### Query on database clusters 87 | 88 | ```go 89 | import ( 90 | "fmt" 91 | 92 | "github.com/evalphobia/wizard/orm/xorm" 93 | ) 94 | 95 | func main() { 96 | orm := xorm.New(wiz) 97 | 98 | blog := &Blog{ArticleID: "hello-world"} 99 | has, err := orm.Get(blog, func(s xorm.Session) (bool, error) { 100 | return s.Get(blog) 101 | }) 102 | // => SELECT * FROM blog WHERE article_id = "hello-world"; -- execute on blog SLAVE 103 | 104 | 105 | fmt.Println(has) 106 | fmt.Println(blog.Content) 107 | 108 | user := &User{ 109 | ID: 1600, // 1600 % 1023 = 577; => shard02 110 | Name: "Adam Smith", 111 | } 112 | 113 | err = orm.Begin(user) // => BEGIN; -- execute on user02-MASTER 114 | if err != nil { 115 | panic("Error on transaction beginning") 116 | } 117 | 118 | total, err := orm.Insert(user, func(s xorm.Session) (int64, error) { 119 | return s.Insert(user) 120 | }) 121 | // => INSERT INTO users VALUES(1600, "Adam Smith"); -- execute on user02-MASTER 122 | 123 | fmt.Println(total) 124 | 125 | newUser := &User{ID: 1600} 126 | has, err = orm.GetUsingMaster(newUser, func(s xorm.Session) (bool, error) { 127 | return s.Get(newUser) 128 | }) 129 | // => SELECT * FROM users WHERE id = 1600; -- execute on user02-MASTER 130 | 131 | 132 | err = orm.Commit(user) 133 | // => COMMIT; -- execute on user02-MASTER 134 | if err != nil { 135 | panic("Error on transaction ending") 136 | } 137 | } 138 | ``` 139 | 140 | ### Notes 141 | 142 | - Clusters is selected by name, which can be any value like `string`, `struct`, `pointer`. 143 | - the pointer value automatically converts to the non-pointer value. 144 | - Struct field tag: `shard_key:"true"` is used as a shard-key 145 | - shard_key is divided by slot size and the mod value is used for shard mapping 146 | - string shard_key convert to int64 with CRC64 and divided by slot-size 147 | 148 | ### Other info 149 | 150 | - Slide ["golang.tokyo #7 Wizard"](https://www.slideshare.net/TakumaMorikawa/golangtokyo-7-wizard-database-sharding-library-for-golang) 151 | -------------------------------------------------------------------------------- /errors/errors.go: -------------------------------------------------------------------------------- 1 | package errors 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type Err struct { 9 | Code int 10 | Info string 11 | } 12 | 13 | func (e Err) Error() string { 14 | return e.Info 15 | } 16 | 17 | func NewErr(code int, msg string) Err { 18 | return Err{ 19 | Code: code, 20 | Info: msg, 21 | } 22 | } 23 | 24 | func NewErrNilDB(name interface{}) Err { 25 | return Err{Code: 10000, Info: "cannot find db, name=" + fmt.Sprint(name)} 26 | } 27 | 28 | func NewErrNilDBs(es []error) Err { 29 | var messages []string 30 | for _, err := range es { 31 | messages = append(messages, err.Error()) 32 | } 33 | return Err{Code: 10001, Info: strings.Join(messages, " || ")} 34 | } 35 | 36 | func NewErrAlreadyRegistared(name interface{}) Err { 37 | return Err{Code: 11001, Info: "already registered table name=" + fmt.Sprint(name)} 38 | } 39 | 40 | func NewErrSlotSizeMin(min int64) Err { 41 | return Err{Code: 11002, Info: fmt.Sprintf("minimun slot size must be positive interger, value=%d", min)} 42 | } 43 | 44 | func NewErrSlotSizeMax(max, slot int64) Err { 45 | return Err{Code: 11003, Info: fmt.Sprintf("maximum slot size is out of range, DefinedSize=%d GivenSize=%d", slot, max)} 46 | } 47 | 48 | func NewErrSlotMinOverlapped(size int64) Err { 49 | return Err{Code: 11004, Info: fmt.Sprintf("minimun slot size is overlapped, value=%d", size)} 50 | } 51 | 52 | func NewErrSlotMaxOverlapped(size int64) Err { 53 | return Err{Code: 11005, Info: fmt.Sprintf("maximun slot size is overlapped, value=%d", size)} 54 | } 55 | 56 | func NewErrNoSession(name interface{}) Err { 57 | return Err{Code: 20001, Info: "cannot find session, name=" + fmt.Sprint(name)} 58 | } 59 | 60 | func NewErrDuplicateTx() Err { 61 | return Err{Code: 20002, Info: "transaction already exists"} 62 | } 63 | 64 | func NewErrWrongTx() Err { 65 | return Err{Code: 20003, Info: "something wrong with the transaction"} 66 | } 67 | func NewErrCommitAll(es []error) Err { 68 | messages := []string{"commit all error: "} 69 | for _, err := range es { 70 | messages = append(messages, err.Error()) 71 | } 72 | return Err{Code: 20004, Info: strings.Join(messages, " ")} 73 | } 74 | 75 | func NewErrRollbackAll(es []error) Err { 76 | messages := []string{"rollback all error: "} 77 | for _, err := range es { 78 | messages = append(messages, err.Error()) 79 | } 80 | return Err{Code: 20005, Info: strings.Join(messages, " ")} 81 | } 82 | 83 | func NewErrAnotherTx(name interface{}) Err { 84 | return Err{Code: 20006, Info: "transaction already exists, db=" + fmt.Sprint(name)} 85 | } 86 | 87 | func NewErrParallelQuery(es []error) Err { 88 | messages := []string{"parallel query error: "} 89 | for _, err := range es { 90 | messages = append(messages, err.Error()) 91 | } 92 | return Err{Code: 30001, Info: strings.Join(messages, " || ")} 93 | } 94 | 95 | func NewErrArgType(msg string) Err { 96 | return Err{Code: 30002, Info: msg} 97 | } 98 | -------------------------------------------------------------------------------- /hash.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | import ( 4 | "fmt" 5 | "hash/crc64" 6 | "io" 7 | ) 8 | 9 | var hashTable = crc64.MakeTable(crc64.ISO) 10 | 11 | // getInt64 returns int64 value 12 | func getInt64(v interface{}) int64 { 13 | switch t := v.(type) { 14 | case int64: 15 | return t 16 | case int: 17 | return int64(t) 18 | case int8: 19 | return int64(t) 20 | case int16: 21 | return int64(t) 22 | case int32: 23 | return int64(t) 24 | case uint: 25 | return int64(t) 26 | case uint8: 27 | return int64(t) 28 | case uint16: 29 | return int64(t) 30 | case uint32: 31 | return int64(t) 32 | case uint64: 33 | return int64(t) 34 | case float32: 35 | return int64(t) 36 | case float64: 37 | return int64(t) 38 | } 39 | return hashToInt64(v) 40 | } 41 | 42 | // hashToInt64 converts any value to int64 using crc64 43 | func hashToInt64(v interface{}) int64 { 44 | str := fmt.Sprint(v) 45 | h := crc64.New(hashTable) 46 | io.WriteString(h, str) 47 | return int64(h.Sum64()) 48 | } 49 | -------------------------------------------------------------------------------- /hash_test.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func testAssertInt64(a *assert.Assertions, v interface{}) { 10 | a.Equal(int64(99), getInt64(v)) 11 | } 12 | 13 | func TestGetInt64(t *testing.T) { 14 | a := assert.New(t) 15 | 16 | var vInt = 99 17 | testAssertInt64(a, vInt) 18 | 19 | var vInt8 int8 = 99 20 | testAssertInt64(a, vInt8) 21 | 22 | var vInt16 int16 = 99 23 | testAssertInt64(a, vInt16) 24 | 25 | var vInt32 int32 = 99 26 | testAssertInt64(a, vInt32) 27 | 28 | var vInt64 int64 = 99 29 | testAssertInt64(a, vInt64) 30 | 31 | var vUInt uint = 99 32 | testAssertInt64(a, vUInt) 33 | 34 | var vUInt8 uint8 = 99 35 | testAssertInt64(a, vUInt8) 36 | 37 | var vUInt16 uint16 = 99 38 | testAssertInt64(a, vUInt16) 39 | 40 | var vUInt32 uint32 = 99 41 | testAssertInt64(a, vUInt32) 42 | 43 | var vUInt64 uint64 = 99 44 | testAssertInt64(a, vUInt64) 45 | 46 | var vFloat32 float32 = 99 47 | testAssertInt64(a, vFloat32) 48 | 49 | var vFloat64 float64 = 99 50 | testAssertInt64(a, vFloat64) 51 | 52 | var vStr = "foobar" 53 | a.Equal(int64(3297785893580976128), getInt64(vStr)) 54 | 55 | type myStruct struct{} 56 | a.Equal(int64(2612580365084131328), getInt64(myStruct{})) 57 | } 58 | 59 | func TestHashToInt64(t *testing.T) { 60 | assert := assert.New(t) 61 | 62 | assert.Equal(int64(3297785893580976128), hashToInt64("foobar")) 63 | 64 | type myStruct struct{} 65 | assert.Equal(int64(2612580365084131328), hashToInt64(myStruct{})) 66 | } 67 | -------------------------------------------------------------------------------- /helper.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | // UseMaster returns db master 4 | func (w *Wizard) UseMaster(obj interface{}) interface{} { 5 | cluster := w.Select(obj) 6 | if cluster == nil { 7 | return nil 8 | } 9 | db := cluster.Master() 10 | if db == nil { 11 | return nil 12 | } 13 | return db.DB() 14 | } 15 | 16 | // UseMasters returns all db master instances for sharding 17 | func (w *Wizard) UseMasters(obj interface{}) []interface{} { 18 | var results []interface{} 19 | c := w.getCluster(obj) 20 | if c == nil { 21 | return results 22 | } 23 | for _, node := range c.Masters() { 24 | db := node.DB() 25 | if db == nil { 26 | continue 27 | } 28 | results = append(results, db) 29 | } 30 | return results 31 | } 32 | 33 | // UseSlave randomly returns db slave from the slaves 34 | // if any slave is not set, master is returned 35 | func (w *Wizard) UseSlave(obj interface{}) interface{} { 36 | cluster := w.Select(obj) 37 | if cluster == nil { 38 | return nil 39 | } 40 | db := cluster.Slave() 41 | if db == nil { 42 | return nil 43 | } 44 | return db.DB() 45 | } 46 | 47 | // UseSlaves randomly returns all db slave instances for sharding 48 | func (w *Wizard) UseSlaves(obj interface{}) []interface{} { 49 | var results []interface{} 50 | c := w.getCluster(obj) 51 | if c == nil { 52 | return results 53 | } 54 | for _, node := range c.Slaves() { 55 | db := node.DB() 56 | if db == nil { 57 | continue 58 | } 59 | results = append(results, db) 60 | } 61 | return results 62 | } 63 | 64 | // UseMasterByKey returns db master for sharding by shard key 65 | func (w *Wizard) UseMasterByKey(obj interface{}, key interface{}) interface{} { 66 | cluster := w.SelectByKey(obj, key) 67 | if cluster == nil { 68 | return nil 69 | } 70 | db := cluster.Master() 71 | if db == nil { 72 | return nil 73 | } 74 | return db.DB() 75 | } 76 | 77 | // UseSlaveByKey randomly returns db slave for sharding by shard key 78 | // if any slave is not set in the cluster, master is returned 79 | func (w *Wizard) UseSlaveByKey(obj interface{}, key interface{}) interface{} { 80 | cluster := w.SelectByKey(obj, key) 81 | if cluster == nil { 82 | return nil 83 | } 84 | db := cluster.Slave() 85 | if db == nil { 86 | return nil 87 | } 88 | return db.DB() 89 | } 90 | -------------------------------------------------------------------------------- /helper_test.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestUseMaster(t *testing.T) { 10 | assert := assert.New(t) 11 | 12 | var wiz *Wizard 13 | var c *StandardCluster 14 | 15 | wiz = NewWizard() 16 | c = wiz.CreateCluster("country_table", "db-master") 17 | c.RegisterSlave("db-slave") 18 | 19 | assert.Equal("db-master", wiz.UseMaster("country_table")) 20 | assert.Nil(wiz.UseMaster("city_table"), "Non registered name") 21 | 22 | var s *ShardCluster 23 | s = wiz.CreateShardCluster("user_table", 997) 24 | s.RegisterShard(0, 499, NewCluster("shard01-master")) 25 | s.RegisterShard(500, 996, NewCluster("shard02-master")) 26 | assert.Equal("shard01-master", wiz.UseMaster("user_table")) 27 | } 28 | 29 | func TestUseMasters(t *testing.T) { 30 | assert := assert.New(t) 31 | 32 | var wiz *Wizard 33 | wiz = NewWizard() 34 | wiz.CreateCluster("country_table", "db-master") 35 | 36 | assert.Contains(wiz.UseMasters("country_table"), "db-master") 37 | assert.Empty(wiz.UseMasters("city_table"), "Non registered name") 38 | 39 | var s *ShardCluster 40 | s = wiz.CreateShardCluster("user_table", 997) 41 | s.RegisterShard(0, 499, NewCluster("shard01-master")) 42 | s.RegisterShard(500, 996, NewCluster("shard02-master")) 43 | 44 | assert.Contains(wiz.UseMasters("user_table"), "shard01-master") 45 | assert.Contains(wiz.UseMasters("user_table"), "shard02-master") 46 | assert.Len(wiz.UseMasters("user_table"), 2) 47 | } 48 | 49 | func TestUseSlave(t *testing.T) { 50 | assert := assert.New(t) 51 | 52 | var wiz *Wizard 53 | var c *StandardCluster 54 | 55 | wiz = NewWizard() 56 | c = wiz.CreateCluster("country_table", "db-master") 57 | assert.Equal("db-master", wiz.UseSlave("country_table"), "Slave() return master when no slaves exists") 58 | 59 | c.RegisterSlave("db-slave") 60 | assert.Equal("db-slave", wiz.UseSlave("country_table")) 61 | 62 | assert.Nil(wiz.UseSlave("city_table"), "Non registered name") 63 | } 64 | 65 | func TestUseMasterByKey(t *testing.T) { 66 | assert := assert.New(t) 67 | 68 | var wiz *Wizard 69 | var c *StandardCluster 70 | 71 | wiz = NewWizard() 72 | c = wiz.CreateCluster("country_table", "db-master") 73 | c.RegisterSlave("db-slave") 74 | 75 | assert.Equal("db-master", wiz.UseMasterByKey("country_table", 1)) 76 | assert.Nil(wiz.UseMasterByKey("city_table", 1), "Non registered name") 77 | 78 | var s *ShardCluster 79 | s = wiz.CreateShardCluster("user_table", 997) 80 | s.RegisterShard(0, 499, NewCluster("shard01-master")) 81 | s.RegisterShard(500, 996, NewCluster("shard02-master")) 82 | assert.Equal("shard01-master", wiz.UseMasterByKey("user_table", 499)) 83 | assert.Equal("shard02-master", wiz.UseMasterByKey("user_table", 500)) 84 | assert.Equal("shard02-master", wiz.UseMasterByKey("user_table", 996)) 85 | assert.Equal("shard01-master", wiz.UseMasterByKey("user_table", 997)) 86 | } 87 | 88 | // TODO: add test for multiple slaves 89 | func TestUseSlaveByKey(t *testing.T) { 90 | assert := assert.New(t) 91 | 92 | var wiz *Wizard 93 | var c *StandardCluster 94 | 95 | wiz = NewWizard() 96 | c = wiz.CreateCluster("country_table", "db-master") 97 | c.RegisterSlave("db-slave") 98 | 99 | assert.Equal("db-master", wiz.UseMasterByKey("country_table", 1)) 100 | assert.Nil(wiz.UseMasterByKey("city_table", 1), "Non registered name") 101 | 102 | var s *ShardCluster 103 | s = wiz.CreateShardCluster("user_table", 997) 104 | c1 := NewCluster("shard01-master") 105 | c1.RegisterSlave("shard01-slave") 106 | c2 := NewCluster("shard02-master") 107 | c2.RegisterSlave("shard02-slave") 108 | s.RegisterShard(0, 499, c1) 109 | s.RegisterShard(500, 996, c2) 110 | assert.Equal("shard01-slave", wiz.UseSlaveByKey("user_table", 499)) 111 | assert.Equal("shard02-slave", wiz.UseSlaveByKey("user_table", 500)) 112 | assert.Equal("shard02-slave", wiz.UseSlaveByKey("user_table", 996)) 113 | assert.Equal("shard01-slave", wiz.UseSlaveByKey("user_table", 997)) 114 | } 115 | -------------------------------------------------------------------------------- /node.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | // Node is struct for single database instance 4 | type Node struct { 5 | db interface{} // db connection 6 | } 7 | 8 | // NewNode returns initialized Node 9 | func NewNode(db interface{}) *Node { 10 | return &Node{db: db} 11 | } 12 | 13 | // DB is used for returning database connection 14 | func (n *Node) DB() interface{} { 15 | return n.db 16 | } 17 | -------------------------------------------------------------------------------- /node_test.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestNewNode(t *testing.T) { 10 | assert := assert.New(t) 11 | 12 | var n *Node 13 | n = NewNode("db") 14 | assert.Equal("db", n.db, "db should be saved on NewNode()") 15 | 16 | type TestDB struct{} 17 | n = NewNode(TestDB{}) 18 | assert.Equal(TestDB{}, n.db, "db should be saved on NewNode()") 19 | } 20 | 21 | func TestDB(t *testing.T) { 22 | assert := assert.New(t) 23 | 24 | var n *Node 25 | n = NewNode("db") 26 | assert.Equal("db", n.DB(), "DB() should equal to Node.db") 27 | 28 | type TestDB struct{} 29 | n = NewNode(TestDB{}) 30 | assert.Equal(TestDB{}, n.DB(), "db should equal to Node.db") 31 | } 32 | -------------------------------------------------------------------------------- /orm/xorm/.gitignore: -------------------------------------------------------------------------------- 1 | xorm_test_*.db -------------------------------------------------------------------------------- /orm/xorm/by_session.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "github.com/evalphobia/wizard" 5 | ) 6 | 7 | // GetBySession performs wrapped function for xorm.Sesion.Get() 8 | func GetBySession(s Session, fn func(Session) (bool, error)) (bool, error) { 9 | return fn(s) 10 | } 11 | 12 | // FindBySession performs wrapped function for xorm.Sesion.Find() 13 | func FindBySession(s Session, fn func(Session) error) error { 14 | return fn(s) 15 | } 16 | 17 | // CountBySession performs wrapped function for xorm.Sesion.Count() 18 | func CountBySession(s Session, fn func(Session) (int, error)) (int, error) { 19 | return fn(s) 20 | } 21 | 22 | // InsertBySession performs wrapped function for xorm.Sesion.Insert() 23 | func InsertBySession(s Session, fn func(Session) (int64, error)) (int64, error) { 24 | return fn(s) 25 | } 26 | 27 | // UpdateBySession performs wrapped function for xorm.Sesion.Update() 28 | func UpdateBySession(s Session, fn func(Session) (bool, error)) (bool, error) { 29 | return fn(s) 30 | } 31 | 32 | // NormalizeValue returns non-pointer value 33 | func NormalizeValue(p interface{}) interface{} { 34 | return wizard.NormalizeValue(p) 35 | } 36 | -------------------------------------------------------------------------------- /orm/xorm/by_session_test.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/go-xorm/xorm" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestGetBySession(t *testing.T) { 11 | assert := assert.New(t) 12 | sess := &xorm.Session{} 13 | 14 | fn := func(s Session) (bool, error) { 15 | assert.Equal(sess, s) 16 | return true, nil 17 | } 18 | b, err := GetBySession(sess, fn) 19 | assert.True(b) 20 | assert.Nil(err) 21 | } 22 | 23 | func TestFindBySession(t *testing.T) { 24 | assert := assert.New(t) 25 | sess := &xorm.Session{} 26 | 27 | fn := func(s Session) error { 28 | assert.Equal(sess, s) 29 | return nil 30 | } 31 | err := FindBySession(sess, fn) 32 | assert.Nil(err) 33 | } 34 | 35 | func TestCountBySession(t *testing.T) { 36 | assert := assert.New(t) 37 | sess := &xorm.Session{} 38 | 39 | fn := func(s Session) (int, error) { 40 | assert.Equal(sess, s) 41 | return 99, nil 42 | } 43 | count, err := CountBySession(sess, fn) 44 | assert.Equal(99, count) 45 | assert.Nil(err) 46 | } 47 | 48 | func TestInsertBySession(t *testing.T) { 49 | assert := assert.New(t) 50 | sess := &xorm.Session{} 51 | 52 | fn := func(s Session) (int64, error) { 53 | assert.Equal(sess, s) 54 | return 99, nil 55 | } 56 | affected, err := InsertBySession(sess, fn) 57 | assert.Equal(int64(99), affected) 58 | assert.Nil(err) 59 | } 60 | 61 | func TestUpdateBySession(t *testing.T) { 62 | assert := assert.New(t) 63 | sess := &xorm.Session{} 64 | 65 | fn := func(s Session) (bool, error) { 66 | assert.Equal(sess, s) 67 | return true, nil 68 | } 69 | b, err := UpdateBySession(sess, fn) 70 | assert.True(b) 71 | assert.Nil(err) 72 | } 73 | 74 | func TestNormalizeValue(t *testing.T) { 75 | assert := assert.New(t) 76 | 77 | assert.Equal("xorm.Xorm", NormalizeValue(Xorm{}), "Struct should return the type name") 78 | assert.Equal("xorm.Xorm", NormalizeValue(&Xorm{}), "Struct pointer should return the type name") 79 | 80 | valueString := "foobar" 81 | assert.Equal(valueString, NormalizeValue(valueString)) 82 | assert.Equal(valueString, NormalizeValue(&valueString)) 83 | 84 | valueInt := 99 85 | assert.Equal(valueInt, NormalizeValue(valueInt)) 86 | assert.Equal(valueInt, NormalizeValue(&valueInt)) 87 | 88 | valueSlice := []string{"a", "b", "c"} 89 | assert.Equal(valueSlice, NormalizeValue(valueSlice)) 90 | assert.Equal(valueSlice, NormalizeValue(&valueSlice)) 91 | 92 | valueMap := map[interface{}]interface{}{"key": "value", 100: 403} 93 | assert.Equal(valueMap, NormalizeValue(valueMap)) 94 | assert.Equal(valueMap, NormalizeValue(&valueMap)) 95 | } 96 | -------------------------------------------------------------------------------- /orm/xorm/condition.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | type Where struct { 4 | Statement string 5 | Args []interface{} 6 | } 7 | 8 | func NewWhere(s string, args ...interface{}) Where { 9 | return Where{ 10 | Statement: s, 11 | Args: args, 12 | } 13 | } 14 | 15 | type Order struct { 16 | Name string 17 | OrderByDesc bool 18 | } 19 | 20 | // FindCondition is conditions for FindParallel 21 | type FindCondition struct { 22 | Table interface{} 23 | Columns []string 24 | Selects string 25 | Where []Where 26 | WhereIn []Where 27 | Group []string 28 | Havings []string 29 | OrderBy []Order 30 | Limit int 31 | Offset int 32 | } 33 | 34 | func NewFindCondition(table interface{}) FindCondition { 35 | return FindCondition{ 36 | Table: table, 37 | } 38 | } 39 | 40 | func (c *FindCondition) Cols(cols ...string) { 41 | c.Columns = append(c.Columns, cols...) 42 | } 43 | 44 | func (c *FindCondition) Select(str string) { 45 | c.Selects = str 46 | } 47 | 48 | func (c *FindCondition) And(s string, args ...interface{}) { 49 | c.Where = append(c.Where, NewWhere(s, args...)) 50 | } 51 | 52 | func (c *FindCondition) In(s string, args ...interface{}) { 53 | c.WhereIn = append(c.WhereIn, NewWhere(s, args...)) 54 | } 55 | 56 | func (c *FindCondition) GroupBy(s ...string) { 57 | c.Group = append(c.Group, s...) 58 | } 59 | 60 | func (c *FindCondition) Having(s ...string) { 61 | c.Havings = append(c.Havings, s...) 62 | } 63 | 64 | func (c *FindCondition) OrderByAsc(s string) { 65 | o := Order{ 66 | Name: s, 67 | } 68 | c.OrderBy = append(c.OrderBy, o) 69 | } 70 | 71 | func (c *FindCondition) OrderByDesc(s string) { 72 | o := Order{ 73 | Name: s, 74 | OrderByDesc: true, 75 | } 76 | c.OrderBy = append(c.OrderBy, o) 77 | } 78 | 79 | func (c *FindCondition) SetLimit(i int) { 80 | c.Limit = i 81 | } 82 | 83 | func (c *FindCondition) SetOffset(i int) { 84 | c.Offset = i 85 | } 86 | 87 | // UpdateCondition is conditions for UpdateParallel 88 | type UpdateCondition struct { 89 | Table interface{} 90 | Where []Where 91 | WhereIn []Where 92 | AllColumns bool 93 | Columns []string 94 | MustColumns []string 95 | OmitColumns []string 96 | NullableColumns []string 97 | Increments []Where 98 | Decrements []Where 99 | } 100 | 101 | func NewUpdateCondition(table interface{}) UpdateCondition { 102 | return UpdateCondition{ 103 | Table: table, 104 | } 105 | } 106 | 107 | func (c *UpdateCondition) And(s string, args ...interface{}) { 108 | c.Where = append(c.Where, NewWhere(s, args...)) 109 | } 110 | 111 | func (c *UpdateCondition) In(s string, args ...interface{}) { 112 | c.WhereIn = append(c.WhereIn, NewWhere(s, args...)) 113 | } 114 | 115 | func (c *UpdateCondition) AllCols() { 116 | c.AllColumns = true 117 | } 118 | 119 | func (c *UpdateCondition) Cols(cols ...string) { 120 | c.Columns = append(c.Columns, cols...) 121 | } 122 | 123 | func (c *UpdateCondition) MustCols(cols ...string) { 124 | c.MustColumns = append(c.MustColumns, cols...) 125 | } 126 | 127 | func (c *UpdateCondition) Omit(cols ...string) { 128 | c.OmitColumns = append(c.OmitColumns, cols...) 129 | } 130 | 131 | func (c *UpdateCondition) Nullable(cols ...string) { 132 | c.NullableColumns = append(c.NullableColumns, cols...) 133 | } 134 | 135 | func (c *UpdateCondition) Incr(s string, args ...interface{}) { 136 | c.Increments = append(c.Increments, NewWhere(s, args...)) 137 | } 138 | 139 | func (c *UpdateCondition) Decr(s string, args ...interface{}) { 140 | c.Decrements = append(c.Decrements, NewWhere(s, args...)) 141 | } 142 | -------------------------------------------------------------------------------- /orm/xorm/examples_test.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "fmt" 5 | 6 | _ "github.com/go-sql-driver/mysql" 7 | "github.com/go-xorm/xorm" 8 | 9 | "github.com/evalphobia/wizard" 10 | ) 11 | 12 | var w *wizard.Wizard 13 | var engine1, engine2, engine3, engine4 *xorm.Engine 14 | 15 | type User struct { 16 | ID int64 `xorm:"pk not null" shard_key:"true"` 17 | Name string `xorm:"varchar(255) not null"` 18 | } 19 | 20 | func _ExampleRegisterStandardDatabases() { 21 | engine1, _ = xorm.NewEngine("mysql", "root:@/example_user?charset=utf8") 22 | engine2, _ = xorm.NewEngine("mysql", "root:@/example_user?charset=utf8") 23 | engine3, _ = xorm.NewEngine("mysql", "root:@/example_foobar?charset=utf8") 24 | engine4, _ = xorm.NewEngine("mysql", "root:@/example_other?charset=utf8") 25 | 26 | w = wizard.NewWizard() 27 | stndardCluster := w.CreateCluster(User{}, engine1) // engine is master database used for table of User{} 28 | stndardCluster.RegisterSlave(engine2) // add slave 29 | 30 | _ = w.CreateCluster("foobar", engine3) // engine3 is master database used for table of foobar 31 | 32 | stndardCluster = wizard.NewCluster(engine4) 33 | w.SetDefault(stndardCluster) // engine4 is master database used for all the other tables 34 | } 35 | 36 | func _ExampleRegisterShardedDatabase() { 37 | engine1, _ = xorm.NewEngine("mysql", "root:@/example_user_a?charset=utf8") 38 | engine2, _ = xorm.NewEngine("mysql", "root:@/example_user_a?charset=utf8") 39 | engine3, _ = xorm.NewEngine("mysql", "root:@/example_user_b?charset=utf8") 40 | engine4, _ = xorm.NewEngine("mysql", "root:@/example_user_b?charset=utf8") 41 | 42 | w = wizard.NewWizard() 43 | shardClusters := w.CreateShardCluster(&User{}, 997) // create shard clusters for User{} with slotsize 997 44 | standardClusterA := wizard.NewCluster(engine1) 45 | standardClusterA.RegisterSlave(engine2) 46 | shardClusters.RegisterShard(0, 500, standardClusterA) 47 | 48 | standardClusterB := wizard.NewCluster(engine3) 49 | standardClusterB.RegisterSlave(engine4) 50 | shardClusters.RegisterShard(501, 996, standardClusterB) 51 | } 52 | 53 | func _ExampleGet() { 54 | orm := New(w) 55 | 56 | user := &User{ID: 99} 57 | has, err := orm.Get(user, func(s Session) (bool, error) { 58 | return s.Get(user) 59 | }) 60 | if err != nil { 61 | fmt.Printf("error occured, %s", err.Error()) 62 | return 63 | } 64 | 65 | if !has { 66 | fmt.Printf("cannot find the user. id:%d", user.ID) 67 | return 68 | } 69 | fmt.Printf("user found. id:%d, name:%s", user.ID, user.Name) 70 | } 71 | 72 | func _ExampleInsert() { 73 | orm := New(w) 74 | 75 | user := &User{ID: 99, Name: "Adam Smith"} 76 | total, err := orm.Insert(testID, user, func(s Session) (int64, error) { 77 | return s.Insert(testID, user) 78 | }) 79 | if err != nil { 80 | fmt.Printf("error occured, %s", err.Error()) 81 | return 82 | } 83 | if total < 1 { 84 | fmt.Printf("insert failed. id:%d", user.ID) 85 | return 86 | } 87 | } 88 | 89 | func _ExampleTransaction() { 90 | var err error 91 | orm := New(w) 92 | 93 | user1 := &User{ID: 1, Name: "Adam Smith"} 94 | user2 := &User{ID: 2, Name: "Benjamin Franklin"} 95 | 96 | s1, _ := orm.Transaction(testID, user1) 97 | s2, _ := orm.Transaction(testID, user2) 98 | 99 | _, err = s1.Insert(user1) 100 | if err != nil { 101 | orm.RollbackAll(testID) 102 | return 103 | } 104 | _, err = s2.Insert(user2) 105 | if err != nil { 106 | orm.RollbackAll(testID) 107 | return 108 | } 109 | orm.CommitAll(testID) 110 | } 111 | 112 | func _ExampleTransactionAuto() { 113 | var err error 114 | orm := New(w) 115 | 116 | user1 := &User{ID: 1, Name: "Adam Smith"} 117 | user2 := &User{ID: 2, Name: "Benjamin Franklin"} 118 | 119 | orm.SetAutoTransaction(testID, true) 120 | _, err = orm.Insert(testID, user1, func(s Session) (int64, error) { 121 | return s.Insert(user1) 122 | }) 123 | if err != nil { 124 | orm.RollbackAll(testID) 125 | return 126 | } 127 | _, err = orm.Insert(testID, user2, func(s Session) (int64, error) { 128 | return s.Insert(user2) 129 | }) 130 | if err != nil { 131 | orm.RollbackAll(testID) 132 | return 133 | } 134 | orm.CommitAll(testID) 135 | } 136 | -------------------------------------------------------------------------------- /orm/xorm/interface.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "database/sql" 5 | "io" 6 | 7 | "github.com/go-xorm/core" 8 | "github.com/go-xorm/xorm" 9 | ) 10 | 11 | // ORM is wrapper interface for wizard.Xorm 12 | type ORM interface { 13 | ReadOnly(Identifier, bool) 14 | IsReadOnly(Identifier) bool 15 | SetAutoTransaction(Identifier, bool) 16 | IsAutoTransaction(Identifier) bool 17 | 18 | Master(interface{}) Engine 19 | MasterByKey(interface{}, interface{}) Engine 20 | Masters(interface{}) []Engine 21 | Slave(interface{}) Engine 22 | SlaveByKey(interface{}, interface{}) Engine 23 | Slaves(interface{}) []Engine 24 | 25 | Get(interface{}, func(Session) (bool, error)) (bool, error) 26 | Find(interface{}, func(Session) error) error 27 | Count(interface{}, func(Session) (int64, error)) (int64, error) 28 | Insert(Identifier, interface{}, func(Session) (int64, error)) (int64, error) 29 | Update(Identifier, interface{}, func(Session) (int64, error)) (int64, error) 30 | FindParallel(interface{}, interface{}, string, ...interface{}) error 31 | FindParallelByCondition(interface{}, FindCondition) error 32 | CountParallelByCondition(interface{}, FindCondition) ([]int64, error) 33 | UpdateParallelByCondition(interface{}, UpdateCondition) (int64, error) 34 | GetUsingMaster(Identifier, interface{}, func(Session) (bool, error)) (bool, error) 35 | FindUsingMaster(Identifier, interface{}, func(Session) error) error 36 | CountUsingMaster(Identifier, interface{}, func(Session) (int64, error)) (int64, error) 37 | 38 | NewMasterSession(interface{}) (Session, error) 39 | 40 | UseMasterSession(Identifier, interface{}) (Session, error) 41 | UseMasterSessionByKey(Identifier, interface{}, interface{}) (Session, error) 42 | UseSlaveSession(Identifier, interface{}) (Session, error) 43 | UseSlaveSessionByKey(Identifier, interface{}, interface{}) (Session, error) 44 | UseAllMasterSessions(Identifier, interface{}) ([]Session, error) 45 | 46 | ForceNewTransaction(interface{}) (Session, error) 47 | Transaction(Identifier, interface{}) (Session, error) 48 | TransactionByKey(Identifier, interface{}, interface{}) (Session, error) 49 | AutoTransaction(Identifier, interface{}, Session) error 50 | CommitAll(Identifier) error 51 | RollbackAll(Identifier) error 52 | CloseAll(Identifier) 53 | } 54 | 55 | // Session is interface for xorm.Session 56 | type Session interface { 57 | xorm.Interface 58 | 59 | And(interface{}, ...interface{}) *xorm.Session 60 | Begin() error 61 | Close() 62 | Commit() error 63 | CreateTable(interface{}) error 64 | DropTable(interface{}) error 65 | ForUpdate() *xorm.Session 66 | Having(string) *xorm.Session 67 | Id(interface{}) *xorm.Session 68 | Init() 69 | InsertMulti(interface{}) (int64, error) 70 | LastSQL() (string, []interface{}) 71 | NoAutoTime() *xorm.Session 72 | Nullable(...string) *xorm.Session 73 | Or(interface{}, ...interface{}) *xorm.Session 74 | Rollback() error 75 | Select(string) *xorm.Session 76 | Sql(string, ...interface{}) *xorm.Session 77 | Sync2(...interface{}) error 78 | } 79 | 80 | // Engine is interface for xorm.Engine 81 | type Engine interface { 82 | xorm.EngineInterface 83 | 84 | After(func(interface{})) *xorm.Session 85 | AutoIncrStr() string 86 | Cascade(...bool) *xorm.Session 87 | ClearCacheBean(interface{}, string) error 88 | Close() error 89 | DataSourceName() string 90 | DriverName() string 91 | DumpAll(io.Writer, ...core.DbType) error 92 | GobRegister(interface{}) *xorm.Engine 93 | Having(string) *xorm.Session 94 | Id(interface{}) *xorm.Session 95 | Import(io.Reader) ([]sql.Result, error) 96 | ImportFile(string) ([]sql.Result, error) 97 | NoCache() *xorm.Session 98 | NoCascade() *xorm.Session 99 | Nullable(...string) *xorm.Session 100 | QuoteStr() string 101 | Select(string) *xorm.Session 102 | SetColumnMapper(core.IMapper) 103 | SetDisableGlobalCache(bool) 104 | SetTableMapper(core.IMapper) 105 | Sql(string, ...interface{}) *xorm.Session 106 | SupportInsertMany() bool 107 | } 108 | 109 | var _ Session = &xorm.Session{} 110 | var _ Engine = &xorm.Engine{} 111 | -------------------------------------------------------------------------------- /orm/xorm/interface_test.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/go-xorm/xorm" 8 | _ "github.com/mattn/go-sqlite3" 9 | ) 10 | 11 | func TestInterface(t *testing.T) { 12 | wiz := testCreateWizard() 13 | 14 | var orm ORM 15 | orm = New(wiz) 16 | _ = orm 17 | 18 | var e Engine 19 | name := "test_if.db" 20 | e, _ = xorm.NewEngine("sqlite3", name) 21 | os.Remove(name) 22 | 23 | var s Session 24 | s = e.NewSession() 25 | _ = s 26 | } 27 | -------------------------------------------------------------------------------- /orm/xorm/xorm.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import "github.com/evalphobia/wizard" 4 | 5 | // Xorm manages database sessions for xorm 6 | type Xorm struct { 7 | *XormWizard 8 | *XormFunction 9 | *XormSessionManager 10 | *XormParallel 11 | 12 | Wiz *wizard.Wizard 13 | } 14 | 15 | // New creates initialized *Xorm 16 | func New(wiz *wizard.Wizard) *Xorm { 17 | orm := &Xorm{} 18 | orm.Wiz = wiz 19 | orm.XormFunction = &XormFunction{orm: orm} 20 | orm.XormWizard = &XormWizard{wiz} 21 | orm.XormSessionManager = &XormSessionManager{ 22 | orm: orm, 23 | list: make(map[Identifier]*SessionList), 24 | } 25 | orm.XormParallel = &XormParallel{orm: orm} 26 | return orm 27 | } 28 | -------------------------------------------------------------------------------- /orm/xorm/xorm_function.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "github.com/evalphobia/wizard/errors" 5 | ) 6 | 7 | // XormFunction manages xorm functions 8 | type XormFunction struct { 9 | orm *Xorm 10 | } 11 | 12 | // Get executes xorm.Sessions.Get() in slave db 13 | func (xfn XormFunction) Get(obj interface{}, fn func(Session) (bool, error)) (bool, error) { 14 | db := xfn.orm.Slave(obj) 15 | if db == nil { 16 | return false, errors.NewErrNilDB(NormalizeValue(obj)) 17 | } 18 | return fn(db.NewSession()) 19 | } 20 | 21 | // Find executes xorm.Sessions.Find() in slave db 22 | func (xfn XormFunction) Find(obj interface{}, fn func(Session) error) error { 23 | db := xfn.orm.Slave(obj) 24 | if db == nil { 25 | return errors.NewErrNilDB(NormalizeValue(obj)) 26 | } 27 | return fn(db.NewSession()) 28 | } 29 | 30 | // Count executes xorm.Sessions.Count() in slave db 31 | func (xfn XormFunction) Count(obj interface{}, fn func(Session) (int64, error)) (int64, error) { 32 | db := xfn.orm.Slave(obj) 33 | if db == nil { 34 | return 0, errors.NewErrNilDB(NormalizeValue(obj)) 35 | } 36 | return fn(db.NewSession()) 37 | } 38 | 39 | // Insert executes xorm.Sessions.Insert() in master db 40 | func (xfn XormFunction) Insert(id Identifier, obj interface{}, fn func(Session) (int64, error)) (int64, error) { 41 | if xfn.orm.IsReadOnly(id) { 42 | return 0, nil 43 | } 44 | 45 | s, err := xfn.orm.UseMasterSession(id, obj) 46 | if err != nil { 47 | return 0, err 48 | } 49 | return fn(s) 50 | } 51 | 52 | // Update executes xorm.Sessions.Update() in master db 53 | func (xfn XormFunction) Update(id Identifier, obj interface{}, fn func(Session) (int64, error)) (int64, error) { 54 | if xfn.orm.IsReadOnly(id) { 55 | return 0, nil 56 | } 57 | 58 | s, err := xfn.orm.UseMasterSession(id, obj) 59 | if err != nil { 60 | return 0, err 61 | } 62 | return fn(s) 63 | } 64 | 65 | // GetUsingMaster executes xorm.Sessions.Get() in master db 66 | func (xfn XormFunction) GetUsingMaster(id Identifier, obj interface{}, fn func(Session) (bool, error)) (bool, error) { 67 | s, err := xfn.orm.UseMasterSession(id, obj) 68 | if err != nil { 69 | return false, err 70 | } 71 | return fn(s) 72 | } 73 | 74 | // FindUsingMaster executes xorm.Sessions.Find() in master db 75 | func (xfn XormFunction) FindUsingMaster(id Identifier, obj interface{}, fn func(Session) error) error { 76 | s, err := xfn.orm.UseMasterSession(id, obj) 77 | if err != nil { 78 | return err 79 | } 80 | return fn(s) 81 | } 82 | 83 | // CountUsingMaster executes xorm.Sessions.Count() in master db 84 | func (xfn XormFunction) CountUsingMaster(id Identifier, obj interface{}, fn func(Session) (int64, error)) (int64, error) { 85 | s, err := xfn.orm.UseMasterSession(id, obj) 86 | if err != nil { 87 | return 0, err 88 | } 89 | return fn(s) 90 | } 91 | -------------------------------------------------------------------------------- /orm/xorm/xorm_function_test.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestGet(t *testing.T) { 10 | assert := assert.New(t) 11 | wiz := testCreateWizard() 12 | orm := New(wiz) 13 | 14 | var row interface{} 15 | var has bool 16 | var err error 17 | fn := func(s Session) (bool, error) { 18 | return s.Get(row) 19 | } 20 | 21 | row = &testUser{ID: 1} 22 | has, err = orm.Get(row, fn) 23 | assert.Nil(err) 24 | assert.True(has) 25 | assert.Equal(int64(1), row.(*testUser).ID) 26 | assert.Equal("Adam", row.(*testUser).Name) 27 | 28 | row = &testUser{ID: 501} 29 | has, err = orm.Get(row, fn) 30 | assert.Nil(err) 31 | assert.True(has) 32 | assert.Equal(int64(501), row.(*testUser).ID) 33 | assert.Equal("Betty", row.(*testUser).Name) 34 | 35 | row = &testFoobar{ID: 1} 36 | has, err = orm.Get(row, fn) 37 | assert.Nil(err) 38 | assert.True(has) 39 | assert.Equal(int64(1), row.(*testFoobar).ID) 40 | assert.Equal("foobar#1", row.(*testFoobar).Name) 41 | 42 | row = &testCompany{ID: 2} 43 | has, err = orm.Get(row, fn) 44 | assert.Nil(err) 45 | assert.True(has) 46 | assert.Equal(int64(2), row.(*testCompany).ID) 47 | assert.Equal("BOX", row.(*testCompany).Name) 48 | 49 | // not found 50 | row = &testUser{ID: 4} 51 | has, err = orm.Get(row, fn) 52 | assert.Nil(err) 53 | assert.False(has) 54 | assert.Equal(int64(4), row.(*testUser).ID) 55 | assert.Equal("", row.(*testUser).Name) 56 | 57 | // not found 58 | row = &testUser{ID: 504} 59 | has, err = orm.Get(row, fn) 60 | assert.Nil(err) 61 | assert.False(has) 62 | assert.Equal(int64(504), row.(*testUser).ID) 63 | assert.Equal("", row.(*testUser).Name) 64 | } 65 | 66 | func TestFind(t *testing.T) { 67 | assert := assert.New(t) 68 | wiz := testCreateWizard() 69 | orm := New(wiz) 70 | 71 | var err error 72 | 73 | // user A 74 | var usersA []*testUser 75 | err = orm.Find(testUser{}, func(s Session) error { 76 | s.Where("id > 1") 77 | return s.Find(&usersA) 78 | }) 79 | assert.Nil(err) 80 | assert.Len(usersA, 2) 81 | 82 | // user B 83 | var usersB []*testUser 84 | err = orm.Find(testUser{ID: 500}, func(s Session) error { 85 | s.Where("id > 1") 86 | return s.Find(&usersB) 87 | }) 88 | assert.Nil(err) 89 | assert.Len(usersB, 3) 90 | 91 | var foobars []*testFoobar 92 | err = orm.Find(&testFoobar{}, func(s Session) error { 93 | s.And("id = 1") 94 | return s.Find(&foobars) 95 | }) 96 | assert.Nil(err) 97 | assert.Len(foobars, 1) 98 | 99 | var companies []*testCompany 100 | err = orm.Find(&testCompany{}, func(s Session) error { 101 | s.Where("id > 2") 102 | return s.Find(&companies) 103 | }) 104 | assert.Nil(err) 105 | assert.Len(companies, 1) 106 | } 107 | 108 | func TestCount(t *testing.T) { 109 | assert := assert.New(t) 110 | wiz := testCreateWizard() 111 | orm := New(wiz) 112 | 113 | var count int64 114 | var err error 115 | 116 | // user A 117 | count, err = orm.Count(&testUser{ID: 1}, func(s Session) (int64, error) { 118 | s.Where("id > 1") 119 | return s.Count(&testUser{}) 120 | }) 121 | assert.Nil(err) 122 | assert.EqualValues(2, count) 123 | 124 | // user B 125 | count, err = orm.Count(&testUser{ID: 501}, func(s Session) (int64, error) { 126 | s.Where("id > 1") 127 | return s.Count(&testUser{}) 128 | }) 129 | assert.Nil(err) 130 | assert.EqualValues(3, count) 131 | 132 | count, err = orm.Count(&testFoobar{}, func(s Session) (int64, error) { 133 | return s.Count(&testFoobar{ID: 1}) 134 | }) 135 | assert.Nil(err) 136 | assert.EqualValues(1, count) 137 | 138 | count, err = orm.Count(&testCompany{}, func(s Session) (int64, error) { 139 | s.Where("id > 2") 140 | return s.Count(&testCompany{}) 141 | }) 142 | assert.Nil(err) 143 | assert.EqualValues(1, count) 144 | } 145 | 146 | func TestInsert(t *testing.T) { 147 | assert := assert.New(t) 148 | wiz := testCreateWizard() 149 | orm := New(wiz) 150 | 151 | var row interface{} 152 | var affected int64 153 | var err error 154 | var success int64 = 1 155 | fn := func(s Session) (int64, error) { 156 | return s.Insert(row) 157 | } 158 | getFn := func(s Session) (bool, error) { 159 | return s.Get(row) 160 | } 161 | countFn := func(table, obj interface{}) int64 { 162 | count, _ := orm.Count(table, func(s Session) (int64, error) { 163 | return s.Count(obj) 164 | }) 165 | return count 166 | } 167 | 168 | testWaitForIO() 169 | 170 | // user A 171 | assert.EqualValues(3, countFn(testUser{ID: 1}, &testUser{})) 172 | row = &testUser{ID: 1000, Name: "Daniel"} 173 | affected, err = orm.Insert(testID, row, fn) 174 | assert.Nil(err) 175 | assert.Equal(success, affected) 176 | assert.EqualValues(4, countFn(testUser{ID: 1}, &testUser{})) 177 | 178 | row = &testUser{ID: 1000} 179 | orm.Get(row, getFn) 180 | assert.Equal("Daniel", row.(*testUser).Name) 181 | 182 | testWaitForIO() 183 | 184 | // user B 185 | assert.EqualValues(3, countFn(testUser{ID: 500}, &testUser{})) 186 | row = &testUser{ID: 1500, Name: "Dorothy"} 187 | affected, err = orm.Insert(testID, row, fn) 188 | assert.Nil(err) 189 | assert.Equal(success, affected) 190 | assert.EqualValues(4, countFn(testUser{ID: 500}, &testUser{})) 191 | 192 | row = &testUser{ID: 1500} 193 | orm.Get(row, getFn) 194 | assert.Equal("Dorothy", row.(*testUser).Name) 195 | 196 | testWaitForIO() 197 | 198 | // foobar 199 | assert.EqualValues(3, countFn(testFoobar{}, &testFoobar{})) 200 | row = &testFoobar{ID: 4, Name: "foobar#4"} 201 | affected, err = orm.Insert(testID, row, fn) 202 | assert.Nil(err) 203 | assert.Equal(success, affected) 204 | assert.EqualValues(4, countFn(testFoobar{}, &testFoobar{})) 205 | 206 | row = &testFoobar{ID: 4} 207 | orm.Get(row, getFn) 208 | assert.Equal("foobar#4", row.(*testFoobar).Name) 209 | 210 | testWaitForIO() 211 | 212 | // other 213 | assert.EqualValues(3, countFn(testCompany{}, &testCompany{})) 214 | row = &testCompany{ID: 4, Name: "Delta Air Lines"} 215 | affected, err = orm.Insert(testID, row, fn) 216 | assert.Nil(err) 217 | assert.Equal(success, affected) 218 | assert.EqualValues(4, countFn(testCompany{}, &testCompany{})) 219 | 220 | row = &testCompany{ID: 4} 221 | orm.Get(row, getFn) 222 | assert.Equal("Delta Air Lines", row.(*testCompany).Name) 223 | 224 | testWaitForIO() 225 | 226 | // multiple rows 227 | assert.EqualValues(4, countFn(testCompany{}, &testCompany{})) 228 | rows := []*testCompany{ 229 | {ID: 5, Name: "eureka"}, 230 | {ID: 6, Name: "Facebook"}, 231 | {ID: 7, Name: "Google"}, 232 | } 233 | affected, err = orm.Insert(testID, testCompany{}, func(s Session) (int64, error) { 234 | return s.Insert(&rows) 235 | }) 236 | assert.Nil(err) 237 | assert.Equal(int64(3), affected) 238 | assert.EqualValues(7, countFn(testCompany{}, &testCompany{})) 239 | 240 | testWaitForIO() 241 | 242 | // readonly 243 | orm.ReadOnly(testID, true) 244 | affected, err = orm.Insert(testID, row, fn) 245 | assert.Nil(err) 246 | assert.EqualValues(0, affected) 247 | 248 | initTestDB() 249 | } 250 | 251 | func TestUpdate(t *testing.T) { 252 | assert := assert.New(t) 253 | wiz := testCreateWizard() 254 | orm := New(wiz) 255 | 256 | var row interface{} 257 | var affected int64 258 | var err error 259 | var success int64 = 1 260 | getFn := func(s Session) (bool, error) { 261 | return s.Get(row) 262 | } 263 | 264 | // user A 265 | var user *testUser 266 | user = &testUser{ID: 1, Name: "Akira"} 267 | affected, err = orm.Update(testID, user, func(s Session) (int64, error) { 268 | s.Where("id = ?", user.ID) 269 | return s.Update(user) 270 | }) 271 | assert.Nil(err) 272 | assert.Equal(success, affected) 273 | 274 | row = &testUser{ID: 1} 275 | orm.Get(row, getFn) 276 | assert.Equal("Akira", row.(*testUser).Name) 277 | 278 | // // user B 279 | user = &testUser{ID: 501, Name: "Aiko"} 280 | affected, err = orm.Update(testID, user, func(s Session) (int64, error) { 281 | s.Where("id = ?", user.ID) 282 | return s.Update(user) 283 | }) 284 | assert.Nil(err) 285 | assert.Equal(success, affected) 286 | 287 | row = &testUser{ID: 501} 288 | orm.Get(row, getFn) 289 | assert.Equal("Aiko", row.(*testUser).Name) 290 | 291 | // foobar 292 | var foobar *testFoobar 293 | foobar = &testFoobar{ID: 1, Name: "foobar#1b"} 294 | affected, err = orm.Update(testID, foobar, func(s Session) (int64, error) { 295 | s.Where("id = ?", foobar.ID) 296 | return s.Update(foobar) 297 | }) 298 | assert.Nil(err) 299 | assert.Equal(success, affected) 300 | 301 | row = &testFoobar{ID: 1} 302 | orm.Get(row, getFn) 303 | assert.Equal("foobar#1b", row.(*testFoobar).Name) 304 | 305 | // other 306 | var company *testCompany 307 | company = &testCompany{ID: 1, Name: "Alibaba"} 308 | affected, err = orm.Update(testID, company, func(s Session) (int64, error) { 309 | s.Where("id = ?", company.ID) 310 | return s.Update(company) 311 | }) 312 | assert.Nil(err) 313 | assert.Equal(success, affected) 314 | 315 | row = &testCompany{ID: 1} 316 | orm.Get(row, getFn) 317 | assert.Equal("Alibaba", row.(*testCompany).Name) 318 | 319 | // multiple rows 320 | foobar = &testFoobar{Name: "foobar#XXX"} 321 | affected, err = orm.Update(testID, foobar, func(s Session) (int64, error) { 322 | return s.Update(foobar) 323 | }) 324 | assert.Nil(err) 325 | assert.Equal(int64(3), affected) 326 | 327 | row = &testFoobar{ID: 1} 328 | orm.Get(row, getFn) 329 | assert.Equal("foobar#XXX", row.(*testFoobar).Name) 330 | row = &testFoobar{ID: 2} 331 | orm.Get(row, getFn) 332 | assert.Equal("foobar#XXX", row.(*testFoobar).Name) 333 | row = &testFoobar{ID: 3} 334 | orm.Get(row, getFn) 335 | assert.Equal("foobar#XXX", row.(*testFoobar).Name) 336 | 337 | // readonly 338 | orm.ReadOnly(testID, true) 339 | affected, err = orm.Update(testID, foobar, func(s Session) (int64, error) { 340 | return s.Update(foobar) 341 | }) 342 | assert.Nil(err) 343 | assert.EqualValues(0, affected) 344 | 345 | initTestDB() 346 | } 347 | 348 | func TestGetUsingMaster(t *testing.T) { 349 | assert := assert.New(t) 350 | wiz := testCreateWizard() 351 | orm := New(wiz) 352 | 353 | var row interface{} 354 | var has bool 355 | var err error 356 | fn := func(s Session) (bool, error) { 357 | return s.Get(row) 358 | } 359 | 360 | orm.SetAutoTransaction(testID, true) 361 | 362 | row = &testUser{ID: 4} 363 | s, _ := orm.Transaction(testID, row) 364 | s.Insert(row) 365 | 366 | // slave 367 | has, err = orm.Get(row, fn) 368 | assert.Nil(err) 369 | assert.False(has) 370 | 371 | // master 372 | has, err = orm.GetUsingMaster(testID, row, fn) 373 | assert.Nil(err) 374 | assert.True(has) 375 | 376 | s.Rollback() 377 | } 378 | 379 | func TestFindUsingMaster(t *testing.T) { 380 | assert := assert.New(t) 381 | wiz := testCreateWizard() 382 | orm := New(wiz) 383 | 384 | var foobars []testFoobar 385 | var err error 386 | fn := func(s Session) error { 387 | return s.Find(&foobars) 388 | } 389 | 390 | orm.SetAutoTransaction(testID, true) 391 | 392 | row := &testFoobar{ID: 4, Name: "foobar#4@FindUsingMaster"} 393 | s, _ := orm.Transaction(testID, row) 394 | s.Insert(row) 395 | 396 | // slave 397 | err = orm.Find(testFoobar{}, fn) 398 | assert.Nil(err) 399 | assert.Len(foobars, 3) 400 | 401 | // master 402 | foobars = foobars[:0] 403 | err = orm.FindUsingMaster(testID, testFoobar{}, fn) 404 | assert.Nil(err) 405 | assert.Len(foobars, 4) 406 | 407 | s.Rollback() 408 | } 409 | 410 | func TestCountUsingMaster(t *testing.T) { 411 | assert := assert.New(t) 412 | wiz := testCreateWizard() 413 | orm := New(wiz) 414 | 415 | var foobar = testFoobar{} 416 | var count int64 417 | var err error 418 | fn := func(s Session) (int64, error) { 419 | return s.Count(&foobar) 420 | } 421 | 422 | orm.SetAutoTransaction(testID, true) 423 | 424 | row := &testFoobar{ID: 4, Name: "foobar#4@CountUsingMaster"} 425 | s, _ := orm.Transaction(testID, row) 426 | s.Insert(row) 427 | 428 | count, err = orm.Count(foobar, fn) 429 | assert.Nil(err) 430 | assert.EqualValues(3, count) 431 | 432 | count, err = orm.CountUsingMaster(testID, foobar, fn) 433 | assert.Nil(err) 434 | assert.EqualValues(4, count) 435 | 436 | s.Rollback() 437 | } 438 | 439 | func TestFunctionNilDB(t *testing.T) { 440 | assert := assert.New(t) 441 | orm := New(emptyWiz) 442 | 443 | // var nilFoobars []*testFoobar 444 | var has bool 445 | var count, affected int64 446 | var err error 447 | 448 | fnGet := func(s Session) (bool, error) { return true, nil } 449 | fnFind := func(s Session) error { return nil } 450 | fnCount := func(s Session) (int64, error) { return 99, nil } 451 | 452 | // Get 453 | has, err = orm.Get(testFoobar{}, fnGet) 454 | assert.NotNil(err) 455 | assert.False(has) 456 | 457 | // Find 458 | err = orm.Find(testFoobar{}, fnFind) 459 | assert.NotNil(err) 460 | 461 | // Count 462 | count, err = orm.Count(testFoobar{}, fnCount) 463 | assert.NotNil(err) 464 | assert.EqualValues(0, count) 465 | 466 | // GetUsingMaster 467 | has, err = orm.GetUsingMaster(testID, testFoobar{}, fnGet) 468 | assert.NotNil(err) 469 | assert.False(has) 470 | 471 | // FindUsingMaster 472 | err = orm.FindUsingMaster(testID, testFoobar{}, fnFind) 473 | assert.NotNil(err) 474 | 475 | // CountUsingMaster 476 | count, err = orm.CountUsingMaster(testID, testFoobar{}, fnCount) 477 | assert.NotNil(err) 478 | assert.EqualValues(0, count) 479 | 480 | // Insert 481 | affected, err = orm.Insert(testID, testFoobar{}, fnCount) 482 | assert.NotNil(err) 483 | assert.EqualValues(0, affected) 484 | 485 | // nil db 486 | affected, err = orm.Update(testID, testFoobar{}, fnCount) 487 | assert.NotNil(err) 488 | assert.EqualValues(0, affected) 489 | } 490 | -------------------------------------------------------------------------------- /orm/xorm/xorm_parallel.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | 7 | "github.com/evalphobia/wizard/errors" 8 | ) 9 | 10 | // XormParallel supports concurrent query 11 | type XormParallel struct { 12 | orm *Xorm 13 | } 14 | 15 | // FindParallel executes SELECT query to all of the shards 16 | func (xpr *XormParallel) FindParallel(listPtr interface{}, table interface{}, where string, args ...interface{}) error { 17 | cond := NewFindCondition(table) 18 | cond.And(where, args...) 19 | return xpr.FindParallelByCondition(listPtr, cond) 20 | } 21 | 22 | // FindParallelByCondition executes SELECT query to all of the shards with conditions 23 | func (xpr *XormParallel) FindParallelByCondition(listPtr interface{}, cond FindCondition) error { 24 | vt := reflect.TypeOf(listPtr) 25 | if vt.Kind() != reflect.Ptr { 26 | return errors.NewErrArgType("listPtr must be a pointer") 27 | } 28 | elem := vt.Elem() 29 | if elem.Kind() != reflect.Slice && elem.Kind() != reflect.Map { 30 | return errors.NewErrArgType("listPtr must be a pointer of slice or map") 31 | } 32 | 33 | // create session with the condition 34 | sessions := xpr.CreateFindSessions(cond) 35 | length := len(sessions) 36 | 37 | // execute query 38 | var errList []error 39 | results := make(chan reflect.Value, length) 40 | for _, s := range sessions { 41 | list := reflect.New(elem) 42 | go func(s Session, list reflect.Value) { 43 | defer s.Close() 44 | err := s.Find(list.Interface()) 45 | if err != nil { 46 | errList = append(errList, err) 47 | } 48 | results <- list 49 | }(s, list) 50 | } 51 | 52 | // wait for the results 53 | e := reflect.ValueOf(listPtr).Elem() 54 | for i := 0; i < length; i++ { 55 | v := <-results 56 | e.Set(reflect.AppendSlice(e, v.Elem())) 57 | } 58 | if len(errList) > 0 { 59 | return errors.NewErrParallelQuery(errList) 60 | } 61 | 62 | return nil 63 | } 64 | 65 | // CountParallelByCondition executes SELECT COUNT(*) query to all of the shards with conditions 66 | func (xpr *XormParallel) CountParallelByCondition(objPtr interface{}, cond FindCondition) ([]int64, error) { 67 | vt := reflect.TypeOf(objPtr) 68 | if vt.Kind() != reflect.Ptr { 69 | return nil, errors.NewErrArgType("objPtr must be a pointer") 70 | } 71 | 72 | // create session with the condition 73 | sessions := xpr.CreateFindSessions(cond) 74 | length := len(sessions) 75 | 76 | // execute query 77 | var errList []error 78 | results := make(chan int64, length) 79 | for _, s := range sessions { 80 | go func(s Session) { 81 | defer s.Close() 82 | count, err := s.Count(objPtr) 83 | if err != nil { 84 | errList = append(errList, err) 85 | } 86 | results <- count 87 | }(s) 88 | } 89 | 90 | // wait for the results 91 | var counts []int64 92 | for i := 0; i < length; i++ { 93 | v := <-results 94 | counts = append(counts, v) 95 | } 96 | if len(errList) > 0 { 97 | return counts, errors.NewErrParallelQuery(errList) 98 | } 99 | 100 | return counts, nil 101 | } 102 | 103 | // CreateFindSessions creates new sessions with conditional clause 104 | func (xpr *XormParallel) CreateFindSessions(cond FindCondition) []Session { 105 | var sessions []Session 106 | slaves := xpr.orm.Slaves(cond.Table) 107 | 108 | for _, slave := range slaves { 109 | s := slave.NewSession() 110 | if len(cond.Columns) != 0 { 111 | s.Cols(cond.Columns...) 112 | } 113 | if cond.Selects != "" { 114 | s.Select(cond.Selects) 115 | } 116 | 117 | for _, w := range cond.Where { 118 | s.And(w.Statement, w.Args...) 119 | } 120 | for _, in := range cond.WhereIn { 121 | s.In(in.Statement, in.Args...) 122 | } 123 | if len(cond.Group) > 0 { 124 | s.GroupBy(strings.Join(cond.Group, ", ")) 125 | } 126 | if len(cond.Havings) > 0 { 127 | s.Having(strings.Join(cond.Havings, " AND ")) 128 | } 129 | for _, o := range cond.OrderBy { 130 | if o.OrderByDesc { 131 | s.Desc(o.Name) 132 | } else { 133 | s.Asc(o.Name) 134 | } 135 | } 136 | if cond.Limit > 0 { 137 | s.Limit(cond.Limit, cond.Offset) 138 | } 139 | sessions = append(sessions, s) 140 | } 141 | return sessions 142 | } 143 | 144 | // UpdateParallelByCondition executes UPDATE query to all of the shards with conditions 145 | func (xpr *XormParallel) UpdateParallelByCondition(objPtr interface{}, cond UpdateCondition) (int64, error) { 146 | // create session with the condition 147 | sessions := xpr.CreateUpdateSessions(cond) 148 | length := len(sessions) 149 | 150 | // execute query 151 | var errList []error 152 | results := make(chan int64, length) 153 | for _, s := range sessions { 154 | go func(s Session, obj interface{}) { 155 | defer s.Close() 156 | count, err := s.Update(obj) 157 | if err != nil { 158 | errList = append(errList, err) 159 | } 160 | results <- count 161 | }(s, objPtr) 162 | } 163 | 164 | // wait for the results 165 | var counts int64 166 | for i := 0; i < length; i++ { 167 | v := <-results 168 | counts += v 169 | } 170 | if len(errList) > 0 { 171 | return counts, errors.NewErrParallelQuery(errList) 172 | } 173 | 174 | return counts, nil 175 | } 176 | 177 | // CreateUpdateSessions creates new sessions with conditional clause for UPDATE query 178 | func (xpr *XormParallel) CreateUpdateSessions(cond UpdateCondition) []Session { 179 | var sessions []Session 180 | masters := xpr.orm.Masters(cond.Table) 181 | for _, master := range masters { 182 | s := master.NewSession() 183 | for _, w := range cond.Where { 184 | s.And(w.Statement, w.Args...) 185 | } 186 | for _, in := range cond.WhereIn { 187 | s.In(in.Statement, in.Args...) 188 | } 189 | 190 | if cond.AllColumns { 191 | s.AllCols() 192 | } 193 | for _, col := range cond.Columns { 194 | s.Cols(col) 195 | } 196 | for _, col := range cond.MustColumns { 197 | s.MustCols(col) 198 | } 199 | for _, col := range cond.OmitColumns { 200 | s.Omit(col) 201 | } 202 | for _, col := range cond.NullableColumns { 203 | s.Nullable(col) 204 | } 205 | 206 | for _, exp := range cond.Increments { 207 | s.Incr(exp.Statement, exp.Args...) 208 | } 209 | for _, exp := range cond.Decrements { 210 | s.Decr(exp.Statement, exp.Args...) 211 | } 212 | 213 | sessions = append(sessions, s) 214 | } 215 | return sessions 216 | } 217 | -------------------------------------------------------------------------------- /orm/xorm/xorm_parallel_test.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestFindParallel(t *testing.T) { 10 | assert := assert.New(t) 11 | wiz := testCreateWizard() 12 | orm := New(wiz) 13 | 14 | var err error 15 | var list []testUser 16 | 17 | err = orm.FindParallel(&list, testUser{}, "id > ? ", 1) 18 | assert.Nil(err) 19 | assert.Len(list, 5) 20 | assert.Contains(list, testUser{ID: 2, Name: "Benjamin"}) 21 | assert.Contains(list, testUser{ID: 3, Name: "Charles"}) 22 | assert.Contains(list, testUser{ID: 500, Name: "Alice"}) 23 | assert.Contains(list, testUser{ID: 501, Name: "Betty"}) 24 | assert.Contains(list, testUser{ID: 502, Name: "Christina"}) 25 | } 26 | 27 | func TestFindParallelByCondition(t *testing.T) { 28 | assert := assert.New(t) 29 | wiz := testCreateWizard() 30 | orm := New(wiz) 31 | 32 | var err error 33 | var list []testUser 34 | 35 | // order by asc 36 | cond := NewFindCondition(testUser{}) 37 | cond.And("id > ?", 1) 38 | cond.SetLimit(1) 39 | cond.OrderByAsc("id") 40 | 41 | err = orm.FindParallelByCondition(&list, cond) 42 | assert.Nil(err) 43 | assert.Len(list, 2) 44 | assert.Contains(list, testUser{ID: 2, Name: "Benjamin"}) 45 | assert.Contains(list, testUser{ID: 500, Name: "Alice"}) 46 | 47 | // order by desc limit 1 48 | list = []testUser{} 49 | cond = NewFindCondition(testUser{}) 50 | cond.And("id > ?", 1) 51 | cond.SetLimit(1) 52 | cond.OrderByDesc("id") 53 | 54 | err = orm.FindParallelByCondition(&list, cond) 55 | assert.Nil(err) 56 | assert.Len(list, 2) 57 | assert.Contains(list, testUser{ID: 3, Name: "Charles"}) 58 | assert.Contains(list, testUser{ID: 502, Name: "Christina"}) 59 | 60 | // order by desc limit 1 offset 1 61 | list = []testUser{} 62 | cond = NewFindCondition(testUser{}) 63 | cond.And("id > ?", 1) 64 | cond.SetLimit(1) 65 | cond.OrderByDesc("id") 66 | cond.SetOffset(1) 67 | err = orm.FindParallelByCondition(&list, cond) 68 | assert.Nil(err) 69 | assert.Len(list, 2) 70 | assert.Contains(list, testUser{ID: 2, Name: "Benjamin"}) 71 | assert.Contains(list, testUser{ID: 501, Name: "Betty"}) 72 | } 73 | 74 | func TestCountParallelByCondition(t *testing.T) { 75 | assert := assert.New(t) 76 | wiz := testCreateWizard() 77 | orm := New(wiz) 78 | 79 | var err error 80 | var testObj testUser 81 | 82 | // order by asc 83 | cond := NewFindCondition(testUser{}) 84 | cond.And("id > ?", 1) 85 | 86 | counts, err := orm.CountParallelByCondition(&testObj, cond) 87 | assert.Nil(err) 88 | assert.Len(counts, 2) 89 | assert.Contains(counts, int64(3)) 90 | assert.Contains(counts, int64(2)) 91 | } 92 | -------------------------------------------------------------------------------- /orm/xorm/xorm_session_list.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import "sync" 4 | 5 | // SessionList contains db sessions list for one group 6 | type SessionList struct { 7 | readOnly bool 8 | autoTx bool 9 | 10 | sessMu sync.RWMutex 11 | sessions map[interface{}]Session 12 | 13 | txMu sync.RWMutex 14 | transactions map[interface{}]Session 15 | } 16 | 17 | func newSessionList() *SessionList { 18 | return &SessionList{ 19 | sessions: make(map[interface{}]Session), 20 | transactions: make(map[interface{}]Session), 21 | } 22 | } 23 | 24 | func (l *SessionList) hasSession(db interface{}) bool { 25 | l.sessMu.RLock() 26 | defer l.sessMu.RUnlock() 27 | _, ok := l.sessions[db] 28 | return ok 29 | } 30 | 31 | func (l *SessionList) getSession(db interface{}) Session { 32 | l.sessMu.RLock() 33 | defer l.sessMu.RUnlock() 34 | return l.sessions[db] 35 | } 36 | 37 | func (l *SessionList) addSession(db interface{}, s Session) { 38 | l.sessMu.Lock() 39 | defer l.sessMu.Unlock() 40 | l.sessions[db] = s 41 | } 42 | 43 | func (l *SessionList) getSessions() map[interface{}]Session { 44 | return l.sessions 45 | } 46 | 47 | func (l *SessionList) clearSessions() { 48 | l.sessMu.Lock() 49 | defer l.sessMu.Unlock() 50 | l.sessions = make(map[interface{}]Session) 51 | } 52 | 53 | func (l *SessionList) getTransaction(db interface{}) Session { 54 | l.txMu.RLock() 55 | defer l.txMu.RUnlock() 56 | return l.transactions[db] 57 | } 58 | 59 | func (l *SessionList) addTransaction(db interface{}, s Session) { 60 | l.txMu.Lock() 61 | defer l.txMu.Unlock() 62 | l.transactions[db] = s 63 | } 64 | 65 | func (l *SessionList) getTransactions() map[interface{}]Session { 66 | return l.transactions 67 | } 68 | 69 | func (l *SessionList) clearTransactions() { 70 | l.txMu.Lock() 71 | defer l.txMu.Unlock() 72 | l.transactions = make(map[interface{}]Session) 73 | } 74 | 75 | // ReadOnly set write proof flag 76 | func (l *SessionList) ReadOnly(b bool) { 77 | l.readOnly = b 78 | } 79 | 80 | // IsReadOnly checks in write proof mode or not 81 | func (l *SessionList) IsReadOnly() bool { 82 | return l.readOnly 83 | } 84 | 85 | // SetAutoTransaction sets auto transaction flag 86 | func (l *SessionList) SetAutoTransaction(b bool) { 87 | l.autoTx = b 88 | } 89 | 90 | // IsAutoTransaction checks in auto transaction mode or not 91 | func (l *SessionList) IsAutoTransaction() bool { 92 | return l.autoTx 93 | } 94 | -------------------------------------------------------------------------------- /orm/xorm/xorm_session_manager.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/evalphobia/wizard/errors" 7 | ) 8 | 9 | // XormSessionManager manages database session list for xorm 10 | type XormSessionManager struct { 11 | orm *Xorm 12 | listMu sync.RWMutex 13 | list map[Identifier]*SessionList 14 | } 15 | 16 | // Identifier is unique object for using same sessions 17 | // e.g. *http.Request, context.Context, etc... 18 | type Identifier interface{} 19 | 20 | func newSession(db Engine, obj interface{}) (Session, error) { 21 | if db == nil { 22 | return nil, errors.NewErrNilDB(NormalizeValue(obj)) 23 | } 24 | return db.NewSession(), nil 25 | } 26 | 27 | // SetAutoTransaction sets auto transaction flag of the SessionList 28 | func (xse *XormSessionManager) SetAutoTransaction(id Identifier, b bool) { 29 | sl := xse.getOrCreateSessionList(id) 30 | sl.SetAutoTransaction(b) 31 | } 32 | 33 | // IsAutoTransaction checks auto transaction flag of the SessionList 34 | func (xse *XormSessionManager) IsAutoTransaction(id Identifier) bool { 35 | sl := xse.getOrCreateSessionList(id) 36 | return sl.IsAutoTransaction() 37 | } 38 | 39 | // ReadOnly changes readonly flag of the SessionList 40 | func (xse *XormSessionManager) ReadOnly(id Identifier, b bool) { 41 | sl := xse.getOrCreateSessionList(id) 42 | sl.ReadOnly(b) 43 | } 44 | 45 | // IsReadOnly returns readonly flag of the SesionList 46 | func (xse *XormSessionManager) IsReadOnly(id Identifier) bool { 47 | sl := xse.getOrCreateSessionList(id) 48 | return sl.IsReadOnly() 49 | } 50 | 51 | // NewMasterSession returns new master session for the db of given object 52 | func (xse *XormSessionManager) NewMasterSession(obj interface{}) (Session, error) { 53 | db := xse.orm.Master(obj) 54 | if db == nil { 55 | return nil, errors.NewErrNilDB(NormalizeValue(obj)) 56 | } 57 | 58 | return db.NewSession(), nil 59 | } 60 | 61 | // UseMasterSession returns new master session for the db of given object 62 | func (xse *XormSessionManager) UseMasterSession(id Identifier, obj interface{}) (Session, error) { 63 | db := xse.orm.Master(obj) 64 | sl := xse.getOrCreateSessionList(id) 65 | if sl.IsAutoTransaction() { 66 | return xse.transaction(id, obj, db) 67 | } 68 | return xse.session(id, obj, db) 69 | } 70 | 71 | // UseMasterSessionByKey returns new master session by shard key 72 | func (xse *XormSessionManager) UseMasterSessionByKey(id Identifier, obj interface{}, key interface{}) (Session, error) { 73 | db := xse.orm.MasterByKey(obj, key) 74 | sl := xse.getOrCreateSessionList(id) 75 | if sl.IsAutoTransaction() { 76 | return xse.transaction(id, obj, db) 77 | } 78 | return xse.session(id, obj, db) 79 | } 80 | 81 | // UseAllMasterSessions returns all of master sessions for the db of given object 82 | func (xse *XormSessionManager) UseAllMasterSessions(id Identifier, obj interface{}) ([]Session, error) { 83 | dbs := xse.orm.Masters(obj) 84 | 85 | var sessions []Session 86 | var errList []error 87 | for _, db := range dbs { 88 | var s Session 89 | var err error 90 | 91 | switch { 92 | // case xse.orm.IsAutoTransaction(): 93 | // s, err = xse.orm.transaction(obj, db) 94 | default: 95 | s, err = xse.session(id, obj, db) 96 | } 97 | 98 | if err != nil { 99 | errList = append(errList, err) 100 | continue 101 | } 102 | sessions = append(sessions, s) 103 | } 104 | 105 | if len(errList) > 0 { 106 | return sessions, errors.NewErrNilDBs(errList) 107 | } 108 | return sessions, nil 109 | } 110 | 111 | // UseSlaveSession returns new slave session for the slave db of given object 112 | func (xse *XormSessionManager) UseSlaveSession(id Identifier, obj interface{}) (Session, error) { 113 | db := xse.orm.Slave(obj) 114 | return xse.session(id, obj, db) 115 | } 116 | 117 | // UseSlaveSessionByKey returns new slave session by shard key 118 | func (xse *XormSessionManager) UseSlaveSessionByKey(id Identifier, obj interface{}, key interface{}) (Session, error) { 119 | db := xse.orm.SlaveByKey(obj, key) 120 | return xse.session(id, obj, db) 121 | } 122 | 123 | // session returns the session for the db of given object 124 | // if old session exists for the object, return it, 125 | // if no session exists for the object, create new one and return it 126 | func (xse *XormSessionManager) session(id Identifier, obj interface{}, db Engine) (Session, error) { 127 | if db == nil { 128 | return nil, errors.NewErrNilDB(NormalizeValue(obj)) 129 | } 130 | // use old session 131 | 132 | s := xse.getSessionFromList(id, db) 133 | if s != nil { 134 | return s, nil 135 | } 136 | 137 | // create new session 138 | s = db.NewSession() 139 | xse.addSessionIntoList(id, db, s) 140 | return s, nil 141 | } 142 | 143 | // getSessionFromList returns the session for the db 144 | func (xse *XormSessionManager) getSessionFromList(id Identifier, db interface{}) Session { 145 | if !xse.hasSessionList(id) { 146 | return nil 147 | } 148 | 149 | sl := xse.getOrCreateSessionList(id) 150 | return sl.getSession(db) 151 | } 152 | 153 | // addSessionIntoList saves the session for the db 154 | func (xse *XormSessionManager) addSessionIntoList(id Identifier, db interface{}, s Session) { 155 | sl := xse.getOrCreateSessionList(id) 156 | sl.addSession(db, s) 157 | } 158 | 159 | // CloseAll closes all of sessions and engines 160 | func (xse *XormSessionManager) CloseAll(id Identifier) { 161 | sl := xse.getOrCreateSessionList(id) 162 | 163 | for _, s := range sl.getSessions() { 164 | s.Close() 165 | } 166 | for _, s := range sl.getTransactions() { 167 | s.Close() 168 | } 169 | sl.clearSessions() 170 | sl.clearTransactions() 171 | 172 | xse.listMu.Lock() 173 | defer xse.listMu.Unlock() 174 | delete(xse.list, id) 175 | } 176 | 177 | func (xse *XormSessionManager) newSessionList(id Identifier) *SessionList { 178 | xse.listMu.Lock() 179 | defer xse.listMu.Unlock() 180 | 181 | if xse.list == nil { 182 | xse.list = make(map[Identifier]*SessionList) 183 | } 184 | xse.list[id] = newSessionList() 185 | return xse.list[id] 186 | } 187 | 188 | func (xse *XormSessionManager) hasSessionList(id Identifier) bool { 189 | xse.listMu.RLock() 190 | defer xse.listMu.RUnlock() 191 | 192 | _, ok := xse.list[id] 193 | return ok 194 | } 195 | 196 | func (xse *XormSessionManager) getOrCreateSessionList(id Identifier) *SessionList { 197 | if !xse.hasSessionList(id) { 198 | xse.newSessionList(id) 199 | } 200 | 201 | xse.listMu.RLock() 202 | defer xse.listMu.RUnlock() 203 | return xse.list[id] 204 | } 205 | -------------------------------------------------------------------------------- /orm/xorm/xorm_session_manager_test.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | var testID = "my unique identifier" 10 | 11 | func TestNewMasterSession(t *testing.T) { 12 | assert := assert.New(t) 13 | wiz := testCreateWizard() 14 | orm := New(wiz) 15 | 16 | var row interface{} 17 | var s Session 18 | var has bool 19 | var err error 20 | 21 | // A 22 | row = &testUser{ID: 2} 23 | s, err = orm.NewMasterSession(row) 24 | assert.Nil(err) 25 | assert.NotNil(s) 26 | 27 | has, err = s.Get(row) 28 | assert.Nil(err) 29 | assert.True(has) 30 | 31 | // B 32 | row = &testUser{ID: 500} 33 | s, err = orm.NewMasterSession(row) 34 | assert.Nil(err) 35 | assert.NotNil(s) 36 | 37 | has, err = s.Get(row) 38 | assert.Nil(err) 39 | assert.True(has) 40 | } 41 | 42 | func TestReadOnly(t *testing.T) { 43 | assert := assert.New(t) 44 | wiz := testCreateWizard() 45 | orm := New(wiz) 46 | sl := orm.XormSessionManager.getOrCreateSessionList(testID) 47 | 48 | assert.False(sl.readOnly) 49 | orm.ReadOnly(testID, true) 50 | assert.True(sl.readOnly) 51 | orm.ReadOnly(testID, false) 52 | assert.False(sl.readOnly) 53 | } 54 | 55 | func TestIsReadOnly(t *testing.T) { 56 | assert := assert.New(t) 57 | wiz := testCreateWizard() 58 | orm := New(wiz) 59 | 60 | assert.False(orm.IsReadOnly(testID)) 61 | orm.ReadOnly(testID, true) 62 | assert.True(orm.IsReadOnly(testID)) 63 | orm.ReadOnly(testID, false) 64 | assert.False(orm.IsReadOnly(testID)) 65 | } 66 | 67 | func TestSetAutoTransaction(t *testing.T) { 68 | assert := assert.New(t) 69 | wiz := testCreateWizard() 70 | orm := New(wiz) 71 | sl := orm.XormSessionManager.getOrCreateSessionList(testID) 72 | 73 | assert.False(sl.autoTx) 74 | orm.SetAutoTransaction(testID, true) 75 | assert.True(sl.autoTx) 76 | orm.SetAutoTransaction(testID, false) 77 | assert.False(sl.autoTx) 78 | } 79 | 80 | func TestSetIsAutoTransaction(t *testing.T) { 81 | assert := assert.New(t) 82 | wiz := testCreateWizard() 83 | orm := New(wiz) 84 | 85 | assert.False(orm.IsAutoTransaction(testID)) 86 | orm.SetAutoTransaction(testID, true) 87 | assert.True(orm.IsAutoTransaction(testID)) 88 | orm.SetAutoTransaction(testID, false) 89 | assert.False(orm.IsAutoTransaction(testID)) 90 | } 91 | 92 | func TestUseMasterSession(t *testing.T) { 93 | assert := assert.New(t) 94 | wiz := testCreateWizard() 95 | orm := New(wiz) 96 | 97 | var row interface{} 98 | var s Session 99 | var has bool 100 | var err error 101 | 102 | // A 103 | row = &testUser{ID: 2} 104 | s, err = orm.UseMasterSession(testID, row) 105 | assert.Nil(err) 106 | assert.NotNil(s) 107 | 108 | has, err = s.Get(row) 109 | assert.Nil(err) 110 | assert.True(has) 111 | 112 | // B 113 | row = &testUser{ID: 500} 114 | s, err = orm.UseMasterSession(testID, row) 115 | assert.Nil(err) 116 | assert.NotNil(s) 117 | 118 | has, err = s.Get(row) 119 | assert.Nil(err) 120 | assert.True(has) 121 | 122 | // auto tx 123 | orm.SetAutoTransaction(testID, true) 124 | s, err = orm.UseMasterSession(testID, testUser{ID: 1}) 125 | s.Insert(testUser{ID: 4}) 126 | count, _ := s.Count(testUser{}) 127 | assert.EqualValues(4, count) 128 | 129 | s.Rollback() 130 | s.Init() 131 | count, _ = s.Count(testUser{}) 132 | assert.EqualValues(3, count) 133 | } 134 | 135 | func TestUseMasterSessionByKey(t *testing.T) { 136 | assert := assert.New(t) 137 | wiz := testCreateWizard() 138 | orm := New(wiz) 139 | 140 | var row *testUser 141 | var s Session 142 | var has bool 143 | var err error 144 | 145 | // A 146 | row = &testUser{} 147 | s, err = orm.UseMasterSessionByKey(testID, row, 1) 148 | assert.Nil(err) 149 | assert.NotNil(s) 150 | 151 | row.ID = 2 152 | has, err = s.Get(row) 153 | assert.Nil(err) 154 | assert.True(has) 155 | 156 | // B 157 | row = &testUser{} 158 | s, err = orm.UseMasterSessionByKey(testID, row, 900) 159 | assert.Nil(err) 160 | assert.NotNil(s) 161 | 162 | row.ID = 500 163 | has, err = s.Get(row) 164 | assert.Nil(err) 165 | assert.True(has) 166 | 167 | // auto tx 168 | orm.SetAutoTransaction(testID, true) 169 | s, err = orm.UseMasterSessionByKey(testID, testUser{}, 1) 170 | s.Insert(testUser{ID: 4}) 171 | count, _ := s.Count(testUser{}) 172 | assert.EqualValues(4, count) 173 | 174 | s.Rollback() 175 | s.Init() 176 | count, _ = s.Count(testUser{}) 177 | assert.EqualValues(3, count) 178 | } 179 | 180 | func TestUseSlaveSession(t *testing.T) { 181 | assert := assert.New(t) 182 | wiz := testCreateWizard() 183 | orm := New(wiz) 184 | 185 | var row interface{} 186 | var s Session 187 | var has bool 188 | var err error 189 | 190 | // A 191 | row = &testUser{ID: 2} 192 | s, err = orm.UseSlaveSession(testID, row) 193 | assert.Nil(err) 194 | assert.NotNil(s) 195 | 196 | has, err = s.Get(row) 197 | assert.Nil(err) 198 | assert.True(has) 199 | 200 | // B 201 | row = &testUser{ID: 500} 202 | s, err = orm.UseSlaveSession(testID, row) 203 | assert.Nil(err) 204 | assert.NotNil(s) 205 | 206 | has, err = s.Get(row) 207 | assert.Nil(err) 208 | assert.True(has) 209 | } 210 | 211 | func TestUseSlaveSessionByKey(t *testing.T) { 212 | assert := assert.New(t) 213 | wiz := testCreateWizard() 214 | orm := New(wiz) 215 | 216 | var row *testUser 217 | var s Session 218 | var has bool 219 | var err error 220 | 221 | // A 222 | row = &testUser{} 223 | s, err = orm.UseSlaveSessionByKey(testID, row, 1) 224 | assert.Nil(err) 225 | assert.NotNil(s) 226 | 227 | row.ID = 2 228 | has, err = s.Get(row) 229 | assert.Nil(err) 230 | assert.True(has) 231 | 232 | // B 233 | row = &testUser{} 234 | s, err = orm.UseSlaveSessionByKey(testID, row, 900) 235 | assert.Nil(err) 236 | assert.NotNil(s) 237 | 238 | row.ID = 500 239 | has, err = s.Get(row) 240 | assert.Nil(err) 241 | assert.True(has) 242 | } 243 | -------------------------------------------------------------------------------- /orm/xorm/xorm_session_manager_tx.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "github.com/evalphobia/wizard/errors" 5 | ) 6 | 7 | // ForceNewTransaction returns the session with new transaction 8 | func (xse *XormSessionManager) ForceNewTransaction(obj interface{}) (Session, error) { 9 | db := xse.orm.Master(obj) 10 | s, err := newSession(db, obj) 11 | if err != nil { 12 | return nil, err 13 | } 14 | err = s.Begin() 15 | if err != nil { 16 | return nil, err 17 | } 18 | return s, nil 19 | } 20 | 21 | // Transaction returns the session with transaction for the db of given object 22 | func (xse *XormSessionManager) Transaction(id Identifier, obj interface{}) (Session, error) { 23 | db := xse.orm.Master(obj) 24 | return xse.transaction(id, obj, db) 25 | } 26 | 27 | // TransactionByKey returns the session with transaction by shard key 28 | func (xse *XormSessionManager) TransactionByKey(id Identifier, obj interface{}, key interface{}) (Session, error) { 29 | db := xse.orm.MasterByKey(obj, key) 30 | return xse.transaction(id, obj, db) 31 | } 32 | 33 | // transaction returns the session with transaction for the db of given object 34 | // if old transaction exists for the object, return it, 35 | // if no transaction exists for the object, create new one and return it 36 | func (xse *XormSessionManager) transaction(id Identifier, obj interface{}, db Engine) (Session, error) { 37 | if db == nil { 38 | return nil, errors.NewErrNilDB(NormalizeValue(obj)) 39 | } 40 | // use old transaction 41 | s := xse.getTransactionFromList(id, db) 42 | if s != nil { 43 | return s, nil 44 | } 45 | 46 | // create new transaction 47 | s = db.NewSession() 48 | err := s.Begin() 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | // save created session with transaction 54 | xse.addTransactionIntoList(id, db, s) 55 | return s, nil 56 | } 57 | 58 | // getTransactionFromList returns the session with transaction for the db 59 | func (xse *XormSessionManager) getTransactionFromList(id Identifier, db interface{}) Session { 60 | if !xse.hasSessionList(id) { 61 | return nil 62 | } 63 | sl := xse.getOrCreateSessionList(id) 64 | return sl.getTransaction(db) 65 | } 66 | 67 | // addTransactionIntoList saves the session with transaction for the db 68 | func (xse *XormSessionManager) addTransactionIntoList(id Identifier, db interface{}, s Session) { 69 | sl := xse.getOrCreateSessionList(id) 70 | sl.addTransaction(db, s) 71 | } 72 | 73 | // AutoTransaction starts transaction for the session and store it 74 | // if not in the AutoTransaction mode, nothing happens 75 | // if old transaction exists, return it 76 | func (xse *XormSessionManager) AutoTransaction(id Identifier, obj interface{}, s Session) error { 77 | sl := xse.getOrCreateSessionList(id) 78 | if !sl.IsAutoTransaction() { 79 | return nil 80 | } 81 | db := xse.orm.Master(obj) 82 | oldTx := sl.getTransaction(db) 83 | switch { 84 | case oldTx == s: 85 | return nil 86 | case oldTx != nil: 87 | return errors.NewErrAnotherTx(NormalizeValue(obj)) 88 | } 89 | 90 | err := s.Begin() 91 | if err != nil { 92 | return err 93 | } 94 | 95 | sl.addTransaction(db, s) 96 | return nil 97 | } 98 | 99 | // CommitAll commits all of transactions 100 | func (xse *XormSessionManager) CommitAll(id Identifier) error { 101 | if !xse.hasSessionList(id) { 102 | return nil 103 | } 104 | 105 | sl := xse.getOrCreateSessionList(id) 106 | switch { 107 | case sl == nil: 108 | return nil 109 | case sl.IsReadOnly(): 110 | return nil 111 | } 112 | 113 | var errList []error 114 | for _, s := range sl.getTransactions() { 115 | err := s.Commit() 116 | if err != nil { 117 | errList = append(errList, err) 118 | } 119 | s.Init() 120 | } 121 | 122 | sl.clearTransactions() 123 | if len(errList) > 0 { 124 | return errors.NewErrCommitAll(errList) 125 | } 126 | return nil 127 | } 128 | 129 | // RollbackAll aborts all of transactions 130 | func (xse *XormSessionManager) RollbackAll(id Identifier) error { 131 | if !xse.hasSessionList(id) { 132 | return nil 133 | } 134 | 135 | sl := xse.getOrCreateSessionList(id) 136 | switch { 137 | case sl == nil: 138 | return nil 139 | case sl.IsReadOnly(): 140 | return nil 141 | } 142 | 143 | var errList []error 144 | for _, s := range sl.getTransactions() { 145 | err := s.Rollback() 146 | if err != nil { 147 | errList = append(errList, err) 148 | } 149 | s.Init() 150 | } 151 | 152 | sl.clearTransactions() 153 | if len(errList) > 0 { 154 | return errors.NewErrRollbackAll(errList) 155 | } 156 | return nil 157 | } 158 | -------------------------------------------------------------------------------- /orm/xorm/xorm_session_manager_tx_test.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestForceNewTransaction(t *testing.T) { 10 | assert := assert.New(t) 11 | wiz := testCreateWizard() 12 | orm := New(wiz) 13 | 14 | s, err := orm.ForceNewTransaction(testUser{ID: 1}) 15 | assert.Nil(err) 16 | assert.NotNil(s) 17 | 18 | assert.EqualValues(3, countUserBySession(s), "initial users count") 19 | 20 | s.Insert(&testUser{ID: 4}) 21 | assert.EqualValues(4, countUserBySession(s), "users count after insert in the transaction") 22 | assert.EqualValues(3, countUserMaster(orm), "users count after insert not in the transaction") 23 | 24 | err = s.Rollback() 25 | assert.Nil(err) 26 | 27 | s.Init() 28 | assert.EqualValues(3, countUserBySession(s), "users count after rollback") 29 | } 30 | 31 | func TestTransaction(t *testing.T) { 32 | assert := assert.New(t) 33 | wiz := testCreateWizard() 34 | orm := New(wiz) 35 | xsm := orm.XormSessionManager 36 | sl := xsm.getOrCreateSessionList(testID) 37 | 38 | assert.Len(sl.getTransactions(), 0) 39 | s, err := orm.Transaction(testID, testUser{ID: 1}) 40 | assert.Nil(err) 41 | assert.NotNil(s) 42 | assert.Len(sl.getTransactions(), 1, "transaction is added") 43 | 44 | assert.EqualValues(3, countUserBySession(s), "initial users count") 45 | 46 | s.Insert(&testUser{ID: 4}) 47 | assert.EqualValues(4, countUserBySession(s), "users count after insert in the transaction") 48 | assert.EqualValues(3, countUserMaster(orm), "users count after insert not in the transaction") 49 | 50 | err = s.Rollback() 51 | assert.Nil(err) 52 | 53 | s.Init() 54 | assert.EqualValues(3, countUserBySession(s), "users count after rollback") 55 | } 56 | 57 | func TestTransactionByKey(t *testing.T) { 58 | assert := assert.New(t) 59 | wiz := testCreateWizard() 60 | orm := New(wiz) 61 | xsm := orm.XormSessionManager 62 | sl := xsm.getOrCreateSessionList(testID) 63 | 64 | assert.Len(sl.getTransactions(), 0) 65 | s, err := orm.TransactionByKey(testID, testUser{}, 1) 66 | assert.Nil(err) 67 | assert.NotNil(s) 68 | assert.Len(sl.getTransactions(), 1, "transaction is added") 69 | 70 | assert.EqualValues(3, countUserBySession(s), "initial users count") 71 | 72 | s.Insert(&testUser{ID: 4}) 73 | assert.EqualValues(4, countUserBySession(s), "users count after insert in the transaction") 74 | assert.EqualValues(3, countUserMaster(orm), "users count after insert not in the transaction") 75 | 76 | err = s.Rollback() 77 | assert.Nil(err) 78 | 79 | s.Init() 80 | assert.EqualValues(3, countUserBySession(s), "users count after rollback") 81 | } 82 | 83 | func TestAutoTransaction(t *testing.T) { 84 | assert := assert.New(t) 85 | wiz := testCreateWizard() 86 | orm := New(wiz) 87 | xsm := orm.XormSessionManager 88 | sl := xsm.getOrCreateSessionList(testID) 89 | assert.Len(sl.getTransactions(), 0) 90 | 91 | user1 := testUser{ID: 1} 92 | 93 | s, _ := orm.NewMasterSession(user1) 94 | 95 | err := orm.AutoTransaction(testID, user1, s) 96 | assert.Nil(err) 97 | assert.Len(sl.getTransactions(), 0, "transaction is not added") 98 | 99 | orm.SetAutoTransaction(testID, true) 100 | err = orm.AutoTransaction(testID, user1, s) 101 | assert.Nil(err) 102 | assert.Len(sl.getTransactions(), 1, "transaction is added") 103 | 104 | assert.EqualValues(3, countUserBySession(s), "initial users count") 105 | s.Insert(&testUser{ID: 4}) 106 | assert.EqualValues(4, countUserBySession(s), "users count after insert in the transaction") 107 | assert.EqualValues(4, countUserMaster(orm), "users count after insert in the transaction") 108 | 109 | s2, _ := newSession(orm.Master(user1), user1) 110 | assert.EqualValues(3, countUserBySession(s2), "users count after insert in another session") 111 | 112 | err = s.Rollback() 113 | assert.Nil(err) 114 | s.Init() 115 | assert.EqualValues(3, countUserBySession(s), "users count after rollback") 116 | 117 | } 118 | 119 | func TestAutoTransactionDuplicateTx(t *testing.T) { 120 | assert := assert.New(t) 121 | wiz := testCreateWizard() 122 | orm := New(wiz) 123 | xsm := orm.XormSessionManager 124 | sl := xsm.getOrCreateSessionList(testID) 125 | assert.Len(sl.getTransactions(), 0) 126 | 127 | var err error 128 | 129 | orm.SetAutoTransaction(testID, true) 130 | s1, _ := orm.NewMasterSession(testUser{ID: 1}) 131 | s2, _ := orm.NewMasterSession(testUser{ID: 500}) 132 | sl.addTransaction(orm.Master(testUser{ID: 1}), s1) 133 | 134 | err = orm.AutoTransaction(testID, testUser{ID: 1}, s1) 135 | assert.Nil(err, "error does not occur if same session exists") 136 | 137 | err = orm.AutoTransaction(testID, testUser{ID: 1}, s2) 138 | assert.NotNil(err, "error occurs if another session exists") 139 | } 140 | 141 | func TestCommitAll(t *testing.T) { 142 | assert := assert.New(t) 143 | wiz := testCreateWizard() 144 | orm := New(wiz) 145 | xsm := orm.XormSessionManager 146 | sl := xsm.getOrCreateSessionList(testID) 147 | assert.Len(sl.getTransactions(), 0) 148 | 149 | user1 := testUser{ID: 1} 150 | user500 := testUser{ID: 500} 151 | 152 | s1, _ := orm.NewMasterSession(user1) 153 | s2, _ := orm.NewMasterSession(user500) 154 | 155 | orm.SetAutoTransaction(testID, true) 156 | orm.AutoTransaction(testID, user1, s1) 157 | orm.AutoTransaction(testID, user500, s2) 158 | assert.Len(sl.getTransactions(), 2, "transaction is added") 159 | 160 | assert.EqualValues(3, countUserBySession(s1), "initial users count") 161 | assert.EqualValues(3, countUserBySession(s2), "initial users count") 162 | 163 | num, err := s1.Insert(&testUser{ID: 4}) 164 | assert.Nil(err) 165 | assert.EqualValues(1, num) 166 | s2.Insert(&testUser{ID: 504}) 167 | assert.EqualValues(4, countUserBySession(s1), "users count after insert in the transaction") 168 | assert.EqualValues(4, countUserBySession(s2), "users count after insert in the transaction") 169 | assert.EqualValues(4, countUserMaster(orm), "users count after insert in the transaction") 170 | assert.EqualValues(4, countUserMasterB(orm), "users count after insert in the transaction") 171 | 172 | orm.SetAutoTransaction(testID, false) 173 | s1b, _ := newSession(orm.Master(user1), user1) 174 | s2b, _ := newSession(orm.Master(user500), user500) 175 | assert.EqualValues(3, countUserBySession(s1b), "users count after insert in another session") 176 | assert.EqualValues(3, countUserBySession(s2b), "users count after insert in another session") 177 | 178 | orm.ReadOnly(testID, true) 179 | err = orm.CommitAll(testID) 180 | assert.Nil(err) 181 | assert.Len(sl.getTransactions(), 2, "transaction is not removed when readonly") 182 | 183 | orm.ReadOnly(testID, false) 184 | err = orm.CommitAll(testID) 185 | assert.Nil(err) 186 | assert.Len(sl.getTransactions(), 0, "transaction is removed") 187 | 188 | assert.EqualValues(4, countUserMaster(orm), "users count after commit") 189 | assert.EqualValues(4, countUserMasterB(orm), "users count after commit") 190 | 191 | initTestDB() 192 | } 193 | 194 | func TestRollbackAll(t *testing.T) { 195 | assert := assert.New(t) 196 | wiz := testCreateWizard() 197 | orm := New(wiz) 198 | xsm := orm.XormSessionManager 199 | sl := xsm.getOrCreateSessionList(testID) 200 | assert.Len(sl.getTransactions(), 0) 201 | 202 | user1 := testUser{ID: 1} 203 | user500 := testUser{ID: 500} 204 | 205 | s1, _ := orm.NewMasterSession(user1) 206 | s2, _ := orm.NewMasterSession(user500) 207 | 208 | orm.SetAutoTransaction(testID, true) 209 | orm.AutoTransaction(testID, user1, s1) 210 | orm.AutoTransaction(testID, user500, s2) 211 | assert.Len(sl.getTransactions(), 2, "transaction is added") 212 | 213 | assert.EqualValues(3, countUserBySession(s1), "initial users count") 214 | assert.EqualValues(3, countUserBySession(s2), "initial users count") 215 | 216 | s1.Insert(&testUser{ID: 4}) 217 | s2.Insert(&testUser{ID: 504}) 218 | assert.EqualValues(4, countUserBySession(s1), "users count after insert in the transaction") 219 | assert.EqualValues(4, countUserBySession(s2), "users count after insert in the transaction") 220 | assert.EqualValues(4, countUserMaster(orm), "users count after insert in the transaction") 221 | assert.EqualValues(4, countUserMasterB(orm), "users count after insert in the transaction") 222 | 223 | orm.SetAutoTransaction(testID, false) 224 | s1b, _ := newSession(orm.Master(user1), user1) 225 | s2b, _ := newSession(orm.Master(user500), user500) 226 | assert.EqualValues(3, countUserBySession(s1b), "users count after insert in another session") 227 | assert.EqualValues(3, countUserBySession(s2b), "users count after insert in another session") 228 | 229 | orm.ReadOnly(testID, true) 230 | err := orm.RollbackAll(testID) 231 | assert.Nil(err) 232 | assert.Len(sl.getTransactions(), 2, "transaction is not removed when readonly") 233 | assert.EqualValues(4, countUserBySession(s1), "rollback does not occur when read only") 234 | assert.EqualValues(4, countUserBySession(s2), "rollback does not occur when read only") 235 | 236 | orm.ReadOnly(testID, false) 237 | err = orm.RollbackAll(testID) 238 | assert.Nil(err) 239 | assert.Len(sl.getTransactions(), 0, "transaction is removed") 240 | 241 | assert.EqualValues(3, countUserMaster(orm), "users count after rollback") 242 | assert.EqualValues(3, countUserMasterB(orm), "users count after rollback") 243 | } 244 | 245 | func TestTransactionNilDB(t *testing.T) { 246 | assert := assert.New(t) 247 | orm := New(emptyWiz) 248 | 249 | var s Session 250 | var err error 251 | 252 | s, err = orm.ForceNewTransaction(testUser{ID: 1}) 253 | assert.NotNil(err) 254 | assert.Nil(s) 255 | 256 | s, err = orm.Transaction(testID, testUser{ID: 1}) 257 | assert.NotNil(err) 258 | assert.Nil(s) 259 | 260 | s, err = orm.TransactionByKey(testID, testUser{}, 1) 261 | assert.NotNil(err) 262 | assert.Nil(s) 263 | } 264 | -------------------------------------------------------------------------------- /orm/xorm/xorm_test.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | "time" 7 | 8 | "github.com/evalphobia/wizard" 9 | "github.com/go-xorm/xorm" 10 | _ "github.com/mattn/go-sqlite3" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | var ( 15 | dbUser01Master, dbUser01Slave01, dbUser01Slave02 Engine // user A 16 | dbUser02Master, dbUser02Slave01, dbUser02Slave02 Engine // user B 17 | dbFoobarMaster, dbFoobarSlave01, dbFoobarSlave02 Engine 18 | dbOther Engine 19 | wiz, emptyWiz *wizard.Wizard 20 | ) 21 | 22 | type testUser struct { 23 | ID int64 `xorm:"id pk not null" shard_key:"true"` 24 | Name string `xorm:"varchar(255) not null"` 25 | } 26 | 27 | func (u testUser) TableName() string { 28 | return "test_user" 29 | } 30 | 31 | type testFoobar struct { 32 | ID int64 `xorm:"id pk not null"` 33 | Name string `xorm:"varchar(255) not null"` 34 | } 35 | 36 | func (f testFoobar) TableName() string { 37 | return "test_foobar" 38 | } 39 | 40 | type testCompany struct { 41 | ID int64 `xorm:"id pk not null"` 42 | Name string `xorm:"varchar(255) not null"` 43 | } 44 | 45 | func (c testCompany) TableName() string { 46 | return "test_company" 47 | } 48 | 49 | func init() { 50 | initTestDB() 51 | emptyWiz = wizard.NewWizard() 52 | } 53 | 54 | func initTestDB() { 55 | testInitializeEngines() 56 | testWaitForIO() 57 | testInitializeSchema() 58 | testWaitForIO() 59 | testInitializeData() 60 | testWaitForIO() 61 | } 62 | 63 | func testWaitForIO() { 64 | time.Sleep(80 * time.Millisecond) 65 | } 66 | 67 | func testInitializeEngines() { 68 | f1 := "xorm_test_user01.db" 69 | f2 := "xorm_test_user02.db" 70 | f3 := "xorm_test_foobar.db" 71 | f4 := "xorm_test_other.db" 72 | os.Remove(f1) 73 | os.Remove(f2) 74 | os.Remove(f3) 75 | os.Remove(f4) 76 | 77 | dbUser01Master, _ = xorm.NewEngine("sqlite3", f1) 78 | dbUser01Slave01, _ = xorm.NewEngine("sqlite3", f1) 79 | dbUser01Slave02, _ = xorm.NewEngine("sqlite3", f1) 80 | dbUser02Master, _ = xorm.NewEngine("sqlite3", f2) 81 | dbUser02Slave01, _ = xorm.NewEngine("sqlite3", f2) 82 | dbUser02Slave02, _ = xorm.NewEngine("sqlite3", f2) 83 | dbFoobarMaster, _ = xorm.NewEngine("sqlite3", f3) 84 | dbFoobarSlave01, _ = xorm.NewEngine("sqlite3", f3) 85 | dbFoobarSlave02, _ = xorm.NewEngine("sqlite3", f3) 86 | dbOther, _ = xorm.NewEngine("sqlite3", f4) 87 | } 88 | 89 | func testInitializeSchema() { 90 | dbUser01Master.Sync(&testUser{}) 91 | dbUser02Master.Sync(&testUser{}) 92 | dbFoobarMaster.Sync(&testFoobar{}) 93 | dbOther.Sync(&testCompany{}) 94 | } 95 | 96 | func testInitializeData() { 97 | dbUser01Master.Delete(testUser{}) 98 | dbUser02Master.Delete(testUser{}) 99 | dbFoobarMaster.Delete(testFoobar{}) 100 | dbOther.Delete(testCompany{}) 101 | 102 | dbUser01Master.Insert(testUser{ID: 1, Name: "Adam"}) 103 | dbUser01Master.Insert(testUser{ID: 2, Name: "Benjamin"}) 104 | dbUser01Master.Insert(testUser{ID: 3, Name: "Charles"}) 105 | dbUser02Master.Insert(testUser{ID: 500, Name: "Alice"}) 106 | dbUser02Master.Insert(testUser{ID: 501, Name: "Betty"}) 107 | dbUser02Master.Insert(testUser{ID: 502, Name: "Christina"}) 108 | dbFoobarMaster.Insert(testFoobar{ID: 1, Name: "foobar#1"}) 109 | dbFoobarMaster.Insert(testFoobar{ID: 2, Name: "foobar#2"}) 110 | dbFoobarMaster.Insert(testFoobar{ID: 3, Name: "foobar#3"}) 111 | dbOther.Insert(testCompany{ID: 1, Name: "Apple"}) 112 | dbOther.Insert(testCompany{ID: 2, Name: "BOX"}) 113 | dbOther.Insert(testCompany{ID: 3, Name: "Criteo"}) 114 | } 115 | 116 | func testCreateWizard() *wizard.Wizard { 117 | wiz := wizard.NewWizard() 118 | 119 | userShards := wiz.CreateShardCluster(testUser{}, 997) 120 | shard01 := wizard.NewCluster(dbUser01Master) 121 | shard01.RegisterSlave(dbUser01Slave01) 122 | shard01.RegisterSlave(dbUser01Slave02) 123 | userShards.RegisterShard(0, 499, shard01) // user A 124 | 125 | shard02 := wizard.NewCluster(dbUser02Master) 126 | shard02.RegisterSlave(dbUser02Slave01) 127 | shard02.RegisterSlave(dbUser02Slave02) 128 | userShards.RegisterShard(500, 996, shard02) // user B 129 | 130 | foobarCluster := wiz.CreateCluster(testFoobar{}, dbFoobarMaster) 131 | foobarCluster.RegisterSlave(dbFoobarSlave01) 132 | foobarCluster.RegisterSlave(dbFoobarSlave02) 133 | 134 | otherCluster := wizard.NewCluster(dbOther) 135 | wiz.SetDefault(otherCluster) 136 | return wiz 137 | } 138 | 139 | func countUserMaster(orm *Xorm) int64 { 140 | count, _ := orm.CountUsingMaster(testID, &testUser{ID: 1}, func(s Session) (int64, error) { 141 | return s.Count(&testUser{}) 142 | }) 143 | return count 144 | } 145 | 146 | func countUserMasterB(orm *Xorm) int64 { 147 | count, _ := orm.CountUsingMaster(testID, &testUser{ID: 500}, func(s Session) (int64, error) { 148 | return s.Count(&testUser{}) 149 | }) 150 | return count 151 | } 152 | 153 | func countUserSlave(orm *Xorm) int64 { 154 | count, _ := orm.Count(&testUser{ID: 1}, func(s Session) (int64, error) { 155 | return s.Count(&testUser{}) 156 | }) 157 | return count 158 | } 159 | 160 | func countUserSlaveB(orm *Xorm) int64 { 161 | count, _ := orm.Count(&testUser{ID: 500}, func(s Session) (int64, error) { 162 | return s.Count(&testUser{}) 163 | }) 164 | return count 165 | } 166 | 167 | func countUserBySession(s Session) int64 { 168 | count, _ := s.Count(testUser{}) 169 | return count 170 | } 171 | 172 | func TestNew(t *testing.T) { 173 | assert := assert.New(t) 174 | wiz := wizard.NewWizard() 175 | 176 | orm := New(wiz) 177 | assert.Equal(wiz, orm.Wiz) 178 | } 179 | -------------------------------------------------------------------------------- /orm/xorm/xorm_wizard.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "github.com/evalphobia/wizard" 5 | ) 6 | 7 | // XormWizard is struct for database selector 8 | type XormWizard struct { 9 | *wizard.Wizard 10 | } 11 | 12 | // Master returns master db for the given object 13 | func (xwiz XormWizard) Master(obj interface{}) Engine { 14 | db := xwiz.UseMaster(obj) 15 | if db == nil { 16 | return nil 17 | } 18 | return db.(Engine) 19 | } 20 | 21 | // MasterByKey returns master db by shard key 22 | func (xwiz XormWizard) MasterByKey(obj interface{}, key interface{}) Engine { 23 | db := xwiz.UseMasterByKey(obj, key) 24 | if db == nil { 25 | return nil 26 | } 27 | return db.(Engine) 28 | } 29 | 30 | // Masters returns all of sharded master db for the given object 31 | func (xwiz XormWizard) Masters(obj interface{}) []Engine { 32 | var results []Engine 33 | for _, db := range xwiz.UseMasters(obj) { 34 | e, ok := db.(Engine) 35 | if !ok || e == nil { 36 | continue 37 | } 38 | results = append(results, e) 39 | } 40 | return results 41 | } 42 | 43 | // Slave randomly returns one of the slave db for the given object 44 | func (xwiz XormWizard) Slave(obj interface{}) Engine { 45 | db := xwiz.UseSlave(obj) 46 | if db == nil { 47 | return nil 48 | } 49 | return db.(Engine) 50 | } 51 | 52 | // SlaveByKey randomly returns one of the slave db by shard key 53 | func (xwiz XormWizard) SlaveByKey(obj interface{}, key interface{}) Engine { 54 | db := xwiz.UseSlaveByKey(obj, key) 55 | if db == nil { 56 | return nil 57 | } 58 | return db.(Engine) 59 | } 60 | 61 | // Slaves randomly returns all of sharded slave db for the given object 62 | func (xwiz XormWizard) Slaves(obj interface{}) []Engine { 63 | var results []Engine 64 | for _, db := range xwiz.UseSlaves(obj) { 65 | e, ok := db.(Engine) 66 | if !ok || e == nil { 67 | continue 68 | } 69 | results = append(results, e) 70 | } 71 | return results 72 | } 73 | -------------------------------------------------------------------------------- /orm/xorm/xorm_wizard_test.go: -------------------------------------------------------------------------------- 1 | package xorm 2 | 3 | import ( 4 | "testing" 5 | 6 | _ "github.com/mattn/go-sqlite3" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestMaster(t *testing.T) { 11 | assert := assert.New(t) 12 | 13 | wiz := testCreateWizard() 14 | orm := New(wiz) 15 | 16 | assert.Equal(dbFoobarMaster, orm.Master(testFoobar{})) 17 | assert.Equal(dbOther, orm.Master("xxx")) 18 | 19 | emptyOrm := New(emptyWiz) 20 | assert.Equal(nil, emptyOrm.Master("empty")) 21 | } 22 | 23 | func TestMasterByKey(t *testing.T) { 24 | assert := assert.New(t) 25 | wiz := testCreateWizard() 26 | orm := New(wiz) 27 | 28 | assert.Equal(dbUser01Master, orm.MasterByKey(testUser{}, 1)) 29 | assert.Equal(dbUser01Master, orm.MasterByKey(testUser{}, 499)) 30 | assert.Equal(dbUser02Master, orm.MasterByKey(testUser{}, 500)) 31 | assert.Equal(dbUser02Master, orm.MasterByKey(testUser{}, 501)) 32 | assert.Equal(dbUser02Master, orm.MasterByKey(testUser{}, 996)) 33 | assert.Equal(dbUser01Master, orm.MasterByKey(testUser{}, 997)) 34 | 35 | emptyOrm := New(emptyWiz) 36 | assert.Equal(nil, emptyOrm.MasterByKey("empty", 1)) 37 | } 38 | 39 | func TestMasters(t *testing.T) { 40 | assert := assert.New(t) 41 | wiz := testCreateWizard() 42 | orm := New(wiz) 43 | 44 | shardMasters := []Engine{dbUser01Master, dbUser02Master} 45 | assert.Equal(shardMasters, orm.Masters(testUser{})) 46 | assert.Equal([]Engine{dbFoobarMaster}, orm.Masters(testFoobar{})) 47 | assert.Equal([]Engine{dbOther}, orm.Masters("xxx")) 48 | 49 | emptyOrm := New(emptyWiz) 50 | assert.Empty(emptyOrm.Masters("empty")) 51 | } 52 | 53 | func TestSlave(t *testing.T) { 54 | assert := assert.New(t) 55 | wiz := testCreateWizard() 56 | orm := New(wiz) 57 | 58 | assert.Contains([]Engine{dbFoobarSlave01, dbFoobarSlave02}, orm.Slave(testFoobar{})) 59 | assert.Equal(dbOther, orm.Slave("xxx")) 60 | 61 | emptyOrm := New(emptyWiz) 62 | assert.Equal(nil, emptyOrm.Slave("empty")) 63 | } 64 | 65 | func TestSlaveByKey(t *testing.T) { 66 | assert := assert.New(t) 67 | wiz := testCreateWizard() 68 | orm := New(wiz) 69 | 70 | db01 := []Engine{dbUser01Slave01, dbUser01Slave02} 71 | db02 := []Engine{dbUser02Slave01, dbUser02Slave02} 72 | assert.Contains(db01, orm.SlaveByKey(testUser{}, 1)) 73 | assert.Contains(db01, orm.SlaveByKey(testUser{}, 499)) 74 | assert.Contains(db02, orm.SlaveByKey(testUser{}, 500)) 75 | assert.Contains(db02, orm.SlaveByKey(testUser{}, 501)) 76 | assert.Contains(db02, orm.SlaveByKey(testUser{}, 996)) 77 | assert.Contains(db01, orm.SlaveByKey(testUser{}, 997)) 78 | 79 | emptyOrm := New(emptyWiz) 80 | assert.Equal(nil, emptyOrm.SlaveByKey("empty", 1)) 81 | } 82 | 83 | func TestSlaves(t *testing.T) { 84 | assert := assert.New(t) 85 | wiz := testCreateWizard() 86 | orm := New(wiz) 87 | 88 | slaves := orm.Slaves(testFoobar{}) 89 | assert.Contains([]Engine{dbFoobarSlave01, dbFoobarSlave02}, slaves[0]) 90 | 91 | slaves = orm.Slaves("xxx") 92 | assert.Equal([]Engine{dbOther}, slaves) 93 | 94 | slaves = orm.Slaves(testUser{}) 95 | assert.Len(slaves, 2) 96 | assert.Contains([]Engine{dbUser01Slave01, dbUser01Slave02}, slaves[0]) 97 | assert.Contains([]Engine{dbUser02Slave01, dbUser02Slave02}, slaves[1]) 98 | 99 | emptyOrm := New(emptyWiz) 100 | assert.Empty(emptyOrm.Slaves("empty")) 101 | } 102 | -------------------------------------------------------------------------------- /reflect.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | ) 7 | 8 | // used for shard key in the tag name of struct 9 | const TagName = "shard_key" 10 | 11 | // NormalizeValue returns value 12 | // if struct is passed, returns name of the struct 13 | // if pointer is passed, returns non-pointer value 14 | func NormalizeValue(p interface{}) interface{} { 15 | v := toValue(p) 16 | if v.Kind() == reflect.Struct { 17 | return v.Type().String() 18 | } 19 | return v.Interface() 20 | } 21 | 22 | func getShardKey(p interface{}) int64 { 23 | v := toValue(p) 24 | if v.Kind() != reflect.Struct { 25 | return 0 26 | } 27 | return getShardKeyFromStruct(p, TagName) 28 | } 29 | 30 | // toValue converts any value to reflect.Value 31 | func toValue(p interface{}) reflect.Value { 32 | v := reflect.ValueOf(p) 33 | if v.Kind() == reflect.Ptr { 34 | v = v.Elem() 35 | } 36 | return v 37 | } 38 | 39 | // toType converts any value to reflect.Type 40 | func toType(p interface{}) reflect.Type { 41 | t := reflect.ValueOf(p).Type() 42 | if t.Kind() == reflect.Ptr { 43 | t = t.Elem() 44 | } 45 | return t 46 | } 47 | 48 | func getShardKeyFromStruct(p interface{}, tagName string) int64 { 49 | t := toType(p) 50 | values := toValue(p) 51 | for i, max := 0, t.NumField(); i < max; i++ { 52 | f := t.Field(i) 53 | if f.PkgPath != "" && !f.Anonymous { 54 | continue 55 | } 56 | 57 | tag := parseTag(f, tagName) 58 | // search recursively when `extends` tag 59 | if tag == "extends" { 60 | v := values.Field(i) 61 | return getShardKeyFromStruct(v.Interface(), tagName) 62 | } 63 | if tag != "true" { 64 | continue 65 | } 66 | v := values.Field(i) 67 | return getInt64(v.Interface()) 68 | } 69 | return 0 70 | } 71 | 72 | // parseTag returns the first tag value of the struct field 73 | func parseTag(f reflect.StructField, tag string) string { 74 | res := strings.Split(f.Tag.Get(tag), ",") 75 | return res[0] 76 | } 77 | -------------------------------------------------------------------------------- /reflect_test.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestNormalizeValue(t *testing.T) { 10 | assert := assert.New(t) 11 | 12 | assert.Equal("wizard.ShardCluster", NormalizeValue(ShardCluster{}), "Struct should return the type name") 13 | assert.Equal("wizard.ShardCluster", NormalizeValue(&ShardCluster{}), "Struct pointer should return the type name") 14 | 15 | valueString := "foobar" 16 | assert.Equal(valueString, NormalizeValue(valueString)) 17 | assert.Equal(valueString, NormalizeValue(&valueString)) 18 | 19 | valueInt := 99 20 | assert.Equal(valueInt, NormalizeValue(valueInt)) 21 | assert.Equal(valueInt, NormalizeValue(&valueInt)) 22 | 23 | valueSlice := []string{"a", "b", "c"} 24 | assert.Equal(valueSlice, NormalizeValue(valueSlice)) 25 | assert.Equal(valueSlice, NormalizeValue(&valueSlice)) 26 | 27 | valueMap := map[interface{}]interface{}{"key": "value", 100: 403} 28 | assert.Equal(valueMap, NormalizeValue(valueMap)) 29 | assert.Equal(valueMap, NormalizeValue(&valueMap)) 30 | } 31 | 32 | func TestGetShardKey(t *testing.T) { 33 | assert := assert.New(t) 34 | 35 | type myStruct1 struct { 36 | UserID int64 37 | CountryID int64 38 | CityID int64 `shard_key:"false"` 39 | } 40 | 41 | type myStruct2 struct { 42 | UserID int64 43 | CountryID int64 `shard_key:"true"` 44 | CityID int64 45 | } 46 | 47 | type myStruct3 struct { 48 | UserID int64 `shard_key:"true"` 49 | CountryID int64 `shard_key:"true"` 50 | CityID int64 `shard_key:"true"` 51 | } 52 | 53 | type personStruct struct { 54 | Name string 55 | City string `shard_key:"true"` 56 | Tel string 57 | } 58 | 59 | m1 := myStruct1{UserID: 1, CountryID: 2, CityID: 3} 60 | m2 := myStruct2{UserID: 1, CountryID: 2, CityID: 3} 61 | m3 := myStruct3{UserID: 1, CountryID: 2, CityID: 3} 62 | 63 | assert.Equal(int64(0), getShardKey(m1), "getShardKey() must return 0 when tag `shard_key:true` is missing") 64 | assert.Equal(m2.CountryID, getShardKey(m2)) 65 | assert.Equal(m3.UserID, getShardKey(m3), "getShardKey() must return 1st field value when multiple tag `shard_key:true` exists") 66 | 67 | adam := personStruct{Name: "Adam Smith", City: "Oxford", Tel: "+81 0120-000-000"} 68 | assert.Equal(getInt64("Oxford"), getShardKey(adam)) 69 | } 70 | -------------------------------------------------------------------------------- /shard_cluster.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | import ( 4 | "github.com/evalphobia/wizard/errors" 5 | ) 6 | 7 | // ShardCluster is struct for sharded database cluster 8 | type ShardCluster struct { 9 | List []*ShardSet // sharded database clusters 10 | slotsize int64 11 | } 12 | 13 | // Master is dummy method for interface 14 | func (c ShardCluster) Master() *Node { 15 | return nil 16 | } 17 | 18 | // Masters returns all db masters from the sharded clusters 19 | func (c ShardCluster) Masters() []*Node { 20 | var result []*Node 21 | for _, s := range c.List { 22 | if s.set == nil { 23 | continue 24 | } 25 | result = append(result, s.set.Master()) 26 | } 27 | return result 28 | } 29 | 30 | // Slave is dummy method for interface 31 | func (c ShardCluster) Slave() *Node { 32 | return nil 33 | } 34 | 35 | // Slaves randomly returns all db slaves from the sharded clusters 36 | func (c ShardCluster) Slaves() []*Node { 37 | var result []*Node 38 | for _, s := range c.List { 39 | if s.set == nil { 40 | continue 41 | } 42 | result = append(result, s.set.Slave()) 43 | } 44 | return result 45 | } 46 | 47 | // SelectByKey returns sharded cluster by shard key 48 | func (c ShardCluster) SelectByKey(key interface{}) *StandardCluster { 49 | i := getInt64(key) 50 | mod := i % c.slotsize 51 | for _, shard := range c.List { 52 | if shard.InRange(mod) { 53 | return shard.set 54 | } 55 | } 56 | return nil 57 | } 58 | 59 | // RegisterShard adds cluster with hash slot range(min and max) 60 | func (c *ShardCluster) RegisterShard(min, max int64, s *StandardCluster) error { 61 | err := c.checkOverlapped(min, max) 62 | if err != nil { 63 | return err 64 | } 65 | 66 | ss := &ShardSet{ 67 | min: min, 68 | max: max, 69 | set: s, 70 | } 71 | err = ss.checkSlotSize(c.slotsize) 72 | if err != nil { 73 | return err 74 | } 75 | 76 | c.List = append(c.List, ss) 77 | return nil 78 | } 79 | 80 | // checkOverlapped checks the hash slot range is not overlapped among the shards 81 | func (c *ShardCluster) checkOverlapped(min, max int64) error { 82 | for _, ss := range c.List { 83 | switch { 84 | case ss.InRange(min): 85 | return errors.NewErrSlotMinOverlapped(min) 86 | case ss.InRange(max): 87 | return errors.NewErrSlotMaxOverlapped(max) 88 | } 89 | } 90 | return nil 91 | } 92 | 93 | // ShardSet is struct of sharded cluster 94 | type ShardSet struct { 95 | min int64 96 | max int64 97 | set *StandardCluster 98 | } 99 | 100 | // InRange checks given number is in range of this shard 101 | func (ss ShardSet) InRange(v int64) bool { 102 | return ss.min <= v && v <= ss.max 103 | } 104 | 105 | // checkSlotSize checks given number is not minus and within slotsize 106 | func (ss ShardSet) checkSlotSize(slot int64) error { 107 | switch { 108 | case !ss.isMinAboveZero(): 109 | return errors.NewErrSlotSizeMin(ss.min) 110 | case !ss.isMaxInSlotSize(slot): 111 | return errors.NewErrSlotSizeMax(ss.max, slot) 112 | } 113 | return nil 114 | } 115 | 116 | // isMinAboveZero checks given number is not minus 117 | func (ss ShardSet) isMinAboveZero() bool { 118 | return ss.min >= 0 119 | } 120 | 121 | // isMaxInSlotSize checks given number is within slotsize 122 | func (ss ShardSet) isMaxInSlotSize(slot int64) bool { 123 | return ss.max < slot 124 | } 125 | -------------------------------------------------------------------------------- /shard_cluster_test.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func testCreateCluster(prefix string) *StandardCluster { 10 | c := NewCluster(prefix + "-master") 11 | c.RegisterSlave(prefix + "-slave01") 12 | c.RegisterSlave(prefix + "-slave02") 13 | c.RegisterSlave(prefix + "-slave03") 14 | return c 15 | } 16 | 17 | func TestShardClusterMaster(t *testing.T) { 18 | assert := assert.New(t) 19 | 20 | var s *ShardCluster 21 | var node *Node 22 | 23 | s = &ShardCluster{slotsize: 1} 24 | node = s.Master() 25 | assert.Nil(node, "Master() should be always nil on ShardCluster") 26 | } 27 | 28 | func TestShardClusterMasters(t *testing.T) { 29 | assert := assert.New(t) 30 | 31 | var s *ShardCluster 32 | var err error 33 | 34 | s = &ShardCluster{slotsize: 1000} 35 | assert.Len(s.Masters(), 0) 36 | 37 | c := testCreateCluster("shard01") 38 | err = s.RegisterShard(0, 500, c) 39 | assert.Nil(err) 40 | assert.Len(s.Masters(), 1) 41 | 42 | c = testCreateCluster("shard02") 43 | err = s.RegisterShard(501, 999, c) 44 | assert.Nil(err) 45 | assert.Len(s.Masters(), 2) 46 | } 47 | 48 | func TestShardClusterSlave(t *testing.T) { 49 | assert := assert.New(t) 50 | 51 | var s *ShardCluster 52 | var node *Node 53 | 54 | s = &ShardCluster{slotsize: 1} 55 | node = s.Slave() 56 | assert.Nil(node, "Slave() should be always nil on ShardCluster") 57 | } 58 | 59 | func TestShardClusterSelectByKey(t *testing.T) { 60 | assert := assert.New(t) 61 | 62 | var s *ShardCluster 63 | var c *StandardCluster 64 | var err error 65 | 66 | s = &ShardCluster{slotsize: 2} 67 | err = s.RegisterShard(0, 0, testCreateCluster("shard01")) 68 | assert.Nil(err) 69 | err = s.RegisterShard(1, 1, testCreateCluster("shard02")) 70 | assert.Nil(err) 71 | 72 | c = s.SelectByKey(0) 73 | assert.Equal("shard01-master", c.Master().DB()) 74 | c = s.SelectByKey(1) 75 | assert.Equal("shard02-master", c.Master().DB()) 76 | c = s.SelectByKey(2) 77 | assert.Equal("shard01-master", c.Master().DB()) 78 | c = s.SelectByKey(3) 79 | assert.Equal("shard02-master", c.Master().DB()) 80 | c = s.SelectByKey(4) 81 | assert.Equal("shard01-master", c.Master().DB()) 82 | c = s.SelectByKey(5) 83 | assert.Equal("shard02-master", c.Master().DB()) 84 | } 85 | 86 | func TestShardClusterRegisterShard(t *testing.T) { 87 | assert := assert.New(t) 88 | 89 | var s *ShardCluster 90 | var err error 91 | 92 | s = &ShardCluster{slotsize: 10} 93 | err = s.RegisterShard(-1, 0, testCreateCluster("min-error")) 94 | assert.NotNil(err, "Slotsize cannot be under 0") 95 | assert.Len(s.List, 0) 96 | 97 | err = s.RegisterShard(0, 10, testCreateCluster("max-error")) 98 | assert.NotNil(err, "Slotsize cannot be greater equal than slotsize") 99 | assert.Len(s.List, 0) 100 | 101 | err = s.RegisterShard(5, 6, testCreateCluster("shard01")) 102 | assert.Nil(err) 103 | 104 | err = s.RegisterShard(6, 9, testCreateCluster("min-error")) 105 | assert.NotNil(err, "Slot min is already registered") 106 | 107 | err = s.RegisterShard(0, 5, testCreateCluster("max-error")) 108 | assert.NotNil(err, "Slot max is already registered") 109 | } 110 | 111 | func TestShardClusterCheckOverlapped(t *testing.T) { 112 | assert := assert.New(t) 113 | 114 | var s *ShardCluster 115 | var err error 116 | 117 | s = &ShardCluster{slotsize: 10} 118 | err = s.RegisterShard(5, 6, testCreateCluster("shard01")) 119 | assert.Nil(err) 120 | 121 | err = s.checkOverlapped(6, 9) 122 | assert.NotNil(err, "Slot min is already registered") 123 | 124 | err = s.checkOverlapped(0, 5) 125 | assert.NotNil(err, "Slot max is already registered") 126 | 127 | err = s.checkOverlapped(0, 4) 128 | assert.Nil(err) 129 | err = s.checkOverlapped(7, 9) 130 | assert.Nil(err) 131 | } 132 | 133 | func TestShardSetInRange(t *testing.T) { 134 | assert := assert.New(t) 135 | 136 | var ss *ShardSet 137 | ss = &ShardSet{ 138 | min: 10, 139 | max: 20, 140 | } 141 | 142 | assert.False(ss.InRange(9)) 143 | assert.True(ss.InRange(10)) 144 | assert.True(ss.InRange(11)) 145 | assert.True(ss.InRange(19)) 146 | assert.True(ss.InRange(20)) 147 | assert.False(ss.InRange(21)) 148 | } 149 | 150 | func TestShardSetCheckSlotSize(t *testing.T) { 151 | assert := assert.New(t) 152 | 153 | var ss *ShardSet 154 | var err error 155 | ss = &ShardSet{ 156 | min: 10, 157 | max: 20, 158 | } 159 | 160 | err = ss.checkSlotSize(19) 161 | assert.NotNil(err, "max must be greater than slotsize") 162 | err = ss.checkSlotSize(20) 163 | assert.NotNil(err, "max must be greater than slotsize") 164 | err = ss.checkSlotSize(21) 165 | assert.Nil(err) 166 | 167 | ss = &ShardSet{ 168 | min: -2, 169 | max: 20, 170 | } 171 | err = ss.checkSlotSize(21) 172 | assert.NotNil(err, "min must be greater equal than 0") 173 | } 174 | 175 | func TestShardSetIsMaxInSlotSize(t *testing.T) { 176 | assert := assert.New(t) 177 | 178 | var ss *ShardSet 179 | ss = &ShardSet{ 180 | min: 10, 181 | max: 20, 182 | } 183 | 184 | assert.False(ss.isMaxInSlotSize(19)) 185 | assert.False(ss.isMaxInSlotSize(20)) 186 | assert.True(ss.isMaxInSlotSize(21)) 187 | } 188 | 189 | func TestShardSetIsMinAboveZero(t *testing.T) { 190 | assert := assert.New(t) 191 | 192 | assert.False(ShardSet{min: -1}.isMinAboveZero()) 193 | assert.True(ShardSet{min: 0}.isMinAboveZero()) 194 | assert.True(ShardSet{min: 1}.isMinAboveZero()) 195 | } 196 | -------------------------------------------------------------------------------- /standard_cluster.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | import ( 4 | "math/rand" 5 | "time" 6 | ) 7 | 8 | func init() { 9 | // initialized for slave balancing 10 | rand.Seed(time.Now().UnixNano()) 11 | } 12 | 13 | // StandardCluster is struct for typical(non-sharded) database cluster 14 | type StandardCluster struct { 15 | master *Node 16 | slaves []*Node 17 | } 18 | 19 | // NewCluster returns the StandardCluster initialized with master database 20 | func NewCluster(db interface{}) *StandardCluster { 21 | node := NewNode(db) 22 | return &StandardCluster{master: node} 23 | } 24 | 25 | // Master returns master database 26 | func (c StandardCluster) Master() *Node { 27 | return c.master 28 | } 29 | 30 | // Masters is dummy method for interface 31 | func (c StandardCluster) Masters() []*Node { 32 | return []*Node{c.master} 33 | } 34 | 35 | // Slave ramdomly returns the slave database. 36 | // if no slave is registered, master is returned 37 | func (c StandardCluster) Slave() *Node { 38 | if len(c.slaves) == 0 { 39 | return c.master 40 | } 41 | return c.slaves[rand.Intn(len(c.slaves))] 42 | } 43 | 44 | // Slaves is dummy method for interface 45 | func (c StandardCluster) Slaves() []*Node { 46 | return []*Node{c.Slave()} 47 | } 48 | 49 | // SelectByKey is dummy method for interface 50 | func (c *StandardCluster) SelectByKey(v interface{}) *StandardCluster { 51 | return c 52 | } 53 | 54 | // RegisterMaster set new master node 55 | func (c *StandardCluster) RegisterMaster(db interface{}) { 56 | c.master = &Node{db: db} 57 | } 58 | 59 | // RegisterSlave adds slave node 60 | func (c *StandardCluster) RegisterSlave(db interface{}) { 61 | c.slaves = append(c.slaves, &Node{db: db}) 62 | } 63 | -------------------------------------------------------------------------------- /standard_cluster_test.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestNewCluster(t *testing.T) { 10 | assert := assert.New(t) 11 | 12 | var c *StandardCluster 13 | c = NewCluster("db") 14 | assert.IsType(Node{}, *c.master) 15 | assert.Equal("db", c.master.db, "db should be saved on NewCluster()") 16 | } 17 | 18 | func TestStandardClusterMaster(t *testing.T) { 19 | assert := assert.New(t) 20 | 21 | var c *StandardCluster 22 | var node *Node 23 | 24 | c = NewCluster("db") 25 | node = c.Master() 26 | assert.Equal("db", node.db) 27 | assert.Equal(c.master, node, "Master() shoud equal to StandardCluster.master") 28 | 29 | c.RegisterMaster("db2") 30 | node = c.Master() 31 | assert.Equal("db2", node.db) 32 | } 33 | 34 | func TestStandardClusterMasters(t *testing.T) { 35 | assert := assert.New(t) 36 | 37 | var c *StandardCluster 38 | var nodes []*Node 39 | 40 | c = NewCluster("db") 41 | nodes = c.Masters() 42 | assert.Equal(c.master, nodes[0]) 43 | assert.Len(nodes, 1) 44 | } 45 | 46 | func TestStandardClusterSlave(t *testing.T) { 47 | assert := assert.New(t) 48 | 49 | var c *StandardCluster 50 | var node *Node 51 | 52 | c = NewCluster("master") 53 | node = c.Slave() 54 | assert.IsType(Node{}, *node) 55 | assert.Equal("master", node.db) 56 | assert.Equal(c.master, node, "Slave() shoud equal to StandardCluster.master when no slaves") 57 | assert.Len(c.slaves, 0) 58 | 59 | c.RegisterSlave("slave") 60 | node = c.Slave() 61 | assert.IsType(Node{}, *node) 62 | assert.Equal("slave", node.db) 63 | assert.Equal(c.slaves[0], node, "Slave() shoud equal to node in StandardCluster.slaves") 64 | assert.Len(c.slaves, 1) 65 | 66 | for i, max := 0, 100; i < max; i++ { 67 | c.RegisterSlave(i) 68 | } 69 | assert.Len(c.slaves, 101) 70 | 71 | node = c.Slave() 72 | db := node.db 73 | for i, max := 0, 10; i < max; i++ { 74 | node = c.Slave() 75 | if node.db != db { 76 | return 77 | } 78 | } 79 | t.Error("Slave() should return different nodes") 80 | } 81 | 82 | func TestStandardClusterSelectByKey(t *testing.T) { 83 | assert := assert.New(t) 84 | 85 | var c, c2, c3, c4 *StandardCluster 86 | 87 | c = NewCluster("db") 88 | c2 = c.SelectByKey(0) 89 | c3 = c.SelectByKey(1) 90 | c4 = c.SelectByKey(9999) 91 | assert.Equal(c, c2) 92 | assert.Equal(c, c3) 93 | assert.Equal(c, c4) 94 | } 95 | 96 | func TestStandardClusterRegisterMaster(t *testing.T) { 97 | assert := assert.New(t) 98 | 99 | var c *StandardCluster 100 | 101 | c = NewCluster("db") 102 | c.RegisterMaster("db2") 103 | c.RegisterMaster("db3") 104 | 105 | assert.Equal("db3", c.master.db) 106 | 107 | c.RegisterMaster("db4") 108 | assert.Equal("db4", c.master.db) 109 | } 110 | 111 | func TestStandardClusterRegisterSlave(t *testing.T) { 112 | assert := assert.New(t) 113 | 114 | var c *StandardCluster 115 | 116 | c = NewCluster("db") 117 | c.RegisterSlave("db2") 118 | c.RegisterSlave("db3") 119 | c.RegisterSlave("db4") 120 | 121 | assert.Equal("db2", c.slaves[0].db) 122 | assert.Equal("db3", c.slaves[1].db) 123 | assert.Equal("db4", c.slaves[2].db) 124 | } 125 | -------------------------------------------------------------------------------- /wizard.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | import ( 4 | "github.com/evalphobia/wizard/errors" 5 | ) 6 | 7 | // Cluster is interface for [StandardCluster | ShardCluster] 8 | type Cluster interface { 9 | SelectByKey(interface{}) *StandardCluster 10 | Master() *Node 11 | Masters() []*Node 12 | Slave() *Node 13 | Slaves() []*Node 14 | } 15 | 16 | // Wizard manages all the database cluster for your app 17 | type Wizard struct { 18 | clusters map[interface{}]Cluster 19 | defaultCluster Cluster 20 | } 21 | 22 | // NewWizard returns initialized empty Wizard 23 | func NewWizard() *Wizard { 24 | return &Wizard{ 25 | clusters: make(map[interface{}]Cluster), 26 | } 27 | } 28 | 29 | // SetDefault set default cluster 30 | // if default is set, this cluster acts like catchall, handles all the other tables. 31 | func (w *Wizard) SetDefault(c Cluster) { 32 | w.defaultCluster = c 33 | } 34 | 35 | // HasDefault checks default cluster is set or not 36 | func (w *Wizard) HasDefault() bool { 37 | return w.defaultCluster != nil 38 | } 39 | 40 | // getCluster returns the cluster by name mapping 41 | func (w *Wizard) getCluster(obj interface{}) Cluster { 42 | c, ok := w.clusters[NormalizeValue(obj)] 43 | switch { 44 | case ok: 45 | return c 46 | case w.HasDefault(): 47 | return w.defaultCluster 48 | default: 49 | return nil 50 | } 51 | } 52 | 53 | // RegisterTables adds cluster and tables for name mapping 54 | func (w *Wizard) RegisterTables(c Cluster, list ...interface{}) error { 55 | for _, obj := range list { 56 | v := NormalizeValue(obj) 57 | if _, ok := w.clusters[v]; ok { 58 | return errors.NewErrAlreadyRegistared(v) 59 | } 60 | w.clusters[v] = c 61 | } 62 | return nil 63 | } 64 | 65 | // setCluster set the cluster with name mapping 66 | func (w *Wizard) setCluster(c Cluster, obj interface{}) { 67 | w.clusters[NormalizeValue(obj)] = c 68 | } 69 | 70 | // CreateCluster set and returns the new StandardCluster 71 | func (w *Wizard) CreateCluster(obj interface{}, db interface{}) *StandardCluster { 72 | c := NewCluster(db) 73 | w.setCluster(c, obj) 74 | return c 75 | } 76 | 77 | // CreateShardCluster set and returns the new ShardCluster 78 | func (w *Wizard) CreateShardCluster(obj interface{}, slot int64) *ShardCluster { 79 | if slot < 1 { 80 | slot = 1 81 | } 82 | c := &ShardCluster{ 83 | slotsize: slot, 84 | } 85 | w.setCluster(c, obj) 86 | return c 87 | } 88 | 89 | // Select returns StandardCluster by name mapping (and implicit hash slot from struct field) 90 | func (w *Wizard) Select(obj interface{}) *StandardCluster { 91 | c := w.getCluster(obj) 92 | switch v := c.(type) { 93 | case *StandardCluster: 94 | return v 95 | case *ShardCluster: 96 | return v.SelectByKey(getShardKey(obj)) 97 | default: 98 | return nil 99 | } 100 | } 101 | 102 | // SelectByKey returns StandardCluster by name mapping and shard key 103 | func (w *Wizard) SelectByKey(obj interface{}, key interface{}) *StandardCluster { 104 | c := w.getCluster(obj) 105 | if c == nil { 106 | return nil 107 | } 108 | return c.SelectByKey(key) 109 | } 110 | -------------------------------------------------------------------------------- /wizard_test.go: -------------------------------------------------------------------------------- 1 | package wizard 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestNewWizard(t *testing.T) { 10 | assert := assert.New(t) 11 | 12 | wiz := NewWizard() 13 | assert.NotNil(wiz) 14 | assert.Empty(wiz.clusters) 15 | } 16 | 17 | func TestSetDefault(t *testing.T) { 18 | assert := assert.New(t) 19 | 20 | wiz := NewWizard() 21 | assert.Nil(wiz.defaultCluster) 22 | 23 | c := NewCluster("db") 24 | wiz.SetDefault(c) 25 | assert.Equal(c, wiz.defaultCluster, "It can set StandardCluster") 26 | 27 | s := &ShardCluster{} 28 | wiz.SetDefault(s) 29 | assert.Equal(s, wiz.defaultCluster, "It can set ShardCluster") 30 | } 31 | 32 | func TestHasDefault(t *testing.T) { 33 | assert := assert.New(t) 34 | 35 | wiz := NewWizard() 36 | assert.False(wiz.HasDefault()) 37 | 38 | c := NewCluster("db") 39 | wiz.SetDefault(c) 40 | assert.True(wiz.HasDefault()) 41 | } 42 | 43 | func TestGetCluster(t *testing.T) { 44 | assert := assert.New(t) 45 | 46 | wiz := NewWizard() 47 | assert.Nil(wiz.getCluster("table name")) 48 | 49 | c := NewCluster("db") 50 | wiz.clusters["table name"] = c 51 | assert.Equal(c, wiz.getCluster("table name")) 52 | } 53 | 54 | func TestSetCluster(t *testing.T) { 55 | assert := assert.New(t) 56 | 57 | wiz := NewWizard() 58 | assert.Nil(wiz.clusters["table name"]) 59 | 60 | c := NewCluster("db") 61 | wiz.setCluster(c, "table name") 62 | assert.Equal(c, wiz.clusters["table name"]) 63 | } 64 | 65 | func TestCreateCluster(t *testing.T) { 66 | assert := assert.New(t) 67 | 68 | wiz := NewWizard() 69 | c := wiz.CreateCluster("table name", "db-master") 70 | assert.NotNil(c) 71 | assert.NotNil(c.master) 72 | assert.Empty(c.slaves) 73 | assert.Equal("db-master", c.master.db) 74 | } 75 | 76 | func TestCreateShardCluster(t *testing.T) { 77 | assert := assert.New(t) 78 | 79 | var s *ShardCluster 80 | wiz := NewWizard() 81 | 82 | var slotsize int64 = 99 83 | s = wiz.CreateShardCluster("table name", slotsize) 84 | assert.NotNil(s) 85 | assert.Empty(s.List) 86 | assert.Equal(int64(99), s.slotsize) 87 | assert.Equal(s, wiz.clusters["table name"]) 88 | 89 | var slotsizeZero int64 90 | s = wiz.CreateShardCluster("table name2", slotsizeZero) 91 | assert.NotNil(s) 92 | assert.Empty(s.List) 93 | assert.Equal(int64(1), s.slotsize) 94 | assert.Equal(s, wiz.clusters["table name2"]) 95 | 96 | var slotsizeMinus int64 = -99 97 | s = wiz.CreateShardCluster("table name", slotsizeMinus) 98 | assert.NotNil(s) 99 | assert.Empty(s.List) 100 | assert.Equal(int64(1), s.slotsize) 101 | assert.Equal(s, wiz.clusters["table name"]) 102 | } 103 | 104 | func TestSelect(t *testing.T) { 105 | assert := assert.New(t) 106 | wiz := NewWizard() 107 | 108 | // shard test 109 | type myStruct struct { 110 | ID int64 `shard_key:"true"` 111 | } 112 | s := wiz.CreateShardCluster(myStruct{}, 100) 113 | shardSet1 := NewCluster("shard01-master") 114 | shardSet2 := NewCluster("shard02-master") 115 | s.RegisterShard(0, 49, shardSet1) 116 | s.RegisterShard(50, 99, shardSet2) 117 | 118 | assert.Equal(shardSet1, wiz.Select(&myStruct{ID: 1})) 119 | assert.Equal(shardSet1, wiz.Select(&myStruct{ID: 49})) 120 | assert.Equal(shardSet2, wiz.Select(&myStruct{ID: 50})) 121 | assert.Equal(shardSet2, wiz.Select(&myStruct{ID: 99})) 122 | assert.Equal(shardSet1, wiz.Select(&myStruct{ID: 100})) 123 | assert.Equal(shardSet1, wiz.Select(&myStruct{ID: 149})) 124 | assert.Equal(shardSet2, wiz.Select(&myStruct{ID: 150})) 125 | assert.Equal(shardSet2, wiz.Select(&myStruct{ID: 199})) 126 | assert.Equal(shardSet1, wiz.Select(&myStruct{ID: 200})) 127 | 128 | // standard test 129 | c1 := wiz.CreateCluster("standard table", "db-master") 130 | c2 := wiz.Select("standard table") 131 | assert.Equal(c1, c2) 132 | 133 | c3 := wiz.CreateCluster(myStruct{}, "db-master") 134 | c4 := wiz.Select(&myStruct{ID: 100}) 135 | assert.Equal(c3, c4) 136 | 137 | // error 138 | _ = wiz.CreateShardCluster("shard table", 100) 139 | nilShard := wiz.Select("shard table") 140 | assert.Nil(nilShard, "Select() returns nil for shardcluster when obj does not contain shardkey") 141 | 142 | nilTable := wiz.Select("not registered") 143 | assert.Nil(nilTable, "Select() returns nil when table name does not registered") 144 | } 145 | 146 | func TestSelectByKey(t *testing.T) { 147 | assert := assert.New(t) 148 | 149 | wiz := NewWizard() 150 | c1 := wiz.CreateCluster("standard table", "db-master") 151 | c2 := wiz.SelectByKey("standard table", 1) 152 | c3 := wiz.SelectByKey("standard table", 99) 153 | assert.Equal(c1, c2) 154 | assert.Equal(c1, c3) 155 | 156 | // object test 157 | type myStruct struct { 158 | ID int64 `shard_key:"true"` 159 | } 160 | s1 := wiz.CreateShardCluster(myStruct{}, 100) 161 | shardSet1 := NewCluster("shard01-master") 162 | shardSet2 := NewCluster("shard02-master") 163 | s1.RegisterShard(0, 49, shardSet1) 164 | s1.RegisterShard(50, 99, shardSet2) 165 | 166 | assert.Equal(shardSet1, wiz.SelectByKey(&myStruct{ID: 99}, 1)) 167 | assert.Equal(shardSet1, wiz.SelectByKey(&myStruct{ID: 99}, 49)) 168 | assert.Equal(shardSet2, wiz.SelectByKey(&myStruct{ID: 99}, 50)) 169 | assert.Equal(shardSet2, wiz.SelectByKey(&myStruct{ID: 99}, 99)) 170 | assert.Equal(shardSet1, wiz.SelectByKey(&myStruct{ID: 99}, 100)) 171 | assert.Equal(shardSet1, wiz.SelectByKey(&myStruct{ID: 99}, 149)) 172 | assert.Equal(shardSet2, wiz.SelectByKey(&myStruct{ID: 99}, 150)) 173 | assert.Equal(shardSet2, wiz.SelectByKey(&myStruct{ID: 99}, 199)) 174 | assert.Equal(shardSet1, wiz.SelectByKey(&myStruct{ID: 99}, 200)) 175 | 176 | // non object test 177 | s2 := wiz.CreateShardCluster("shard table", 100) 178 | s2.RegisterShard(0, 49, NewCluster("x01-master")) 179 | s2.RegisterShard(50, 99, NewCluster("x02-master")) 180 | c4 := s2.SelectByKey(5000) 181 | c5 := wiz.SelectByKey("shard table", 5000) 182 | assert.Equal(c4, c5, "Select() returns nil for shardcluster when obj does not contain shardkey") 183 | 184 | // error 185 | nilTable := wiz.SelectByKey("not registered", 99) 186 | assert.Nil(nilTable, "Select() returns nil when table name does not registered") 187 | } 188 | --------------------------------------------------------------------------------