├── simulator ├── doc.go ├── switcher_test.go ├── conn_mat_test.go ├── conn_mat.go ├── switcher.go ├── events_test.go ├── network_test.go ├── events.go └── network.go ├── go.mod ├── collcomm ├── allreduce │ ├── naive_test.go │ ├── tree_test.go │ ├── stream_test.go │ ├── allreduce.go │ ├── naive.go │ ├── tree.go │ ├── tester.go │ ├── bench_allreduce │ │ └── main.go │ └── stream.go ├── reduce_fn.go └── comms.go ├── go.sum ├── disthash ├── consistent_test.go ├── disthash.go ├── hash.go ├── consistent.go └── tester.go ├── raft ├── state_machines.go ├── raft.go ├── log.go ├── candidate.go ├── client.go ├── messages.go ├── follower.go ├── leader.go └── raft_test.go ├── paxos ├── basic_paxos_test.go └── basic_paxos.go └── README.md /simulator/doc.go: -------------------------------------------------------------------------------- 1 | // Package simulator simulates a computer network. 2 | package simulator 3 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/unixpickle/dist-sys 2 | 3 | go 1.19 4 | 5 | require github.com/unixpickle/essentials v1.3.0 6 | 7 | require github.com/google/uuid v1.3.1 8 | -------------------------------------------------------------------------------- /collcomm/allreduce/naive_test.go: -------------------------------------------------------------------------------- 1 | package allreduce 2 | 3 | import "testing" 4 | 5 | func TestNaiveAllreducer(t *testing.T) { 6 | RunAllreducerTests(t, NaiveAllreducer{}) 7 | } 8 | -------------------------------------------------------------------------------- /collcomm/allreduce/tree_test.go: -------------------------------------------------------------------------------- 1 | package allreduce 2 | 3 | import "testing" 4 | 5 | func TestTreeAllreducer(t *testing.T) { 6 | RunAllreducerTests(t, TreeAllreducer{}) 7 | } 8 | -------------------------------------------------------------------------------- /collcomm/allreduce/stream_test.go: -------------------------------------------------------------------------------- 1 | package allreduce 2 | 3 | import "testing" 4 | 5 | func TestStreamAllreducer(t *testing.T) { 6 | RunAllreducerTests(t, StreamAllreducer{}) 7 | } 8 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= 2 | github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 3 | github.com/unixpickle/essentials v1.3.0 h1:H258Z5Uo1pVzFjxD2rwFWzHPN3s0J0jLs5kuxTRSfCs= 4 | github.com/unixpickle/essentials v1.3.0/go.mod h1:dQ1idvqrgrDgub3mfckQm7osVPzT3u9rB6NK/LEhmtQ= 5 | -------------------------------------------------------------------------------- /disthash/consistent_test.go: -------------------------------------------------------------------------------- 1 | package disthash 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestConsistent(t *testing.T) { 9 | for _, size := range []int{1, 10, 200} { 10 | t.Run(fmt.Sprintf("Size%d", size), func(t *testing.T) { 11 | TestDistHash(t, func() DistHash { 12 | return NewConsistent(size) 13 | }) 14 | }) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /collcomm/allreduce/allreduce.go: -------------------------------------------------------------------------------- 1 | // Package allreduce implements algorithms for summing or 2 | // maxing vectors across many different connected Nodes. 3 | package allreduce 4 | 5 | import "github.com/unixpickle/dist-sys/collcomm" 6 | 7 | // Allreducer is an algorithm that can apply a ReduceFn to 8 | // vectors that are distributed across nodes. 9 | // 10 | // It is not safe to call Allreduce() multiple times in a 11 | // row with the same Comms object. 12 | // A new set of ports must be used every time to avoid 13 | // interference. 14 | type Allreducer interface { 15 | Allreduce(c *collcomm.Comms, data []float64, fn collcomm.ReduceFn) []float64 16 | } 17 | -------------------------------------------------------------------------------- /collcomm/allreduce/naive.go: -------------------------------------------------------------------------------- 1 | package allreduce 2 | 3 | import "github.com/unixpickle/dist-sys/collcomm" 4 | 5 | // A NaiveAllreducer sends every gradient from every node 6 | // to every other node. 7 | type NaiveAllreducer struct{} 8 | 9 | // Allreduce runs fn() on all of the nodes' vectors on 10 | // every node. 11 | func (n NaiveAllreducer) Allreduce(c *collcomm.Comms, data []float64, 12 | fn collcomm.ReduceFn) []float64 { 13 | gatheredVecs := make([][]float64, len(c.Ports)) 14 | 15 | c.Bcast(data) 16 | 17 | for i := 0; i < len(gatheredVecs)-1; i++ { 18 | incoming, source := c.Recv() 19 | gatheredVecs[c.IndexOf(source)] = incoming 20 | } 21 | 22 | gatheredVecs[c.Index()] = data 23 | 24 | return fn(c.Handle, gatheredVecs...) 25 | } 26 | -------------------------------------------------------------------------------- /disthash/disthash.go: -------------------------------------------------------------------------------- 1 | // Package disthash implements distributed hashing 2 | // algorithms. 3 | package disthash 4 | 5 | import ( 6 | "github.com/unixpickle/dist-sys/simulator" 7 | ) 8 | 9 | // A DistHash is a distributed hashing algorithm. 10 | // It assigns different keys to different nodes. 11 | type DistHash interface { 12 | // AddSite adds a bucket to the hash table. 13 | AddSite(node *simulator.Node, id Hasher) 14 | 15 | // RemoveSite removes a bucket from the hash table. 16 | RemoveSite(node *simulator.Node) 17 | 18 | // Sites returns all buckets in the hash table. 19 | Sites() []*simulator.Node 20 | 21 | // KeySites returns the buckets where a key is stored. 22 | // 23 | // This must return at least one node, provided there 24 | // are any nodes in the hash table. 25 | KeySites(key Hasher) []*simulator.Node 26 | } 27 | -------------------------------------------------------------------------------- /collcomm/reduce_fn.go: -------------------------------------------------------------------------------- 1 | package collcomm 2 | 3 | import ( 4 | "github.com/unixpickle/dist-sys/simulator" 5 | ) 6 | 7 | // FlopTime is the amount of virtual time it takes to 8 | // perform a single floating-point operation. 9 | const FlopTime = 1e-9 10 | 11 | // A ReduceFn is an operation that reduces many vectors 12 | // into a single vector. 13 | type ReduceFn func(h *simulator.Handle, vecs ...[]float64) []float64 14 | 15 | // Sum is a ReduceFn that computes a vector sum. 16 | func Sum(h *simulator.Handle, vecs ...[]float64) []float64 { 17 | for _, v := range vecs[1:] { 18 | if len(v) != len(vecs[0]) { 19 | panic("mismatching lengths") 20 | } 21 | } 22 | res := make([]float64, len(vecs[0])) 23 | for _, v := range vecs { 24 | for i, x := range v { 25 | res[i] += x 26 | } 27 | } 28 | 29 | // Simulate computation time. 30 | h.Sleep(FlopTime * float64(len(vecs)*len(vecs[0]))) 31 | 32 | return res 33 | } 34 | -------------------------------------------------------------------------------- /disthash/hash.go: -------------------------------------------------------------------------------- 1 | package disthash 2 | 3 | import ( 4 | "bytes" 5 | "crypto/md5" 6 | "encoding/gob" 7 | "math" 8 | ) 9 | 10 | // A Hasher is an object which can be turned into raw data 11 | // that quasi-uniquely represents the data. 12 | type Hasher interface { 13 | Hash() []byte 14 | } 15 | 16 | // A GobHasher hashes objects by encoding them as gobs. 17 | type GobHasher struct { 18 | V interface{} 19 | } 20 | 21 | // Hash returns the gob-encoded data. 22 | func (g *GobHasher) Hash() []byte { 23 | var buf bytes.Buffer 24 | if err := gob.NewEncoder(&buf).Encode(g.V); err != nil { 25 | panic(err) 26 | } 27 | return buf.Bytes() 28 | } 29 | 30 | // FloatHash hashes a value into a floating point number 31 | // in the range [0, 1). 32 | func FloatHash(data []byte) float64 { 33 | digest := md5.Sum(data) 34 | var number int64 35 | for i, x := range digest[:8] { 36 | number |= (int64(x) << uint(8*i)) 37 | } 38 | return math.Min(math.Nextafter(1, -1), float64(number)/math.Pow(2, 64)) 39 | } 40 | -------------------------------------------------------------------------------- /simulator/switcher_test.go: -------------------------------------------------------------------------------- 1 | package simulator 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | ) 7 | 8 | func TestGreedyDropSwitcher(t *testing.T) { 9 | switcher := &GreedyDropSwitcher{ 10 | SendRates: []float64{1.0, 2.0, 3.0}, 11 | RecvRates: []float64{2.0, 1.0, 1.0}, 12 | } 13 | inputMatrices := [][]float64{ 14 | { 15 | 0.0, 1.0, 0.0, 16 | 0.0, 0.0, 1.0, 17 | 1.0, 0.0, 0.0, 18 | }, 19 | { 20 | 1.0, 0.0, 0.0, 21 | 1.0, 0.0, 0.0, 22 | 1.0, 0.0, 0.0, 23 | }, 24 | { 25 | 1.0, 1.0, 1.0, 26 | 1.0, 1.0, 1.0, 27 | 1.0, 1.0, 1.0, 28 | }, 29 | } 30 | outputMatrices := [][]float64{ 31 | { 32 | 0.0, 1.0, 0.0, 33 | 0.0, 0.0, 1.0, 34 | 2.0, 0.0, 0.0, 35 | }, 36 | { 37 | 1.0 / 3.0, 0.0, 0.0, 38 | 2.0 / 3.0, 0.0, 0.0, 39 | 3.0 / 3.0, 0.0, 0.0, 40 | }, 41 | { 42 | 1.0 / 3.0, 1.0 / 6.0, 1.0 / 6.0, 43 | 2.0 / 3.0, 2.0 / 6.0, 2.0 / 6.0, 44 | 3.0 / 3.0, 3.0 / 6.0, 3.0 / 6.0, 45 | }, 46 | } 47 | for i, input := range inputMatrices { 48 | output := outputMatrices[i] 49 | connMat := &ConnMat{numNodes: 3, rates: input} 50 | switcher.SwitchedRates(connMat) 51 | for j, actual := range connMat.rates { 52 | if math.Abs(actual-output[j]) > 0.001 { 53 | t.Errorf("test %d: expected %v but got %v", i, output, connMat.rates) 54 | break 55 | } 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /raft/state_machines.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | // HashMapCommand is a command for a *HashMap. 4 | // 5 | // If Value is an empty string, this is a get command. 6 | type HashMapCommand struct { 7 | Key string 8 | Value string 9 | } 10 | 11 | func (h HashMapCommand) Size() int { 12 | return 2 + len(h.Key) + len(h.Value) 13 | } 14 | 15 | // StringResult is a Result type from a HashMap containing 16 | // a value stored at a given key. 17 | type StringResult struct { 18 | Value string 19 | } 20 | 21 | func (s StringResult) Size() int { 22 | return len(s.Value) 23 | } 24 | 25 | // HashMap is a state machine which takes HashMapCommands 26 | // and either sets or gets map values. 27 | type HashMap struct { 28 | mapping map[string]string 29 | } 30 | 31 | func (h *HashMap) ApplyState(command HashMapCommand) Result { 32 | if h.mapping == nil { 33 | h.mapping = map[string]string{} 34 | } 35 | if command.Value == "" { 36 | x, _ := h.mapping[command.Key] 37 | return StringResult{x} 38 | } 39 | h.mapping[command.Key] = command.Value 40 | return StringResult{command.Value} 41 | } 42 | 43 | func (h *HashMap) Size() int { 44 | if h.mapping == nil { 45 | h.mapping = map[string]string{} 46 | } 47 | // One stop character for the whole map and each key / value 48 | size := 1 49 | for k, v := range h.mapping { 50 | size += len(k) + len(v) + 2 51 | } 52 | return size 53 | } 54 | 55 | func (h *HashMap) Clone() *HashMap { 56 | if h.mapping == nil { 57 | h.mapping = map[string]string{} 58 | } 59 | m := map[string]string{} 60 | for k, v := range h.mapping { 61 | m[k] = v 62 | } 63 | return &HashMap{mapping: m} 64 | } 65 | -------------------------------------------------------------------------------- /simulator/conn_mat_test.go: -------------------------------------------------------------------------------- 1 | package simulator 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestConnMatSums(t *testing.T) { 8 | mat := NewConnMat(4) 9 | mat.Set(1, 2, 3.0) 10 | mat.Set(0, 2, 2.0) 11 | mat.Set(2, 3, 4.0) 12 | if res := mat.SumDest(0); res != 0.0 { 13 | t.Errorf("expected sum of 0.0 but got %f", res) 14 | } 15 | if res := mat.SumDest(1); res != 0.0 { 16 | t.Errorf("expected sum of 0.0 but got %f", res) 17 | } 18 | if res := mat.SumDest(2); res != 5.0 { 19 | t.Errorf("expected sum of 5.0 but got %f", res) 20 | } 21 | if res := mat.SumDest(3); res != 4.0 { 22 | t.Errorf("expected sum of 4.0 but got %f", res) 23 | } 24 | if res := mat.SumSource(0); res != 2.0 { 25 | t.Errorf("expected sum of 2.0 but got %f", res) 26 | } 27 | if res := mat.SumSource(1); res != 3.0 { 28 | t.Errorf("expected sum of 3.0 but got %f", res) 29 | } 30 | if res := mat.SumSource(2); res != 4.0 { 31 | t.Errorf("expected sum of 4.0 but got %f", res) 32 | } 33 | if res := mat.SumSource(3); res != 0.0 { 34 | t.Errorf("expected sum of 0.0 but got %f", res) 35 | } 36 | } 37 | 38 | func TestConnMatScales(t *testing.T) { 39 | mat := NewConnMat(4) 40 | mat.Set(1, 2, 3.0) 41 | mat.Set(1, 3, 5.0) 42 | mat.Set(0, 2, 2.0) 43 | mat.Set(2, 3, 4.0) 44 | 45 | mat.ScaleSource(1, 2.0) 46 | for i, expected := range []float64{0, 0, 3.0 * 2.0, 5.0 * 2.0} { 47 | if res := mat.Get(1, i); res != expected { 48 | t.Errorf("column %d: expected %f but got %f", i, expected, res) 49 | } 50 | } 51 | 52 | mat.ScaleDest(3, 3.0) 53 | for i, expected := range []float64{0, 5.0 * 2.0 * 3.0, 4.0 * 3.0, 0} { 54 | if res := mat.Get(i, 3); res != expected { 55 | t.Errorf("row %d: expected %f but got %f", i, expected, res) 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /raft/raft.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/unixpickle/dist-sys/simulator" 7 | ) 8 | 9 | type Raft[C Command, S StateMachine[C, S]] struct { 10 | Context context.Context 11 | Handle *simulator.Handle 12 | Network simulator.Network 13 | Port *simulator.Port 14 | Others []*simulator.Port 15 | 16 | // Algorithm state. 17 | Log *Log[C, S] 18 | Term int64 19 | 20 | // Settings for timeouts 21 | ElectionTimeout float64 22 | HeartbeatInterval float64 23 | } 24 | 25 | func (r *Raft[C, S]) RunLoop() { 26 | var followerMsg *simulator.Message 27 | for { 28 | f := &Follower[C, S]{ 29 | Context: r.Context, 30 | Handle: r.Handle, 31 | Network: r.Network, 32 | Port: r.Port, 33 | Others: r.Others, 34 | Log: r.Log, 35 | Term: r.Term, 36 | 37 | ElectionTimeout: r.ElectionTimeout, 38 | } 39 | f.RunLoop(followerMsg) 40 | select { 41 | case <-r.Context.Done(): 42 | return 43 | default: 44 | } 45 | r.Term = f.Term + 1 46 | 47 | c := &Candidate[C, S]{ 48 | Context: r.Context, 49 | Handle: r.Handle, 50 | Network: r.Network, 51 | Port: r.Port, 52 | Others: r.Others, 53 | Log: r.Log, 54 | Term: r.Term, 55 | 56 | ElectionTimeout: r.ElectionTimeout, 57 | } 58 | followerMsg = c.RunLoop() 59 | r.Term = c.Term 60 | select { 61 | case <-r.Context.Done(): 62 | return 63 | default: 64 | } 65 | if followerMsg == nil { 66 | followerMsg = (&Leader[C, S]{ 67 | Context: r.Context, 68 | Handle: r.Handle, 69 | Network: r.Network, 70 | Port: r.Port, 71 | Followers: r.Others, 72 | Log: r.Log, 73 | Term: r.Term, 74 | 75 | HeartbeatInterval: r.HeartbeatInterval, 76 | }).RunLoop() 77 | select { 78 | case <-r.Context.Done(): 79 | return 80 | default: 81 | } 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /collcomm/allreduce/tree.go: -------------------------------------------------------------------------------- 1 | package allreduce 2 | 3 | import ( 4 | "github.com/unixpickle/dist-sys/collcomm" 5 | "github.com/unixpickle/dist-sys/simulator" 6 | ) 7 | 8 | // A TreeAllreducer arranges the Ports in a binary tree 9 | // and performs a reduction by going up the three to a 10 | // root node, and then back down the tree to a leaf. 11 | type TreeAllreducer struct{} 12 | 13 | // Allreduce calls fn on vectors along a tree and returns 14 | // the resulting reduced vector. 15 | func (t TreeAllreducer) Allreduce(c *collcomm.Comms, data []float64, 16 | fn collcomm.ReduceFn) []float64 { 17 | parent, children := positionInTree(c) 18 | 19 | messages := [][]float64{data} 20 | for _ = range children { 21 | msg, _ := c.Recv() 22 | messages = append(messages, msg) 23 | } 24 | 25 | finalVector := fn(c.Handle, messages...) 26 | if parent != nil { 27 | c.Send(parent, finalVector) 28 | finalVector, _ = c.Recv() 29 | } 30 | 31 | for _, child := range children { 32 | c.Send(child, finalVector) 33 | } 34 | 35 | return finalVector 36 | } 37 | 38 | // positionInTree returns the child Ports and parent node 39 | // for a host in the reduction tree. 40 | // 41 | // There may be no children. 42 | // There may be no parent (for the root node). 43 | func positionInTree(c *collcomm.Comms) (parent *simulator.Port, children []*simulator.Port) { 44 | idx := c.Index() 45 | for depth := uint(0); true; depth++ { 46 | rowSize := 1 << depth 47 | rowStart := rowSize - 1 48 | if idx >= rowStart+rowSize { 49 | continue 50 | } 51 | rowIdx := idx - rowStart 52 | if depth > 0 { 53 | parent = c.Ports[rowIdx/2+(rowSize/2-1)] 54 | } 55 | firstChild := rowIdx*2 + (rowSize*2 - 1) 56 | for i := 0; i < 2; i++ { 57 | if firstChild+i < len(c.Ports) { 58 | children = append(children, c.Ports[firstChild+i]) 59 | } 60 | } 61 | return 62 | } 63 | panic("unreachable") 64 | } 65 | -------------------------------------------------------------------------------- /simulator/conn_mat.go: -------------------------------------------------------------------------------- 1 | package simulator 2 | 3 | // A ConnMat is a connectivity matrix. 4 | // 5 | // Entries in the matrix indicate a transfer rate from a 6 | // source node (row) to a destination node (column). 7 | type ConnMat struct { 8 | numNodes int 9 | rates []float64 10 | } 11 | 12 | // NewConnMat creates an all-zero connection matrix. 13 | func NewConnMat(numNodes int) *ConnMat { 14 | return &ConnMat{ 15 | numNodes: numNodes, 16 | rates: make([]float64, numNodes*numNodes), 17 | } 18 | } 19 | 20 | // NumNodes returns the number of nodes. 21 | func (c *ConnMat) NumNodes() int { 22 | return c.numNodes 23 | } 24 | 25 | // Get an entry in the matrix. 26 | func (c *ConnMat) Get(src, dst int) float64 { 27 | if src < 0 || dst < 0 || src >= c.numNodes || dst >= c.numNodes { 28 | panic("index out of bounds") 29 | } 30 | return c.rates[src*c.numNodes+dst] 31 | } 32 | 33 | // Set an entry in the matrix. 34 | func (c *ConnMat) Set(src, dst int, value float64) { 35 | if src < 0 || dst < 0 || src >= c.numNodes || dst >= c.numNodes { 36 | panic("index out of bounds") 37 | } 38 | c.rates[src*c.numNodes+dst] = value 39 | } 40 | 41 | // SumDest sums a column of the matrix. 42 | func (c *ConnMat) SumDest(dst int) float64 { 43 | if dst < 0 || dst >= c.numNodes { 44 | panic("index out of bounds") 45 | } 46 | var sum float64 47 | for i := dst; i < len(c.rates); i += c.numNodes { 48 | sum += c.rates[i] 49 | } 50 | return sum 51 | } 52 | 53 | // SumSource sums a row of the matrix. 54 | func (c *ConnMat) SumSource(src int) float64 { 55 | if src < 0 || src >= c.numNodes { 56 | panic("index out of bounds") 57 | } 58 | var sum float64 59 | for i := src * c.numNodes; i < (src+1)*c.numNodes; i++ { 60 | sum += c.rates[i] 61 | } 62 | return sum 63 | } 64 | 65 | // ScaleDest scales a column of the matrix. 66 | func (c *ConnMat) ScaleDest(dst int, scale float64) { 67 | if dst < 0 || dst >= c.numNodes { 68 | panic("index out of bounds") 69 | } 70 | for i := dst; i < len(c.rates); i += c.numNodes { 71 | c.rates[i] *= scale 72 | } 73 | } 74 | 75 | // ScaleSource scales a row of the matrix. 76 | func (c *ConnMat) ScaleSource(src int, scale float64) { 77 | if src < 0 || src >= c.numNodes { 78 | panic("index out of bounds") 79 | } 80 | for i := src * c.numNodes; i < (src+1)*c.numNodes; i++ { 81 | c.rates[i] *= scale 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /simulator/switcher.go: -------------------------------------------------------------------------------- 1 | package simulator 2 | 3 | // A Switcher is a switching algorithm that determines how 4 | // rapidly data flows in a graph of nodes. 5 | // One job of the Switcher is to decide how to deal with 6 | // oversubscription. 7 | type Switcher interface { 8 | // Apply the switching algorithm to compute the 9 | // transfer rates of every connection. 10 | // 11 | // The mat argument is passed in with 1's wherever a 12 | // node wants to send data to another node, and 0's 13 | // everywhere else. 14 | // 15 | // When the function returns, mat indicates the rate 16 | // of data between every pair of nodes. 17 | SwitchedRates(mat *ConnMat) 18 | } 19 | 20 | // A GreedyDropSwitcher emulates a switch where outgoing 21 | // data is spread evenly across a node's outputs, and 22 | // inputs to a node are dropped uniformly at random when a 23 | // node is oversubscribed. 24 | // 25 | // This is equivalent to first normalizing the rows of a 26 | // connection matrix, and then normalizing the columns. 27 | type GreedyDropSwitcher struct { 28 | SendRates []float64 29 | RecvRates []float64 30 | } 31 | 32 | // NewGreedyDropSwitcher creates a GreedyDropSwitcher with 33 | // uniform upload and download rates across all nodes. 34 | func NewGreedyDropSwitcher(numNodes int, rate float64) *GreedyDropSwitcher { 35 | rates := make([]float64, numNodes) 36 | for i := range rates { 37 | rates[i] = rate 38 | } 39 | return &GreedyDropSwitcher{ 40 | SendRates: rates, 41 | RecvRates: rates, 42 | } 43 | } 44 | 45 | // NumNodes gets the number of nodes the switch expects. 46 | func (g *GreedyDropSwitcher) NumNodes() int { 47 | return len(g.SendRates) 48 | } 49 | 50 | // SwitchedRates performs the switching algorithm. 51 | func (g *GreedyDropSwitcher) SwitchedRates(mat *ConnMat) { 52 | if mat.NumNodes() != g.NumNodes() { 53 | panic("unexpected number of nodes") 54 | } 55 | 56 | // Split upload traffic evenly across sockets. 57 | for src := 0; src < g.NumNodes(); src++ { 58 | numDests := mat.SumSource(src) 59 | if numDests > 0 { 60 | mat.ScaleSource(src, g.SendRates[src]/numDests) 61 | } 62 | } 63 | 64 | // Drop download traffic in proportion to the number 65 | // of incoming packets from each socket. 66 | for dst := 0; dst < g.NumNodes(); dst++ { 67 | incomingRate := mat.SumDest(dst) 68 | if incomingRate > g.RecvRates[dst] { 69 | mat.ScaleDest(dst, g.RecvRates[dst]/incomingRate) 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /collcomm/allreduce/tester.go: -------------------------------------------------------------------------------- 1 | package allreduce 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/rand" 7 | "testing" 8 | 9 | "github.com/unixpickle/dist-sys/collcomm" 10 | 11 | "github.com/unixpickle/dist-sys/simulator" 12 | ) 13 | 14 | // RunAllreducerTests runs a battery of tests on an 15 | // Allreducer. 16 | func RunAllreducerTests(t *testing.T, reducer Allreducer) { 17 | for _, numNodes := range []int{1, 2, 5, 15, 16, 17} { 18 | for _, size := range []int{0, 1337} { 19 | for _, randomized := range []bool{false, true} { 20 | testName := fmt.Sprintf("Nodes=%d,Size=%d,Random=%v", numNodes, size, randomized) 21 | t.Run(testName, func(t *testing.T) { 22 | loop := simulator.NewEventLoop() 23 | vectors := make([][]float64, numNodes) 24 | nodes := make([]*simulator.Node, numNodes) 25 | sum := make([]float64, size) 26 | for i := range nodes { 27 | vectors[i] = make([]float64, size) 28 | for j := range vectors[i] { 29 | vectors[i][j] = rand.NormFloat64() 30 | sum[j] += vectors[i][j] 31 | } 32 | nodes[i] = simulator.NewNode() 33 | } 34 | 35 | var network simulator.Network 36 | if randomized { 37 | network = simulator.RandomNetwork{} 38 | } else { 39 | switcher := simulator.NewGreedyDropSwitcher(numNodes, 1.0) 40 | network = simulator.NewSwitcherNetwork(switcher, nodes, 0.1) 41 | } 42 | 43 | results := make([][]float64, numNodes) 44 | collcomm.SpawnComms(loop, network, nodes, func(c *collcomm.Comms) { 45 | results[c.Index()] = reducer.Allreduce(c, vectors[c.Index()], collcomm.Sum) 46 | }) 47 | 48 | if err := loop.Run(); err != nil { 49 | t.Fatal(err) 50 | } 51 | 52 | verifyReductionResults(t, results, sum) 53 | }) 54 | } 55 | } 56 | } 57 | } 58 | 59 | func verifyReductionResults(t *testing.T, results [][]float64, expected []float64) { 60 | for i, res := range results[1:] { 61 | if len(res) != len(expected) { 62 | t.Errorf("result %d has length %d but expected %d", i, len(res), len(expected)) 63 | continue 64 | } 65 | for j, actual := range res { 66 | if actual != results[0][j] { 67 | t.Errorf("result %d is not identical to result 0", i) 68 | break 69 | } 70 | } 71 | } 72 | 73 | for i, x := range expected { 74 | if math.Abs(x-results[0][i]) > 1e-5 { 75 | t.Errorf("sum is incorrect (expected %f but got %f at component %d)", 76 | x, results[0][i], i) 77 | break 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /disthash/consistent.go: -------------------------------------------------------------------------------- 1 | package disthash 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | 7 | "github.com/unixpickle/essentials" 8 | 9 | "github.com/unixpickle/dist-sys/simulator" 10 | ) 11 | 12 | // Consistent implements consistent hashing. 13 | type Consistent struct { 14 | nodes []*circleNode 15 | numPoints int 16 | } 17 | 18 | // NewConsistent creates a consistent hash table that 19 | // generates the given number of points per node. 20 | func NewConsistent(numPoints int) *Consistent { 21 | return &Consistent{numPoints: numPoints} 22 | } 23 | 24 | // AddSite adds a bucket to the hash table. 25 | func (c *Consistent) AddSite(node *simulator.Node, id Hasher) { 26 | data := id.Hash() 27 | for i := 0; i < c.numPoints; i++ { 28 | var buf bytes.Buffer 29 | binary.Write(&buf, binary.LittleEndian, uint32(i)) 30 | buf.Write(data) 31 | pos := FloatHash(buf.Bytes()) 32 | c.nodes = append(c.nodes, &circleNode{Node: node, ID: id, Position: pos}) 33 | } 34 | essentials.VoodooSort(c.nodes, func(i, j int) bool { 35 | return c.nodes[i].Position < c.nodes[j].Position 36 | }) 37 | } 38 | 39 | // RemoveSite removes a bucket from the hash table. 40 | func (c *Consistent) RemoveSite(node *simulator.Node) { 41 | var newNodes []*circleNode 42 | for _, point := range c.nodes { 43 | if point.Node != node { 44 | newNodes = append(newNodes, point) 45 | } 46 | } 47 | c.nodes = newNodes 48 | } 49 | 50 | // Sites returns all buckets in the hash table. 51 | func (c *Consistent) Sites() []*simulator.Node { 52 | nodes := map[*simulator.Node]bool{} 53 | var res []*simulator.Node 54 | for _, point := range c.nodes { 55 | if !nodes[point.Node] { 56 | res = append(res, point.Node) 57 | } 58 | nodes[point.Node] = true 59 | } 60 | return res 61 | } 62 | 63 | // KeySites returns the buckets where a key is stored. 64 | // 65 | // This must return at least one node, provided there 66 | // are any nodes in the hash table. 67 | func (c *Consistent) KeySites(key Hasher) []*simulator.Node { 68 | if len(c.nodes) == 0 { 69 | return nil 70 | } 71 | pos := FloatHash(key.Hash()) 72 | for _, point := range c.nodes { 73 | if point.Position > pos { 74 | return []*simulator.Node{point.Node} 75 | } 76 | } 77 | return []*simulator.Node{c.nodes[0].Node} 78 | } 79 | 80 | // A circleNode is one node around a circle with a 81 | // circumference of one. 82 | type circleNode struct { 83 | Node *simulator.Node 84 | ID Hasher 85 | 86 | // In the range [0, 1.0). 87 | Position float64 88 | } 89 | -------------------------------------------------------------------------------- /raft/log.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | // A LogEntry is an immutable step in a log that mutates or 4 | // returns state about a StateMachine. 5 | type Command interface { 6 | Size() int 7 | } 8 | 9 | // Result is some result from a command. 10 | type Result interface { 11 | Size() int 12 | } 13 | 14 | type LogEntry[C Command] struct { 15 | Term int64 16 | Command C 17 | } 18 | 19 | func (l LogEntry[C]) Size() int { 20 | return 8 + l.Command.Size() 21 | } 22 | 23 | // A StateMachine with log type L stores some state which 24 | // can be mutated by log entries, and each state transition 25 | // may return some value. 26 | type StateMachine[C Command, Self any] interface { 27 | ApplyState(C) Result 28 | Size() int 29 | Clone() Self 30 | } 31 | 32 | // A Log stores a state machine and log entries which are 33 | // applied to it. 34 | type Log[C Command, S StateMachine[C, S]] struct { 35 | // Origin is the state machine up to the last committed 36 | // state. 37 | Origin S 38 | 39 | // OriginIndex is the index of the next log index. 40 | OriginIndex int64 41 | 42 | // OriginTerm is the term of the latest log message 43 | // that was applied to the state machine. 44 | OriginTerm int64 45 | 46 | // Log entries starting from origin leading up to 47 | // latest state. 48 | Entries []LogEntry[C] 49 | } 50 | 51 | // LatestTermAndIndex gets the latest position in the log, 52 | // which can be used for leader election decisions. 53 | func (l *Log[C, S]) LatestTermAndIndex() (int64, int64) { 54 | if len(l.Entries) == 0 { 55 | return l.OriginTerm, l.OriginIndex 56 | } else { 57 | e := l.Entries[len(l.Entries)-1] 58 | return e.Term, l.OriginIndex + int64(len(l.Entries)) 59 | } 60 | } 61 | 62 | // Commit caches log entries before a log index and 63 | // advances the commit index. 64 | func (l *Log[C, S]) Commit(commitIndex int64) []Result { 65 | if commitIndex <= l.OriginIndex { 66 | return nil 67 | } 68 | var results []Result 69 | for i := l.OriginIndex; i < commitIndex; i++ { 70 | results = append(results, l.Origin.ApplyState(l.Entries[i-l.OriginIndex].Command)) 71 | } 72 | l.OriginTerm = l.Entries[commitIndex-l.OriginIndex-1].Term 73 | l.Entries = append([]LogEntry[C]{}, l.Entries[commitIndex-l.OriginIndex:]...) 74 | l.OriginIndex = commitIndex 75 | return results 76 | } 77 | 78 | // Append adds a command to the log and returns the index 79 | // of the resulting log entry. 80 | func (l *Log[C, S]) Append(term int64, command C) int64 { 81 | l.Entries = append(l.Entries, LogEntry[C]{ 82 | Term: term, 83 | Command: command, 84 | }) 85 | return l.OriginIndex + int64(len(l.Entries)-1) 86 | } 87 | -------------------------------------------------------------------------------- /raft/candidate.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/unixpickle/dist-sys/simulator" 7 | ) 8 | 9 | type Candidate[C Command, S StateMachine[C, S]] struct { 10 | Context context.Context 11 | Handle *simulator.Handle 12 | Network simulator.Network 13 | Port *simulator.Port 14 | Others []*simulator.Port 15 | 16 | // Algorithm state. 17 | Log *Log[C, S] 18 | Term int64 19 | 20 | // ElectionTimeout should be randomized per follower to 21 | // break ties in the common case. 22 | ElectionTimeout float64 23 | 24 | // Internal state 25 | timerStream *simulator.EventStream 26 | timer *simulator.Timer 27 | } 28 | 29 | // RunLoop waits until an election is complete. 30 | // 31 | // Returns nil if this node is the follower. 32 | // Returns a message from another node if this node is a 33 | // follower, and this message should be handled by the 34 | // follower loop. 35 | func (c *Candidate[C, S]) RunLoop() *simulator.Message { 36 | c.timerStream = c.Handle.Stream() 37 | c.timer = c.Handle.Schedule(c.timerStream, nil, c.ElectionTimeout) 38 | 39 | defer func() { 40 | c.Handle.Cancel(c.timer) 41 | }() 42 | 43 | c.broadcastCandidacy() 44 | 45 | numVotes := 0 46 | for { 47 | result := c.Handle.Poll(c.timerStream, c.Port.Incoming) 48 | select { 49 | case <-c.Context.Done(): 50 | return nil 51 | default: 52 | } 53 | if result.Stream == c.timerStream { 54 | numVotes = 0 55 | c.Term++ 56 | c.timer = c.Handle.Schedule(c.timerStream, nil, c.ElectionTimeout) 57 | c.broadcastCandidacy() 58 | continue 59 | } 60 | 61 | rawMsg := result.Message.(*simulator.Message) 62 | 63 | if msg, ok := rawMsg.Message.(*RaftMessage[C, S]); ok { 64 | if term := msg.Term(); term < c.Term { 65 | continue 66 | } else if term > c.Term { 67 | return rawMsg 68 | } else if msg.VoteResponse != nil { 69 | if msg.VoteResponse.ReceivedVote { 70 | numVotes += 1 71 | } 72 | if numVotes >= len(c.Others)/2 { 73 | return nil 74 | } 75 | } else if msg.AppendLogs != nil { 76 | return rawMsg 77 | } 78 | } 79 | } 80 | } 81 | 82 | func (c *Candidate[C, S]) broadcastCandidacy() { 83 | var messages []*simulator.Message 84 | for _, port := range c.Others { 85 | latestTerm, latestIndex := c.Log.LatestTermAndIndex() 86 | raftMsg := &RaftMessage[C, S]{ 87 | Vote: &Vote{ 88 | Term: c.Term, 89 | LatestTerm: latestTerm, 90 | LatestIndex: latestIndex, 91 | }, 92 | } 93 | messages = append(messages, &simulator.Message{ 94 | Source: c.Port, 95 | Dest: port, 96 | Message: raftMsg, 97 | Size: float64(raftMsg.Size()), 98 | }) 99 | } 100 | c.Network.Send(c.Handle, messages...) 101 | } 102 | -------------------------------------------------------------------------------- /collcomm/allreduce/bench_allreduce/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | 7 | "github.com/unixpickle/dist-sys/collcomm" 8 | "github.com/unixpickle/dist-sys/collcomm/allreduce" 9 | "github.com/unixpickle/dist-sys/simulator" 10 | ) 11 | 12 | // RunInfo describes a specific network configuration. 13 | type RunInfo struct { 14 | NumNodes int 15 | Latency float64 16 | Rate float64 17 | } 18 | 19 | // Run creates a network and drops each host into its own 20 | // Goroutine. 21 | func (r *RunInfo) Run(loop *simulator.EventLoop, commFn func(c *collcomm.Comms)) { 22 | nodes := make([]*simulator.Node, r.NumNodes) 23 | for i := range nodes { 24 | nodes[i] = simulator.NewNode() 25 | } 26 | switcher := simulator.NewGreedyDropSwitcher(r.NumNodes, r.Rate) 27 | network := simulator.NewSwitcherNetwork(switcher, nodes, r.Latency) 28 | collcomm.SpawnComms(loop, network, nodes, commFn) 29 | loop.MustRun() 30 | } 31 | 32 | func main() { 33 | reducers := []allreduce.Allreducer{ 34 | allreduce.NaiveAllreducer{}, 35 | allreduce.TreeAllreducer{}, 36 | allreduce.StreamAllreducer{}, 37 | } 38 | reducerNames := []string{"Naive", "Tree", "Stream"} 39 | runs := []RunInfo{ 40 | { 41 | NumNodes: 2, 42 | Latency: 0.1, 43 | Rate: 1e6, 44 | }, 45 | { 46 | NumNodes: 16, 47 | Latency: 1e-3, 48 | Rate: 1e6, 49 | }, 50 | { 51 | NumNodes: 32, 52 | Latency: 0.1, 53 | Rate: 1e6, 54 | }, 55 | { 56 | NumNodes: 32, 57 | Latency: 0.1, 58 | Rate: 1e9, 59 | }, 60 | { 61 | NumNodes: 32, 62 | Latency: 1e-4, 63 | Rate: 1e9, 64 | }, 65 | } 66 | vecSizes := []int{10, 10000, 10000000} 67 | 68 | // Markdown table header. 69 | fmt.Print("| Nodes | Latency | NIC rate | Size ") 70 | for _, reducerName := range reducerNames { 71 | fmt.Printf("| %s ", reducerName) 72 | } 73 | fmt.Println("|") 74 | for i := 0; i < 4+len(reducers); i++ { 75 | fmt.Print("|:--") 76 | } 77 | fmt.Println("|") 78 | 79 | // Markdown table body. 80 | for _, runInfo := range runs { 81 | for _, size := range vecSizes { 82 | fmt.Printf( 83 | "| %d | %s | %s | %d ", 84 | runInfo.NumNodes, 85 | strconv.FormatFloat(runInfo.Latency, 'f', -1, 64), 86 | strconv.FormatFloat(runInfo.Rate, 'E', -1, 64), 87 | size, 88 | ) 89 | for _, reducer := range reducers { 90 | loop := simulator.NewEventLoop() 91 | runInfo.Run(loop, func(c *collcomm.Comms) { 92 | vec := make([]float64, size) 93 | reducer.Allreduce(c, vec, FakeReduce) 94 | }) 95 | fmt.Printf("| %f ", loop.Time()) 96 | } 97 | fmt.Println("|") 98 | } 99 | } 100 | } 101 | 102 | // FakeReduce is a ReduceFn that takes no actual CPU time. 103 | func FakeReduce(h *simulator.Handle, vecs ...[]float64) []float64 { 104 | h.Sleep(collcomm.FlopTime * float64(len(vecs)*len(vecs[0]))) 105 | return make([]float64, len(vecs[0])) 106 | } 107 | -------------------------------------------------------------------------------- /collcomm/comms.go: -------------------------------------------------------------------------------- 1 | package collcomm 2 | 3 | import "github.com/unixpickle/dist-sys/simulator" 4 | 5 | // Comms manages a set of connections between a bunch of 6 | // nodes. 7 | // During a collective operation, each node has a local 8 | // Comms object that represents its view of the world. 9 | // A new Comms object should be used for each operation, 10 | // thus automatically handling multiplexing. 11 | type Comms struct { 12 | // Handle is the node's main Goroutine's handle on the 13 | // event loop. 14 | Handle *simulator.Handle 15 | 16 | // Port is the current node's port. 17 | Port *simulator.Port 18 | 19 | // Ports contains ports to all the nodes in the 20 | // network, including the current node. 21 | Ports []*simulator.Port 22 | 23 | // Network is the network connecting the nodes. 24 | Network simulator.Network 25 | } 26 | 27 | // SpawnComms creates Comms objects for every node in a 28 | // network and calls f for each node in its own Goroutine. 29 | func SpawnComms(loop *simulator.EventLoop, network simulator.Network, nodes []*simulator.Node, 30 | f func(c *Comms)) { 31 | ports := make([]*simulator.Port, len(nodes)) 32 | for i, node := range nodes { 33 | ports[i] = node.Port(loop) 34 | } 35 | for i := range nodes { 36 | port := ports[i] 37 | loop.Go(func(h *simulator.Handle) { 38 | f(&Comms{ 39 | Handle: h, 40 | Port: port, 41 | Ports: ports, 42 | Network: network, 43 | }) 44 | }) 45 | } 46 | } 47 | 48 | // Size gets the number of nodes. 49 | func (c *Comms) Size() int { 50 | return len(c.Ports) 51 | } 52 | 53 | // Bcast sends a vector to every other node. 54 | func (c *Comms) Bcast(vec []float64) { 55 | messages := make([]*simulator.Message, 0, len(c.Ports)-1) 56 | for _, port := range c.Ports { 57 | if port == c.Port { 58 | continue 59 | } 60 | messages = append(messages, &simulator.Message{ 61 | Source: c.Port, 62 | Dest: port, 63 | Message: vec, 64 | Size: float64(len(vec) * 8), 65 | }) 66 | } 67 | c.Network.Send(c.Handle, messages...) 68 | } 69 | 70 | // Send schedules a message to be sent to the destination. 71 | func (c *Comms) Send(dst *simulator.Port, vec []float64) { 72 | c.Network.Send(c.Handle, &simulator.Message{ 73 | Source: c.Port, 74 | Dest: dst, 75 | Message: vec, 76 | Size: float64(len(vec) * 8), 77 | }) 78 | } 79 | 80 | // Recv receives the next vector. 81 | func (c *Comms) Recv() ([]float64, *simulator.Port) { 82 | res := c.Port.Recv(c.Handle) 83 | return res.Message.([]float64), res.Source 84 | } 85 | 86 | // Index returns the current node's index in the list of 87 | // nodes. 88 | func (c *Comms) Index() int { 89 | return c.IndexOf(c.Port) 90 | } 91 | 92 | // IndexOf returns any node's index. 93 | func (c *Comms) IndexOf(p *simulator.Port) int { 94 | for i, port := range c.Ports { 95 | if port == p { 96 | return i 97 | } 98 | } 99 | panic("unreachable") 100 | } 101 | -------------------------------------------------------------------------------- /raft/client.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "errors" 5 | "math/rand" 6 | 7 | "github.com/google/uuid" 8 | "github.com/unixpickle/dist-sys/simulator" 9 | ) 10 | 11 | var ( 12 | ErrClientTimeout = errors.New("raft server did not respond") 13 | ErrLeaderUnknown = errors.New("raft leader is not currently known") 14 | ) 15 | 16 | type Client[C Command] struct { 17 | Handle *simulator.Handle 18 | Network simulator.Network 19 | Port *simulator.Port 20 | Servers []*simulator.Port 21 | 22 | SendTimeout float64 23 | 24 | timerStream *simulator.EventStream 25 | timer *simulator.Timer 26 | leader *simulator.Port 27 | } 28 | 29 | // Send attempts to send the message to the servers and get 30 | // the result of the command. 31 | // 32 | // It retries until the command is executed, or the given 33 | // number of tries is exceeded. 34 | // 35 | // If numAttempts is 0, then retries are executed forever. 36 | func (c *Client[C]) Send(command C, numAttempts int) (Result, error) { 37 | c.timerStream = c.Handle.Stream() 38 | c.timer = c.Handle.Schedule(c.timerStream, nil, c.SendTimeout) 39 | defer func() { 40 | defer c.Handle.Cancel(c.timer) 41 | }() 42 | 43 | id := uuid.NewString() 44 | 45 | if c.leader == nil { 46 | c.leader = c.Servers[rand.Intn(len(c.Servers))] 47 | } 48 | 49 | msg := &CommandMessage[C]{ 50 | Command: command, 51 | ID: id, 52 | } 53 | 54 | madeAttempts := 0 55 | for { 56 | c.sendToLeader(msg) 57 | res, redir, err := c.waitForResult(id) 58 | if err != nil { 59 | madeAttempts++ 60 | if madeAttempts == numAttempts { 61 | return nil, err 62 | } 63 | if err == ErrLeaderUnknown { 64 | // Wait some time for the cluster to come back up. 65 | c.Handle.Sleep(c.SendTimeout) 66 | } 67 | c.timer = c.Handle.Schedule(c.timerStream, nil, c.SendTimeout) 68 | c.leader = c.Servers[rand.Intn(len(c.Servers))] 69 | } else if redir != nil { 70 | c.leader = nil 71 | for _, p := range c.Servers { 72 | if p.Node == redir { 73 | c.leader = p 74 | } 75 | } 76 | if c.leader == nil { 77 | panic("unknown redirect server") 78 | } 79 | } else { 80 | return res, nil 81 | } 82 | } 83 | } 84 | 85 | func (c *Client[C]) sendToLeader(msg *CommandMessage[C]) { 86 | c.Network.Send(c.Handle, &simulator.Message{ 87 | Source: c.Port, 88 | Dest: c.leader, 89 | Message: msg, 90 | Size: float64(msg.Command.Size() + len(msg.ID)), 91 | }) 92 | } 93 | 94 | func (c *Client[C]) waitForResult(id string) (res Result, redir *simulator.Node, err error) { 95 | for { 96 | resp := c.Handle.Poll(c.timerStream, c.Port.Incoming) 97 | if resp.Stream == c.timerStream { 98 | return nil, nil, ErrClientTimeout 99 | } 100 | rawMsg := resp.Message.(*simulator.Message) 101 | if obj, ok := rawMsg.Message.(*CommandResponse); ok { 102 | if obj.ID != id { 103 | continue 104 | } 105 | if obj.Redirect != nil { 106 | return nil, obj.Redirect, nil 107 | } else if obj.LeaderUnknown { 108 | return nil, nil, ErrLeaderUnknown 109 | } else { 110 | return obj.Result, nil, nil 111 | } 112 | } 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /raft/messages.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import "github.com/unixpickle/dist-sys/simulator" 4 | 5 | type RaftMessage[C Command, S StateMachine[C, S]] struct { 6 | AppendLogs *AppendLogs[C, S] 7 | AppendLogsResponse *AppendLogsResponse[C, S] 8 | Vote *Vote 9 | VoteResponse *VoteResponse 10 | } 11 | 12 | func (r *RaftMessage[C, S]) Size() int { 13 | headerSize := 1 14 | if r.AppendLogs != nil { 15 | return headerSize + r.AppendLogs.Size() 16 | } else if r.AppendLogsResponse != nil { 17 | return headerSize + r.AppendLogsResponse.Size() 18 | } else if r.Vote != nil { 19 | return headerSize + r.Vote.Size() 20 | } else if r.VoteResponse != nil { 21 | return headerSize + r.VoteResponse.Size() 22 | } 23 | panic("unknown message type") 24 | } 25 | 26 | func (r *RaftMessage[C, S]) Term() int64 { 27 | if r.AppendLogs != nil { 28 | return r.AppendLogs.Term 29 | } else if r.AppendLogsResponse != nil { 30 | return r.AppendLogsResponse.Term 31 | } else if r.Vote != nil { 32 | return r.Vote.Term 33 | } else if r.VoteResponse != nil { 34 | return r.VoteResponse.Term 35 | } 36 | panic("unknown message type") 37 | } 38 | 39 | // AppendLogs is a message sent from leaders to followers. 40 | type AppendLogs[C Command, S StateMachine[C, S]] struct { 41 | Term int64 42 | CommitIndex int64 43 | 44 | // SeqNum is unique per term and helps leaders 45 | // determine if this is a stale response. 46 | SeqNum int64 47 | 48 | // OriginTerm is the term of the message corresponding 49 | // to the origin. 50 | OriginTerm int64 51 | 52 | // OriginIndex may be different than CommitIndex if we 53 | // are sending newer entries to this worker than are in 54 | // the committed state machine. 55 | // 56 | // In this case, it will be greater than CommitIndex, 57 | // and Origin will be nil. 58 | OriginIndex int64 59 | 60 | // Origin may be specified if we are too far behind. 61 | Origin *S 62 | 63 | Entries []LogEntry[C] 64 | } 65 | 66 | // Size tabulates the approximate number of bytes requires 67 | // to encode this message. 68 | func (a *AppendLogs[C, S]) Size() int { 69 | res := 8 * 3 70 | if a.Origin != nil { 71 | res += (*a.Origin).Size() 72 | } 73 | for _, e := range a.Entries { 74 | res += e.Size() 75 | } 76 | return res 77 | } 78 | 79 | // AppendLogsResponse is sent by the followers to the 80 | // leader in response to an AppendLogs message. 81 | type AppendLogsResponse[C Command, S StateMachine[C, S]] struct { 82 | Term int64 83 | CommitIndex int64 84 | LatestIndex int64 85 | SeqNum int64 86 | 87 | // Success will be false if there was not enough 88 | // data to fill in the logs. 89 | Success bool 90 | } 91 | 92 | func (a *AppendLogsResponse[C, S]) Size() int { 93 | return 8 * 3 94 | } 95 | 96 | type Vote struct { 97 | Term int64 98 | 99 | // Last log message 100 | LatestTerm int64 101 | LatestIndex int64 102 | } 103 | 104 | func (v *Vote) Size() int { 105 | return 8 * 3 106 | } 107 | 108 | type VoteResponse struct { 109 | Term int64 110 | ReceivedVote bool 111 | } 112 | 113 | func (v *VoteResponse) Size() int { 114 | return 8 + 1 115 | } 116 | 117 | type CommandMessage[C Command] struct { 118 | Command C 119 | ID string 120 | } 121 | 122 | type CommandResponse struct { 123 | ID string 124 | Result Result // non-nil if Redirect is nil and Unknown is false 125 | LeaderUnknown bool // if true, the leader is not known 126 | Redirect *simulator.Node // non-nil if this is no longer the leader 127 | } 128 | 129 | func (c *CommandResponse) Size() int { 130 | if c.Result != nil { 131 | // Result size + header 132 | return 1 + c.Result.Size() 133 | } else { 134 | // IPv4 + port number + header 135 | return 1 + 6 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /disthash/tester.go: -------------------------------------------------------------------------------- 1 | package disthash 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/unixpickle/dist-sys/simulator" 8 | "github.com/unixpickle/essentials" 9 | ) 10 | 11 | // TestDistHash runs a battery of tests on a DistHash. 12 | func TestDistHash(t *testing.T, maker func() DistHash) { 13 | t.Run("StaticConsistency", func(t *testing.T) { 14 | TestStaticConsistency(t, maker()) 15 | }) 16 | t.Run("RemovalConsistency", func(t *testing.T) { 17 | TestRemovalConsistency(t, maker()) 18 | }) 19 | t.Run("AddConsistency", func(t *testing.T) { 20 | TestAddConsistency(t, maker()) 21 | }) 22 | } 23 | 24 | // TestStaticConsistency checks that a DistHash produces 25 | // the same bins when no sites are added/removed. 26 | func TestStaticConsistency(t *testing.T, d DistHash) { 27 | site1 := simulator.NewNode() 28 | site2 := simulator.NewNode() 29 | d.AddSite(site1, &GobHasher{V: "hi"}) 30 | d.AddSite(site2, &GobHasher{V: "hey"}) 31 | 32 | sites := make([][][]*simulator.Node, 2) 33 | for i := 0; i < 2; i++ { 34 | sites[i] = [][]*simulator.Node{} 35 | for j := 0; j < 100; j++ { 36 | keySites := copySites(d.KeySites(&GobHasher{V: j})) 37 | if len(keySites) == 0 { 38 | t.Error("no sites for key") 39 | } 40 | sites[i] = append(sites[i], keySites) 41 | } 42 | } 43 | if !reflect.DeepEqual(sites[0], sites[1]) { 44 | t.Error("sites were inconsistent") 45 | } 46 | } 47 | 48 | // TestRemovalConsistency checks that a DistHash does not 49 | // mess with objects stored in sites that are not removed, 50 | // even when other sites are removed. 51 | func TestRemovalConsistency(t *testing.T, d DistHash) { 52 | var sites []*simulator.Node 53 | for i := 0; i < 3; i++ { 54 | site := simulator.NewNode() 55 | sites = append(sites, site) 56 | d.AddSite(site, &GobHasher{V: i}) 57 | } 58 | 59 | oldSites := [][]*simulator.Node{} 60 | for i := 0; i < 100; i++ { 61 | keySites := copySites(d.KeySites(&GobHasher{V: i})) 62 | if len(keySites) == 0 { 63 | t.Error("no sites for key") 64 | } 65 | oldSites = append(oldSites, keySites) 66 | } 67 | 68 | d.RemoveSite(sites[0]) 69 | 70 | for i := 0; i < 100; i++ { 71 | newSites := copySites(d.KeySites(&GobHasher{V: i})) 72 | if !essentials.Contains(oldSites[i], sites[0]) { 73 | if !reflect.DeepEqual(oldSites[i], newSites) { 74 | t.Errorf("site %d should be unaffected, but went from %v to %v", 75 | i, oldSites[i], newSites) 76 | } 77 | } else { 78 | if len(newSites) == 0 { 79 | t.Error("no sites for key") 80 | } else if essentials.Contains(newSites, sites[0]) { 81 | t.Error("removed site still in use") 82 | } 83 | } 84 | } 85 | } 86 | 87 | // TestAddConsistency checks that a DistHash does not move 88 | // keys around excessively when adding new sites. 89 | func TestAddConsistency(t *testing.T, d DistHash) { 90 | var sites []*simulator.Node 91 | for i := 0; i < 3; i++ { 92 | site := simulator.NewNode() 93 | sites = append(sites, site) 94 | d.AddSite(site, &GobHasher{V: i}) 95 | } 96 | 97 | oldSites := [][]*simulator.Node{} 98 | for i := 0; i < 100; i++ { 99 | keySites := copySites(d.KeySites(&GobHasher{V: i})) 100 | if len(keySites) == 0 { 101 | t.Error("no sites for key") 102 | } 103 | oldSites = append(oldSites, keySites) 104 | } 105 | 106 | newSite := simulator.NewNode() 107 | d.AddSite(newSite, &GobHasher{V: -1}) 108 | 109 | for i := 0; i < 100; i++ { 110 | newSites := copySites(d.KeySites(&GobHasher{V: i})) 111 | if !essentials.Contains(newSites, newSite) { 112 | if !reflect.DeepEqual(oldSites[i], newSites) { 113 | t.Errorf("site %d should be unaffected, but went from %v to %v", 114 | i, oldSites[i], newSites) 115 | } 116 | } else if len(newSites) == 0 { 117 | t.Error("no sites for key") 118 | } 119 | } 120 | } 121 | 122 | // copySites copies a slice in such a way that 123 | // reflect.DeepEqual will not be confused by unused 124 | // capacity. 125 | func copySites(sites []*simulator.Node) []*simulator.Node { 126 | return append([]*simulator.Node{}, sites...) 127 | } 128 | -------------------------------------------------------------------------------- /paxos/basic_paxos_test.go: -------------------------------------------------------------------------------- 1 | package paxos 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/unixpickle/dist-sys/simulator" 7 | ) 8 | 9 | func TestBasicOneProposer(t *testing.T) { 10 | loop := simulator.NewEventLoop() 11 | nodes := []*simulator.Node{} 12 | ports := []*simulator.Port{} 13 | for i := 0; i < 3; i++ { 14 | nodes = append(nodes, simulator.NewNode()) 15 | ports = append(ports, nodes[len(nodes)-1].Port(loop)) 16 | } 17 | switcher := simulator.NewGreedyDropSwitcher(len(nodes), 1e6) 18 | network := simulator.NewSwitcherNetwork(switcher, nodes, 0) 19 | 20 | doneStream := loop.Stream() 21 | paxos := &BasicPaxos{Timeout: 10.0} 22 | 23 | // Run an acceptor on every node. 24 | for i := range nodes { 25 | idx := i 26 | loop.Go(func(h *simulator.Handle) { 27 | paxos.Accept(h, network, ports[idx], doneStream) 28 | }) 29 | } 30 | 31 | // Run a proposer on one of the nodes. 32 | loop.Go(func(h *simulator.Handle) { 33 | value := paxos.Propose(h, network, nodes[0].Port(loop), ports, "goodbye world", 13) 34 | if value != "goodbye world" { 35 | t.Errorf("unexpected value: %s", value) 36 | } 37 | for range ports { 38 | h.Schedule(doneStream, nil, 0) 39 | } 40 | }) 41 | 42 | loop.MustRun() 43 | } 44 | 45 | func TestBasicSlowProposer(t *testing.T) { 46 | // Make sure we try out many random seeds. 47 | for i := 0; i < 100; i++ { 48 | for _, randomized := range []bool{false, true} { 49 | for _, backoff := range []float64{0, 1} { 50 | loop := simulator.NewEventLoop() 51 | nodes := []*simulator.Node{} 52 | ports := []*simulator.Port{} 53 | for i := 0; i < 3; i++ { 54 | nodes = append(nodes, simulator.NewNode()) 55 | ports = append(ports, nodes[len(nodes)-1].Port(loop)) 56 | } 57 | var network simulator.Network 58 | if randomized { 59 | network = simulator.RandomNetwork{} 60 | } else { 61 | switcher := simulator.NewGreedyDropSwitcher(len(nodes), 1e6) 62 | network = simulator.NewSwitcherNetwork(switcher, nodes, 0.1) 63 | } 64 | 65 | doneStream := loop.Stream() 66 | paxos := &BasicPaxos{Backoff: backoff, Timeout: 1.5} 67 | 68 | // Run an acceptor on every node. 69 | for i := range nodes { 70 | idx := i 71 | loop.Go(func(h *simulator.Handle) { 72 | paxos.Accept(h, network, ports[idx], doneStream) 73 | }) 74 | } 75 | 76 | // A fast proposer that runs first. 77 | loop.Go(func(h *simulator.Handle) { 78 | value := paxos.Propose(h, network, nodes[0].Port(loop), ports, "goodbye world", 13) 79 | if value != "goodbye world" { 80 | t.Errorf("unexpected value: %s", value) 81 | } 82 | }) 83 | 84 | // A slow proposer should accept the faster one. 85 | loop.Go(func(h *simulator.Handle) { 86 | h.Sleep(1e5) 87 | value := paxos.Propose(h, network, nodes[1].Port(loop), ports, "hello world", 11) 88 | if value != "goodbye world" { 89 | t.Errorf("unexpected value: %s", value) 90 | } 91 | for range ports { 92 | h.Schedule(doneStream, nil, 0) 93 | } 94 | }) 95 | 96 | loop.MustRun() 97 | } 98 | } 99 | } 100 | } 101 | 102 | // TestExponentialBackoff tests the Basic Paxos backoff 103 | // mechanism in a case where two proposers would normally 104 | // compete forever. 105 | func TestBasicBackoff(t *testing.T) { 106 | const NumNodes = 5 107 | const NumProposers = 3 108 | 109 | loop := simulator.NewEventLoop() 110 | nodes := []*simulator.Node{} 111 | ports := []*simulator.Port{} 112 | for i := 0; i < NumNodes; i++ { 113 | nodes = append(nodes, simulator.NewNode()) 114 | ports = append(ports, nodes[len(nodes)-1].Port(loop)) 115 | } 116 | 117 | // The network is slow enough that large accept 118 | // messages will take much longer than prepare 119 | // messages. 120 | switcher := simulator.NewGreedyDropSwitcher(len(nodes), 1.0) 121 | network := simulator.NewSwitcherNetwork(switcher, nodes, 0) 122 | 123 | doneStream := loop.Stream() 124 | paxos := &BasicPaxos{Backoff: 1.0, Timeout: 1e10} 125 | 126 | // Run an acceptor on every node. 127 | for i := range nodes { 128 | idx := i 129 | loop.Go(func(h *simulator.Handle) { 130 | paxos.Accept(h, network, ports[idx], doneStream) 131 | }) 132 | } 133 | 134 | // Finish the acceptors once both proposers are done. 135 | proposerDones := loop.Stream() 136 | loop.Go(func(h *simulator.Handle) { 137 | for i := 0; i < NumProposers; i++ { 138 | h.Poll(proposerDones) 139 | } 140 | for range ports { 141 | h.Schedule(doneStream, nil, 0) 142 | } 143 | }) 144 | 145 | // Run competing proposers. 146 | for i := 0; i < NumProposers; i++ { 147 | idx := i 148 | loop.Go(func(h *simulator.Handle) { 149 | value := paxos.Propose(h, network, nodes[idx].Port(loop), ports, "msg", 1000) 150 | if value != "msg" { 151 | t.Errorf("unexpected value: %s", value) 152 | } 153 | h.Schedule(proposerDones, nil, 0) 154 | }) 155 | } 156 | 157 | loop.MustRun() 158 | } 159 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dist-sys 2 | 3 | The main goal of this repository is to play with distributed systems *without* needing a large cluster of machines. As I learn about distributed algorithms, I'm going to implement them here. 4 | 5 | # Resources 6 | 7 | * Allreduce algorithms: [Optimization of Collective Communication Operations in MPICH](http://www.mcs.anl.gov/~thakur/papers/ijhpca-coll.pdf) 8 | * Paxos 9 | * [Wikipedia page](https://en.wikipedia.org/wiki/Paxos_(computer_science)#Multi-Paxos) 10 | * [Youtube - The Paxos Algorithm | Rachid Guerraoui](https://www.youtube.com/watch?v=WX4gjowx45E) 11 | * [Youtube - Paxos lecture (Raft user study)](https://www.youtube.com/watch?v=JEpsBg0AO6o) 12 | * Distributed hashing 13 | * [Consistent hashing - Wikipedia](https://en.wikipedia.org/wiki/Consistent_hashing) 14 | * [Rendezvous hashing - Wikipedia](https://en.wikipedia.org/wiki/Rendezvous_hashing) 15 | 16 | # Packages 17 | 18 | * [simulator](#simulator) - simulate a network of machines using "virtual" time. 19 | * [collcomm](#collcomm) - collective communications, e.g. "allreduce" to quickly sum large vectors across many machines. 20 | * [paxos](paxos) - an implementation of the Paxos consensus algorithm. 21 | 22 | ## simulator 23 | 24 | This package provides an API for simulating a distributed network of machines. It has two core APIs: 25 | 26 | * An event loop, which schedules events in "virtual" time. 27 | * A simulated network, which sits on top of the event loop to provide realistic message delivery times. 28 | 29 | To use the event loop, first create an `EventLoop` with `NewEventLoop()`. Then use `loop.Go()` to run new Goroutines, each of which gets its own `Handle` to the event loop. To wait for events, use `handle.Poll()`. Virtual time will only pass while all Goroutines with `Handle`s are polling on those `Handle`s. This way, Goroutines can do any amount of real-world work without virtual time passing, and vice versa. If you want virtual time to reflect some time for computation, you can use `handle.Sleep()` to explicitly let a certain amount of virtual time pass. 30 | 31 | This event loop example shows communication between two Goroutines: 32 | 33 | ```go 34 | loop := NewEventLoop() 35 | stream := loop.Stream() 36 | loop.Go(func(h *Handle) { 37 | msg := h.Poll(stream).Message 38 | fmt.Println(msg, h.Time()) 39 | }) 40 | loop.Go(func(h *Handle) { 41 | message := "Hello, world!" 42 | delay := 15.5 43 | h.Schedule(stream, message, delay) 44 | }) 45 | loop.Run() 46 | // Output: Hello, world! 15.5 47 | ``` 48 | 49 | The network API sits on top of the event loop API. Here's an example of how one might use the network API to time a simple back-and-forth interaction between nodes: 50 | 51 | ```go 52 | loop := NewEventLoop() 53 | 54 | // A switch with two ports that do I/O at 2 bytes/sec. 55 | switcher := NewGreedyDropSwitcher(2, 2.0) 56 | 57 | node1 := &Node{Incoming: loop.Stream()} 58 | node2 := &Node{Incoming: loop.Stream()} 59 | latency := 0.25 60 | network := NewSwitcherNetwork(switcher, []*Node{node1, node2}, latency) 61 | 62 | // Goroutine for node 1. 63 | loop.Go(func(h *Handle) { 64 | message := node1.Recv(h).Message.(string) 65 | response := strings.ToUpper(message) 66 | 67 | // Simulate time it took to do the calculation. 68 | h.Sleep(0.125) 69 | 70 | network.Send(h, &Message{ 71 | Source: node1, 72 | Dest: node2, 73 | Message: response, 74 | Size: float64(len(message)), 75 | }) 76 | }) 77 | 78 | // Goroutine for node 2. 79 | loop.Go(func(h *Handle) { 80 | msg := "this should be capitalized" 81 | network.Send(h, &Message{ 82 | Source: node2, 83 | Dest: node1, 84 | Message: msg, 85 | Size: float64(len(msg)), 86 | }) 87 | response := node2.Recv(h).Message.(string) 88 | fmt.Println(response, h.Time()) 89 | }) 90 | 91 | loop.Run() 92 | 93 | // Output: THIS SHOULD BE CAPITALIZED 26.625 94 | ``` 95 | 96 | ## collcomm 97 | 98 | This package is a re-implementation of some [MPI](https://en.wikipedia.org/wiki/Message_Passing_Interface) routines. 99 | 100 | ### allreduce 101 | 102 | This package implements the "allreduce" collective communication API. For those who are not familiar, allreduce is extremely useful in large-scale machine learning. It allows you to quickly sum vectors (gradients, in most cases) across a large cluster of machines. 103 | 104 | I've implemented three algorithms for allreduce: 105 | 106 | * Naive: every node sends its vector to every other node. 107 | * Tree: the nodes arrange themselves into a binary tree, and reduce up and down the tree. 108 | * Stream: the nodes arrange themselves into a ring, and the vector streams around the ring twice. The first time is for reduction, and the second time is to broadcast the reduced vector to the other nodes. 109 | 110 | Here are some performance results for the different algorithms in a simulated network. The last row is the most realistic for a datacenter. 111 | 112 | | Nodes | Latency | NIC rate | Size | Naive | Tree | Stream | 113 | |:--|:--|:--|:--|:--|:--|:--| 114 | | 2 | 0.1 | 1E+06 | 10 | 0.100080 | 0.200160 | 0.800232 | 115 | | 2 | 0.1 | 1E+06 | 10000 | 0.180020 | 0.360030 | 1.020026 | 116 | | 2 | 0.1 | 1E+06 | 10000000 | 80.120000 | 160.230000 | 200.860011 | 117 | | 16 | 0.001 | 1E+06 | 10 | 0.002200 | 0.008880 | 0.067676 | 118 | | 16 | 0.001 | 1E+06 | 10000 | 1.201160 | 1.047350 | 0.419556 | 119 | | 16 | 0.001 | 1E+06 | 10000000 | 1200.161000 | 1040.107250 | 305.143356 | 120 | | 32 | 0.1 | 1E+06 | 10 | 0.102480 | 1.001120 | 9.901000 | 121 | | 32 | 0.1 | 1E+06 | 10000 | 2.580320 | 2.272630 | 19.551941 | 122 | | 32 | 0.1 | 1E+06 | 10000000 | 2480.420000 | 1361.042500 | 336.256147 | 123 | | 32 | 0.1 | 1E+09 | 10 | 0.100003 | 1.000001 | 9.900001 | 124 | | 32 | 0.1 | 1E+09 | 10000 | 0.102800 | 1.001270 | 19.100486 | 125 | | 32 | 0.1 | 1E+09 | 10000000 | 2.900000 | 2.402500 | 19.185915 | 126 | | 32 | 0.0001 | 1E+09 | 10 | 0.000103 | 0.001001 | 0.009901 | 127 | | 32 | 0.0001 | 1E+09 | 10000 | 0.002900 | 0.002403 | 0.019586 | 128 | | 32 | 0.0001 | 1E+09 | 10000000 | 2.800100 | 1.490913 | 0.358567 | 129 | -------------------------------------------------------------------------------- /raft/follower.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/unixpickle/dist-sys/simulator" 7 | ) 8 | 9 | type Follower[C Command, S StateMachine[C, S]] struct { 10 | Context context.Context 11 | Handle *simulator.Handle 12 | Network simulator.Network 13 | Port *simulator.Port 14 | Others []*simulator.Port 15 | 16 | // Algorithm state. 17 | Log *Log[C, S] 18 | Term int64 19 | 20 | // ElectionTimeout should be randomized per follower to 21 | // break ties in the common case. 22 | ElectionTimeout float64 23 | 24 | // Internal state 25 | timerStream *simulator.EventStream 26 | timer *simulator.Timer 27 | 28 | currentVote *simulator.Port 29 | leader *simulator.Port 30 | } 31 | 32 | // RunLoop runs as a follower until there is a timeout, at 33 | // which point the node should enter the candidate phase. 34 | func (f *Follower[C, S]) RunLoop(initialMessage *simulator.Message) { 35 | f.timerStream = f.Handle.Stream() 36 | f.timer = f.Handle.Schedule(f.timerStream, nil, f.ElectionTimeout) 37 | 38 | if initialMessage != nil { 39 | f.handleMessage(initialMessage) 40 | } 41 | 42 | for { 43 | result := f.Handle.Poll(f.timerStream, f.Port.Incoming) 44 | select { 45 | case <-f.Context.Done(): 46 | return 47 | default: 48 | } 49 | if result.Stream == f.timerStream { 50 | return 51 | } 52 | f.handleMessage(result.Message.(*simulator.Message)) 53 | } 54 | } 55 | 56 | func (f *Follower[C, S]) handleMessage(rawMsg *simulator.Message) { 57 | if sourcePortIndex(rawMsg, f.Others) == -1 { 58 | f.handleCommand(rawMsg.Source, rawMsg.Message.(*CommandMessage[C])) 59 | return 60 | } 61 | 62 | msg := rawMsg.Message.(*RaftMessage[C, S]) 63 | if term := msg.Term(); term < f.Term { 64 | return 65 | } else if term > f.Term { 66 | f.currentVote = nil 67 | f.leader = nil 68 | f.Term = term 69 | } 70 | 71 | if msg.AppendLogs != nil { 72 | f.handleAppendLogs(rawMsg.Source, msg.AppendLogs) 73 | } else if msg.Vote != nil { 74 | f.handleVote(rawMsg.Source, msg.Vote) 75 | } 76 | } 77 | 78 | func (f *Follower[C, S]) resetTimer() { 79 | f.Handle.Cancel(f.timer) 80 | f.timer = f.Handle.Schedule(f.timerStream, nil, f.ElectionTimeout) 81 | } 82 | 83 | func (f *Follower[C, S]) handleCommand(source *simulator.Port, msg *CommandMessage[C]) { 84 | var leader *simulator.Node 85 | if f.leader != nil { 86 | leader = f.leader.Node 87 | } 88 | resp := &CommandResponse{ 89 | ID: msg.ID, 90 | LeaderUnknown: leader == nil, 91 | Redirect: leader, 92 | } 93 | f.Network.Send(f.Handle, &simulator.Message{ 94 | Source: f.Port, 95 | Dest: source, 96 | Message: resp, 97 | Size: float64(resp.Size()), 98 | }) 99 | } 100 | 101 | func (f *Follower[C, S]) handleAppendLogs(source *simulator.Port, msg *AppendLogs[C, S]) { 102 | f.resetTimer() 103 | f.leader = source 104 | 105 | if msg.Origin != nil { 106 | // This is the easy case: they are forcing our log to be 107 | // in a particular state. 108 | f.Log.OriginIndex = msg.OriginIndex 109 | f.Log.OriginTerm = msg.OriginTerm 110 | f.Log.Origin = *msg.Origin 111 | f.Log.Entries = msg.Entries 112 | 113 | _, latestIndex := f.Log.LatestTermAndIndex() 114 | resp := &RaftMessage[C, S]{ 115 | AppendLogsResponse: &AppendLogsResponse[C, S]{ 116 | Term: f.Term, 117 | SeqNum: msg.SeqNum, 118 | CommitIndex: msg.CommitIndex, 119 | LatestIndex: latestIndex, 120 | Success: true, 121 | }, 122 | } 123 | f.Network.Send(f.Handle, &simulator.Message{ 124 | Source: f.Port, 125 | Dest: source, 126 | Message: resp, 127 | Size: float64(resp.Size()), 128 | }) 129 | return 130 | } 131 | 132 | resp := &RaftMessage[C, S]{ 133 | AppendLogsResponse: &AppendLogsResponse[C, S]{ 134 | Term: f.Term, 135 | SeqNum: msg.SeqNum, 136 | Success: false, 137 | }, 138 | } 139 | 140 | if msg.OriginIndex <= f.Log.OriginIndex { 141 | // We have already committed zero or more of these entries, 142 | // so the rest must be accepted. 143 | if int(f.Log.OriginIndex-msg.OriginIndex) <= len(msg.Entries) { 144 | f.Log.Entries = msg.Entries[f.Log.OriginIndex-msg.OriginIndex:] 145 | f.Log.Commit(msg.CommitIndex) 146 | } else { 147 | // The sender actually truncated our logs, meaning 148 | // they do not have a commit as far as we do. 149 | panic("truncated log before commit index") 150 | // f.Log.Entries = []LogEntry[C]{} 151 | } 152 | resp.AppendLogsResponse.Success = true 153 | } else { 154 | startEntry := msg.OriginIndex - f.Log.OriginIndex 155 | // We want to make sure they aren't appending past our log 156 | if int(startEntry) <= len(f.Log.Entries) { 157 | // We don't want to accept this suffix if we weren't consistent 158 | // up to the prefix. 159 | if f.Log.Entries[startEntry-1].Term == msg.OriginTerm { 160 | // We can append all the entries after this. 161 | if len(msg.Entries) > 0 { 162 | // We only append and reallocate if this isn't a heartbeat. 163 | f.Log.Entries = append( 164 | append( 165 | []LogEntry[C]{}, 166 | f.Log.Entries[:startEntry]..., 167 | ), 168 | msg.Entries..., 169 | ) 170 | } 171 | f.Log.Commit(msg.CommitIndex) 172 | resp.AppendLogsResponse.Success = true 173 | } 174 | } 175 | } 176 | 177 | resp.AppendLogsResponse.CommitIndex = f.Log.OriginIndex 178 | _, resp.AppendLogsResponse.LatestIndex = f.Log.LatestTermAndIndex() 179 | 180 | f.Network.Send(f.Handle, &simulator.Message{ 181 | Source: f.Port, 182 | Dest: source, 183 | Message: resp, 184 | Size: float64(resp.Size()), 185 | }) 186 | } 187 | 188 | func (f *Follower[C, S]) handleVote(source *simulator.Port, msg *Vote) { 189 | latestTerm, latestIndex := f.Log.LatestTermAndIndex() 190 | resp := &RaftMessage[C, S]{ 191 | VoteResponse: &VoteResponse{ 192 | Term: f.Term, 193 | ReceivedVote: true, 194 | }, 195 | } 196 | if (f.currentVote != nil && source != f.currentVote) || 197 | latestTerm > msg.LatestTerm || 198 | (latestTerm == msg.LatestTerm && latestIndex > msg.LatestIndex) { 199 | resp.VoteResponse.ReceivedVote = false 200 | } else { 201 | f.currentVote = source 202 | } 203 | f.Network.Send(f.Handle, &simulator.Message{ 204 | Source: f.Port, 205 | Dest: source, 206 | Message: resp, 207 | Size: float64(resp.Size()), 208 | }) 209 | } 210 | -------------------------------------------------------------------------------- /simulator/events_test.go: -------------------------------------------------------------------------------- 1 | package simulator 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func ExampleEventLoop() { 11 | loop := NewEventLoop() 12 | stream := loop.Stream() 13 | loop.Go(func(h *Handle) { 14 | msg := h.Poll(stream).Message 15 | fmt.Println(msg, h.Time()) 16 | }) 17 | loop.Go(func(h *Handle) { 18 | message := "Hello, world!" 19 | delay := 15.5 20 | h.Schedule(stream, message, delay) 21 | }) 22 | loop.MustRun() 23 | // Output: Hello, world! 15.5 24 | } 25 | 26 | func TestEventLoopTimer(t *testing.T) { 27 | loop := NewEventLoop() 28 | stream := loop.Stream() 29 | value := make(chan interface{}, 1) 30 | loop.Go(func(h *Handle) { 31 | value <- h.Poll(stream).Message 32 | }) 33 | loop.Go(func(h *Handle) { 34 | h.Schedule(stream, 1337, 15.5) 35 | }) 36 | if err := loop.Run(); err != nil { 37 | t.Fatal(err) 38 | } 39 | if loop.Time() != 15.5 { 40 | t.Errorf("time should be 15.5 but is %f", loop.Time()) 41 | } 42 | select { 43 | case val := <-value: 44 | if val != 1337 { 45 | t.Errorf("value should be 1337 but is %f", val) 46 | } 47 | default: 48 | t.Error("timer never fired") 49 | } 50 | } 51 | 52 | func TestEventLoopTimerOrder(t *testing.T) { 53 | loop := NewEventLoop() 54 | 55 | stream1 := loop.Stream() 56 | stream2 := loop.Stream() 57 | 58 | values := make(chan interface{}, 2) 59 | 60 | for _, stream := range []*EventStream{stream1, stream2} { 61 | s := stream 62 | loop.Go(func(h *Handle) { 63 | event := h.Poll(s) 64 | if event.Stream != s { 65 | t.Error("incorrect stream") 66 | } 67 | values <- event.Message 68 | }) 69 | } 70 | 71 | loop.Go(func(h *Handle) { 72 | h.Schedule(stream1, 123, 5.0) 73 | h.Schedule(stream2, 1339, 7.0) 74 | }) 75 | 76 | if err := loop.Run(); err != nil { 77 | t.Fatal(err) 78 | } 79 | 80 | if loop.Time() != 7.0 { 81 | t.Errorf("time should be 7.0 but got %f", loop.Time()) 82 | } 83 | 84 | val1 := <-values 85 | val2 := <-values 86 | if val1 != 123 { 87 | t.Errorf("value 1 should be 123 but got %d", val1) 88 | } 89 | if val2 != 1339 { 90 | t.Errorf("value 2 should be 1339 but got %d", val2) 91 | } 92 | } 93 | 94 | // TestEventLoopMultiConsumer tests that the EventLoop 95 | // properly supports multiple threads reading from the 96 | // same event stream. 97 | func TestEventLoopMultiConsumer(t *testing.T) { 98 | orderings := map[[3]int]bool{} 99 | for i := 0; i < 10000; i++ { 100 | loop := NewEventLoop() 101 | stream := loop.Stream() 102 | var ordering [3]int 103 | for j := 0; j < 3; j++ { 104 | idx := j 105 | loop.Go(func(h *Handle) { 106 | msg := h.Poll(stream).Message 107 | ordering[idx] = msg.(int) 108 | }) 109 | } 110 | loop.Go(func(h *Handle) { 111 | h.Schedule(stream, 1, 1.0) 112 | h.Schedule(stream, 2, 2.0) 113 | h.Schedule(stream, 3, 3.0) 114 | }) 115 | if err := loop.Run(); err != nil { 116 | t.Fatal(err) 117 | } 118 | if loop.Time() != 3 { 119 | t.Errorf("time should be 3.0 but got %f", loop.Time()) 120 | } 121 | orderings[ordering] = true 122 | } 123 | if len(orderings) != 6 { 124 | t.Errorf("expected 6 possible orderings but saw %d", len(orderings)) 125 | } 126 | } 127 | 128 | // TestEventLoopBuffering tests that messages sent to an 129 | // EventStream will be queued if no Goroutine is currently 130 | // polling on the stream. 131 | func TestEventLoopBuffering(t *testing.T) { 132 | loop := NewEventLoop() 133 | 134 | readFirst := loop.Stream() 135 | readSecond := loop.Stream() 136 | neverRead := loop.Stream() 137 | 138 | value := make(chan interface{}, 1) 139 | 140 | loop.Go(func(h *Handle) { 141 | h.Poll(readFirst) 142 | value <- h.Poll(readSecond).Message 143 | }) 144 | 145 | loop.Go(func(h *Handle) { 146 | h.Schedule(readSecond, 1337, 3.0) 147 | h.Sleep(2) 148 | h.Schedule(neverRead, 321, 4.0) 149 | h.Schedule(readFirst, 123, 7.0) 150 | }) 151 | 152 | if err := loop.Run(); err != nil { 153 | t.Fatal(err) 154 | } 155 | 156 | if loop.Time() != 9.0 { 157 | t.Errorf("time should be 9.0 but got %f", loop.Time()) 158 | } 159 | 160 | if val := <-value; val != 1337 { 161 | t.Errorf("expected 1337 but got %d", val) 162 | } 163 | } 164 | 165 | // TestEventLoopPollMulti tests polling multiple streams 166 | // at once. 167 | func TestEventLoopPollMulti(t *testing.T) { 168 | loop := NewEventLoop() 169 | 170 | first := loop.Stream() 171 | second := loop.Stream() 172 | third := loop.Stream() 173 | 174 | values := make(chan interface{}, 3) 175 | 176 | loop.Go(func(h *Handle) { 177 | for _, stream := range []*EventStream{first, second, third} { 178 | event := h.Poll(third, second, first) 179 | if event.Stream != stream { 180 | t.Error("incorrect stream order") 181 | } 182 | values <- event.Message 183 | } 184 | }) 185 | 186 | loop.Go(func(h *Handle) { 187 | h.Schedule(first, 133, 3.0) 188 | h.Sleep(3.5) 189 | h.Schedule(third, 333, 7.0) 190 | 191 | // Real time should play no part in the ordering 192 | // of messages. 193 | time.Sleep(time.Second / 4) 194 | 195 | h.Schedule(second, 233, 1.0) 196 | }) 197 | 198 | if err := loop.Run(); err != nil { 199 | t.Fatal(err) 200 | } 201 | 202 | if loop.Time() != 10.5 { 203 | t.Errorf("time should be 10.5 but got %f", loop.Time()) 204 | } 205 | 206 | for _, expected := range []int{133, 233, 333} { 207 | if val := <-values; val != expected { 208 | t.Errorf("expected %d but got %d", expected, val) 209 | } 210 | } 211 | } 212 | 213 | // TestEventLoopDeadlocks makes sure that the event loop 214 | // can detect deadlocks. 215 | func TestEventLoopDeadlocks(t *testing.T) { 216 | loop := NewEventLoop() 217 | 218 | stream1 := loop.Stream() 219 | stream2 := loop.Stream() 220 | 221 | loop.Go(func(h *Handle) { 222 | h.Poll(stream1) 223 | h.Schedule(stream2, 1337, 0.0) 224 | }) 225 | 226 | loop.Go(func(h *Handle) { 227 | time.Sleep(time.Second / 4) 228 | h.Poll(stream2) 229 | h.Schedule(stream1, 1337, 0.0) 230 | }) 231 | 232 | if loop.Run() == nil { 233 | t.Error("did not detect deadlock") 234 | } 235 | } 236 | 237 | func BenchmarkTimerManipulation(b *testing.B) { 238 | for i := 0; i < b.N; i++ { 239 | loop := NewEventLoop() 240 | loop.Go(func(h *Handle) { 241 | stream := h.Stream() 242 | for j := 0; j < 32; j++ { 243 | timers := []*Timer{} 244 | for k := 0; k < 32; k++ { 245 | timers = append(timers, h.Schedule(stream, "hello", rand.Float64())) 246 | } 247 | for _, j := range rand.Perm(len(timers)) { 248 | h.Cancel(timers[j]) 249 | } 250 | } 251 | }) 252 | loop.Run() 253 | } 254 | } 255 | -------------------------------------------------------------------------------- /paxos/basic_paxos.go: -------------------------------------------------------------------------------- 1 | package paxos 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | 7 | "github.com/unixpickle/dist-sys/simulator" 8 | "github.com/unixpickle/essentials" 9 | ) 10 | 11 | // BasicPaxos implements the simplest version of the Paxos 12 | // algorithm. 13 | // 14 | // This algorithm does not support re-configuration, 15 | // multiple values, more than N/2 failures, etc. 16 | type BasicPaxos struct { 17 | // Backoff controls the rate of exponential backoff. 18 | // It is a coefficient for a 2^attempts term. 19 | Backoff float64 20 | 21 | // Timeout is the number of seconds to wait before 22 | // giving up on a message. 23 | Timeout float64 24 | } 25 | 26 | // Propose runs a Paxos proposer until a value has been 27 | // accepted by a quorum of acceptors. 28 | // 29 | // The cluster must contain at least 3 acceptors. 30 | func (b *BasicPaxos) Propose(h *simulator.Handle, n simulator.Network, p *simulator.Port, 31 | acceptors []*simulator.Port, value interface{}, size int) interface{} { 32 | quorum := quorumSize(len(acceptors)) 33 | sendVal := &basicValue{value: value, size: size} 34 | var round int 35 | for i := 0; true; i++ { 36 | b.backoff(h, i) 37 | 38 | prepare := &basicPrepareReq{round: round} 39 | for _, acc := range acceptors { 40 | n.Send(h, &simulator.Message{ 41 | Source: p, 42 | Dest: acc, 43 | Size: float64(prepare.Size()), 44 | Message: prepare, 45 | }) 46 | } 47 | 48 | acceptorQuorum, prepResps := b.prepareResponses(h, p, acceptors, round) 49 | if len(acceptorQuorum) < quorum { 50 | round = essentials.MaxInt(round+1, basicPrepareNextRound(prepResps)) 51 | continue 52 | } 53 | 54 | if v := basicPrepareAcceptedValue(prepResps); v != nil { 55 | sendVal = v 56 | } 57 | 58 | accept := &basicAcceptReq{round: round, value: sendVal} 59 | for _, acc := range acceptorQuorum { 60 | n.Send(h, &simulator.Message{ 61 | Source: p, 62 | Dest: acc, 63 | Size: float64(accept.Size()), 64 | Message: accept, 65 | }) 66 | } 67 | 68 | if b.proposeResponse(h, p, len(acceptorQuorum), quorum, round) { 69 | return sendVal.value 70 | } 71 | 72 | round++ 73 | } 74 | panic("unreachable") 75 | } 76 | 77 | // Accept runs a Paxos acceptor. 78 | // 79 | // Returns when an event is sent to the done stream. 80 | func (b *BasicPaxos) Accept(h *simulator.Handle, n simulator.Network, p *simulator.Port, 81 | done *simulator.EventStream) { 82 | currentRound := -1 83 | var acceptedRound int 84 | var acceptedValue *basicValue 85 | for { 86 | event := h.Poll(done, p.Incoming) 87 | if event.Stream == done { 88 | return 89 | } 90 | msg := event.Message.(*simulator.Message) 91 | switch packet := msg.Message.(type) { 92 | case *basicPrepareReq: 93 | response := &basicPrepareResp{} 94 | if packet.round > currentRound { 95 | currentRound = packet.round 96 | response.success = true 97 | response.acceptedRound = acceptedRound 98 | response.acceptedValue = acceptedValue 99 | } 100 | response.currentRound = currentRound 101 | n.Send(h, &simulator.Message{ 102 | Source: msg.Dest, 103 | Dest: msg.Source, 104 | Size: float64(response.Size()), 105 | Message: response, 106 | }) 107 | case *basicAcceptReq: 108 | response := &basicAcceptResp{} 109 | response.round = packet.round 110 | if packet.round == currentRound { 111 | acceptedRound = packet.round 112 | acceptedValue = packet.value 113 | response.accepted = true 114 | } 115 | n.Send(h, &simulator.Message{ 116 | Source: msg.Dest, 117 | Dest: msg.Source, 118 | Size: float64(response.Size()), 119 | Message: response, 120 | }) 121 | default: 122 | panic("unexpected message type") 123 | } 124 | } 125 | } 126 | 127 | // backoff performs exponential backoff if necessary. 128 | func (b *BasicPaxos) backoff(h *simulator.Handle, tryNumber int) { 129 | if b.Backoff == 0 || tryNumber == 0 { 130 | return 131 | } 132 | delay := rand.Float64() * b.Backoff * math.Pow(2.0, float64(tryNumber)) 133 | h.Sleep(delay) 134 | } 135 | 136 | // prepareResponses reads responses from a prepare step. 137 | // 138 | // It returns the ports for all acceptors that accepted 139 | // the round number, and all the responses. 140 | func (b *BasicPaxos) prepareResponses(h *simulator.Handle, p *simulator.Port, 141 | acceptors []*simulator.Port, round int) ([]*simulator.Port, []*basicPrepareResp) { 142 | timeout := h.Stream() 143 | h.Schedule(timeout, nil, b.Timeout) 144 | 145 | var accPorts []*simulator.Port 146 | var responses []*basicPrepareResp 147 | for len(responses) < len(acceptors) { 148 | msg := h.Poll(timeout, p.Incoming) 149 | if msg.Stream == timeout { 150 | break 151 | } 152 | netMsg := msg.Message.(*simulator.Message) 153 | if resp, ok := netMsg.Message.(*basicPrepareResp); ok { 154 | responses = append(responses, resp) 155 | if resp.success && resp.currentRound == round { 156 | accPorts = append(accPorts, netMsg.Source) 157 | } 158 | } 159 | } 160 | return accPorts, responses 161 | } 162 | 163 | // proposeResponse reads responses to an accept request 164 | // and determines whether the value was accepted or not. 165 | func (b *BasicPaxos) proposeResponse(h *simulator.Handle, p *simulator.Port, 166 | numSent, quorum, round int) bool { 167 | timeout := h.Stream() 168 | h.Schedule(timeout, nil, b.Timeout) 169 | 170 | var numAccepts int 171 | var numAcceptResponses int 172 | for numAcceptResponses < numSent && numAccepts < quorum { 173 | msg := h.Poll(timeout, p.Incoming) 174 | if msg.Stream == timeout { 175 | break 176 | } 177 | netMsg := msg.Message.(*simulator.Message) 178 | if resp, ok := netMsg.Message.(*basicAcceptResp); ok { 179 | if resp.round == round { 180 | numAcceptResponses++ 181 | if resp.accepted { 182 | numAccepts++ 183 | } 184 | } 185 | } 186 | } 187 | 188 | return numAccepts >= quorum 189 | } 190 | 191 | type basicValue struct { 192 | value interface{} 193 | size int 194 | } 195 | 196 | func (b *basicValue) Size() int { 197 | if b == nil { 198 | return 0 199 | } 200 | return b.size 201 | } 202 | 203 | type basicPrepareReq struct { 204 | round int 205 | } 206 | 207 | func (b *basicPrepareReq) Size() int { 208 | return 8 209 | } 210 | 211 | type basicPrepareResp struct { 212 | success bool 213 | 214 | // On success, this is the promised round number. 215 | // Otherwise, this is the round number that superseded 216 | // the requested round number. 217 | currentRound int 218 | 219 | // Non-nil if a previous value was accepted. 220 | acceptedValue *basicValue 221 | acceptedRound int 222 | } 223 | 224 | func basicPrepareAcceptedValue(resps []*basicPrepareResp) *basicValue { 225 | highest := -1 226 | var value *basicValue 227 | for _, r := range resps { 228 | if r.acceptedRound > highest && r.acceptedValue != nil { 229 | highest = r.acceptedRound 230 | value = r.acceptedValue 231 | } 232 | } 233 | return value 234 | } 235 | 236 | func basicPrepareNextRound(resps []*basicPrepareResp) int { 237 | var max int 238 | for _, r := range resps { 239 | if r.currentRound > max { 240 | max = r.currentRound 241 | } 242 | } 243 | return max + 1 244 | } 245 | 246 | func (b *basicPrepareResp) Size() int { 247 | return 17 + b.acceptedValue.Size() 248 | } 249 | 250 | type basicAcceptReq struct { 251 | round int 252 | value *basicValue 253 | } 254 | 255 | func (b *basicAcceptReq) Size() int { 256 | return 8 + b.value.Size() 257 | } 258 | 259 | type basicAcceptResp struct { 260 | round int 261 | accepted bool 262 | } 263 | 264 | func (b *basicAcceptResp) Size() int { 265 | return 9 266 | } 267 | 268 | func quorumSize(numNodes int) int { 269 | return numNodes/2 + 1 270 | } 271 | -------------------------------------------------------------------------------- /collcomm/allreduce/stream.go: -------------------------------------------------------------------------------- 1 | package allreduce 2 | 3 | import ( 4 | "github.com/unixpickle/dist-sys/collcomm" 5 | "github.com/unixpickle/dist-sys/simulator" 6 | "github.com/unixpickle/essentials" 7 | ) 8 | 9 | // A StreamAllreducer splits a vector up into smaller 10 | // messages and streams the messages through all the nodes 11 | // at once. 12 | // 13 | // The reduction has two phases: Reduce and Broadcast. 14 | // During Reduce, the fully reduced vector arrives at the 15 | // first node. 16 | // During Broadcast, the reduced vector is streamed from 17 | // the first node to all the other nodes. 18 | type StreamAllreducer struct { 19 | // Granularity determines how many chunks the data is 20 | // split up into. 21 | // The actual number of chunks is multiplied by the 22 | // number of nodes. 23 | // 24 | // If Granularity is 0, it is treated as 1. 25 | Granularity int 26 | } 27 | 28 | // Allreduce calls fn on chunks of data at a time and 29 | // returns a vector resulting from the final reduction. 30 | func (s StreamAllreducer) Allreduce(c *collcomm.Comms, data []float64, 31 | fn collcomm.ReduceFn) []float64 { 32 | if len(data) == 0 || len(c.Ports) == 1 { 33 | return data 34 | } 35 | if c.Index() == 0 { 36 | return s.allreduceRoot(c, data) 37 | } 38 | return s.allreduceOther(c, data, fn) 39 | } 40 | 41 | func (s StreamAllreducer) allreduceRoot(c *collcomm.Comms, data []float64) []float64 { 42 | chunksOut := s.chunkify(c, data) 43 | reduced := make([]float64, 0, len(data)) 44 | 45 | // Kick off the reduction cycle. 46 | (&streamPacket{packetType: streamPacketReduce, payload: chunksOut[0]}).Send(c) 47 | chunksOut = chunksOut[1:] 48 | 49 | // Push the reduction through the ring. 50 | waitingReduceAck := true 51 | for len(reduced) < len(data) { 52 | packet := recvStreamPacket(c) 53 | switch packet.packetType { 54 | case streamPacketReduce: 55 | reduced = append(reduced, packet.payload...) 56 | (&streamPacket{packetType: streamPacketReduceAck}).Send(c) 57 | case streamPacketReduceAck: 58 | if !waitingReduceAck { 59 | panic("unexpected ACK") 60 | } 61 | if len(chunksOut) > 0 { 62 | (&streamPacket{packetType: streamPacketReduce, payload: chunksOut[0]}).Send(c) 63 | chunksOut = chunksOut[1:] 64 | } else { 65 | waitingReduceAck = false 66 | } 67 | default: 68 | panic("unexpected packet type") 69 | } 70 | } 71 | 72 | if len(chunksOut) > 0 { 73 | panic("unexpected reduction completion") 74 | } else if len(reduced) != len(data) { 75 | panic("excess data") 76 | } 77 | 78 | // Push the data through the bcast cycle. 79 | for _, chunk := range s.chunkify(c, reduced) { 80 | (&streamPacket{packetType: streamPacketBcast, payload: chunk}).Send(c) 81 | for { 82 | packet := recvStreamPacket(c) 83 | if packet.packetType == streamPacketReduceAck { 84 | if !waitingReduceAck { 85 | panic("unexpected ACK") 86 | } 87 | waitingReduceAck = false 88 | } else if packet.packetType == streamPacketBcastAck { 89 | break 90 | } else { 91 | panic("unexpected packet type") 92 | } 93 | } 94 | } 95 | 96 | return reduced 97 | } 98 | 99 | func (s StreamAllreducer) allreduceOther(c *collcomm.Comms, data []float64, fn collcomm.ReduceFn) []float64 { 100 | var reduced []float64 101 | 102 | isLastNode := c.Index()+1 == len(c.Ports) 103 | 104 | // Reduce our data into the stream. 105 | var reduceBlocked bool 106 | var reduceBuf []*streamPacket 107 | remainingData := data 108 | for len(reduced) == 0 { 109 | packet := recvStreamPacket(c) 110 | switch packet.packetType { 111 | case streamPacketReduce: 112 | (&streamPacket{packetType: streamPacketReduceAck}).Send(c) 113 | chunk := fn(c.Handle, packet.payload, remainingData[:len(packet.payload)]) 114 | remainingData = remainingData[len(packet.payload):] 115 | outPacket := &streamPacket{packetType: streamPacketReduce, payload: chunk} 116 | reduceBuf = append(reduceBuf, outPacket) 117 | case streamPacketReduceAck: 118 | if !reduceBlocked { 119 | panic("unexpected ACK") 120 | } 121 | reduceBlocked = false 122 | case streamPacketBcast: 123 | if len(reduceBuf) > 0 { 124 | panic("got bcast before reduce finished") 125 | } 126 | reduced = append(reduced, packet.payload...) 127 | (&streamPacket{packetType: streamPacketBcastAck}).Send(c) 128 | if !isLastNode { 129 | // Otherwise, the packet will never reach 130 | // the next node in the ring. 131 | packet.Send(c) 132 | } 133 | default: 134 | panic("unexpected packet type") 135 | } 136 | if !reduceBlocked && len(reduceBuf) > 0 { 137 | reduceBuf[0].Send(c) 138 | essentials.OrderedDelete(&reduceBuf, 0) 139 | reduceBlocked = true 140 | } 141 | } 142 | 143 | // Read the broadcasted reduction. 144 | bcastBlocked := true 145 | var bcastBuf []*streamPacket 146 | for len(reduced) < len(data) || len(bcastBuf) > 0 { 147 | packet := recvStreamPacket(c) 148 | switch packet.packetType { 149 | case streamPacketReduceAck: 150 | if !reduceBlocked { 151 | panic("unexpected ACK") 152 | } 153 | reduceBlocked = false 154 | case streamPacketBcast: 155 | reduced = append(reduced, packet.payload...) 156 | (&streamPacket{packetType: streamPacketBcastAck}).Send(c) 157 | if !isLastNode { 158 | outPacket := &streamPacket{packetType: streamPacketBcast, payload: packet.payload} 159 | bcastBuf = append(bcastBuf, outPacket) 160 | } 161 | case streamPacketBcastAck: 162 | if !bcastBlocked { 163 | panic("unexpected ACK") 164 | } 165 | bcastBlocked = false 166 | default: 167 | panic("unexpected packet type") 168 | } 169 | if !bcastBlocked && len(bcastBuf) > 0 { 170 | bcastBuf[0].Send(c) 171 | essentials.OrderedDelete(&bcastBuf, 0) 172 | bcastBlocked = true 173 | } 174 | } 175 | 176 | if reduceBlocked { 177 | panic("missed expected ACK") 178 | } 179 | 180 | return reduced 181 | } 182 | 183 | func (s StreamAllreducer) chunkify(c *collcomm.Comms, data []float64) [][]float64 { 184 | granularity := s.Granularity 185 | if granularity == 0 { 186 | granularity = 1 187 | } 188 | chunkSize := len(data) / (len(c.Ports) * granularity) 189 | if chunkSize < 1 { 190 | chunkSize = 1 191 | } 192 | var res [][]float64 193 | for i := 0; i < len(data); i += chunkSize { 194 | if i+chunkSize > len(data) { 195 | res = append(res, data[i:]) 196 | } else { 197 | res = append(res, data[i:i+chunkSize]) 198 | } 199 | } 200 | return res 201 | } 202 | 203 | type streamPacketType int 204 | 205 | const ( 206 | streamPacketReduce streamPacketType = iota 207 | streamPacketReduceAck 208 | streamPacketBcast 209 | streamPacketBcastAck 210 | ) 211 | 212 | type streamPacket struct { 213 | packetType streamPacketType 214 | payload []float64 215 | } 216 | 217 | func recvStreamPacket(c *collcomm.Comms) *streamPacket { 218 | msg := c.Port.Recv(c.Handle) 219 | return msg.Message.(*streamPacket) 220 | } 221 | 222 | func (s *streamPacket) Size() float64 { 223 | return float64(len(s.payload)*8) + 1.0 224 | } 225 | 226 | // Send sends the packet to the appropriate host. 227 | // For ACKs, this is the previous host. 228 | // For other messages, this is the next host. 229 | func (s *streamPacket) Send(c *collcomm.Comms) { 230 | idx := c.Index() 231 | var dstIdx int 232 | if s.packetType == streamPacketReduceAck || s.packetType == streamPacketBcastAck { 233 | dstIdx = idx - 1 234 | if dstIdx < 0 { 235 | dstIdx = len(c.Ports) - 1 236 | } 237 | } else { 238 | dstIdx = (idx + 1) % len(c.Ports) 239 | } 240 | c.Network.Send(c.Handle, &simulator.Message{ 241 | Source: c.Port, 242 | Dest: c.Ports[dstIdx], 243 | Message: s, 244 | Size: s.Size(), 245 | }) 246 | } 247 | 248 | type streamChunkInfo struct { 249 | data []float64 250 | start int 251 | } 252 | -------------------------------------------------------------------------------- /simulator/network_test.go: -------------------------------------------------------------------------------- 1 | package simulator 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/rand" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func ExampleNetwork() { 12 | loop := NewEventLoop() 13 | 14 | // A switch with two ports that do I/O at 2 bytes/sec. 15 | switcher := NewGreedyDropSwitcher(2, 2.0) 16 | 17 | node1 := NewNode() 18 | node2 := NewNode() 19 | latency := 0.25 20 | network := NewSwitcherNetwork(switcher, []*Node{node1, node2}, latency) 21 | port1 := node1.Port(loop) 22 | port2 := node2.Port(loop) 23 | 24 | // Goroutine for node 1. 25 | loop.Go(func(h *Handle) { 26 | message := port1.Recv(h).Message.(string) 27 | response := strings.ToUpper(message) 28 | 29 | // Simulate time it took to do the calculation. 30 | h.Sleep(0.125) 31 | 32 | network.Send(h, &Message{ 33 | Source: port1, 34 | Dest: port2, 35 | Message: response, 36 | Size: float64(len(message)), 37 | }) 38 | }) 39 | 40 | // Goroutine for node 2. 41 | loop.Go(func(h *Handle) { 42 | msg := "this should be capitalized" 43 | network.Send(h, &Message{ 44 | Source: port2, 45 | Dest: port1, 46 | Message: msg, 47 | Size: float64(len(msg)), 48 | }) 49 | response := port2.Recv(h).Message.(string) 50 | fmt.Println(response, h.Time()) 51 | }) 52 | 53 | loop.MustRun() 54 | 55 | // Output: THIS SHOULD BE CAPITALIZED 26.625 56 | } 57 | 58 | func TestSwitchedNetworkSingleMessage(t *testing.T) { 59 | loop := NewEventLoop() 60 | 61 | switcher := NewGreedyDropSwitcher(2, 2.0) 62 | node1 := NewNode() 63 | node2 := NewNode() 64 | network := NewSwitcherNetwork(switcher, []*Node{node1, node2}, 3.0) 65 | port1 := node1.Port(loop) 66 | port2 := node2.Port(loop) 67 | 68 | fmt.Println(node1 == node2) 69 | 70 | loop.Go(func(h *Handle) { 71 | network.Send(h, &Message{ 72 | Source: port1, 73 | Dest: port2, 74 | Message: "hi node 2", 75 | Size: 124.0, 76 | }) 77 | if val := port1.Recv(h).Message; val != "hi node 1" { 78 | t.Errorf("unexpected message: %s", val) 79 | } 80 | }) 81 | loop.Go(func(h *Handle) { 82 | network.Send(h, &Message{ 83 | Source: port2, 84 | Dest: port1, 85 | Message: "hi node 1", 86 | Size: 124.0, 87 | }) 88 | if val := port2.Recv(h).Message; val != "hi node 2" { 89 | t.Errorf("unexpected message: %s", val) 90 | } 91 | }) 92 | 93 | if err := loop.Run(); err != nil { 94 | t.Fatal(err) 95 | } 96 | 97 | expectedTime := 124.0/2.0 + 3.0 98 | if loop.Time() != expectedTime { 99 | t.Errorf("time should be %f but got %f", expectedTime, loop.Time()) 100 | } 101 | } 102 | 103 | func TestSwitchedNetworkOversubscribed(t *testing.T) { 104 | loop := NewEventLoop() 105 | 106 | dataRate := 4.0 107 | switcher := NewGreedyDropSwitcher(2, dataRate) 108 | node1 := NewNode() 109 | node2 := NewNode() 110 | network := NewSwitcherNetwork(switcher, []*Node{node1, node2}, 2.0) 111 | port1 := node1.Port(loop) 112 | port2 := node2.Port(loop) 113 | 114 | loop.Go(func(h *Handle) { 115 | network.Send(h, &Message{ 116 | Source: port1, 117 | Dest: port2, 118 | Message: "hi node 2 (message 1)", 119 | Size: 123.0, 120 | }) 121 | network.Send(h, &Message{ 122 | Source: port1, 123 | Dest: port2, 124 | Message: "hi node 2 (message 2)", 125 | Size: 124.0, 126 | }) 127 | if val := port1.Recv(h).Message; val != "hi node 1" { 128 | t.Errorf("unexpected message: %s", val) 129 | } 130 | expectedTime := 1.0 + 2.0 + 124.0/dataRate 131 | if h.Time() != expectedTime { 132 | t.Errorf("expected time %f but got %f", expectedTime, h.Time()) 133 | } 134 | }) 135 | 136 | loop.Go(func(h *Handle) { 137 | // Make sure the other messages are in-flight. 138 | // This helps us test for the fact that we can 139 | // reschedule a message before the other messages. 140 | h.Sleep(1) 141 | 142 | network.Send(h, &Message{ 143 | Source: port2, 144 | Dest: port1, 145 | Message: "hi node 1", 146 | Size: 124.0, 147 | }) 148 | if val := port2.Recv(h).Message; val != "hi node 2 (message 1)" { 149 | t.Errorf("unexpected message: %s", val) 150 | } 151 | expectedTime := 2.0 + 2.0*123.0/dataRate 152 | if h.Time() != expectedTime { 153 | t.Errorf("expected time %f but got %f", expectedTime, h.Time()) 154 | } 155 | if val := port2.Recv(h).Message; val != "hi node 2 (message 2)" { 156 | t.Errorf("unexpected message: %s", val) 157 | } 158 | expectedTime += 1.0 / dataRate 159 | if h.Time() != expectedTime { 160 | t.Errorf("expected time %f but got %f", expectedTime, h.Time()) 161 | } 162 | }) 163 | 164 | if err := loop.Run(); err != nil { 165 | t.Fatal(err) 166 | } 167 | 168 | expectedTime := 2.0 + 2.0*123.0/dataRate + 1.0/dataRate 169 | if loop.Time() != expectedTime { 170 | t.Errorf("time should be %f but got %f", expectedTime, loop.Time()) 171 | } 172 | 173 | // Make sure that there are no stray messages. 174 | for _, port := range []*Port{port1, port2} { 175 | loop.Go(func(h *Handle) { 176 | h.Poll(port.Incoming) 177 | }) 178 | if loop.Run() == nil { 179 | t.Error("expected deadlock error") 180 | } 181 | } 182 | } 183 | 184 | func TestSwitchedNetworkBatchedEquivalence(t *testing.T) { 185 | loop := NewEventLoop() 186 | 187 | dataRate := 4.0 188 | switcher := NewGreedyDropSwitcher(2, dataRate) 189 | node1 := NewNode() 190 | node2 := NewNode() 191 | network := NewSwitcherNetwork(switcher, []*Node{node1, node2}, 2.0) 192 | 193 | testBatchedEquivalence(t, loop, network, node1.Port(loop), node2.Port(loop)) 194 | } 195 | 196 | func testBatchedEquivalence(t *testing.T, loop *EventLoop, network Network, p1, p2 *Port) { 197 | messages := []*Message{} 198 | for i := 0; i < 20; i++ { 199 | messages = append(messages, &Message{ 200 | Source: p1, 201 | Dest: p2, 202 | Message: rand.NormFloat64(), 203 | Size: rand.Float64() + 0.1, 204 | }) 205 | } 206 | 207 | var serialMessages []*Message 208 | var serialTimes []float64 209 | loop.Go(func(h *Handle) { 210 | for _, msg := range messages { 211 | network.Send(h, msg) 212 | } 213 | for range messages { 214 | serialMessages = append(serialMessages, p2.Recv(h)) 215 | serialTimes = append(serialTimes, h.Time()) 216 | } 217 | }) 218 | if err := loop.Run(); err != nil { 219 | t.Fatal(err) 220 | } 221 | 222 | t1 := loop.Time() 223 | 224 | loop.Go(func(h *Handle) { 225 | network.Send(h, messages...) 226 | startTime := h.Time() 227 | for i := range messages { 228 | msg := p2.Recv(h) 229 | if serialMessages[i] != msg { 230 | t.Errorf("msg %d: expected %v but got %v", i, serialMessages[i], msg) 231 | } 232 | curTime := h.Time() - startTime 233 | if math.Abs(curTime-serialTimes[i])/curTime > 1e-5 { 234 | t.Errorf("msg %d: expected time %f but got %f", i, serialTimes[i], curTime) 235 | } 236 | } 237 | }) 238 | if err := loop.Run(); err != nil { 239 | t.Fatal(err) 240 | } 241 | 242 | t2 := loop.Time() 243 | 244 | if math.Abs(t2-2*t1)/t1 > 1e-5 { 245 | t.Errorf("expected end time %f but got %f", 2*t1, t2) 246 | } 247 | } 248 | 249 | func BenchmarkSwitchedNetworkSends(b *testing.B) { 250 | for i := 0; i < b.N; i++ { 251 | loop := NewEventLoop() 252 | nodes := make([]*Node, 8) 253 | ports := make([]*Port, 8) 254 | for i := range nodes { 255 | nodes[i] = NewNode() 256 | ports[i] = nodes[i].Port(loop) 257 | } 258 | switcher := NewGreedyDropSwitcher(len(nodes), 1.0) 259 | network := NewSwitcherNetwork(switcher, nodes, 0.1) 260 | for j := range ports { 261 | port := ports[j] 262 | loop.Go(func(h *Handle) { 263 | for _, other := range ports { 264 | network.Send(h, &Message{ 265 | Source: port, 266 | Dest: other, 267 | Message: "hello", 268 | Size: rand.Float64(), 269 | }) 270 | } 271 | }) 272 | } 273 | if err := loop.Run(); err != nil { 274 | b.Fatal(err) 275 | } 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /simulator/events.go: -------------------------------------------------------------------------------- 1 | package simulator 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "math" 7 | "math/rand" 8 | "sync" 9 | 10 | "github.com/unixpickle/essentials" 11 | ) 12 | 13 | // An EventStream is a uni-directional channel of events 14 | // that are passed through an EventLoop. 15 | // 16 | // It is only safe to use an EventStream on one EventLoop 17 | // at once. 18 | type EventStream struct { 19 | loop *EventLoop 20 | pending []interface{} 21 | } 22 | 23 | // An Event is a message received on some EventStream. 24 | type Event struct { 25 | Message interface{} 26 | Stream *EventStream 27 | } 28 | 29 | // A Timer controls the delayed delivery of an event. 30 | // In particular, a Timer represents a single send that 31 | // will happen in the (virtual) future. 32 | type Timer struct { 33 | time float64 34 | event *Event 35 | } 36 | 37 | // Time gets the time when the timer will be fired. 38 | // 39 | // If the virtual time is lower than a timer's Time(), 40 | // then it is guaranteed that the timer has not fired. 41 | func (t *Timer) Time() float64 { 42 | return t.time 43 | } 44 | 45 | // A Handle is a Goroutine's mechanism for accessing an 46 | // EventLoop. Goroutines should not share Handles. 47 | type Handle struct { 48 | *EventLoop 49 | 50 | // These fields are empty when the Goroutine is 51 | // not polling on any streams. 52 | pollStreams []*EventStream 53 | pollChan chan<- *Event 54 | } 55 | 56 | // Poll waits for the next event from a set of streams. 57 | func (h *Handle) Poll(streams ...*EventStream) *Event { 58 | ch := make(chan *Event, 1) 59 | h.modifyHandles(func() { 60 | if h.pollStreams != nil { 61 | panic("Handle is shared between Goroutines") 62 | } 63 | for _, stream := range streams { 64 | if len(stream.pending) > 0 { 65 | msg := stream.pending[0] 66 | essentials.OrderedDelete(&stream.pending, 0) 67 | ch <- &Event{Message: msg, Stream: stream} 68 | return 69 | } 70 | } 71 | h.pollStreams = streams 72 | h.pollChan = ch 73 | }) 74 | return <-ch 75 | } 76 | 77 | // Schedule creates a Timer for delivering an event. 78 | func (h *Handle) Schedule(stream *EventStream, msg interface{}, delay float64) *Timer { 79 | if stream.loop != h.EventLoop { 80 | panic("EventStream is not associated with the correct EventLoop") 81 | } 82 | var timer *Timer 83 | h.modify(func() { 84 | timer = &Timer{ 85 | time: h.time + delay, 86 | event: &Event{Message: msg, Stream: stream}, 87 | } 88 | if math.IsInf(timer.time, 0) || math.IsNaN(timer.time) { 89 | panic(fmt.Sprintf("invalid deadline: %f", timer.time)) 90 | } 91 | h.timers = append(h.timers, timer) 92 | }) 93 | return timer 94 | } 95 | 96 | // Cancel stops a timer if the timer is scheduled. 97 | // 98 | // If the timer is not scheduled, this has no effect. 99 | func (h *Handle) Cancel(t *Timer) { 100 | h.modify(func() { 101 | for i, timer := range h.timers { 102 | if timer == t { 103 | essentials.UnorderedDelete(&h.timers, i) 104 | } 105 | } 106 | }) 107 | } 108 | 109 | // Sleep waits for a certain amount of virtual time to 110 | // elapse. 111 | func (h *Handle) Sleep(delay float64) { 112 | stream := h.Stream() 113 | h.Schedule(stream, nil, delay) 114 | h.Poll(stream) 115 | } 116 | 117 | // An EventLoop is a global scheduler for events in a 118 | // simulated distributed system. 119 | // 120 | // All Goroutines which access an EventLoop should be 121 | // started using the EventLoop.Go() method. 122 | // 123 | // The event loop will only run when all active Goroutines 124 | // are polling for an event. 125 | // This way, simulated machines don't have to worry about 126 | // real timing while performing computations. 127 | type EventLoop struct { 128 | lock sync.Mutex 129 | timers []*Timer 130 | handles []*Handle 131 | 132 | time float64 133 | 134 | running bool 135 | notifyCh chan struct{} 136 | } 137 | 138 | // NewEventLoop creates an event loop. 139 | // 140 | // The event loop's clock starts at 0. 141 | func NewEventLoop() *EventLoop { 142 | return &EventLoop{notifyCh: make(chan struct{}, 1)} 143 | } 144 | 145 | // Stream creates a new EventStream. 146 | func (e *EventLoop) Stream() *EventStream { 147 | return &EventStream{loop: e} 148 | } 149 | 150 | // Go runs a function in a Goroutine and passes it a new 151 | // handle to the EventLoop. 152 | func (e *EventLoop) Go(f func(h *Handle)) { 153 | h := &Handle{EventLoop: e} 154 | e.lock.Lock() 155 | e.handles = append(e.handles, h) 156 | e.lock.Unlock() 157 | go func() { 158 | f(h) 159 | e.modifyHandles(func() { 160 | for i, handle := range e.handles { 161 | if handle == h { 162 | essentials.UnorderedDelete(&e.handles, i) 163 | return 164 | } 165 | } 166 | panic("cannot free handle that does not exist") 167 | }) 168 | }() 169 | } 170 | 171 | // Run runs the loop and blocks until all handles have 172 | // been closed. 173 | // 174 | // It is not safe to run the loop from more than one 175 | // Goroutine at once. 176 | // 177 | // Returns with an error if there is a deadlock. 178 | func (e *EventLoop) Run() error { 179 | e.lock.Lock() 180 | if e.running { 181 | e.lock.Unlock() 182 | panic("EventLoop is already running.") 183 | } 184 | e.running = true 185 | e.lock.Unlock() 186 | 187 | defer func() { 188 | e.lock.Lock() 189 | e.running = false 190 | e.lock.Unlock() 191 | }() 192 | 193 | for range e.notifyCh { 194 | if shouldContinue, err := e.step(); !shouldContinue { 195 | return err 196 | } 197 | } 198 | 199 | panic("unreachable") 200 | } 201 | 202 | // MustRun is like Run, but it panics if there is a 203 | // deadlock. 204 | func (e *EventLoop) MustRun() { 205 | if err := e.Run(); err != nil { 206 | panic(err) 207 | } 208 | } 209 | 210 | // Time gets the current virtual time. 211 | func (e *EventLoop) Time() float64 { 212 | e.lock.Lock() 213 | defer e.lock.Unlock() 214 | return e.time 215 | } 216 | 217 | // modify calls a function f() such that f can safely 218 | // change the loop state. 219 | // 220 | // This assumes that handle states are not being modified, 221 | // meaning that no scheduling changes can occur. 222 | // If this is not the case, use modifyHandles. 223 | func (e *EventLoop) modify(f func()) { 224 | e.lock.Lock() 225 | defer e.lock.Unlock() 226 | f() 227 | } 228 | 229 | // modifyHandles is like modify(), but it may alter the 230 | // loop state in such a way that scheduling changes occur. 231 | func (e *EventLoop) modifyHandles(f func()) { 232 | e.lock.Lock() 233 | defer func() { 234 | e.lock.Unlock() 235 | select { 236 | case e.notifyCh <- struct{}{}: 237 | default: 238 | } 239 | }() 240 | f() 241 | } 242 | 243 | // step runs the next event on the loop, if possible. 244 | // 245 | // If the event loop can no longer run, the first return 246 | // value is false. 247 | // If this is due to an error, the second argument 248 | // indicates the error. 249 | func (e *EventLoop) step() (bool, error) { 250 | e.lock.Lock() 251 | defer e.lock.Unlock() 252 | 253 | if len(e.handles) == 0 { 254 | return false, nil 255 | } 256 | 257 | for _, h := range e.handles { 258 | if len(h.pollStreams) == 0 { 259 | // Do not run the loop while a Goroutine is 260 | // doing work in real-time. 261 | return true, nil 262 | } 263 | } 264 | 265 | for len(e.timers) > 0 { 266 | // Shuffle so that two timers with the same deadline 267 | // don't execute in a deterministic order. 268 | indices := rand.Perm(len(e.timers)) 269 | 270 | minTimerIdx := indices[0] 271 | for _, i := range indices[1:] { 272 | if e.timers[i].time < e.timers[minTimerIdx].time { 273 | minTimerIdx = i 274 | } 275 | } 276 | timer := e.timers[minTimerIdx] 277 | 278 | essentials.UnorderedDelete(&e.timers, minTimerIdx) 279 | e.time = math.Max(e.time, timer.time) 280 | if e.deliver(timer.event) { 281 | return true, nil 282 | } 283 | } 284 | 285 | return false, errors.New("deadlock: all Handles are polling") 286 | } 287 | 288 | func (e *EventLoop) deliver(event *Event) bool { 289 | // Shuffle the handles so that two receivers don't get 290 | // messages in a deterministic order. 291 | indices := rand.Perm(len(e.handles)) 292 | for _, i := range indices { 293 | h := e.handles[i] 294 | for _, stream := range h.pollStreams { 295 | if stream == event.Stream { 296 | h.pollChan <- event 297 | h.pollChan = nil 298 | h.pollStreams = nil 299 | return true 300 | } 301 | } 302 | } 303 | event.Stream.pending = append(event.Stream.pending, event.Message) 304 | return false 305 | } 306 | -------------------------------------------------------------------------------- /raft/leader.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "context" 5 | "math" 6 | "sort" 7 | 8 | "github.com/unixpickle/dist-sys/simulator" 9 | ) 10 | 11 | type Leader[C Command, S StateMachine[C, S]] struct { 12 | Context context.Context 13 | 14 | // Network configuration 15 | Handle *simulator.Handle 16 | Network simulator.Network 17 | Port *simulator.Port 18 | Followers []*simulator.Port 19 | 20 | // Algorithm state. 21 | Log *Log[C, S] 22 | Term int64 23 | 24 | // Settings 25 | HeartbeatInterval float64 26 | 27 | // followerKnownLogIndices stores the latest of our log 28 | // indices confirmed by each follower. Starts at latest 29 | // commit, and may be reduced or increased based on 30 | // responses. 31 | followerKnownLogIndices []int64 32 | 33 | // followerSentLogIndices is like the above field, but 34 | // may be higher if we have sent some log entries that 35 | // have not been acknowledged. 36 | followerSentLogIndices []int64 37 | 38 | // Used to track when to send new AppendLogs and 39 | // determine if a response is stale. 40 | seqNum int64 41 | lastSeqNums []int64 42 | lastSentTime []float64 43 | 44 | // Used for triggering AppendLogs heartbeats. 45 | timer *simulator.Timer 46 | timerStream *simulator.EventStream 47 | 48 | // Maps log indices to client connections for sending state 49 | // machine results back to clients. 50 | callbacks map[int64]*simulator.Port 51 | commandIDs map[int64]string 52 | } 53 | 54 | // RunLoop runs the leader loop until we stop being the 55 | // leader, at which point the first non-leader message is 56 | // returned. 57 | func (l *Leader[C, S]) RunLoop() *simulator.Message { 58 | l.callbacks = map[int64]*simulator.Port{} 59 | l.commandIDs = map[int64]string{} 60 | l.followerKnownLogIndices = make([]int64, len(l.Followers)) 61 | l.followerSentLogIndices = make([]int64, len(l.Followers)) 62 | l.lastSeqNums = make([]int64, len(l.Followers)) 63 | l.lastSentTime = make([]float64, len(l.Followers)) 64 | l.seqNum = 1 65 | for i := 0; i < len(l.Followers); i++ { 66 | l.followerKnownLogIndices[i] = l.Log.OriginIndex 67 | l.followerSentLogIndices[i] = l.Log.OriginIndex 68 | l.lastSentTime[i] = math.Inf(-1) 69 | } 70 | l.timerStream = l.Handle.Stream() 71 | l.timer = l.Handle.Schedule(l.timerStream, nil, l.HeartbeatInterval/2) 72 | defer func() { 73 | // Timer will be updated every time it fires. 74 | l.Handle.Cancel(l.timer) 75 | }() 76 | 77 | l.sendAppendLogs() 78 | 79 | for { 80 | event := l.Handle.Poll(l.timerStream, l.Port.Incoming) 81 | select { 82 | case <-l.Context.Done(): 83 | return nil 84 | default: 85 | } 86 | if event.Stream == l.timerStream { 87 | l.timer = l.Handle.Schedule(l.timerStream, nil, l.HeartbeatInterval/2) 88 | l.sendAppendLogs() 89 | } else { 90 | msg := event.Message.(*simulator.Message) 91 | if !l.handleMessage(msg) { 92 | return msg 93 | } 94 | } 95 | } 96 | } 97 | 98 | func (l *Leader[C, S]) handleMessage(rawMessage *simulator.Message) bool { 99 | followerIndex := sourcePortIndex(rawMessage, l.Followers) 100 | if followerIndex == -1 { 101 | command := rawMessage.Message.(*CommandMessage[C]) 102 | l.handleCommand(rawMessage.Source, command) 103 | return true 104 | } 105 | 106 | msg := rawMessage.Message.(*RaftMessage[C, S]) 107 | if t := msg.Term(); t < l.Term { 108 | return true 109 | } else if t > l.Term { 110 | return false 111 | } 112 | 113 | resp := msg.AppendLogsResponse 114 | 115 | if resp == nil { 116 | // This could be a residual vote from this term after 117 | // we had enough votes to become leader. 118 | return true 119 | } 120 | 121 | if resp.SeqNum != l.lastSeqNums[followerIndex] { 122 | // This is a stale response. 123 | return true 124 | } 125 | 126 | // Make note that no message is in flight so that we are 127 | // now free to send more messages. 128 | l.lastSeqNums[followerIndex] = 0 129 | 130 | if resp.Success { 131 | l.followerKnownLogIndices[followerIndex] = resp.LatestIndex 132 | l.maybeAdvanceCommit() 133 | 134 | if l.lastSeqNums[followerIndex] != 0 || 135 | l.followerSentLogIndices[followerIndex] == l.Log.OriginIndex+int64(len(l.Log.Entries)) { 136 | // This follower is now up-to-date or is being 137 | // updated by a commit AppendLogs message. 138 | return true 139 | } 140 | } else { 141 | l.followerKnownLogIndices[followerIndex] = resp.CommitIndex 142 | l.followerSentLogIndices[followerIndex] = resp.CommitIndex 143 | } 144 | 145 | // If we are here, there is more to send to this follower, 146 | // either because it failed to handle the last message, or 147 | // because new entries have been added since. 148 | msg = l.appendLogsForFollower(followerIndex) 149 | rawMsg := &simulator.Message{ 150 | Source: l.Port, 151 | Dest: l.Followers[followerIndex], 152 | Message: msg, 153 | Size: float64(msg.Size()), 154 | } 155 | l.Network.Send(l.Handle, rawMsg) 156 | 157 | return true 158 | } 159 | 160 | func (l *Leader[C, S]) handleCommand(port *simulator.Port, command *CommandMessage[C]) { 161 | idx := l.Log.Append(l.Term, command.Command) 162 | l.callbacks[idx] = port 163 | l.commandIDs[idx] = command.ID 164 | l.sendFreshAppendLogs() 165 | } 166 | 167 | func (l *Leader[C, S]) sendFreshAppendLogs() { 168 | for i, lastSent := range l.lastSeqNums { 169 | if lastSent != 0 { 170 | // A message is in flight, so we won't send a new one 171 | // until we get an ack. 172 | continue 173 | } 174 | l.lastSentTime[i] = math.Inf(-1) 175 | } 176 | l.sendAppendLogs() 177 | } 178 | 179 | func (l *Leader[C, S]) sendAppendLogs() { 180 | messages := make([]*simulator.Message, 0, len(l.Followers)) 181 | for i, port := range l.Followers { 182 | if l.Handle.Time() < l.lastSentTime[i]+l.HeartbeatInterval/2 { 183 | // Don't send redundant messages if a keepalive is unneeded. 184 | continue 185 | } 186 | msg := l.appendLogsForFollower(i) 187 | messages = append(messages, &simulator.Message{ 188 | Source: l.Port, 189 | Dest: port, 190 | Message: msg, 191 | Size: float64(msg.Size()), 192 | }) 193 | } 194 | l.Network.Send(l.Handle, messages...) 195 | } 196 | 197 | // appendLogsForFollower creates an AppendLogs message and 198 | // updates our message book-keeping under the assumption 199 | // that the message will be sent. 200 | func (l *Leader[C, S]) appendLogsForFollower(i int) *RaftMessage[C, S] { 201 | logIndex := l.followerSentLogIndices[i] 202 | msg := &RaftMessage[C, S]{AppendLogs: &AppendLogs[C, S]{ 203 | Term: l.Term, 204 | SeqNum: l.seqNum, 205 | CommitIndex: l.Log.OriginIndex, 206 | }} 207 | 208 | // Book-keeping under the assumption that we will send 209 | // this message. 210 | l.lastSeqNums[i] = l.seqNum 211 | l.lastSentTime[i] = l.Handle.Time() 212 | l.followerSentLogIndices[i] = l.Log.OriginIndex + int64(len(l.Log.Entries)) 213 | l.seqNum++ 214 | 215 | if logIndex < l.Log.OriginIndex { 216 | // We must include the entire state machine. 217 | msg.AppendLogs.OriginIndex = l.Log.OriginIndex 218 | msg.AppendLogs.OriginTerm = l.Log.OriginTerm 219 | originState := l.Log.Origin.Clone() 220 | msg.AppendLogs.Origin = &originState 221 | msg.AppendLogs.Entries = append([]LogEntry[C]{}, l.Log.Entries...) 222 | } else if logIndex > l.Log.OriginIndex { 223 | msg.AppendLogs.OriginIndex = logIndex 224 | msg.AppendLogs.OriginTerm = l.Log.Entries[logIndex-l.Log.OriginIndex-1].Term 225 | msg.AppendLogs.Entries = append( 226 | []LogEntry[C]{}, 227 | l.Log.Entries[logIndex-l.Log.OriginIndex:]..., 228 | ) 229 | } else { 230 | msg.AppendLogs.OriginIndex = l.Log.OriginIndex 231 | msg.AppendLogs.OriginTerm = l.Log.OriginTerm 232 | msg.AppendLogs.Entries = append([]LogEntry[C]{}, l.Log.Entries...) 233 | } 234 | return msg 235 | } 236 | 237 | func (l *Leader[C, S]) maybeAdvanceCommit() { 238 | sorted := append([]int64{}, l.followerKnownLogIndices...) 239 | sort.Slice(sorted, func(i int, j int) bool { 240 | return sorted[i] < sorted[j] 241 | }) 242 | minCommit := sorted[len(sorted)/2] 243 | 244 | if minCommit > l.Log.OriginIndex { 245 | commitEntry := l.Log.Entries[minCommit-l.Log.OriginIndex-1] 246 | if commitEntry.Term != l.Term { 247 | // We can only safely commit log entries from the 248 | // current term, and previous entries will be 249 | // implicitly committed as a result. 250 | return 251 | } 252 | 253 | oldCommit := l.Log.OriginIndex 254 | results := l.Log.Commit(minCommit) 255 | 256 | // Now that we have committed, we can send results to clients. 257 | var callbackMessages []*simulator.Message 258 | for i, result := range results { 259 | index := int64(i) + oldCommit 260 | if cb, ok := l.callbacks[index]; ok { 261 | resp := &CommandResponse{Result: result, ID: l.commandIDs[index]} 262 | callbackMessages = append(callbackMessages, &simulator.Message{ 263 | Source: l.Port, 264 | Dest: cb, 265 | Message: resp, 266 | Size: float64(resp.Size()), 267 | }) 268 | delete(l.callbacks, index) 269 | delete(l.commandIDs, index) 270 | } 271 | } 272 | l.Network.Send(l.Handle, callbackMessages...) 273 | 274 | l.sendFreshAppendLogs() 275 | } 276 | } 277 | 278 | func sourcePortIndex(msg *simulator.Message, ports []*simulator.Port) int { 279 | for i, f := range ports { 280 | if msg.Source == f { 281 | return i 282 | } 283 | } 284 | return -1 285 | } 286 | -------------------------------------------------------------------------------- /simulator/network.go: -------------------------------------------------------------------------------- 1 | package simulator 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | "sync" 7 | 8 | "github.com/unixpickle/essentials" 9 | ) 10 | 11 | // A Node represents a machine on a virtual network. 12 | type Node struct { 13 | unused int 14 | } 15 | 16 | // NewNode creates a new, unique Node. 17 | func NewNode() *Node { 18 | return &Node{} 19 | } 20 | 21 | // Port creates a new Port connected to the Node. 22 | func (n *Node) Port(loop *EventLoop) *Port { 23 | return &Port{Node: n, Incoming: loop.Stream()} 24 | } 25 | 26 | // A Port identifies a point of communication on a Node. 27 | // Data is sent from Ports and received on Ports. 28 | type Port struct { 29 | // The Node to which the Port is attached. 30 | Node *Node 31 | 32 | // A stream of *Message objects. 33 | Incoming *EventStream 34 | } 35 | 36 | // Recv receives the next message. 37 | func (p *Port) Recv(h *Handle) *Message { 38 | return h.Poll(p.Incoming).Message.(*Message) 39 | } 40 | 41 | // A Message is a chunk of data sent between nodes over a 42 | // network. 43 | type Message struct { 44 | Source *Port 45 | Dest *Port 46 | Message interface{} 47 | Size float64 48 | } 49 | 50 | // A Network represents an abstract way of communicating 51 | // between nodes. 52 | type Network interface { 53 | // Send message objects from one node to another. 54 | // The message will arrive on the receiving port's 55 | // incoming EventStream if the communication is 56 | // successful. 57 | // 58 | // This is a non-blocking operation. 59 | // 60 | // It is preferrable to pass multiple messages in at 61 | // once, if possible. 62 | // Otherwise, the Network may have to continually 63 | // re-plan the entire message delivery timeline. 64 | Send(h *Handle, msgs ...*Message) 65 | } 66 | 67 | // A RandomNetwork is a network that assigns random delays 68 | // to every message. 69 | type RandomNetwork struct{} 70 | 71 | // Send sends the messages with random delays. 72 | func (r RandomNetwork) Send(h *Handle, msgs ...*Message) { 73 | for _, msg := range msgs { 74 | h.Schedule(msg.Dest.Incoming, msg, rand.Float64()) 75 | } 76 | } 77 | 78 | // A SwitcherNetwork is a network where data is passed 79 | // through a Switcher. Multiple messages along the same 80 | // edge are sent concurrently, potentially making each one 81 | // take longer to arrive at its destination. 82 | type SwitcherNetwork struct { 83 | lock sync.Mutex 84 | 85 | switcher Switcher 86 | nodes []*Node 87 | latency float64 88 | 89 | plan switchedPlan 90 | } 91 | 92 | // NewSwitcherNetwork creates a new SwitcherNetwork. 93 | // 94 | // The latency argument adds an extra constant-length 95 | // timeout to every message delivery. 96 | // The latency period does influence oversubscription, 97 | // so one message's latency period may interfere with 98 | // another message's transmission. 99 | // In practice, this may result in twice the latency-based 100 | // congestion that would actually occur in a network. 101 | func NewSwitcherNetwork(switcher Switcher, nodes []*Node, latency float64) *SwitcherNetwork { 102 | return &SwitcherNetwork{ 103 | switcher: switcher, 104 | nodes: nodes, 105 | latency: latency, 106 | } 107 | } 108 | 109 | // Send sends the message over the network. 110 | // 111 | // This may affect the speed of messages that are already 112 | // being transmitted. 113 | func (s *SwitcherNetwork) Send(h *Handle, msgs ...*Message) { 114 | s.lock.Lock() 115 | defer s.lock.Unlock() 116 | 117 | state := s.stopPlan(h) 118 | for _, msg := range msgs { 119 | state = append(state, &switchedMsg{ 120 | msg: msg, 121 | remainingLatency: s.latency, 122 | remainingSize: msg.Size, 123 | }) 124 | } 125 | s.createPlan(h, state) 126 | } 127 | 128 | func (s *SwitcherNetwork) stopPlan(h *Handle) []*switchedMsg { 129 | var currentState []*switchedMsg 130 | for _, step := range s.plan { 131 | if h.Time() >= step.endTime { 132 | // The timers may have fired, so we let this go. 133 | continue 134 | } 135 | if h.Time() >= step.startTime { 136 | // Interpolate in the current segment. 137 | elapsed := h.Time() - step.startTime 138 | for _, msg := range step.startState { 139 | currentState = append(currentState, msg.AddTime(elapsed)) 140 | } 141 | } 142 | for _, timer := range step.timers { 143 | h.Cancel(timer) 144 | } 145 | } 146 | return currentState 147 | } 148 | 149 | func (s *SwitcherNetwork) computeDataRates(state []*switchedMsg) { 150 | nodeToIndex := map[*Node]int{} 151 | for i, node := range s.nodes { 152 | nodeToIndex[node] = i 153 | } 154 | 155 | // Technically this is a tiny bit incorrect, since the 156 | // latency period isn't taken into account. 157 | // Really, during the latency period, the sender NIC 158 | // is clogged up but the receiver NIC is not. 159 | 160 | mat := NewConnMat(len(s.nodes)) 161 | counts := NewConnMat(len(s.nodes)) 162 | for _, msg := range state { 163 | src, dst := nodeToIndex[msg.msg.Source.Node], nodeToIndex[msg.msg.Dest.Node] 164 | mat.Set(src, dst, 1) 165 | counts.Set(src, dst, counts.Get(src, dst)+1) 166 | } 167 | s.switcher.SwitchedRates(mat) 168 | for _, msg := range state { 169 | src, dst := nodeToIndex[msg.msg.Source.Node], nodeToIndex[msg.msg.Dest.Node] 170 | msg.dataRate = mat.Get(src, dst) / counts.Get(src, dst) 171 | } 172 | } 173 | 174 | func (s *SwitcherNetwork) createPlan(h *Handle, state []*switchedMsg) { 175 | s.plan = make(switchedPlan, 0, len(state)) 176 | startTime := h.Time() 177 | for len(state) > 0 { 178 | s.computeDataRates(state) 179 | 180 | nextMsgs, newState, lowestETA := messagesWithLowestETA(state) 181 | 182 | timers := make([]*Timer, len(nextMsgs)) 183 | for i, msg := range nextMsgs { 184 | delay := startTime - h.Time() + lowestETA 185 | timers[i] = h.Schedule(msg.msg.Dest.Incoming, msg.msg, delay) 186 | } 187 | 188 | endTime := timers[0].Time() 189 | s.plan = append(s.plan, &switchedPlanSegment{ 190 | startTime: startTime, 191 | endTime: endTime, 192 | timers: timers, 193 | startState: state, 194 | }) 195 | 196 | for i, msg := range newState { 197 | newState[i] = msg.AddTime(endTime - startTime) 198 | } 199 | state = newState 200 | startTime = endTime 201 | } 202 | } 203 | 204 | // switchedMsg encodes the state of a message that is 205 | // being sent through the network. 206 | type switchedMsg struct { 207 | msg *Message 208 | 209 | remainingLatency float64 210 | 211 | remainingSize float64 212 | dataRate float64 213 | } 214 | 215 | // ETA gets the time until the message is sent. 216 | func (s *switchedMsg) ETA() float64 { 217 | return math.Max(0, s.remainingLatency+s.remainingSize/s.dataRate) 218 | } 219 | 220 | // AddTime updates the message's state to reflect a 221 | // certain amount of time elapsing. 222 | func (s *switchedMsg) AddTime(t float64) *switchedMsg { 223 | res := *s 224 | 225 | if t < res.remainingLatency { 226 | res.remainingLatency -= t 227 | return &res 228 | } 229 | 230 | t -= res.remainingLatency 231 | res.remainingLatency = 0 232 | res.remainingSize -= res.dataRate * t 233 | 234 | return &res 235 | } 236 | 237 | // switchedPlanSegment represents a period of time during 238 | // which the message state is not changing, aside from 239 | // more data being sent or more latency being paid for. 240 | // 241 | // Each segment ends with at least one Timer, which 242 | // notifies a node about a received message. 243 | type switchedPlanSegment struct { 244 | startTime float64 245 | endTime float64 246 | timers []*Timer 247 | 248 | startState []*switchedMsg 249 | } 250 | 251 | // switchedPlan represents a sequence of switched state 252 | // changes that, together, send all of the current 253 | // messages on the network. 254 | type switchedPlan []*switchedPlanSegment 255 | 256 | func messagesWithLowestETA(msgs []*switchedMsg) (lowest, rest []*switchedMsg, lowestETA float64) { 257 | etas := make([]float64, len(msgs)) 258 | for i, msg := range msgs { 259 | etas[i] = msg.ETA() 260 | } 261 | lowestETA = etas[0] 262 | for _, eta := range etas { 263 | if eta < lowestETA { 264 | lowestETA = eta 265 | } 266 | } 267 | 268 | lowest = make([]*switchedMsg, 0, 1) 269 | rest = make([]*switchedMsg, 0, len(msgs)-1) 270 | 271 | for i, msg := range msgs { 272 | if etas[i] == lowestETA { 273 | lowest = append(lowest, msg) 274 | } else { 275 | rest = append(rest, msg) 276 | } 277 | } 278 | 279 | return lowest, rest, lowestETA 280 | } 281 | 282 | // An OrderedNetwork delivers messages sent to endpoints in 283 | // order, while allowing non-determinism and temporarily 284 | // disconnected nodes. 285 | type OrderedNetwork struct { 286 | Rate float64 287 | MaxRandomLatency float64 288 | 289 | lock sync.Mutex 290 | nextTimes map[*Node]float64 291 | downNodes map[*Node]bool 292 | timers map[*Node][]*Timer 293 | 294 | sniffers map[*EventStream]struct{} 295 | } 296 | 297 | func NewOrderedNetwork(rate float64, maxRandomLatency float64) *OrderedNetwork { 298 | return &OrderedNetwork{ 299 | Rate: rate, 300 | MaxRandomLatency: maxRandomLatency, 301 | nextTimes: map[*Node]float64{}, 302 | downNodes: map[*Node]bool{}, 303 | timers: map[*Node][]*Timer{}, 304 | sniffers: map[*EventStream]struct{}{}, 305 | } 306 | } 307 | 308 | // Send sends the messages over the network in order. 309 | func (o *OrderedNetwork) Send(h *Handle, msgs ...*Message) { 310 | o.lock.Lock() 311 | defer o.lock.Unlock() 312 | 313 | o.cleanupTimers(h) 314 | 315 | curTime := h.Time() 316 | 317 | for _, msg := range msgs { 318 | src := msg.Source.Node 319 | dest := msg.Dest.Node 320 | if o.downNodes[src] || o.downNodes[dest] { 321 | continue 322 | } 323 | for sniffer := range o.sniffers { 324 | h.Schedule(sniffer, msg, 0) 325 | } 326 | latency := rand.Float64() * o.MaxRandomLatency 327 | delay := latency + msg.Size/o.Rate 328 | 329 | var timer *Timer 330 | if t, ok := o.nextTimes[dest]; !ok || t <= curTime { 331 | timer = h.Schedule(msg.Dest.Incoming, msg, delay) 332 | o.nextTimes[dest] = curTime + delay 333 | } else { 334 | timer = h.Schedule(msg.Dest.Incoming, msg, delay+(t-curTime)) 335 | o.nextTimes[dest] = delay + t 336 | } 337 | o.timers[dest] = append(o.timers[dest], timer) 338 | o.timers[src] = append(o.timers[src], timer) 339 | } 340 | } 341 | 342 | // SendInstantly is like Send, but it ignores scheduling to 343 | // ensure the message arrives at the destination in zero 344 | // time. 345 | // 346 | // This method still avoids sending to down nodes. 347 | func (o *OrderedNetwork) SendInstantly(h *Handle, msgs ...*Message) { 348 | o.lock.Lock() 349 | defer o.lock.Unlock() 350 | 351 | for _, msg := range msgs { 352 | src := msg.Source.Node 353 | dest := msg.Dest.Node 354 | if o.downNodes[src] || o.downNodes[dest] { 355 | continue 356 | } 357 | h.Schedule(msg.Dest.Incoming, msg, 0) 358 | } 359 | } 360 | 361 | // Sniff repeatedly calls f with all messages sent on the 362 | // network until f returns false. 363 | func (o *OrderedNetwork) Sniff(h *Handle, f func(*Message) bool) { 364 | stream := h.Stream() 365 | 366 | o.lock.Lock() 367 | o.sniffers[stream] = struct{}{} 368 | o.lock.Unlock() 369 | 370 | defer func() { 371 | o.lock.Lock() 372 | delete(o.sniffers, stream) 373 | o.lock.Unlock() 374 | }() 375 | 376 | for { 377 | event := h.Poll(stream) 378 | if !f(event.Message.(*Message)) { 379 | break 380 | } 381 | } 382 | } 383 | 384 | func (o *OrderedNetwork) cleanupTimers(h *Handle) { 385 | time := h.Time() 386 | o.filterTimer(h, func(t *Timer) bool { 387 | return t.Time() >= time 388 | }) 389 | } 390 | 391 | func (o *OrderedNetwork) SetDown(h *Handle, node *Node, down bool) { 392 | o.lock.Lock() 393 | defer o.lock.Unlock() 394 | 395 | o.downNodes[node] = down 396 | 397 | if !down { 398 | return 399 | } 400 | 401 | delete(o.nextTimes, node) 402 | 403 | // Kill all active messages to and from the node. 404 | o.cleanupTimers(h) 405 | timers := o.timers[node] 406 | delete(o.timers, node) 407 | canceled := map[*Timer]bool{} 408 | for _, t := range timers { 409 | canceled[t] = true 410 | h.Cancel(t) 411 | } 412 | o.filterTimer(h, func(t *Timer) bool { 413 | return !canceled[t] 414 | }) 415 | } 416 | 417 | func (o *OrderedNetwork) IsDown(n *Node) bool { 418 | o.lock.Lock() 419 | defer o.lock.Unlock() 420 | return o.downNodes[n] 421 | } 422 | 423 | func (o *OrderedNetwork) filterTimer(h *Handle, f func(t *Timer) bool) { 424 | var keys []*Node 425 | for k := range o.timers { 426 | keys = append(keys, k) 427 | } 428 | for _, k := range keys { 429 | timers := o.timers[k] 430 | for i := 0; i < len(timers); i++ { 431 | if !f(timers[i]) { 432 | essentials.UnorderedDelete(&timers, i) 433 | i-- 434 | } 435 | } 436 | o.timers[k] = timers 437 | } 438 | } 439 | -------------------------------------------------------------------------------- /raft/raft_test.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "math/rand" 7 | "strconv" 8 | "sync" 9 | "testing" 10 | 11 | "github.com/unixpickle/dist-sys/simulator" 12 | ) 13 | 14 | func TestRaftSimpleCase(t *testing.T) { 15 | t.Run("Latency", func(t *testing.T) { 16 | testRaftSimpleCase(t, 5, true) 17 | }) 18 | t.Run("Instant", func(t *testing.T) { 19 | testRaftSimpleCase(t, 5, false) 20 | }) 21 | } 22 | 23 | func testRaftSimpleCase(t *testing.T, numNodes int, randomized bool) { 24 | env := NewRaftEnvironment(5, 1, randomized) 25 | env.Loop.Go(func(h *simulator.Handle) { 26 | defer env.Cancel() 27 | client := &Client[HashMapCommand]{ 28 | Handle: h, 29 | Network: env.Network, 30 | Port: env.Clients[0], 31 | Servers: env.Servers, 32 | SendTimeout: 10, 33 | } 34 | for i := 0; i < 10; i++ { 35 | value := "hello" + strconv.Itoa(i) 36 | x, err := client.Send(HashMapCommand{Key: strconv.Itoa(i), Value: value}, 0) 37 | if err != nil { 38 | t.Fatal(err) 39 | } else if v := x.(StringResult).Value; v != value { 40 | t.Fatalf("expected %#v but got %#v", v, x) 41 | } 42 | for j := 0; j <= i; j++ { 43 | expected := "hello" + strconv.Itoa(i) 44 | resp, err := client.Send(HashMapCommand{Key: strconv.Itoa(i)}, 0) 45 | if err != nil { 46 | t.Fatal(err) 47 | } else if v := resp.(StringResult).Value; v != expected { 48 | t.Fatalf("expected %#v but got %#v", expected, v) 49 | } 50 | } 51 | } 52 | }) 53 | 54 | env.Loop.MustRun() 55 | } 56 | 57 | func TestRaftNodeFailures(t *testing.T) { 58 | t.Run("Latency", func(t *testing.T) { 59 | testRaftNodeFailures(t, 5, true) 60 | }) 61 | t.Run("Instant", func(t *testing.T) { 62 | testRaftNodeFailures(t, 5, false) 63 | }) 64 | } 65 | 66 | func testRaftNodeFailures(t *testing.T, numNodes int, randomized bool) { 67 | env := NewRaftEnvironment(5, 1, randomized) 68 | env.Loop.Go(func(h *simulator.Handle) { 69 | defer env.Cancel() 70 | client := &Client[HashMapCommand]{ 71 | Handle: h, 72 | Network: env.Network, 73 | Port: env.Clients[0], 74 | Servers: env.Servers, 75 | SendTimeout: 10, 76 | } 77 | downMachine := -1 78 | for i := 0; i < 20; i++ { 79 | value := "hello" + strconv.Itoa(i) 80 | x, err := client.Send(HashMapCommand{Key: strconv.Itoa(i), Value: value}, 0) 81 | if err != nil { 82 | t.Fatal(err) 83 | } else if v := x.(StringResult).Value; v != value { 84 | t.Fatalf("expected %#v but got %#v", v, x) 85 | } 86 | for j := 0; j <= i; j++ { 87 | // With some probability, we bring a machine down or back up, 88 | // then wait a bit for it to potentially fail. 89 | if rand.Intn(2) == 0 { 90 | if downMachine == -1 { 91 | downMachine = rand.Intn(len(env.Servers)) 92 | env.Network.SetDown(h, env.Servers[downMachine].Node, true) 93 | } else { 94 | env.Network.SetDown(h, env.Servers[downMachine].Node, false) 95 | downMachine = -1 96 | } 97 | h.Sleep(rand.Float64() * 45) 98 | } 99 | 100 | expected := "hello" + strconv.Itoa(i) 101 | resp, err := client.Send(HashMapCommand{Key: strconv.Itoa(i)}, 0) 102 | if err != nil { 103 | t.Fatal(err) 104 | } else if v := resp.(StringResult).Value; v != expected { 105 | t.Fatalf("expected %#v but got %#v", expected, v) 106 | } 107 | } 108 | } 109 | }) 110 | 111 | env.Loop.MustRun() 112 | } 113 | 114 | func TestRaftMultiClientNodeFailures(t *testing.T) { 115 | t.Run("Latency", func(t *testing.T) { 116 | testRaftMultiClientNodeFailures(t, 5, true) 117 | }) 118 | t.Run("Instant", func(t *testing.T) { 119 | testRaftMultiClientNodeFailures(t, 5, false) 120 | }) 121 | } 122 | 123 | func testRaftMultiClientNodeFailures(t *testing.T, numNodes int, randomized bool) { 124 | env := NewRaftEnvironment(5, 2, randomized) 125 | 126 | var wg sync.WaitGroup 127 | for i, port := range env.Clients { 128 | wg.Add(1) 129 | keyPrefix := fmt.Sprintf("rank%d-", i) 130 | env.Loop.Go(func(h *simulator.Handle) { 131 | defer wg.Done() 132 | client := &Client[HashMapCommand]{ 133 | Handle: h, 134 | Network: env.Network, 135 | Port: port, 136 | Servers: env.Servers, 137 | SendTimeout: 10, 138 | } 139 | downMachine := -1 140 | for i := 0; i < 20; i++ { 141 | value := "hello" + strconv.Itoa(i) 142 | x, err := client.Send(HashMapCommand{Key: keyPrefix + strconv.Itoa(i), Value: value}, 0) 143 | if err != nil { 144 | t.Fatal(err) 145 | } else if v := x.(StringResult).Value; v != value { 146 | t.Fatalf("expected %#v but got %#v", v, x) 147 | } 148 | for j := 0; j <= i; j++ { 149 | // With some probability, we bring a machine down or back up, 150 | // then wait a bit for it to potentially fail. 151 | if rand.Intn(2) == 0 { 152 | if downMachine == -1 { 153 | downMachine = rand.Intn(len(env.Servers)) 154 | env.Network.SetDown(h, env.Servers[downMachine].Node, true) 155 | } else { 156 | env.Network.SetDown(h, env.Servers[downMachine].Node, false) 157 | downMachine = -1 158 | } 159 | h.Sleep(rand.Float64() * 45) 160 | } 161 | 162 | expected := "hello" + strconv.Itoa(i) 163 | resp, err := client.Send(HashMapCommand{Key: keyPrefix + strconv.Itoa(i)}, 0) 164 | if err != nil { 165 | t.Fatal(err) 166 | } else if v := resp.(StringResult).Value; v != expected { 167 | t.Fatalf("expected %#v but got %#v", expected, v) 168 | } 169 | } 170 | } 171 | }) 172 | } 173 | 174 | go func() { 175 | wg.Wait() 176 | env.Cancel() 177 | }() 178 | 179 | env.Loop.MustRun() 180 | } 181 | 182 | func TestRaftMultiClientInterleavedNodeFailures(t *testing.T) { 183 | // Catch non-deterministic failure 184 | for i := 0; i < 20; i++ { 185 | env := NewRaftEnvironment(5, 2, true) 186 | 187 | // Increase random latency to cause splits / partial sends. 188 | env.Network.MaxRandomLatency = 1.0 189 | 190 | var wg sync.WaitGroup 191 | for i, port := range env.Clients { 192 | wg.Add(1) 193 | keyPrefix := fmt.Sprintf("rank%d-", i) 194 | env.Loop.Go(func(h *simulator.Handle) { 195 | defer wg.Done() 196 | client := &Client[HashMapCommand]{ 197 | Handle: h, 198 | Network: env.Network, 199 | Port: port, 200 | Servers: env.Servers, 201 | 202 | // With lower values, servers seem to get overloaded. 203 | SendTimeout: 30, 204 | } 205 | for i := 0; i < 20; i++ { 206 | value := "hello" + strconv.Itoa(i) 207 | x, err := client.Send(HashMapCommand{Key: keyPrefix + strconv.Itoa(i), Value: value}, 0) 208 | if err != nil { 209 | t.Fatal(err) 210 | } else if v := x.(StringResult).Value; v != value { 211 | t.Fatalf("expected %#v but got %#v", v, x) 212 | } 213 | for j := 0; j <= i; j++ { 214 | expected := "hello" + strconv.Itoa(i) 215 | resp, err := client.Send(HashMapCommand{Key: keyPrefix + strconv.Itoa(i)}, 0) 216 | if err != nil { 217 | t.Fatal(err) 218 | } else if v := resp.(StringResult).Value; v != expected { 219 | t.Fatalf("expected %#v but got %#v", expected, v) 220 | } 221 | h.Sleep(rand.Float64()*5 + 5) 222 | } 223 | } 224 | }) 225 | } 226 | 227 | // Randomly bring down at most two servers at once. 228 | for i := 0; i < 2; i++ { 229 | env.Loop.Go(func(h *simulator.Handle) { 230 | for { 231 | select { 232 | case <-env.Context.Done(): 233 | return 234 | default: 235 | } 236 | h.Sleep(rand.Float64() * 30) 237 | node := rand.Intn(len(env.Servers)) 238 | env.Network.SetDown(h, env.Servers[node].Node, true) 239 | h.Sleep(rand.Float64() * 120) 240 | env.Network.SetDown(h, env.Servers[node].Node, false) 241 | h.Sleep(rand.Float64()*60 + 120) 242 | } 243 | }) 244 | } 245 | 246 | go func() { 247 | wg.Wait() 248 | env.Cancel() 249 | }() 250 | 251 | env.Loop.MustRun() 252 | } 253 | } 254 | 255 | func TestRaftCommitOlderTerm(t *testing.T) { 256 | env := NewRaftEnvironment(5, 1, false) 257 | 258 | env.Loop.Go(func(h *simulator.Handle) { 259 | defer env.Cancel() 260 | 261 | clientPort := env.Clients[0] 262 | 263 | findLeader := func() *simulator.Port { 264 | // Sniff the network until we can discern who the current 265 | // leader is. 266 | var result *simulator.Port 267 | env.Network.Sniff(h, func(msg *simulator.Message) bool { 268 | if rm, ok := msg.Message.(*RaftMessage[HashMapCommand, *HashMap]); ok { 269 | if rm.AppendLogs != nil { 270 | result = msg.Source 271 | return false 272 | } 273 | } 274 | return true 275 | }) 276 | return result 277 | } 278 | 279 | bumpLeaderToTerm := func(leaderPort *simulator.Port, term int64) { 280 | // Send a bogus packet to the leader before it's disconnected 281 | // so that it starts voting with very high terms. 282 | // 283 | // This helps guarantee that this node will be elected leader 284 | // eventually once it comes back up. 285 | msg := &RaftMessage[HashMapCommand, *HashMap]{ 286 | Vote: &Vote{Term: term}, 287 | } 288 | var follower *simulator.Port 289 | for _, port := range env.Servers { 290 | if port != leaderPort && !env.Network.IsDown(port.Node) { 291 | follower = port 292 | break 293 | } 294 | } 295 | env.Network.SendInstantly(h, &simulator.Message{ 296 | Source: follower, 297 | Dest: leaderPort, 298 | Message: msg, 299 | Size: float64(msg.Size()), 300 | }) 301 | env.Network.SendInstantly(h) 302 | } 303 | 304 | // Wait until a leader is almost certainly alive. 305 | h.Sleep(120) 306 | 307 | leader := findLeader() 308 | 309 | msg := &CommandMessage[HashMapCommand]{ 310 | Command: HashMapCommand{Key: "key1", Value: "value1"}, 311 | ID: "id1", 312 | } 313 | env.Network.SendInstantly(h, &simulator.Message{ 314 | Source: clientPort, 315 | Dest: leader, 316 | Message: msg, 317 | Size: float64(msg.Command.Size() + len(msg.ID)), 318 | }) 319 | h.Sleep(1e-8) 320 | // Make sure leader will win next election. 321 | bumpLeaderToTerm(leader, 10) 322 | h.Sleep(1e-8) 323 | env.Network.SetDown(h, leader.Node, true) 324 | 325 | // A new leader should exist which doesn't know about our 326 | // previous log entry. 327 | h.Sleep(120) 328 | newLeader := findLeader() 329 | 330 | // Same deal: give them an entry only they know about. 331 | msg = &CommandMessage[HashMapCommand]{ 332 | Command: HashMapCommand{Key: "key2", Value: "value2"}, 333 | ID: "id2", 334 | } 335 | env.Network.SendInstantly(h, &simulator.Message{ 336 | Source: clientPort, 337 | Dest: newLeader, 338 | Message: msg, 339 | Size: float64(msg.Command.Size() + len(msg.ID)), 340 | }) 341 | h.Sleep(1e-8) 342 | bumpLeaderToTerm(newLeader, 100) 343 | h.Sleep(1e-8) 344 | env.Network.SetDown(h, newLeader.Node, true) 345 | env.Network.SetDown(h, leader.Node, false) 346 | 347 | // Give original leader a chance to erroneously commit 348 | // if it is incorrectly implemented. 349 | h.Sleep(120) 350 | newNewLeader := findLeader() 351 | if newNewLeader != leader { 352 | t.Fatal("expected to obtain original leader") 353 | } 354 | 355 | // Bring back up the second leader. 356 | env.Network.SetDown(h, newLeader.Node, false) 357 | env.Network.SetDown(h, leader.Node, true) 358 | 359 | h.Sleep(120) 360 | newNewNewLeader := findLeader() 361 | if newNewNewLeader != newLeader { 362 | t.Fatal("unexpected leader after long downtime (second leader)") 363 | } 364 | 365 | // Add another key to the second leader to force it 366 | // to commit. 367 | client := &Client[HashMapCommand]{ 368 | Handle: h, 369 | Network: env.Network, 370 | Port: clientPort, 371 | Servers: env.Servers, 372 | 373 | // With lower values, servers seem to get overloaded. 374 | SendTimeout: 30, 375 | } 376 | client.Send(HashMapCommand{Key: "key3", Value: "value3"}, 0) 377 | 378 | // Now we will use the cluster of the three nodes 379 | // were never sent data directly as followers. 380 | // This will tell us what state they have 381 | // internally. 382 | env.Network.SetDown(h, newLeader.Node, true) 383 | h.Sleep(120) 384 | 385 | // Make sure the key from newLeader is committed 386 | // and not the previous one. 387 | res, _ := client.Send(HashMapCommand{Key: "key1"}, 0) 388 | if res.(StringResult).Value != "" { 389 | t.Errorf("unexpected key1 value: %#v", res) 390 | } 391 | res, _ = client.Send(HashMapCommand{Key: "key2"}, 0) 392 | if res.(StringResult).Value != "value2" { 393 | t.Errorf("unexpected key2 value: %#v", res) 394 | } 395 | res, _ = client.Send(HashMapCommand{Key: "key3"}, 0) 396 | if res.(StringResult).Value != "value3" { 397 | t.Errorf("unexpected key3 value: %#v", res) 398 | } 399 | }) 400 | 401 | env.Loop.MustRun() 402 | } 403 | 404 | type RaftEnvironment struct { 405 | Cancel func() 406 | Context context.Context 407 | Servers []*simulator.Port 408 | Clients []*simulator.Port 409 | Loop *simulator.EventLoop 410 | Network *simulator.OrderedNetwork 411 | } 412 | 413 | func NewRaftEnvironment(numServers, numClients int, randomized bool) *RaftEnvironment { 414 | loop := simulator.NewEventLoop() 415 | nodes := []*simulator.Node{} 416 | ports := []*simulator.Port{} 417 | for i := 0; i < numServers+numClients; i++ { 418 | node := simulator.NewNode() 419 | nodes = append(nodes, node) 420 | ports = append(ports, node.Port(loop)) 421 | } 422 | 423 | // The network is always ordered, but may have random latency. 424 | var latency float64 425 | if randomized { 426 | latency = 0.1 427 | } else { 428 | latency = 0.0 429 | } 430 | network := simulator.NewOrderedNetwork(1e6, latency) 431 | 432 | context, cancelFn := context.WithCancel(context.Background()) 433 | 434 | for i := 0; i < numServers; i++ { 435 | index := i 436 | loop.Go(func(h *simulator.Handle) { 437 | var other []*simulator.Port 438 | var port *simulator.Port 439 | for i, p := range ports[:len(ports)-numClients] { 440 | if i == index { 441 | port = p 442 | } else { 443 | other = append(other, p) 444 | } 445 | } 446 | (&Raft[HashMapCommand, *HashMap]{ 447 | Context: context, 448 | Handle: h, 449 | Network: network, 450 | Port: port, 451 | Others: other, 452 | Log: &Log[HashMapCommand, *HashMap]{ 453 | Origin: &HashMap{}, 454 | }, 455 | ElectionTimeout: 30 + float64(index)*2, 456 | HeartbeatInterval: 10, 457 | }).RunLoop() 458 | }) 459 | } 460 | 461 | return &RaftEnvironment{ 462 | Cancel: cancelFn, 463 | Context: context, 464 | Servers: ports[:numServers], 465 | Clients: ports[numServers:], 466 | Loop: loop, 467 | Network: network, 468 | } 469 | } 470 | --------------------------------------------------------------------------------