├── .gitignore ├── README.md ├── populate.sh ├── sharding.toml ├── launch.sh ├── config ├── config_test.go └── config.go ├── main.go ├── replication └── replication.go ├── web ├── web.go └── web_test.go ├── cmd └── bench │ └── main.go └── db ├── db_test.go └── db.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.db 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distrib KV 2 | Sources for the "distributed key-value database series" on YouTube: https://www.youtube.com/playlist?list=PLWwSgbaBp9XrMkjEhmTIC37WX2JfwZp7I 3 | -------------------------------------------------------------------------------- /populate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for shard in 127.0.0.2:8080; do 4 | echo $shard 5 | for i in {1..10000}; do 6 | curl "http://$shard/set?key=key-$RANDOM&value=value-$RANDOM" 7 | done 8 | done 9 | -------------------------------------------------------------------------------- /sharding.toml: -------------------------------------------------------------------------------- 1 | [[shards]] 2 | name = "Moscow" 3 | idx = 0 4 | address = "127.0.0.2:8080" 5 | replicas = ["127.0.0.22:8080"] 6 | 7 | [[shards]] 8 | name = "Minsk" 9 | idx = 1 10 | address = "127.0.0.3:8080" 11 | replicas = ["127.0.0.33:8080"] 12 | 13 | [[shards]] 14 | name = "Kiev" 15 | idx = 2 16 | address = "127.0.0.4:8080" 17 | replicas = ["127.0.0.44:8080"] 18 | 19 | [[shards]] 20 | name = "Tashkent" 21 | idx = 3 22 | address = "127.0.0.5:8080" 23 | replicas = ["127.0.0.55:8080"] 24 | -------------------------------------------------------------------------------- /launch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | trap 'killall distribkv' SIGINT 5 | 6 | cd $(dirname $0) 7 | 8 | killall distribkv || true 9 | sleep 0.1 10 | 11 | go install -v 12 | 13 | distribkv -db-location=moscow.db -http-addr=127.0.0.2:8080 -config-file=sharding.toml -shard=Moscow & 14 | distribkv -db-location=moscow-r.db -http-addr=127.0.0.22:8080 -config-file=sharding.toml -shard=Moscow -replica & 15 | 16 | distribkv -db-location=minsk.db -http-addr=127.0.0.3:8080 -config-file=sharding.toml -shard=Minsk & 17 | distribkv -db-location=minsk-r.db -http-addr=127.0.0.33:8080 -config-file=sharding.toml -shard=Minsk -replica & 18 | 19 | distribkv -db-location=kiev.db -http-addr=127.0.0.4:8080 -config-file=sharding.toml -shard=Kiev & 20 | distribkv -db-location=kiev-r.db -http-addr=127.0.0.44:8080 -config-file=sharding.toml -shard=Kiev -replica & 21 | 22 | distribkv -db-location=tashkent.db -http-addr=127.0.0.5:8080 -config-file=sharding.toml -shard=Tashkent & 23 | distribkv -db-location=tashkent-r.db -http-addr=127.0.0.55:8080 -config-file=sharding.toml -shard=Tashkent -replica & 24 | 25 | wait 26 | -------------------------------------------------------------------------------- /config/config_test.go: -------------------------------------------------------------------------------- 1 | package config_test 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "reflect" 7 | "testing" 8 | 9 | "github.com/YuriyNasretdinov/distribkv/config" 10 | ) 11 | 12 | func createConfig(t *testing.T, contents string) config.Config { 13 | t.Helper() 14 | 15 | f, err := ioutil.TempFile(os.TempDir(), "config.toml") 16 | if err != nil { 17 | t.Fatalf("Couldn't create a temp file: %v", err) 18 | } 19 | defer f.Close() 20 | 21 | name := f.Name() 22 | defer os.Remove(name) 23 | 24 | _, err = f.WriteString(contents) 25 | if err != nil { 26 | t.Fatalf("Could not write the config contents: %v", err) 27 | } 28 | 29 | c, err := config.ParseFile(name) 30 | if err != nil { 31 | t.Fatalf("Could not parse config: %v", err) 32 | } 33 | 34 | return c 35 | } 36 | 37 | func TestConfigParse(t *testing.T) { 38 | got := createConfig(t, `[[shards]] 39 | name = "Moscow" 40 | idx = 0 41 | address = "localhost:8080"`) 42 | 43 | want := config.Config{ 44 | Shards: []config.Shard{ 45 | { 46 | Name: "Moscow", 47 | Idx: 0, 48 | Address: "localhost:8080", 49 | }, 50 | }, 51 | } 52 | 53 | if !reflect.DeepEqual(got, want) { 54 | t.Errorf("The config does match: got: %#v, want: %#v", got, want) 55 | } 56 | } 57 | 58 | func TestParseShards(t *testing.T) { 59 | c := createConfig(t, ` 60 | [[shards]] 61 | name = "Moscow" 62 | idx = 0 63 | address = "localhost:8080" 64 | [[shards]] 65 | name = "Minsk" 66 | idx = 1 67 | address = "localhost:8081"`) 68 | 69 | got, err := config.ParseShards(c.Shards, "Minsk") 70 | if err != nil { 71 | t.Fatalf("Could not parse shards %#v: %v", c.Shards, err) 72 | } 73 | 74 | want := &config.Shards{ 75 | Count: 2, 76 | CurIdx: 1, 77 | Addrs: map[int]string{ 78 | 0: "localhost:8080", 79 | 1: "localhost:8081", 80 | }, 81 | } 82 | 83 | if !reflect.DeepEqual(got, want) { 84 | t.Errorf("The shards config does match: got: %#v, want: %#v", got, want) 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "hash/fnv" 6 | 7 | "github.com/BurntSushi/toml" 8 | ) 9 | 10 | // Shard describes a shard that holds the appropriate set of keys. 11 | // Each shard has unique set of keys. 12 | type Shard struct { 13 | Name string 14 | Idx int 15 | Address string 16 | } 17 | 18 | // Config describes the sharding config. 19 | type Config struct { 20 | Shards []Shard 21 | } 22 | 23 | // ParseFile parses the config and returns it upon success. 24 | func ParseFile(filename string) (Config, error) { 25 | var c Config 26 | if _, err := toml.DecodeFile(filename, &c); err != nil { 27 | return Config{}, err 28 | } 29 | return c, nil 30 | } 31 | 32 | // Shards represents an easier-to-use representation of 33 | // the sharding config: the shards count, current index and 34 | // the addresses of all other shards too. 35 | type Shards struct { 36 | Count int 37 | CurIdx int 38 | Addrs map[int]string 39 | } 40 | 41 | // ParseShards converts and verifies the list of shards 42 | // specified in the config into a form that can be used 43 | // for routing. 44 | func ParseShards(shards []Shard, curShardName string) (*Shards, error) { 45 | shardCount := len(shards) 46 | shardIdx := -1 47 | addrs := make(map[int]string) 48 | 49 | for _, s := range shards { 50 | if _, ok := addrs[s.Idx]; ok { 51 | return nil, fmt.Errorf("duplicate shard index: %d", s.Idx) 52 | } 53 | 54 | addrs[s.Idx] = s.Address 55 | if s.Name == curShardName { 56 | shardIdx = s.Idx 57 | } 58 | } 59 | 60 | for i := 0; i < shardCount; i++ { 61 | if _, ok := addrs[i]; !ok { 62 | return nil, fmt.Errorf("shard %d is not found", i) 63 | } 64 | } 65 | 66 | if shardIdx < 0 { 67 | return nil, fmt.Errorf("shard %q was not found", curShardName) 68 | } 69 | 70 | return &Shards{ 71 | Addrs: addrs, 72 | Count: shardCount, 73 | CurIdx: shardIdx, 74 | }, nil 75 | } 76 | 77 | // Index returns the shard number for the corresponding key. 78 | func (s *Shards) Index(key string) int { 79 | h := fnv.New64() 80 | h.Write([]byte(key)) 81 | return int(h.Sum64() % uint64(s.Count)) 82 | } 83 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "net/http" 7 | 8 | "github.com/YuriyNasretdinov/distribkv/config" 9 | "github.com/YuriyNasretdinov/distribkv/db" 10 | "github.com/YuriyNasretdinov/distribkv/replication" 11 | "github.com/YuriyNasretdinov/distribkv/web" 12 | ) 13 | 14 | var ( 15 | dbLocation = flag.String("db-location", "", "The path to the bolt db database") 16 | httpAddr = flag.String("http-addr", "127.0.0.1:8080", "HTTP host and port") 17 | configFile = flag.String("config-file", "sharding.toml", "Config file for static sharding") 18 | shard = flag.String("shard", "", "The name of the shard for the data") 19 | replica = flag.Bool("replica", false, "Whether or not run as a read-only replica") 20 | ) 21 | 22 | func parseFlags() { 23 | flag.Parse() 24 | 25 | if *dbLocation == "" { 26 | log.Fatalf("Must provide db-location") 27 | } 28 | 29 | if *shard == "" { 30 | log.Fatalf("Must provide shard") 31 | } 32 | } 33 | 34 | func main() { 35 | parseFlags() 36 | 37 | c, err := config.ParseFile(*configFile) 38 | if err != nil { 39 | log.Fatalf("Error parsing config %q: %v", *configFile, err) 40 | } 41 | 42 | shards, err := config.ParseShards(c.Shards, *shard) 43 | if err != nil { 44 | log.Fatalf("Error parsing shards config: %v", err) 45 | } 46 | 47 | log.Printf("Shard count is %d, current shard: %d", shards.Count, shards.CurIdx) 48 | 49 | db, close, err := db.NewDatabase(*dbLocation, *replica) 50 | if err != nil { 51 | log.Fatalf("Error creating %q: %v", *dbLocation, err) 52 | } 53 | defer close() 54 | 55 | if *replica { 56 | leaderAddr, ok := shards.Addrs[shards.CurIdx] 57 | if !ok { 58 | log.Fatalf("Could not find address for leader for shard %d", shards.CurIdx) 59 | } 60 | go replication.ClientLoop(db, leaderAddr) 61 | } 62 | 63 | srv := web.NewServer(db, shards) 64 | 65 | http.HandleFunc("/get", srv.GetHandler) 66 | http.HandleFunc("/set", srv.SetHandler) 67 | http.HandleFunc("/purge", srv.DeleteExtraKeysHandler) 68 | http.HandleFunc("/next-replication-key", srv.GetNextKeyForReplication) 69 | http.HandleFunc("/delete-replication-key", srv.DeleteReplicationKey) 70 | 71 | log.Fatal(http.ListenAndServe(*httpAddr, nil)) 72 | } 73 | -------------------------------------------------------------------------------- /replication/replication.go: -------------------------------------------------------------------------------- 1 | package replication 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "io/ioutil" 8 | "log" 9 | "net/http" 10 | "net/url" 11 | "time" 12 | 13 | "github.com/YuriyNasretdinov/distribkv/db" 14 | ) 15 | 16 | // NextKeyValue contains the response for GetNextKeyForReplication. 17 | type NextKeyValue struct { 18 | Key string 19 | Value string 20 | Err error 21 | } 22 | 23 | type client struct { 24 | db *db.Database 25 | leaderAddr string 26 | } 27 | 28 | // ClientLoop continuously downloads new keys from the master and applies them. 29 | func ClientLoop(db *db.Database, leaderAddr string) { 30 | c := &client{db: db, leaderAddr: leaderAddr} 31 | for { 32 | present, err := c.loop() 33 | if err != nil { 34 | log.Printf("Loop error: %v", err) 35 | time.Sleep(time.Second) 36 | continue 37 | } 38 | 39 | if !present { 40 | time.Sleep(time.Millisecond * 100) 41 | } 42 | } 43 | } 44 | 45 | func (c *client) loop() (present bool, err error) { 46 | resp, err := http.Get("http://" + c.leaderAddr + "/next-replication-key") 47 | if err != nil { 48 | return false, err 49 | } 50 | 51 | var res NextKeyValue 52 | if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { 53 | return false, err 54 | } 55 | defer resp.Body.Close() 56 | 57 | if res.Err != nil { 58 | return false, err 59 | } 60 | 61 | if res.Key == "" { 62 | return false, nil 63 | } 64 | 65 | if err := c.db.SetKeyOnReplica(res.Key, []byte(res.Value)); err != nil { 66 | return false, err 67 | } 68 | 69 | if err := c.deleteFromReplicationQueue(res.Key, res.Value); err != nil { 70 | log.Printf("DeleteKeyFromReplication failed: %v", err) 71 | } 72 | 73 | return true, nil 74 | } 75 | 76 | func (c *client) deleteFromReplicationQueue(key, value string) error { 77 | u := url.Values{} 78 | u.Set("key", key) 79 | u.Set("value", value) 80 | 81 | log.Printf("Deleting key=%q, value=%q from replication queue on %q", key, value, c.leaderAddr) 82 | 83 | resp, err := http.Get("http://" + c.leaderAddr + "/delete-replication-key?" + u.Encode()) 84 | if err != nil { 85 | return err 86 | } 87 | defer resp.Body.Close() 88 | 89 | result, err := ioutil.ReadAll(resp.Body) 90 | if err != nil { 91 | return err 92 | } 93 | 94 | if !bytes.Equal(result, []byte("ok")) { 95 | return errors.New(string(result)) 96 | } 97 | 98 | return nil 99 | } 100 | -------------------------------------------------------------------------------- /web/web.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | 9 | "github.com/YuriyNasretdinov/distribkv/config" 10 | "github.com/YuriyNasretdinov/distribkv/db" 11 | "github.com/YuriyNasretdinov/distribkv/replication" 12 | ) 13 | 14 | // Server contains HTTP method handlers to be used for the database. 15 | type Server struct { 16 | db *db.Database 17 | shards *config.Shards 18 | } 19 | 20 | // NewServer creates a new instance with HTTP handlers to be used to get and set values. 21 | func NewServer(db *db.Database, s *config.Shards) *Server { 22 | return &Server{ 23 | db: db, 24 | shards: s, 25 | } 26 | } 27 | 28 | func (s *Server) redirect(shard int, w http.ResponseWriter, r *http.Request) { 29 | url := "http://" + s.shards.Addrs[shard] + r.RequestURI 30 | fmt.Fprintf(w, "redirecting from shard %d to shard %d (%q)\n", s.shards.CurIdx, shard, url) 31 | 32 | resp, err := http.Get(url) 33 | if err != nil { 34 | w.WriteHeader(500) 35 | fmt.Fprintf(w, "Error redirecting the request: %v", err) 36 | return 37 | } 38 | defer resp.Body.Close() 39 | 40 | io.Copy(w, resp.Body) 41 | } 42 | 43 | // GetHandler handles read requests from the database. 44 | func (s *Server) GetHandler(w http.ResponseWriter, r *http.Request) { 45 | r.ParseForm() 46 | key := r.Form.Get("key") 47 | 48 | shard := s.shards.Index(key) 49 | 50 | if shard != s.shards.CurIdx { 51 | s.redirect(shard, w, r) 52 | return 53 | } 54 | 55 | value, err := s.db.GetKey(key) 56 | 57 | fmt.Fprintf(w, "Shard = %d, current shard = %d, addr = %q, Value = %q, error = %v", shard, s.shards.CurIdx, s.shards.Addrs[shard], value, err) 58 | } 59 | 60 | // SetHandler handles write requests from the database. 61 | func (s *Server) SetHandler(w http.ResponseWriter, r *http.Request) { 62 | r.ParseForm() 63 | key := r.Form.Get("key") 64 | value := r.Form.Get("value") 65 | 66 | shard := s.shards.Index(key) 67 | if shard != s.shards.CurIdx { 68 | s.redirect(shard, w, r) 69 | return 70 | } 71 | 72 | err := s.db.SetKey(key, []byte(value)) 73 | fmt.Fprintf(w, "Error = %v, shardIdx = %d, current shard = %d", err, shard, s.shards.CurIdx) 74 | } 75 | 76 | // DeleteExtraKeysHandler deletes keys that don't belong to the current shard. 77 | func (s *Server) DeleteExtraKeysHandler(w http.ResponseWriter, r *http.Request) { 78 | fmt.Fprintf(w, "Error = %v", s.db.DeleteExtraKeys(func(key string) bool { 79 | return s.shards.Index(key) != s.shards.CurIdx 80 | })) 81 | } 82 | 83 | // GetNextKeyForReplication returns the next key for replication. 84 | func (s *Server) GetNextKeyForReplication(w http.ResponseWriter, r *http.Request) { 85 | enc := json.NewEncoder(w) 86 | k, v, err := s.db.GetNextKeyForReplication() 87 | enc.Encode(&replication.NextKeyValue{ 88 | Key: string(k), 89 | Value: string(v), 90 | Err: err, 91 | }) 92 | } 93 | 94 | // DeleteReplicationKey deletes the key from replica queue. 95 | func (s *Server) DeleteReplicationKey(w http.ResponseWriter, r *http.Request) { 96 | r.ParseForm() 97 | 98 | key := r.Form.Get("key") 99 | value := r.Form.Get("value") 100 | 101 | err := s.db.DeleteReplicationKey([]byte(key), []byte(value)) 102 | if err != nil { 103 | w.WriteHeader(http.StatusExpectationFailed) 104 | fmt.Fprintf(w, "error: %v", err) 105 | return 106 | } 107 | 108 | fmt.Fprintf(w, "ok") 109 | } 110 | -------------------------------------------------------------------------------- /cmd/bench/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "io" 7 | "io/ioutil" 8 | "log" 9 | "math/rand" 10 | "net/http" 11 | "net/url" 12 | "sync" 13 | "time" 14 | ) 15 | 16 | var ( 17 | addr = flag.String("addr", "localhost:8080", "The HTTP host port for the instance that is benchmarked.") 18 | iterations = flag.Int("iterations", 1000, "The number of iterations for writing") 19 | readIterations = flag.Int("read-iterations", 100000, "The number of iterations for reading") 20 | concurrency = flag.Int("concurrency", 1, "How many goroutines to run in parallel when doing writes") 21 | ) 22 | 23 | var httpClient = &http.Client{ 24 | Transport: &http.Transport{ 25 | IdleConnTimeout: time.Second * 60, 26 | MaxIdleConns: 300, 27 | MaxConnsPerHost: 300, 28 | MaxIdleConnsPerHost: 300, 29 | }, 30 | } 31 | 32 | func benchmark(name string, iter int, fn func() string) (qps float64, strs []string) { 33 | var max time.Duration 34 | var min = time.Hour 35 | 36 | start := time.Now() 37 | for i := 0; i < iter; i++ { 38 | iterStart := time.Now() 39 | strs = append(strs, fn()) 40 | iterTime := time.Since(iterStart) 41 | if iterTime > max { 42 | max = iterTime 43 | } 44 | if iterTime < min { 45 | min = iterTime 46 | } 47 | } 48 | 49 | avg := time.Since(start) / time.Duration(iter) 50 | qps = float64(iter) / (float64(time.Since(start)) / float64(time.Second)) 51 | fmt.Printf("Func %s took %s avg, %.1f QPS, %s max, %s min\n", name, avg, qps, max, min) 52 | 53 | return qps, strs 54 | } 55 | 56 | func writeRand() (key string) { 57 | key = fmt.Sprintf("key-%d", rand.Intn(1000000)) 58 | value := fmt.Sprintf("value-%d", rand.Intn(1000000)) 59 | 60 | values := url.Values{} 61 | values.Set("key", key) 62 | values.Set("value", value) 63 | 64 | resp, err := httpClient.Get("http://" + (*addr) + "/set?" + values.Encode()) 65 | if err != nil { 66 | log.Fatalf("Error during set: %v", err) 67 | } 68 | 69 | io.Copy(ioutil.Discard, resp.Body) 70 | defer resp.Body.Close() 71 | 72 | return key 73 | } 74 | 75 | func readRand(allKeys []string) (key string) { 76 | key = allKeys[rand.Intn(len(allKeys))] 77 | 78 | values := url.Values{} 79 | values.Set("key", key) 80 | 81 | resp, err := httpClient.Get("http://" + (*addr) + "/get?" + values.Encode()) 82 | if err != nil { 83 | log.Fatalf("Error during get: %v", err) 84 | } 85 | io.Copy(ioutil.Discard, resp.Body) 86 | defer resp.Body.Close() 87 | 88 | return key 89 | } 90 | 91 | func benchmarkWrite() (allKeys []string) { 92 | var wg sync.WaitGroup 93 | var mu sync.Mutex 94 | var totalQPS float64 95 | 96 | for i := 0; i < *concurrency; i++ { 97 | wg.Add(1) 98 | go func() { 99 | qps, strs := benchmark("write", *iterations, writeRand) 100 | mu.Lock() 101 | totalQPS += qps 102 | allKeys = append(allKeys, strs...) 103 | mu.Unlock() 104 | 105 | wg.Done() 106 | }() 107 | } 108 | 109 | wg.Wait() 110 | 111 | log.Printf("Write total QPS: %.1f, set %d keys", totalQPS, len(allKeys)) 112 | 113 | return allKeys 114 | } 115 | 116 | func benchmarkRead(allKeys []string) { 117 | var totalQPS float64 118 | var mu sync.Mutex 119 | var wg sync.WaitGroup 120 | 121 | for i := 0; i < *concurrency; i++ { 122 | wg.Add(1) 123 | go func() { 124 | qps, _ := benchmark("read", *readIterations, func() string { return readRand(allKeys) }) 125 | mu.Lock() 126 | totalQPS += qps 127 | mu.Unlock() 128 | 129 | wg.Done() 130 | }() 131 | } 132 | 133 | wg.Wait() 134 | 135 | log.Printf("Read total QPS: %.1f", totalQPS) 136 | } 137 | 138 | func main() { 139 | rand.Seed(time.Now().UnixNano()) 140 | flag.Parse() 141 | 142 | fmt.Printf("Running with %d iterations and concurrency level %d\n", *iterations, *concurrency) 143 | 144 | allKeys := benchmarkWrite() 145 | 146 | go benchmarkWrite() 147 | benchmarkRead(allKeys) 148 | } 149 | -------------------------------------------------------------------------------- /db/db_test.go: -------------------------------------------------------------------------------- 1 | package db_test 2 | 3 | import ( 4 | "bytes" 5 | "io/ioutil" 6 | "os" 7 | "testing" 8 | 9 | "github.com/YuriyNasretdinov/distribkv/db" 10 | ) 11 | 12 | func createTempDb(t *testing.T, readOnly bool) *db.Database { 13 | t.Helper() 14 | 15 | f, err := ioutil.TempFile(os.TempDir(), "kvdb") 16 | if err != nil { 17 | t.Fatalf("Could not create temp file: %v", err) 18 | } 19 | name := f.Name() 20 | f.Close() 21 | t.Cleanup(func() { os.Remove(name) }) 22 | 23 | db, closeFunc, err := db.NewDatabase(name, readOnly) 24 | if err != nil { 25 | t.Fatalf("Could not create a new database: %v", err) 26 | } 27 | t.Cleanup(func() { closeFunc() }) 28 | 29 | return db 30 | } 31 | 32 | func TestGetSet(t *testing.T) { 33 | db := createTempDb(t, false) 34 | 35 | if err := db.SetKey("party", []byte("Great")); err != nil { 36 | t.Fatalf("Could not write key: %v", err) 37 | } 38 | 39 | value, err := db.GetKey("party") 40 | if err != nil { 41 | t.Fatalf(`Could not get the key "party": %v`, err) 42 | } 43 | 44 | if !bytes.Equal(value, []byte("Great")) { 45 | t.Errorf(`Unexpected value for key "party": got %q, want %q`, value, "Great") 46 | } 47 | 48 | k, v, err := db.GetNextKeyForReplication() 49 | if err != nil { 50 | t.Fatalf(`Unexpected error for GetNextKeyForReplication(): %v`, err) 51 | } 52 | 53 | if !bytes.Equal(k, []byte("party")) || !bytes.Equal(v, []byte("Great")) { 54 | t.Errorf(`GetNextKeyForReplication(): got %q, %q; want %q, %q`, k, v, "party", "Great") 55 | } 56 | } 57 | 58 | func TestDeleteReplicationKey(t *testing.T) { 59 | db := createTempDb(t, false) 60 | 61 | setKey(t, db, "party", "Great") 62 | 63 | k, v, err := db.GetNextKeyForReplication() 64 | if err != nil { 65 | t.Fatalf(`Unexpected error for GetNextKeyForReplication(): %v`, err) 66 | } 67 | 68 | if !bytes.Equal(k, []byte("party")) || !bytes.Equal(v, []byte("Great")) { 69 | t.Errorf(`GetNextKeyForReplication(): got %q, %q; want %q, %q`, k, v, "party", "Great") 70 | } 71 | 72 | if err := db.DeleteReplicationKey([]byte("party"), []byte("Bad")); err == nil { 73 | t.Fatalf(`DeleteReplicationKey("party", "Bad"): got nil error, want non-nil error`) 74 | } 75 | 76 | if err := db.DeleteReplicationKey([]byte("party"), []byte("Great")); err != nil { 77 | t.Fatalf(`DeleteReplicationKey("party", "Great"): got %q, want nil error`, err) 78 | } 79 | 80 | k, v, err = db.GetNextKeyForReplication() 81 | if err != nil { 82 | t.Fatalf(`Unexpected error for GetNextKeyForReplication(): %v`, err) 83 | } 84 | 85 | if k != nil || v != nil { 86 | t.Errorf(`GetNextKeyForReplication(): got %v, %v; want nil, nil`, k, v) 87 | } 88 | } 89 | 90 | func TestSetReadOnly(t *testing.T) { 91 | db := createTempDb(t, true) 92 | 93 | if err := db.SetKey("party", []byte("Bad")); err == nil { 94 | t.Fatalf("SetKey(%q, %q): got nil error, want non-nil error", "party", []byte("Bad")) 95 | } 96 | } 97 | 98 | func setKey(t *testing.T, d *db.Database, key, value string) { 99 | t.Helper() 100 | 101 | if err := d.SetKey(key, []byte(value)); err != nil { 102 | t.Fatalf("SetKey(%q, %q) failed: %v", key, value, err) 103 | } 104 | } 105 | 106 | func getKey(t *testing.T, d *db.Database, key string) string { 107 | t.Helper() 108 | 109 | value, err := d.GetKey(key) 110 | if err != nil { 111 | t.Fatalf("GetKey(%q) failed: %v", key, err) 112 | } 113 | 114 | return string(value) 115 | } 116 | 117 | func TestDeleteExtraKeys(t *testing.T) { 118 | db := createTempDb(t, false) 119 | 120 | setKey(t, db, "party", "Great") 121 | setKey(t, db, "us", "CapitalistPigs") 122 | 123 | if err := db.DeleteExtraKeys(func(name string) bool { return name == "us" }); err != nil { 124 | t.Fatalf("Could not delete extra keys: %v", err) 125 | } 126 | 127 | if value := getKey(t, db, "party"); value != "Great" { 128 | t.Errorf(`Unexpected value for key "party": got %q, want %q`, value, "Great") 129 | } 130 | 131 | if value := getKey(t, db, "us"); value != "" { 132 | t.Errorf(`Unexpected value for key "us": got %q, want %q`, value, "") 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /web/web_test.go: -------------------------------------------------------------------------------- 1 | package web_test 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "net/http" 9 | "net/http/httptest" 10 | "os" 11 | "strings" 12 | "testing" 13 | 14 | "github.com/YuriyNasretdinov/distribkv/config" 15 | "github.com/YuriyNasretdinov/distribkv/web" 16 | 17 | "github.com/YuriyNasretdinov/distribkv/db" 18 | ) 19 | 20 | func createShardDb(t *testing.T, idx int) *db.Database { 21 | t.Helper() 22 | 23 | tmpFile, err := ioutil.TempFile(os.TempDir(), fmt.Sprintf("db%d", idx)) 24 | if err != nil { 25 | t.Fatalf("Could not create a temp db %d: %v", idx, err) 26 | } 27 | 28 | tmpFile.Close() 29 | 30 | name := tmpFile.Name() 31 | t.Cleanup(func() { os.Remove(name) }) 32 | 33 | db, closeFunc, err := db.NewDatabase(name, false) 34 | if err != nil { 35 | t.Fatalf("Could not create new database %q: %v", name, err) 36 | } 37 | t.Cleanup(func() { closeFunc() }) 38 | 39 | return db 40 | } 41 | 42 | func createShardServer(t *testing.T, idx int, addrs map[int]string) (*db.Database, *web.Server) { 43 | t.Helper() 44 | 45 | db := createShardDb(t, idx) 46 | 47 | cfg := &config.Shards{ 48 | Addrs: addrs, 49 | Count: len(addrs), 50 | CurIdx: idx, 51 | } 52 | 53 | s := web.NewServer(db, cfg) 54 | return db, s 55 | } 56 | 57 | func TestWebServer(t *testing.T) { 58 | var ts1GetHandler, ts1SetHandler func(w http.ResponseWriter, r *http.Request) 59 | var ts2GetHandler, ts2SetHandler func(w http.ResponseWriter, r *http.Request) 60 | 61 | ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 62 | if strings.HasPrefix(r.RequestURI, "/get") { 63 | ts1GetHandler(w, r) 64 | } else if strings.HasPrefix(r.RequestURI, "/set") { 65 | ts1SetHandler(w, r) 66 | } 67 | })) 68 | defer ts1.Close() 69 | 70 | ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 71 | if strings.HasPrefix(r.RequestURI, "/get") { 72 | ts2GetHandler(w, r) 73 | } else if strings.HasPrefix(r.RequestURI, "/set") { 74 | ts2SetHandler(w, r) 75 | } 76 | })) 77 | defer ts2.Close() 78 | 79 | addrs := map[int]string{ 80 | 0: strings.TrimPrefix(ts1.URL, "http://"), 81 | 1: strings.TrimPrefix(ts2.URL, "http://"), 82 | } 83 | 84 | db1, web1 := createShardServer(t, 0, addrs) 85 | db2, web2 := createShardServer(t, 1, addrs) 86 | 87 | // Calculated manually and depends on the sharding function. 88 | keys := map[string]int{ 89 | "Soviet": 1, 90 | "USA": 0, 91 | } 92 | 93 | ts1GetHandler = web1.GetHandler 94 | ts1SetHandler = web1.SetHandler 95 | ts2GetHandler = web2.GetHandler 96 | ts2SetHandler = web2.SetHandler 97 | 98 | for key := range keys { 99 | // Send all to first shard to test redirects. 100 | _, err := http.Get(fmt.Sprintf(ts1.URL+"/set?key=%s&value=value-%s", key, key)) 101 | if err != nil { 102 | t.Fatalf("Could not set the key %q: %v", key, err) 103 | } 104 | } 105 | 106 | for key := range keys { 107 | // Send all to first shard to test redirects. 108 | resp, err := http.Get(fmt.Sprintf(ts1.URL+"/get?key=%s", key)) 109 | if err != nil { 110 | t.Fatalf("Get key %q error: %v", key, err) 111 | } 112 | contents, err := ioutil.ReadAll(resp.Body) 113 | if err != nil { 114 | t.Fatalf("Could read contents of the key %q: %v", key, err) 115 | } 116 | 117 | want := []byte("value-" + key) 118 | if !bytes.Contains(contents, want) { 119 | t.Errorf("Unexpected contents of the key %q: got %q, want the result to contain %q", key, contents, want) 120 | } 121 | 122 | log.Printf("Contents of key %q: %s", key, contents) 123 | } 124 | 125 | value1, err := db1.GetKey("USA") 126 | if err != nil { 127 | t.Fatalf("USA key error: %v", err) 128 | } 129 | 130 | want1 := "value-USA" 131 | if !bytes.Equal(value1, []byte(want1)) { 132 | t.Errorf("Unexpected value of USA key: got %q, want %q", value1, want1) 133 | } 134 | 135 | value2, err := db2.GetKey("Soviet") 136 | if err != nil { 137 | t.Fatalf("Soviet key error: %v", err) 138 | } 139 | 140 | want2 := "value-Soviet" 141 | if !bytes.Equal(value2, []byte(want2)) { 142 | t.Errorf("Unexpected value of Soviet key: got %q, want %q", value2, want2) 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /db/db.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | 8 | bolt "go.etcd.io/bbolt" 9 | ) 10 | 11 | var defaultBucket = []byte("default") 12 | var replicaBucket = []byte("replication") 13 | 14 | // Database is an open bolt database. 15 | type Database struct { 16 | db *bolt.DB 17 | readOnly bool 18 | } 19 | 20 | // NewDatabase returns an instance of a database that we can work with. 21 | func NewDatabase(dbPath string, readOnly bool) (db *Database, closeFunc func() error, err error) { 22 | boltDb, err := bolt.Open(dbPath, 0600, nil) 23 | if err != nil { 24 | return nil, nil, err 25 | } 26 | 27 | db = &Database{db: boltDb, readOnly: readOnly} 28 | closeFunc = boltDb.Close 29 | 30 | if err := db.createBuckets(); err != nil { 31 | closeFunc() 32 | return nil, nil, fmt.Errorf("creating default bucket: %w", err) 33 | } 34 | 35 | return db, closeFunc, nil 36 | } 37 | 38 | func (d *Database) createBuckets() error { 39 | return d.db.Update(func(tx *bolt.Tx) error { 40 | if _, err := tx.CreateBucketIfNotExists(defaultBucket); err != nil { 41 | return err 42 | } 43 | if _, err := tx.CreateBucketIfNotExists(replicaBucket); err != nil { 44 | return err 45 | } 46 | return nil 47 | }) 48 | } 49 | 50 | // SetKey sets the key to the requested value into the default database or returns an error. 51 | func (d *Database) SetKey(key string, value []byte) error { 52 | if d.readOnly { 53 | return errors.New("read-only mode") 54 | } 55 | 56 | return d.db.Update(func(tx *bolt.Tx) error { 57 | if err := tx.Bucket(defaultBucket).Put([]byte(key), value); err != nil { 58 | return err 59 | } 60 | 61 | return tx.Bucket(replicaBucket).Put([]byte(key), value) 62 | }) 63 | } 64 | 65 | // SetKeyOnReplica sets the key to the requested value into the default database and does not write 66 | // to the replication queue. 67 | // This method is intended to be used only on replicas. 68 | func (d *Database) SetKeyOnReplica(key string, value []byte) error { 69 | return d.db.Update(func(tx *bolt.Tx) error { 70 | return tx.Bucket(defaultBucket).Put([]byte(key), value) 71 | }) 72 | } 73 | 74 | func copyByteSlice(b []byte) []byte { 75 | if b == nil { 76 | return nil 77 | } 78 | res := make([]byte, len(b)) 79 | copy(res, b) 80 | return res 81 | } 82 | 83 | // GetNextKeyForReplication returns the key and value for the keys that have 84 | // changed and have not yet been applied to replicas. 85 | // If there are no new keys, nil key and value will be returned. 86 | func (d *Database) GetNextKeyForReplication() (key, value []byte, err error) { 87 | err = d.db.View(func(tx *bolt.Tx) error { 88 | b := tx.Bucket(replicaBucket) 89 | k, v := b.Cursor().First() 90 | key = copyByteSlice(k) 91 | value = copyByteSlice(v) 92 | return nil 93 | }) 94 | 95 | if err != nil { 96 | return nil, nil, err 97 | } 98 | 99 | return key, value, nil 100 | } 101 | 102 | // DeleteReplicationKey deletes the key from the replication queue 103 | // if the value matches the contents or if the key is already absent. 104 | func (d *Database) DeleteReplicationKey(key, value []byte) (err error) { 105 | return d.db.Update(func(tx *bolt.Tx) error { 106 | b := tx.Bucket(replicaBucket) 107 | 108 | v := b.Get(key) 109 | if v == nil { 110 | return errors.New("key does not exist") 111 | } 112 | 113 | if !bytes.Equal(v, value) { 114 | return errors.New("value does not match") 115 | } 116 | 117 | return b.Delete(key) 118 | }) 119 | } 120 | 121 | // GetKey get the value of the requested from a default database. 122 | func (d *Database) GetKey(key string) ([]byte, error) { 123 | var result []byte 124 | err := d.db.View(func(tx *bolt.Tx) error { 125 | b := tx.Bucket(defaultBucket) 126 | result = copyByteSlice(b.Get([]byte(key))) 127 | return nil 128 | }) 129 | 130 | if err == nil { 131 | return result, nil 132 | } 133 | return nil, err 134 | } 135 | 136 | // DeleteExtraKeys deletes the keys that do not belong to this shard. 137 | func (d *Database) DeleteExtraKeys(isExtra func(string) bool) error { 138 | var keys []string 139 | 140 | err := d.db.View(func(tx *bolt.Tx) error { 141 | b := tx.Bucket(defaultBucket) 142 | return b.ForEach(func(k, v []byte) error { 143 | ks := string(k) 144 | if isExtra(ks) { 145 | keys = append(keys, ks) 146 | } 147 | return nil 148 | }) 149 | }) 150 | 151 | if err != nil { 152 | return err 153 | } 154 | 155 | return d.db.Update(func(tx *bolt.Tx) error { 156 | b := tx.Bucket(defaultBucket) 157 | 158 | for _, k := range keys { 159 | if err := b.Delete([]byte(k)); err != nil { 160 | return err 161 | } 162 | } 163 | return nil 164 | }) 165 | } 166 | --------------------------------------------------------------------------------