├── src ├── raft │ ├── py.typed │ ├── __init__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── run_server.py │ │ ├── persistent_log.py │ │ ├── transport.py │ │ └── network.py │ ├── messages.py │ ├── log.py │ └── server.py └── setup.py ├── pytest.ini ├── requirements.txt ├── mypy.ini ├── Makefile ├── tests ├── integration │ ├── test_persistent_log.py │ └── test_network.py ├── unit │ ├── figure_7.py │ ├── test_all_servers.py │ ├── test_candidate.py │ ├── test_client_interaction_multiple_servers.py │ ├── test_elections_multiple_servers.py │ ├── test_voting.py │ ├── test_log_replication_multiple_servers.py │ ├── test_log.py │ ├── test_follower.py │ └── test_leader.py └── e2e │ └── test_basic_log_replication_with_tcp.py ├── README.md └── .gitignore /src/raft/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/raft/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/raft/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = --color=yes -vv --tb=short 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | # trio 3 | coloroma 4 | pytest-icdiff 5 | pylint 6 | -------------------------------------------------------------------------------- /src/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="raft", 5 | version="0.1", 6 | packages=["raft"], 7 | scripts=[], 8 | ) 9 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = False 3 | namespace_packages = False 4 | check_untyped_defs = True 5 | disallow_untyped_calls = True 6 | disallow_untyped_decorators = True 7 | warn_unused_ignores = True 8 | warn_return_any = True 9 | disallow_untyped_defs = False 10 | scripts_are_modules = True 11 | 12 | # enable this to force Dict[x, y] instead of "dict" 13 | disallow_any_generics = False 14 | 15 | # this is maximum hardcore 16 | strict = False 17 | 18 | [mypy-pytest.*,setuptools.*,colorama.*] 19 | ignore_missing_imports = True 20 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | export COMPOSE_DOCKER_CLI_BUILD=1 2 | export DOCKER_BUILDKIT=1 3 | SHELL = zsh 4 | 5 | watch: 6 | ls **/*.py | entr pytest tests/unit 7 | 8 | pylint: ## runs just pylint 9 | pylint -j0 src tests 10 | 11 | black: ## run just black 12 | black -l 99 src tests --check --exclude migrations/* 13 | 14 | mypy: ## run just mypy 15 | dmypy run src tests/* 16 | 17 | lint: pylint black mypy ## run all linters locally 18 | 19 | pretty: ## make the code pretty 20 | black -l 99 src tests 21 | 22 | unit: ## run unit tests 23 | pytest tests/unit 24 | 25 | integration: ## run integration tests 26 | pytest tests/integration 27 | 28 | all: unit integration lint 29 | -------------------------------------------------------------------------------- /tests/integration/test_persistent_log.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=redefined-outer-name 2 | from pathlib import Path 3 | import tempfile 4 | import pytest 5 | from raft.log import Log 6 | from raft.adapters.persistent_log import PersistentLog, Entry 7 | 8 | @pytest.fixture 9 | def temp_path(): 10 | tf = Path(tempfile.NamedTemporaryFile().name) 11 | yield tf 12 | tf.unlink() 13 | 14 | 15 | 16 | def test_can_round_trip_a_log(temp_path): 17 | log = PersistentLog(temp_path) # type: Log 18 | entry1 = Entry(1, 'foo=1') 19 | entry2 = Entry(2, 'foo=2') 20 | log.add_entry(entry1, 0, 0, 0) 21 | log.add_entry(entry2, 1, 1, 0) 22 | new_log = PersistentLog(temp_path) 23 | assert new_log.read() == [entry1, entry2] 24 | -------------------------------------------------------------------------------- /tests/integration/test_network.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import time 3 | from raft.adapters.network import TCPRaftNet 4 | from raft.messages import Message, AppendEntriesSucceeded 5 | from raft.adapters.transport import connect_tenaciously 6 | 7 | def test_sending_message(): 8 | s1net = TCPRaftNet('S1') 9 | s2net = TCPRaftNet('S2') 10 | s1net.start() 11 | s2net.start() 12 | 13 | # check servers are up and listening 14 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 15 | host, port = s1net.host 16 | connect_tenaciously(sock, host=host, port=port) 17 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 18 | host, port = s2net.host 19 | connect_tenaciously(sock, host=host, port=port) 20 | 21 | msg = Message(frm='S1', to='S2', cmd=AppendEntriesSucceeded(3)) 22 | s1net.dispatch(msg) 23 | time.sleep(0.2) 24 | msgs = s2net.get_messages('S2') 25 | assert msgs == [msg] 26 | -------------------------------------------------------------------------------- /tests/unit/figure_7.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from raft.server import Server, Leader, Follower 3 | from raft.log import InMemoryLog, Entry 4 | 5 | logs_from_paper_str = ''' 6 | l,1114455666 7 | a,111445566 8 | b,1114 9 | c,11144556666 10 | d,111445566677 11 | e,1114444 12 | f,11122233333 13 | ''' 14 | 15 | def make_servers() -> Dict[str, Server]: 16 | servers = {} 17 | peers = [c for c in 'labcdef'] 18 | for line in logs_from_paper_str.strip().splitlines(): 19 | name, _, entries = line.strip().partition(',') 20 | log = InMemoryLog([ 21 | Entry(term=int(c), cmd=f'foo={c}') 22 | for c in entries 23 | ]) 24 | args = dict( 25 | name=name, peers=peers, now=0, log=log, currentTerm=int(entries[-1]), votedFor=None 26 | ) 27 | if name == 'l': 28 | servers[name] = Leader(**args) 29 | else: 30 | servers[name] = Follower(**args) 31 | return servers 32 | -------------------------------------------------------------------------------- /src/raft/messages.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Union 3 | from raft.log import Entry 4 | 5 | 6 | @dataclass 7 | class ClientSetCommand: 8 | guid: str 9 | cmd: str 10 | 11 | 12 | @dataclass 13 | class ClientSetSucceeded: 14 | guid: str 15 | 16 | 17 | @dataclass 18 | class AppendEntries: 19 | term: int 20 | leaderId: str 21 | prevLogIndex: int 22 | prevLogTerm: int 23 | entries: List[Entry] 24 | leaderCommit: int 25 | 26 | 27 | @dataclass 28 | class AppendEntriesSucceeded: 29 | matchIndex: int 30 | 31 | 32 | @dataclass 33 | class AppendEntriesFailed: 34 | term: int 35 | 36 | 37 | @dataclass 38 | class RequestVote: 39 | term: int 40 | candidateId: str 41 | lastLogIndex: int 42 | lastLogTerm: int 43 | 44 | 45 | @dataclass 46 | class VoteGranted: 47 | pass 48 | 49 | 50 | @dataclass 51 | class VoteDenied: 52 | term: int 53 | 54 | 55 | @dataclass 56 | class Message: 57 | frm: str 58 | to: str 59 | cmd: Union[ 60 | ClientSetCommand, 61 | AppendEntries, 62 | AppendEntriesSucceeded, 63 | AppendEntriesFailed, 64 | RequestVote, 65 | VoteGranted, 66 | VoteDenied, 67 | ] 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rafting Trip - August 3-7, 2020 2 | 3 | This was my code from David Beazley's amazing [Rafting 4 | Trip](https://dabeaz.com/raft.html) training course, 5 | where we tried to implement the 6 | [raft distributed consensus algorithm](https://raft.github.io/) 7 | in a week. 8 | 9 | I didn't finish, but I got fairly far. Code here for my archive. 10 | 11 | 12 | * [x] transaction log class with accept/reject logic for append requests 13 | * [x] vaguely ports & adapters architecture, with raft protocol classes as core domain 14 | * [x] optional persistent log storage adapter 15 | * [x] log replication leader -> follower, with backtracking algorithm 16 | * [x] heartbeats + time events 17 | * [x] elections, including timeouts, conversion to candidate, voting, promotion to leader 18 | * [x] `RaftNet` abstraction for transport, with fake version for tests 19 | * [x] `run_server()` and `clock_tick()` abstractions for running the algo over time 20 | * [x] ability to unit test / simulate interaction of multiple nodes in-memory 21 | * [x] end-to-end and integration tests 22 | * [ ] client interactions `<--` this is the big missing piece, which might 23 | have caused me to rethink the architecture. may require some intermediary 24 | piece in between Server objects / `run_server()` and clients. 25 | * [ ] cluster membership changes 26 | -------------------------------------------------------------------------------- /src/raft/adapters/run_server.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=redefined-outer-name 2 | import sys 3 | import time 4 | from typing import Tuple 5 | from raft.log import InMemoryLog 6 | from raft.adapters.network import RaftNetwork, TCPRaftNet 7 | from raft.server import Server, Follower 8 | 9 | def run_tcp_server(server: Server, raftnet: RaftNetwork): 10 | print(f'Starting server {server.name}') 11 | while True: 12 | clock_tick(server, raftnet, time.time()) 13 | time.sleep(0.01) 14 | 15 | def clock_tick(server: Server, raftnet: RaftNetwork, now: float): 16 | server.clock_tick(now=now) # am expecting this to handle timeouts, heartbeats, etc 17 | 18 | for m in raftnet.get_messages(server.name): 19 | server.handle_message(m) 20 | 21 | while server.outbox: 22 | m = server.outbox.pop(0) 23 | raftnet.dispatch(m) 24 | 25 | 26 | def _main(name: str) -> Tuple[Server, RaftNetwork]: 27 | # pylint: disable=import-outside-toplevel 28 | raftnet = TCPRaftNet(name) 29 | raftnet.start() 30 | server = Follower( 31 | name=name, log=InMemoryLog([]), currentTerm=0, votedFor=None 32 | ) 33 | import threading 34 | threading.Thread(target=run_tcp_server, args=(server, raftnet), daemon=True).start() 35 | return server, raftnet 36 | 37 | if __name__ == '__main__': 38 | server, raftnet = _main(sys.argv[1]) 39 | -------------------------------------------------------------------------------- /tests/unit/test_all_servers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from raft.log import InMemoryLog 3 | from raft.server import Leader, Follower, Candidate 4 | from raft.messages import Message, AppendEntries, AppendEntriesFailed, RequestVote, VoteDenied 5 | 6 | some_messages = [ 7 | Message(frm="S2", to="S1", cmd=AppendEntries(term=5, leaderId="S2", prevLogIndex=22, prevLogTerm=4, entries=[], leaderCommit=9)), 8 | Message(frm="S2", to="S1", cmd=AppendEntriesFailed(term=5)), 9 | Message(frm="S2", to="S1", cmd=RequestVote(term=5, candidateId="S2", lastLogIndex=22, lastLogTerm=4)), 10 | Message(frm="S2", to="S1", cmd=VoteDenied(term=5)), 11 | ] 12 | @pytest.mark.parametrize('server_class', [Leader, Follower, Candidate]) 13 | @pytest.mark.parametrize('msg', some_messages) 14 | def test_messages_with_higher_term_should_convert_to_follower(server_class, msg): 15 | s = server_class("S1", peers=["S1", "S2", "S3"], now=0, log=InMemoryLog([]), currentTerm=4, votedFor="S3") 16 | s.handle_message(msg) 17 | assert s.currentTerm == 5 18 | assert isinstance(s, Follower) 19 | assert s.votedFor == None 20 | 21 | @pytest.mark.parametrize('server_class', [Leader, Follower, Candidate]) 22 | def test_clock_tick_stores_time(server_class): 23 | s = server_class("S1", peers=["S1", "S2", "S3"], now=0, log=InMemoryLog([]), currentTerm=4, votedFor=None) 24 | assert s.now == 0 25 | s.clock_tick(3) 26 | assert s.now == 3 27 | -------------------------------------------------------------------------------- /src/raft/adapters/persistent_log.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import asdict 3 | from pathlib import Path 4 | from typing import List 5 | 6 | from raft.log import Entry, InMemoryLog 7 | 8 | 9 | class PersistentLog: 10 | 11 | def __init__(self, path: Path): 12 | self.path = path 13 | if not self.path.exists(): 14 | existing_entries = [] 15 | else: 16 | existing_entries = [ 17 | Entry(**entry) for entry in json.loads(self.path.read_text()) 18 | ] 19 | self.log = InMemoryLog(existing_entries) 20 | 21 | @property 22 | def lastLogIndex(self) -> int: 23 | return self.log.lastLogIndex 24 | 25 | @property 26 | def last_log_term(self) -> int: 27 | return self.log.last_log_term 28 | 29 | def check_log(self, prevLogIndex: int, prevLogTerm: int) -> bool: 30 | return self.log.check_log(prevLogIndex, prevLogTerm) 31 | 32 | def add_entry( 33 | self, 34 | entry: Entry, 35 | prevLogIndex: int, 36 | prevLogTerm: int, 37 | leaderCommit: int, 38 | ) -> bool: 39 | result = self.log.add_entry( 40 | entry, prevLogIndex, prevLogTerm, leaderCommit 41 | ) 42 | self._flush() 43 | return result 44 | 45 | def read(self) -> List[Entry]: 46 | return self.log.read() 47 | 48 | def _flush(self) -> None: 49 | self.path.write_text(json.dumps([asdict(e) for e in self.read()])) 50 | -------------------------------------------------------------------------------- /src/raft/adapters/transport.py: -------------------------------------------------------------------------------- 1 | import time 2 | from socket import socket 3 | from colorama import Style 4 | 5 | class ConnectionClosed(ConnectionError): 6 | pass 7 | 8 | 9 | def _log(msg: str) -> None: 10 | print(f'{Style.DIM}[transport] {msg}{Style.RESET_ALL}') 11 | 12 | HOST = '127.0.0.1' 13 | 14 | def send_message(sock: socket, msg: bytes) -> None: 15 | size_prefix = b'%12d' % len(msg) 16 | sock.sendall(size_prefix + msg) 17 | 18 | 19 | def _recv_exactly(sock: socket, num_bytes: int) -> bytes: 20 | # Receive exactly a requested number of bytes on a socket 21 | msg = b'' 22 | while num_bytes: 23 | part = sock.recv(num_bytes) # No guarantee we get the complete message 24 | if part == b'': 25 | raise EOFError(f'Client disconnected with partial message {msg!r}') 26 | msg += part 27 | num_bytes -= len(part) 28 | return msg 29 | 30 | def recv_message(sock: socket) -> bytes: 31 | try: 32 | expected_size = int(_recv_exactly(sock, 12)) 33 | except EOFError: 34 | raise ConnectionClosed('Client disconnected') 35 | return _recv_exactly(sock, expected_size) 36 | 37 | 38 | def connect_tenaciously(s: socket, port: int, host: str = HOST) -> socket: 39 | tries_left = 10 40 | while True: 41 | try: 42 | _log(f'connection attempt {11-tries_left}') 43 | s.connect((host, port)) 44 | return s 45 | except ConnectionRefusedError: 46 | tries_left -= 1 47 | if tries_left == 0: 48 | raise 49 | time.sleep(0.05) 50 | -------------------------------------------------------------------------------- /tests/unit/test_candidate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from raft.server import Follower, Leader, Candidate 3 | from raft.log import InMemoryLog, Entry 4 | from raft.messages import Message, RequestVote, VoteGranted, VoteDenied 5 | 6 | def make_candidate(peers=None) -> Candidate: 7 | if peers is None: 8 | peers = ["S1", "S2", "S3", "S4", "S5"] 9 | c = Follower( 10 | name="S1", 11 | peers=peers, 12 | now=0, 13 | log=InMemoryLog([Entry(1, "foo=1"), Entry(1, "foo=2")]), 14 | currentTerm=10, 15 | votedFor=None, 16 | ) 17 | c._become_candidate() 18 | return c 19 | 20 | def test_votes_granted_below_quorum_do_not_have_immediate_effect(): 21 | c = make_candidate() 22 | c.handle_message(Message(frm="S2", to="S1", cmd=VoteGranted())) 23 | assert isinstance(c, Candidate) 24 | 25 | 26 | def test_becomes_leader_on_first_vote_to_go_above_quorum(): 27 | c = make_candidate() 28 | c.handle_message(Message(frm="S2", to="S1", cmd=VoteGranted())) 29 | assert isinstance(c, Candidate) 30 | c.handle_message(Message(frm="S3", to="S1", cmd=VoteGranted())) 31 | assert isinstance(c, Leader) 32 | 33 | def test_becomes_leader_on_first_vote_in_three_servers_case(): 34 | c = make_candidate(peers=["S1", "S2", "S3"]) 35 | c.handle_message(Message(frm="S2", to="S1", cmd=VoteGranted())) 36 | assert isinstance(c, Leader) 37 | 38 | def test_becoming_leader_resets_leader_state_matchindex_and_nextindex(): 39 | c = make_candidate(peers=["S1", "S2", "S3"]) 40 | # maybe these are hanging around from previous state somehow 41 | c.matchIndex = {"S1": 1, "S2": 2, "S3": 3} 42 | c.nextIndex = {"S1": 1, "S2": 2, "S3": 3} 43 | c.handle_message(Message(frm="S2", to="S1", cmd=VoteGranted())) 44 | assert isinstance(c, Leader) 45 | c.matchIndex = {"S1": 0, "S2": 0, "S3": 0} 46 | c.nextIndex = {"S1": 3, "S2": 3, "S3": 3} 47 | -------------------------------------------------------------------------------- /tests/unit/test_client_interaction_multiple_servers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from raft.adapters.network import FakeRaftNetwork 3 | from raft.adapters.run_server import clock_tick 4 | from raft.log import InMemoryLog, Entry 5 | from raft.messages import Message, ClientSetCommand, ClientSetSucceeded 6 | from raft.server import Leader, Follower, HEARTBEAT_FREQUENCY 7 | 8 | 9 | @pytest.mark.xfail 10 | def test_client_gets_response_but_only_when_new_entry_is_on_a_majority_of_servers(): 11 | peers = ["S1", "S2", "S3"] 12 | leader = Leader( 13 | name="S1", 14 | now=1, 15 | log=InMemoryLog([Entry(term=1, cmd='foo=old1'), Entry(term=1, cmd='foo=old2')]), 16 | peers=peers, 17 | currentTerm=1, 18 | votedFor=None, 19 | ) 20 | f1 = Follower( 21 | name="S2", 22 | peers=peers, 23 | now=1, 24 | log=InMemoryLog([]), 25 | currentTerm=1, 26 | votedFor=None, 27 | ) 28 | f2 = Follower( 29 | name="S3", 30 | peers=peers, 31 | now=1, 32 | log=InMemoryLog([]), 33 | currentTerm=1, 34 | votedFor=None, 35 | ) 36 | client_set = Message(frm="client.id", to="S1", cmd=ClientSetCommand(guid='gooey', cmd="foo=new")) 37 | 38 | raftnet = FakeRaftNetwork([]) 39 | raftnet.dispatch(client_set) 40 | 41 | for i in range(1, 11): 42 | print(f"*** --- CLOCK TIIIIICK {i} --- ***") 43 | clock_tick(leader, raftnet, 1 + i / 100.0) 44 | clock_tick(f1, raftnet, 1 + i / 100.0) 45 | clock_tick(f2, raftnet, 1 + i / 100.0) 46 | if ( 47 | (f1.log.read() and f1.log.read()[-1].cmd != 'foo=new') 48 | and 49 | (f2.log.read() and f2.log.read()[-1].cmd != 'foo=new') 50 | ): 51 | # no quorum yet 52 | assert not [m for m in raftnet._messages if m.to == 'client.id'] 53 | 54 | assert f1.log.read()[-1].cmd == 'foo=new' 55 | assert f2.log.read()[-1].cmd == 'foo=new' 56 | 57 | [response] = raftnet.get_messages('client.id') 58 | assert response.cmd == ClientSetSucceeded(guid='gooey') 59 | -------------------------------------------------------------------------------- /tests/unit/test_elections_multiple_servers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from raft.adapters.network import FakeRaftNetwork 3 | from raft.adapters.run_server import clock_tick 4 | from raft.log import InMemoryLog, Entry 5 | from raft.messages import Message, ClientSetCommand 6 | from raft.server import Leader, Follower, Candidate, MIN_ELECTION_TIMEOUT 7 | import figure_7 8 | 9 | def make_follower(name, peers) -> Follower: 10 | return Follower( 11 | name=name, 12 | peers=peers, 13 | now=0, 14 | log=InMemoryLog([]), 15 | currentTerm=0, 16 | votedFor=None, 17 | ) 18 | 19 | 20 | def test_simple_election(): 21 | peers = ["S1", "S2", "S3"] 22 | f1, f2, f3 = [make_follower(n, peers) for n in peers] 23 | 24 | raftnet = FakeRaftNetwork([]) 25 | start = int(MIN_ELECTION_TIMEOUT * 1000) - 1 26 | for i in range(start, start * 2): 27 | print(f"*** --- CLOCK TIIIIICK {i} --- ***") 28 | clock_tick(f1, raftnet, i / 1000) 29 | clock_tick(f2, raftnet, i / 1000) 30 | clock_tick(f3, raftnet, i / 1000) 31 | 32 | assert any(isinstance(f, Leader) for f in [f1, f2, f3]) 33 | 34 | 35 | @pytest.mark.xfail 36 | def test_figure_7_elections_always_get_committed_logs(): 37 | for _ in range(10): # do this lots of times to get a few random outcomes 38 | servers = figure_7.make_servers() 39 | del servers['l'] # oh noes, what will they do without a leader??? 40 | 41 | raftnet = FakeRaftNetwork([]) 42 | start_ms = int(MIN_ELECTION_TIMEOUT * 1000) - 1 43 | for i in range(start_ms, start_ms * 10): 44 | # print(f"*** --- CLOCK TIIIIICK {i} --- ***") 45 | for _, s in servers.items(): 46 | clock_tick(s, raftnet, i / 1000.0) 47 | 48 | new_logs = '\n'.join( 49 | f'{n}:{"".join(str(e.term) for e in s.log.read())}' 50 | for n, s in servers.items() 51 | ) 52 | print(new_logs) 53 | for n, s in servers.items(): 54 | print(f'Checking log for {s}: {s.log.read()}') 55 | terms = [e.term for e in s.log.read()] 56 | if terms[:9] != list(map(int, '111445566')): 57 | for m in raftnet._message_backups: 58 | if m.to == n or m.frm == n: 59 | print(m) 60 | assert terms[:9] == list(map(int, '111445566')) 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /tests/e2e/test_basic_log_replication_with_tcp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=too-many-locals 2 | import time 3 | import socket 4 | import threading 5 | 6 | from raft.adapters.run_server import run_tcp_server 7 | from raft.adapters.network import TCPRaftNet 8 | from raft.adapters.transport import connect_tenaciously 9 | from raft.log import InMemoryLog, Entry 10 | from raft.messages import Message, ClientSetCommand 11 | from raft.server import Leader, Follower 12 | 13 | 14 | def test_replication_with_tcp_servers(): 15 | networks = {name: TCPRaftNet(name) for name in TCPRaftNet.SERVERS} 16 | for net in networks.values(): 17 | net.start() 18 | 19 | # check networks are up and listening 20 | for net in networks.values(): 21 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 22 | host, port = net.host 23 | connect_tenaciously(sock, host=host, port=port) 24 | 25 | leader_entries = [ 26 | Entry(term=1, cmd="monkeys=1"), 27 | Entry(term=2, cmd="bananas=2"), 28 | Entry(term=2, cmd="turtles=3"), 29 | ] 30 | one_wrong_entry = [ 31 | Entry(term=1, cmd="monkeys=1"), 32 | Entry(term=1, cmd="monkeys=2"), 33 | ] 34 | peers = ["S1", "S2", "S3"] 35 | leader = Leader( 36 | name="S1", 37 | log=InMemoryLog(leader_entries), 38 | peers=peers, 39 | currentTerm=2, 40 | votedFor=None, 41 | ) 42 | f1 = Follower( 43 | name="S2", peers=peers, log=InMemoryLog([]), currentTerm=2, votedFor=None 44 | ) 45 | f2 = Follower( 46 | name="S3", 47 | peers=peers, 48 | log=InMemoryLog(one_wrong_entry), 49 | currentTerm=2, 50 | votedFor=None, 51 | ) 52 | 53 | client_set = Message(frm="client.id", to="S1", cmd=ClientSetCommand(guid='guid', cmd="gherkins=4")) 54 | 55 | leadernet = networks["S1"] 56 | f1net = networks["S2"] 57 | f2net = networks["S3"] 58 | randomnet = networks["S5"] 59 | randomnet.dispatch(client_set) 60 | 61 | # start threads to actually run each server 62 | threading.Thread( 63 | target=run_tcp_server, args=(leader, leadernet), daemon=True 64 | ).start() 65 | threading.Thread(target=run_tcp_server, args=(f1, f1net), daemon=True).start() 66 | threading.Thread(target=run_tcp_server, args=(f2, f2net), daemon=True).start() 67 | 68 | time.sleep(0.3) 69 | 70 | expected = leader_entries + [Entry(term=2, cmd="gherkins=4")] 71 | assert leader.log.read() == expected 72 | assert f1.log.read() == expected 73 | assert f2.log.read() == expected 74 | -------------------------------------------------------------------------------- /tests/unit/test_voting.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from raft.server import Follower 3 | from raft.log import InMemoryLog 4 | from raft.messages import Message, RequestVote, VoteGranted, VoteDenied 5 | 6 | def make_follower(votedFor=None) -> Follower: 7 | return Follower( 8 | name="S1", 9 | peers=["S1", "S2", "S3"], 10 | now=0, 11 | log=InMemoryLog([]), 12 | currentTerm=10, 13 | votedFor=votedFor, 14 | ) 15 | 16 | def make_RequestVote(term: int, lastLogIndex: int, lastLogTerm: int) -> Message: 17 | return Message(frm="S2", to="S1", cmd=RequestVote(term=term, candidateId="S2", lastLogIndex=lastLogIndex, lastLogTerm=lastLogTerm)) 18 | 19 | def test_deny_vote_if_candidate_term_too_old(): 20 | s = make_follower() 21 | s.handle_message(make_RequestVote(term=s.currentTerm -1, lastLogIndex=s.log.lastLogIndex + 1, lastLogTerm=s.log.last_log_term)) 22 | [msg] = s.outbox 23 | assert msg.cmd == VoteDenied(term=s.currentTerm) 24 | 25 | def test_deny_vote_if_lastLogTerm_is_too_low(): 26 | s = make_follower() 27 | s.handle_message(make_RequestVote(term=s.currentTerm, lastLogIndex=s.log.lastLogIndex + 1, lastLogTerm=s.log.last_log_term - 1)) 28 | [msg] = s.outbox 29 | assert msg.cmd == VoteDenied(term=s.currentTerm) 30 | 31 | def test_deny_vote_if_lastLogIndex_is_too_low(): 32 | s = make_follower() 33 | s.handle_message(make_RequestVote(term=s.currentTerm, lastLogIndex=s.log.lastLogIndex - 1, lastLogTerm=s.log.last_log_term - 1)) 34 | [msg] = s.outbox 35 | assert msg.cmd == VoteDenied(term=s.currentTerm) 36 | 37 | 38 | def test_deny_vote_if_already_voted_for_someone_else_in_this_term(): 39 | s = make_follower(votedFor="S3") 40 | s.handle_message(make_RequestVote(term=s.currentTerm, lastLogIndex=s.log.lastLogIndex + 1, lastLogTerm=s.log.last_log_term + 1)) 41 | [msg] = s.outbox 42 | assert msg.cmd == VoteDenied(term=s.currentTerm) 43 | 44 | 45 | def test_grant_vote_if_term_greater_and_logindex_greater_and_lastlogterm_greater(): 46 | s = make_follower() 47 | s.handle_message(make_RequestVote(term=s.currentTerm + 1, lastLogIndex=s.log.lastLogIndex + 1, lastLogTerm=s.log.last_log_term + 1)) 48 | [msg] = s.outbox 49 | assert msg.cmd == VoteGranted() 50 | 51 | 52 | def test_grant_vote_again_if_already_voted_for_same_candidate(): 53 | s = make_follower(votedFor="S2") 54 | s.handle_message(make_RequestVote(term=s.currentTerm + 1, lastLogIndex=s.log.lastLogIndex + 1, lastLogTerm=s.log.last_log_term + 1)) 55 | [msg] = s.outbox 56 | assert msg.cmd == VoteGranted() 57 | 58 | 59 | def test_can_grant_vote_for_same_request_term(): 60 | s = make_follower() 61 | s.handle_message(make_RequestVote(term=s.currentTerm, lastLogIndex=s.log.lastLogIndex + 1, lastLogTerm=s.log.last_log_term + 1)) 62 | [msg] = s.outbox 63 | assert msg.cmd == VoteGranted() 64 | 65 | def test_can_grant_vote_for_same_lastLogTerm(): 66 | s = make_follower() 67 | s.handle_message(make_RequestVote(term=s.currentTerm + 1, lastLogIndex=s.log.lastLogIndex + 1, lastLogTerm=s.log.last_log_term)) 68 | [msg] = s.outbox 69 | assert msg.cmd == VoteGranted() 70 | -------------------------------------------------------------------------------- /src/raft/log.py: -------------------------------------------------------------------------------- 1 | from typing import List, Protocol 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class Entry: 7 | term: int 8 | cmd: str 9 | 10 | 11 | class Log(Protocol): 12 | 13 | @property 14 | def lastLogIndex(self) -> int: 15 | """1-based index of latest entry""" 16 | ... 17 | 18 | @property 19 | def last_log_term(self) -> int: 20 | """term of latest entry""" 21 | ... 22 | 23 | def entry_term(self, index: int) -> int: 24 | """term of entry at (1-based) index position""" 25 | ... 26 | 27 | def entry_at(self, index: int) -> Entry: 28 | """Entry at (1-based) index position""" 29 | ... 30 | 31 | def check_log(self, prevLogIndex: int, prevLogTerm: int) -> bool: 32 | ... 33 | 34 | def add_entry( 35 | self, 36 | entry: Entry, 37 | prevLogIndex: int, 38 | prevLogTerm: int, 39 | leaderCommit: int, 40 | ) -> bool: 41 | ... 42 | 43 | def read(self) -> List[Entry]: 44 | ... 45 | 46 | 47 | class InMemoryLog: 48 | 49 | def __init__(self, log: List[Entry]) -> None: 50 | self._log = log 51 | 52 | def _has_entry_at(self, index: int) -> bool: 53 | """1-based""" 54 | return 0 < index <= len(self._log) 55 | 56 | def _replace_at(self, index: int, entry: Entry) -> None: 57 | """1-based index. truncates any after, unless entry matches""" 58 | if self._has_entry_at(index) and self.entry_at(index) == entry: 59 | return 60 | self._log = self._log[:index - 1] + [entry] 61 | 62 | @property 63 | def lastLogIndex(self) -> int: 64 | return len(self._log) 65 | 66 | @property 67 | def last_log_term(self) -> int: 68 | if len(self._log) == 0: 69 | return 0 70 | return self.entry_term(self.lastLogIndex) 71 | 72 | def entry_term(self, index: int) -> int: 73 | if index == 0: 74 | return 0 75 | return self.entry_at(index).term 76 | 77 | def entry_at(self, index: int) -> Entry: 78 | if index < 0: 79 | return self._log[index] 80 | return self._log[index - 1] 81 | 82 | def check_log(self, prevLogIndex: int, prevLogTerm: int) -> bool: 83 | """check whether prevLogIndex and prevLogTerm match. 1-based index""" 84 | if prevLogIndex == 0: 85 | return True 86 | if not self._has_entry_at(prevLogIndex): 87 | print(f'nope, no entry at {prevLogIndex}') 88 | return False 89 | if self.entry_term(prevLogIndex) != prevLogTerm: 90 | print(f'nope, entry at {prevLogIndex} had wrong term') 91 | return False 92 | return True 93 | 94 | def add_entry( 95 | self, 96 | entry: Entry, 97 | prevLogIndex: int, # 1-based 98 | prevLogTerm: int, 99 | leaderCommit: int, # 1-based, ignored for now. # TODO: remove? 100 | ) -> bool: 101 | if not self.check_log(prevLogIndex, prevLogTerm): 102 | return False 103 | self._replace_at(prevLogIndex + 1, entry) 104 | return True 105 | 106 | 107 | def read(self) -> List[Entry]: 108 | return self._log 109 | -------------------------------------------------------------------------------- /tests/unit/test_log_replication_multiple_servers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from raft.adapters.network import FakeRaftNetwork 3 | from raft.adapters.run_server import clock_tick 4 | from raft.log import InMemoryLog, Entry 5 | from raft.messages import Message, ClientSetCommand 6 | from raft.server import Leader, Follower, HEARTBEAT_FREQUENCY 7 | import figure_7 8 | 9 | 10 | def test_replication_one_server_simple_case(): 11 | leader = Leader( 12 | name="S1", 13 | now=1, 14 | log=InMemoryLog([]), 15 | peers=["S1", "S2"], 16 | currentTerm=1, 17 | votedFor=None, 18 | ) 19 | follower = Follower( 20 | name="S2", 21 | peers=["S1", "S2"], 22 | now=1, 23 | log=InMemoryLog([]), 24 | currentTerm=1, 25 | votedFor=None, 26 | ) 27 | client_set = Message(frm="client.id", to="S1", cmd=ClientSetCommand(guid='gooey', cmd="foo=1")) 28 | 29 | raftnet = FakeRaftNetwork([]) 30 | raftnet.dispatch(client_set) 31 | 32 | clock_tick(leader, raftnet, 1) 33 | clock_tick(follower, raftnet, 1) 34 | assert follower.log.read()[-1].cmd == "foo=1" 35 | 36 | 37 | def test_replication_multiple_servers_simple_case(): 38 | peers = ["S1", "S2", "S3"] 39 | leader = Leader( 40 | name="S1", now=1, log=InMemoryLog([]), peers=peers, currentTerm=1, votedFor=None 41 | ) 42 | f1 = Follower( 43 | name="S2", peers=peers, now=1, log=InMemoryLog([]), currentTerm=1, votedFor=None 44 | ) 45 | f2 = Follower( 46 | name="S3", peers=peers, now=1, log=InMemoryLog([]), currentTerm=1, votedFor=None 47 | ) 48 | 49 | client_set = Message(frm="client.id", to="S1", cmd=ClientSetCommand(guid='gooey', cmd="foo=1")) 50 | 51 | raftnet = FakeRaftNetwork([]) 52 | raftnet.dispatch(client_set) 53 | clock_tick(leader, raftnet, 1) 54 | clock_tick(f1, raftnet, 1) 55 | clock_tick(f2, raftnet, 1) 56 | assert f1.log.read()[-1].cmd == "foo=1" 57 | assert f2.log.read()[-1].cmd == "foo=1" 58 | 59 | 60 | def test_replication_backtracking(): 61 | peers = ["S1", "S2", "S3"] 62 | leader_entries = [ 63 | Entry(term=1, cmd="monkeys=1"), 64 | Entry(term=2, cmd="bananas=2"), 65 | Entry(term=2, cmd="turtles=3"), 66 | ] 67 | one_wrong_entry = [ 68 | Entry(term=1, cmd="monkeys=1"), 69 | Entry(term=1, cmd="monkeys=2"), 70 | ] 71 | 72 | leader = Leader( 73 | name="S1", 74 | now=1, 75 | log=InMemoryLog(leader_entries), 76 | peers=peers, 77 | currentTerm=2, 78 | votedFor=None, 79 | ) 80 | f1 = Follower( 81 | name="S2", peers=peers, now=1, log=InMemoryLog([]), currentTerm=2, votedFor=None 82 | ) 83 | f2 = Follower( 84 | name="S3", 85 | peers=peers, 86 | now=1, 87 | log=InMemoryLog(one_wrong_entry), 88 | currentTerm=2, 89 | votedFor=None, 90 | ) 91 | 92 | client_set = Message(frm="client.id", to="S1", cmd=ClientSetCommand(guid='gooey', cmd="gherkins=4")) 93 | 94 | raftnet = FakeRaftNetwork([]) 95 | raftnet.dispatch(client_set) 96 | 97 | for i in range(1, 11): # IDEA: while raftnet.messages? 98 | print(f"*** --- CLOCK TIIIIICK {i} --- ***") 99 | clock_tick(leader, raftnet, 1 + i / 100.0) 100 | clock_tick(f1, raftnet, 1 + i / 100.0) 101 | clock_tick(f2, raftnet, 1 + i / 100.0) 102 | 103 | expected = leader_entries + [Entry(term=2, cmd="gherkins=4")] 104 | 105 | assert leader.log.read() == expected 106 | assert f1.log.read() == expected 107 | assert f2.log.read() == expected 108 | 109 | 110 | def test_figure_seven_from_paper(): 111 | servers = figure_7.make_servers() 112 | print(servers) 113 | 114 | raftnet = FakeRaftNetwork([]) 115 | one_heartbeat_in = HEARTBEAT_FREQUENCY + 0.0001 116 | 117 | for i in range(1, 100): 118 | print(f"*** --- CLOCK TIIIIICK {i} --- ***") 119 | for _, s in servers.items(): 120 | clock_tick(s, raftnet, one_heartbeat_in + i / 1000.0) 121 | 122 | for n, s in servers.items(): 123 | print(f'Checking log for server {n}: {s.log.read()}') 124 | terms = [e.term for e in s.log.read()] 125 | assert terms[:10] == list(map(int, '1114455666')) 126 | -------------------------------------------------------------------------------- /src/raft/adapters/network.py: -------------------------------------------------------------------------------- 1 | from colorama import Fore, Style 2 | from enum import Enum 3 | from typing import Any, Dict, List, Protocol 4 | from dataclasses import dataclass 5 | import pickle 6 | import queue 7 | import socket 8 | import threading 9 | 10 | from raft.messages import Message 11 | from raft.adapters import transport 12 | 13 | 14 | def _tid() -> str: 15 | tid = threading.get_ident() 16 | return f'Thread-{tid}' 17 | 18 | 19 | class RaftNetwork(Protocol): 20 | 21 | def get_messages(self, to: str) -> List[Message]: 22 | ... 23 | 24 | def dispatch(self, msg: Message) -> None: 25 | ... 26 | 27 | 28 | class FakeRaftNetwork: 29 | def __init__(self, messages: List[Message]): 30 | self._messages = messages 31 | self._message_backups = [] # type: List[Message] 32 | 33 | def get_messages(self, to: str) -> List[Message]: 34 | """retrieve messages for someone, and take them out of the network""" 35 | theirs = [m for m in self._messages if m.to == to] 36 | self._message_backups.extend(theirs) 37 | for m in theirs: 38 | self._messages.remove(m) 39 | return theirs 40 | 41 | def dispatch(self, msg: Message) -> None: 42 | """put the message into the network""" 43 | self._messages.append(msg) 44 | 45 | 46 | 47 | # -- Dave's code, modified 48 | 49 | def receive_message(sock: socket.socket) -> Message: 50 | raw = transport.recv_message(sock) 51 | result = pickle.loads(raw) # TODO: messages include user-submitted data. so this is terrible 52 | assert isinstance(result, Message) 53 | return result 54 | 55 | def send_message(msg: Message, sock: socket.socket) -> None: 56 | transport.send_message(sock, pickle.dumps(msg)) 57 | 58 | 59 | @dataclass 60 | class Host: 61 | name: str 62 | hostname: str 63 | port: int 64 | 65 | 66 | class TCPRaftNet: 67 | SERVERS = { 68 | 'S1': ('localhost', 16001), 69 | 'S2': ('localhost', 16002), 70 | 'S3': ('localhost', 16003), 71 | 'S4': ('localhost', 16004), 72 | 'S5': ('localhost', 16005), 73 | } 74 | 75 | def __init__(self, name) -> None: 76 | self.name = name 77 | self.host = self.SERVERS[name] 78 | # Message queues. There is a separate outgoing queue for each destination. 79 | # There is a single incoming queue for all received messages. 80 | self._outgoing = { 81 | name: queue.Queue() 82 | for name in self.SERVERS 83 | if name != self.name 84 | } # type: Dict[str, queue.Queue] 85 | self._incoming = queue.Queue() # type: queue.Queue[Message] 86 | 87 | def _debug(self, msg) -> None: 88 | print(f'{Fore.YELLOW}[raftnet][{self.name}][{_tid()}] {msg}{Style.RESET_ALL}') 89 | 90 | def dispatch(self, msg: Message) -> None: 91 | # Drop the message in a queue and walk away immediately. Does NOT block. 92 | self._outgoing[msg.to].put(msg) 93 | 94 | 95 | def get_messages(self, to: str) -> List[Message]: 96 | assert to == self.name # TODO: this argument is only really needed so 1 FakeRaftNetwork can be shared amongst multiple servers. eh. 97 | messages = [] 98 | while not self._incoming.empty(): 99 | messages.append(self._incoming.get()) 100 | return messages 101 | 102 | 103 | def start(self) -> None: 104 | # Launch various threads related to the network component 105 | 106 | # Thread responsible for listening for incoming connections 107 | threading.Thread(target=self.acceptor_thread, daemon=True).start() 108 | 109 | # Threads dedicated to sending outgoing messages 110 | for server_name in self._outgoing: 111 | threading.Thread(target=self.sender_thread, args=(server_name,), daemon=True).start() 112 | 113 | def acceptor_thread(self): 114 | self._debug('Starting acceptor thread') 115 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 116 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) 117 | sock.bind(self.host) 118 | sock.listen() 119 | self._debug('Acceptor thread waiting for connections') 120 | while True: 121 | client, _ = sock.accept() 122 | threading.Thread(target=self.receiver_thread, args=(client,), daemon=True).start() 123 | 124 | def receiver_thread(self, sock): 125 | self._debug('Starting receiver thread') 126 | # Thread that deals with messages 127 | try: 128 | while True: 129 | msg = receive_message(sock) 130 | self._debug(f'Received msg {msg}') 131 | self._incoming.put(msg) 132 | except (EOFError, ConnectionError): 133 | sock.close() 134 | 135 | def sender_thread(self, server_name): 136 | self._debug(f'starting sender thread') 137 | sock = None 138 | while True: 139 | msg = self._outgoing[server_name].get() 140 | self._debug(f'sender: sending msg {msg}') 141 | # Make some kind of best-effort to send the message 142 | if sock is None: 143 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 144 | try: 145 | sock.connect(self.SERVERS[server_name]) 146 | except OSError: 147 | sock.close() 148 | sock = None 149 | continue # Oh well. Throw the message away. 150 | 151 | try: 152 | self._debug('sender: actually sending now') 153 | send_message(msg, sock) 154 | except OSError: 155 | sock.close() 156 | sock = None 157 | -------------------------------------------------------------------------------- /tests/unit/test_log.py: -------------------------------------------------------------------------------- 1 | from raft.log import InMemoryLog, Entry 2 | 3 | def test_some_helpers(): 4 | entries = [ 5 | Entry(term=1, cmd="foo=1"), 6 | Entry(term=2, cmd="foo=2"), 7 | Entry(term=3, cmd="foo=3"), 8 | ] 9 | log = InMemoryLog(entries) 10 | assert log.lastLogIndex == 3 11 | assert log.entry_term(1) == 1 12 | assert log.entry_term(2) == 2 13 | assert log.entry_term(3) == 3 14 | assert log.entry_term(0) == 0 15 | assert log.entry_term(-1) == 3 16 | 17 | assert log.entry_at(1) == entries[0] 18 | assert log.entry_at(2) == entries[1] 19 | assert log.entry_at(3) == entries[2] 20 | assert log.entry_at(-1) == entries[2] 21 | 22 | 23 | 24 | def test_add_entry_happy_path(): 25 | old_entry = Entry(term=1, cmd="foo=1") 26 | log = InMemoryLog([old_entry]) 27 | new_entry = Entry(term=2, cmd="foo=2") 28 | result = log.add_entry(new_entry, prevLogIndex=1, prevLogTerm=1, leaderCommit=0) 29 | assert log.read()[-1] == new_entry 30 | assert result is True 31 | 32 | 33 | def test_add_first_entry(): 34 | log = InMemoryLog([]) 35 | new_entry = Entry(term=1, cmd="foo=2") 36 | log.add_entry(new_entry, prevLogIndex=0, prevLogTerm=0, leaderCommit=0) 37 | assert log.read() == [new_entry] 38 | 39 | 40 | def test_idempotent_at_end(): 41 | old_entry = Entry(term=1, cmd="foo=1") 42 | log = InMemoryLog([old_entry]) 43 | new_entry = Entry(term=1, cmd="foo=2") 44 | log.add_entry(new_entry, prevLogIndex=1, prevLogTerm=1, leaderCommit=0) 45 | result = log.add_entry(new_entry, prevLogIndex=1, prevLogTerm=1, leaderCommit=0) 46 | assert log.read() == [old_entry, new_entry] 47 | assert result is True 48 | 49 | 50 | def test_cannot_add_past_end(): 51 | old_entry = Entry(term=1, cmd="foo=1") 52 | log = InMemoryLog([old_entry]) 53 | new_entry = Entry(term=1, cmd="foo=2") 54 | result = log.add_entry(new_entry, prevLogIndex=2, prevLogTerm=1, leaderCommit=0) 55 | assert log.read() == [old_entry] 56 | assert result is False 57 | 58 | 59 | def test_cannot_add_if_prevLogTerm_does_not_Match(): 60 | old_entry = Entry(term=1, cmd="foo=1") 61 | log = InMemoryLog([old_entry]) 62 | new_entry = Entry(term=1, cmd="foo=2") 63 | result = log.add_entry(new_entry, prevLogIndex=1, prevLogTerm=2, leaderCommit=0) 64 | assert log.read() == [old_entry] 65 | assert result is False 66 | 67 | 68 | def test_cannot_add_if_prevLogTerm_does_not_match(): 69 | old_entry = Entry(term=1, cmd="foo=1") 70 | log = InMemoryLog([old_entry]) 71 | new_entry = Entry(term=1, cmd="foo=2") 72 | result = log.add_entry(new_entry, prevLogIndex=1, prevLogTerm=2, leaderCommit=0) 73 | assert log.read() == [old_entry] 74 | assert result is False 75 | 76 | 77 | def test_can_overwrite_one_if_prevLogTerm_matches(): 78 | old_log = [Entry(term=1, cmd="foo=1"), Entry(term=1, cmd="foo=2")] 79 | log = InMemoryLog(old_log) 80 | new_entry = Entry(term=2, cmd="foo=3") 81 | result = log.add_entry(new_entry, prevLogIndex=1, prevLogTerm=1, leaderCommit=0) 82 | assert log.read() == [old_log[0], new_entry] 83 | assert result is True 84 | 85 | 86 | def test_edge_case_can_ovewrite_zeroth_entry_if_its_the_only_one(): 87 | old_entry = Entry(term=1, cmd="foo=1") 88 | log = InMemoryLog([old_entry]) 89 | new_entry = Entry(term=1, cmd="foo=2") 90 | result = log.add_entry(new_entry, prevLogIndex=0, prevLogTerm=0, leaderCommit=0) 91 | assert log.read() == [new_entry] 92 | assert result is True 93 | 94 | 95 | def test_edge_case_can_ovewrite_zeroth_entry_and_all_following(): 96 | old_log = [ 97 | Entry(term=1, cmd="foo=1"), 98 | Entry(term=1, cmd="foo=2"), 99 | Entry(term=1, cmd="foo=3"), 100 | ] 101 | log = InMemoryLog(old_log) 102 | new_entry = Entry(term=2, cmd="bar=1") 103 | result = log.add_entry(new_entry, prevLogIndex=0, prevLogTerm=0, leaderCommit=0) 104 | assert log.read() == [new_entry] 105 | assert result is True 106 | 107 | 108 | def test_valid_overwrite_in_the_middle_of_the_log_kills_all_later_ones(): 109 | old_log = [ 110 | Entry(term=1, cmd="foo=1"), 111 | Entry(term=1, cmd="foo=2"), 112 | Entry(term=1, cmd="foo=3"), 113 | ] 114 | log = InMemoryLog(old_log) 115 | new_entry = Entry(term=2, cmd="bar=1") 116 | result = log.add_entry(new_entry, prevLogIndex=1, prevLogTerm=1, leaderCommit=0) 117 | assert log.read() == [old_log[0], new_entry] 118 | assert result is True 119 | 120 | 121 | def test_do_not_overwrite_if_new_entry_matches(): 122 | old_log = [ 123 | Entry(term=1, cmd="foo=1"), 124 | Entry(term=1, cmd="foo=2"), 125 | Entry(term=1, cmd="foo=3"), 126 | ] 127 | log = InMemoryLog(old_log) 128 | new_entry = old_log[1] 129 | result = log.add_entry(new_entry, prevLogIndex=1, prevLogTerm=1, leaderCommit=0) 130 | assert log.read() == old_log 131 | assert result is True 132 | 133 | 134 | def test_check_log_happy_path(): 135 | old_entry = Entry(term=1, cmd="foo=1") 136 | log = InMemoryLog([old_entry]) 137 | result = log.check_log(prevLogIndex=1, prevLogTerm=1) 138 | assert result is True 139 | 140 | 141 | def test_check_log_when_empty(): 142 | log = InMemoryLog([]) 143 | result = log.check_log(prevLogIndex=0, prevLogTerm=0) 144 | assert result is True 145 | 146 | def test_prevLogIndex_zero_is_always_true(): 147 | log = InMemoryLog([Entry(term=1, cmd='foo=1')]) 148 | result = log.check_log(prevLogIndex=0, prevLogTerm=0) 149 | assert result is True 150 | 151 | 152 | def test_check_log_index_past_end(): 153 | old_entry = Entry(term=1, cmd="foo=1") 154 | log = InMemoryLog([old_entry]) 155 | result = log.check_log(prevLogIndex=2, prevLogTerm=1) 156 | assert result is False 157 | 158 | 159 | def test_prevLogTerm_does_not_Match(): 160 | old_entry = Entry(term=1, cmd="foo=1") 161 | log = InMemoryLog([old_entry]) 162 | result = log.check_log(prevLogIndex=1, prevLogTerm=2) 163 | assert result is False 164 | -------------------------------------------------------------------------------- /tests/unit/test_follower.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from raft.server import Follower, ELECTION_TIMEOUT_JITTER, MIN_ELECTION_TIMEOUT 3 | from raft.log import InMemoryLog, Entry 4 | from raft.messages import ( 5 | AppendEntries, 6 | AppendEntriesSucceeded, 7 | AppendEntriesFailed, 8 | RequestVote, 9 | Message, 10 | ) 11 | 12 | 13 | def test_append_entries_adds_to_local_log_and_returns_success_response(): 14 | log = InMemoryLog([]) 15 | s = Follower( 16 | name="S2", 17 | peers=["S1", "S2", "S3"], 18 | now=1, 19 | log=log, 20 | currentTerm=1, 21 | votedFor=None, 22 | ) 23 | old_timeout = s._election_timeout 24 | 25 | new_entry = Entry(term=1, cmd="foo=bar") 26 | s.now = 2 27 | s.handle_message( 28 | Message( 29 | frm="S1", 30 | to="S2", 31 | cmd=AppendEntries( 32 | term=1, 33 | leaderId="S1", 34 | prevLogIndex=0, 35 | prevLogTerm=0, 36 | leaderCommit=0, 37 | entries=[new_entry], 38 | ), 39 | ) 40 | ) 41 | assert s.log.read() == [new_entry] 42 | expected_response = AppendEntriesSucceeded(matchIndex=1) 43 | assert s.outbox == [Message(frm="S2", to="S1", cmd=expected_response)] 44 | assert s._election_timeout > old_timeout 45 | 46 | 47 | def test_append_entries_success_returns_matchindex_at_given_position(): 48 | log = InMemoryLog([ 49 | Entry(term=1, cmd='cmd=1'), 50 | Entry(term=1, cmd='cmd=2'), 51 | ]) 52 | s = Follower( 53 | name="S2", 54 | peers=["S1", "S2", "S3"], 55 | now=1, 56 | log=log, 57 | currentTerm=1, 58 | votedFor=None, 59 | ) 60 | new_entry = Entry(term=1, cmd="cmd=3") 61 | s.now = 2 62 | s.handle_message( 63 | Message( 64 | frm="S1", 65 | to="S2", 66 | cmd=AppendEntries( 67 | term=1, 68 | leaderId="S1", 69 | prevLogIndex=2, 70 | prevLogTerm=1, 71 | leaderCommit=0, 72 | entries=[new_entry], 73 | ), 74 | ) 75 | ) 76 | expected_response = AppendEntriesSucceeded(matchIndex=3) 77 | assert s.outbox == [Message(frm="S2", to="S1", cmd=expected_response)] 78 | 79 | 80 | def test_append_entries_with_no_entry_aka_heartbeat_at_zero(): 81 | log = InMemoryLog([]) 82 | s = Follower( 83 | name="S2", 84 | peers=["S1", "S2", "S3"], 85 | now=1, 86 | log=log, 87 | currentTerm=1, 88 | votedFor=None, 89 | ) 90 | old_timeout = s._election_timeout 91 | s.now = 2 92 | s.handle_message( 93 | Message( 94 | frm="S1", 95 | to="S2", 96 | cmd=AppendEntries( 97 | term=1, 98 | leaderId="S1", 99 | prevLogIndex=0, 100 | prevLogTerm=0, 101 | leaderCommit=0, 102 | entries=[], 103 | ), 104 | ) 105 | ) 106 | assert s.log.read() == [] 107 | expected_response = AppendEntriesSucceeded(matchIndex=0) 108 | assert s.outbox == [Message(frm="S2", to="S1", cmd=expected_response)] 109 | assert s._election_timeout > old_timeout 110 | 111 | 112 | def test_append_entries_with_no_entry_aka_heartbeat_at_nonzero(): 113 | log = InMemoryLog([Entry(term=1, cmd="foo=1"), Entry(term=1, cmd="foo=2")]) 114 | s = Follower( 115 | name="S2", 116 | peers=["S1", "S2", "S3"], 117 | now=1, 118 | log=log, 119 | currentTerm=1, 120 | votedFor=None, 121 | ) 122 | s.handle_message( 123 | Message( 124 | frm="S1", 125 | to="S2", 126 | cmd=AppendEntries( 127 | term=1, 128 | leaderId="S1", 129 | prevLogIndex=2, 130 | prevLogTerm=1, 131 | leaderCommit=0, 132 | entries=[], 133 | ), 134 | ) 135 | ) 136 | expected_response = AppendEntriesSucceeded(matchIndex=2) 137 | assert s.outbox == [Message(frm="S2", to="S1", cmd=expected_response)] 138 | 139 | 140 | def test_append_entries_heartbeat_in_middle_of_log_returns_matchindex_at_that_position(): 141 | log = InMemoryLog([Entry(term=1, cmd="foo=1"), Entry(term=1, cmd="foo=2"), Entry(term=1, cmd='foo=3')]) 142 | s = Follower( 143 | name="S2", 144 | peers=["S1", "S2", "S3"], 145 | now=1, 146 | log=log, 147 | currentTerm=1, 148 | votedFor=None, 149 | ) 150 | s.handle_message( 151 | Message( 152 | frm="S1", 153 | to="S2", 154 | cmd=AppendEntries( 155 | term=1, 156 | leaderId="S1", 157 | prevLogIndex=2, 158 | prevLogTerm=1, 159 | leaderCommit=0, 160 | entries=[], 161 | ), 162 | ) 163 | ) 164 | expected_response = AppendEntriesSucceeded(matchIndex=2) 165 | assert s.outbox == [Message(frm="S2", to="S1", cmd=expected_response)] 166 | 167 | 168 | def test_append_entries_failed_response(): 169 | old_entries = [Entry(term=1, cmd="first=entry"), Entry(term=2, cmd="e=2")] 170 | log = InMemoryLog(old_entries) 171 | s = Follower( 172 | name="S2", 173 | peers=["S1", "S2", "S3"], 174 | now=1, 175 | log=log, 176 | currentTerm=2, 177 | votedFor=None, 178 | ) 179 | new_entry = Entry(term=1, cmd="term=wrong") 180 | s.handle_message( 181 | Message( 182 | frm="S1", 183 | to="S2", 184 | cmd=AppendEntries( 185 | term=2, 186 | leaderId="S1", 187 | prevLogIndex=2, 188 | prevLogTerm=1, 189 | leaderCommit=0, 190 | entries=[new_entry], 191 | ), 192 | ) 193 | ) 194 | assert s.log.read() == old_entries 195 | expected_response = AppendEntriesFailed(term=2) 196 | assert s.outbox == [Message(frm="S2", to="S1", cmd=expected_response)] 197 | 198 | 199 | def test_append_entries_failed_response_to_heartbeat(): 200 | old_entries = [Entry(term=1, cmd="first=entry"), Entry(term=2, cmd="e=2")] 201 | log = InMemoryLog(old_entries) 202 | s = Follower( 203 | name="S2", 204 | peers=["S1", "S2", "S3"], 205 | now=1, 206 | log=log, 207 | currentTerm=2, 208 | votedFor=None, 209 | ) 210 | s.handle_message( 211 | Message( 212 | frm="S1", 213 | to="S2", 214 | cmd=AppendEntries( 215 | term=2, 216 | leaderId="S1", 217 | prevLogIndex=2, 218 | prevLogTerm=1, 219 | leaderCommit=0, 220 | entries=[], 221 | ), 222 | ) 223 | ) 224 | assert s.log.read() == old_entries 225 | expected_response = AppendEntriesFailed(term=2) 226 | assert s.outbox == [Message(frm="S2", to="S1", cmd=expected_response)] 227 | 228 | 229 | def test_clock_tick_does_nothing_by_default(): 230 | term = 2 231 | s = Follower( 232 | name="S2", 233 | peers=["S1", "S2", "S3"], 234 | now=1, 235 | log=InMemoryLog([]), 236 | currentTerm=term, 237 | votedFor=None, 238 | ) 239 | a_tiny_amount_of_time = 0.001 240 | s.clock_tick(a_tiny_amount_of_time) 241 | assert s.currentTerm == term 242 | assert s.outbox == [] 243 | 244 | 245 | def test_calls_election_if_clock_tick_past_election_timeout(): 246 | log = [Entry(2, "foo=1"), Entry(3, "foo=2")] 247 | f = Follower( 248 | name="S2", 249 | peers=["S1", "S2", "S3"], 250 | now=0, 251 | log=InMemoryLog(log), 252 | currentTerm=3, 253 | votedFor=None, 254 | ) 255 | a_tiny_amount_of_time = 0.001 256 | f.clock_tick(a_tiny_amount_of_time) 257 | assert f.outbox == [] 258 | 259 | past_timeout = 1 260 | assert MIN_ELECTION_TIMEOUT + ELECTION_TIMEOUT_JITTER < past_timeout 261 | f.clock_tick(past_timeout) 262 | 263 | assert f.currentTerm == 4 264 | assert f.votedFor == "S2" 265 | expected_messages = [ 266 | Message( 267 | frm="S2", 268 | to="S1", 269 | cmd=RequestVote(term=4, candidateId="S2", lastLogIndex=2, lastLogTerm=3), 270 | ), 271 | Message( 272 | frm="S2", 273 | to="S3", 274 | cmd=RequestVote(term=4, candidateId="S2", lastLogIndex=2, lastLogTerm=3), 275 | ), 276 | ] 277 | assert f.outbox == expected_messages 278 | 279 | a_tiny_amount_of_time = 0.001 280 | f.clock_tick(past_timeout + a_tiny_amount_of_time) 281 | assert f.outbox == expected_messages # ie no change 282 | -------------------------------------------------------------------------------- /tests/unit/test_leader.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from raft.server import Server, Leader, HEARTBEAT_FREQUENCY 3 | from raft.log import InMemoryLog, Entry 4 | from raft.messages import ( 5 | AppendEntries, 6 | AppendEntriesSucceeded, 7 | AppendEntriesFailed, 8 | Message, 9 | ClientSetCommand, 10 | ) 11 | 12 | def test_init(): 13 | peers = ["S1", "S2", "S3"] 14 | old_entries = [Entry(term=1, cmd="old=1"), Entry(term=2, cmd="old=2")] 15 | log = InMemoryLog(old_entries) 16 | s = Leader(name="S1", now=1, log=log, peers=peers, currentTerm=2, votedFor=None) 17 | assert s.matchIndex == {'S2': 0, 'S3': 0} 18 | assert s.nextIndex == {'S2': 3, 'S3': 3} 19 | 20 | 21 | def test_handle_client_set_updates_local_log_and_puts_AppendEntries_in_outbox(): 22 | peers = ["S1", "S2", "S3", "S4", "S5"] 23 | old_entries = [Entry(term=1, cmd="old=1"), Entry(term=2, cmd="old=2")] 24 | log = InMemoryLog(old_entries) 25 | s = Leader(name="S1", now=1, log=log, peers=peers, currentTerm=2, votedFor=None) 26 | 27 | s.handle_message(Message(frm="client.id", to="S1", cmd=ClientSetCommand(guid='gaga', cmd="foo=bar"))) 28 | expected_entry = Entry(term=2, cmd="foo=bar") 29 | assert s.log.read() == old_entries + [expected_entry] 30 | expected_appendentries = AppendEntries( 31 | term=2, 32 | leaderId="S1", 33 | prevLogIndex=2, 34 | prevLogTerm=2, 35 | leaderCommit=0, 36 | entries=[expected_entry], 37 | ) 38 | assert s.outbox == [ 39 | Message(frm="S1", to=s, cmd=expected_appendentries) for s in peers if s != "S1" 40 | ] 41 | 42 | 43 | def test_successful_appendentries_response_updates_matchIndex_last_entry_case(): 44 | peers = ["S1", "S2", "S3", "S4", "S5"] 45 | old_entries = [Entry(term=1, cmd="old=1"), Entry(term=2, cmd="old=2")] 46 | log = InMemoryLog(old_entries) 47 | s = Leader(name="S1", now=1, log=log, peers=peers, currentTerm=2, votedFor=None) 48 | 49 | s.matchIndex["S2"] = 1 # arbitrarily 50 | s.nextIndex["S2"] = 2 51 | s.handle_message( 52 | Message(frm="S2", to="S1", cmd=AppendEntriesSucceeded(matchIndex=2)) 53 | ) 54 | assert s.matchIndex["S2"] == 2 55 | assert s.nextIndex["S2"] == 3 56 | assert s.outbox == [] 57 | 58 | 59 | def test_successful_appendentries_cannot_take_nextIndex_past_end(): 60 | peers = ["S1", "S2", "S3", "S4", "S5"] 61 | old_entries = [Entry(term=1, cmd="old=1"), Entry(term=2, cmd="old=2")] 62 | log = InMemoryLog(old_entries) 63 | s = Leader(name="S1", now=1, log=log, peers=peers, currentTerm=2, votedFor=None) 64 | 65 | s.nextIndex["S2"] = 3 66 | s.handle_message( 67 | Message(frm="S2", to="S1", cmd=AppendEntriesSucceeded(matchIndex=2)) 68 | ) 69 | assert s.nextIndex["S2"] == 3 70 | 71 | 72 | def test_duplicate_appendentries_responses_do_not_double_increment_index_counters(): 73 | peers = ["S1", "S2", "S3", "S4", "S5"] 74 | old_entries = [ 75 | Entry(term=1, cmd="old=1"), Entry(term=2, cmd="old=2"), Entry(term=2, cmd='old=3') 76 | ] 77 | log = InMemoryLog(old_entries) 78 | s = Leader(name="S1", now=1, log=log, peers=peers, currentTerm=2, votedFor=None) 79 | s.matchIndex["S2"] = 1 # arbitrarily 80 | s.nextIndex["S2"] = 2 81 | s.handle_message( 82 | Message(frm="S2", to="S1", cmd=AppendEntriesSucceeded(matchIndex=2)) 83 | ) 84 | assert s.matchIndex["S2"] == 2 85 | assert s.nextIndex["S2"] == 3 86 | s.handle_message( 87 | Message(frm="S2", to="S1", cmd=AppendEntriesSucceeded(matchIndex=2)) 88 | ) 89 | assert s.matchIndex["S2"] == 2 90 | assert s.nextIndex["S2"] == 3 91 | 92 | 93 | def test_failed_appendentries_decrements_nextindex_and_adds_new_AppendEntries_to_outbox(): 94 | peers = ["S1", "S2", "S3", "S4", "S5"] 95 | old_entries = [Entry(term=1, cmd="old=1"), Entry(term=2, cmd="old=2")] 96 | log = InMemoryLog(old_entries) 97 | s = Leader(name="S1", now=1, log=log, peers=peers, currentTerm=2, votedFor=None) 98 | s.matchIndex["S2"] = 2 # arbitrarily 99 | s.nextIndex["S2"] = 2 # arbitrarily 100 | s.handle_message(Message(frm="S2", to="S1", cmd=AppendEntriesFailed(term=2))) 101 | assert s.matchIndex["S2"] == 2 # should not move 102 | assert s.nextIndex["S2"] == 1 103 | assert s.outbox == [ 104 | Message( 105 | frm="S1", 106 | to="S2", 107 | cmd=AppendEntries( 108 | term=2, 109 | leaderId="S1", 110 | prevLogIndex=0, 111 | prevLogTerm=0, 112 | leaderCommit=0, 113 | entries=[old_entries[0]], 114 | ), 115 | ) 116 | ] 117 | 118 | 119 | def test_failed_appendentries_cannot_take_nextIndex_below_one(): 120 | peers = ["S1", "S2", "S3", "S4", "S5"] 121 | old_entries = [Entry(term=1, cmd="old=1"), Entry(term=2, cmd="old=2")] 122 | log = InMemoryLog(old_entries) 123 | s = Leader(name="S1", now=1, log=log, peers=peers, currentTerm=2, votedFor=None) 124 | s.nextIndex["S2"] = 1 125 | 126 | s.handle_message(Message(frm="S2", to="S1", cmd=AppendEntriesFailed(term=2))) 127 | assert s.nextIndex["S2"] == 1 128 | 129 | 130 | @pytest.mark.xfail 131 | def test_duplicate_failed_appendentries_do_not_double_decrement_or_double_reappend(): 132 | peers = ["S1", "S2", "S3", "S4", "S5"] 133 | old_entries = [Entry(term=1, cmd="old=1"), Entry(term=2, cmd="old=2")] 134 | log = InMemoryLog(old_entries) 135 | s = Leader(name="S1", now=1, log=log, peers=peers, currentTerm=2, votedFor=None) 136 | s.nextIndex["S2"] = 3 137 | s.matchIndex["S2"] = 2 # arbitrarily 138 | 139 | s.handle_message(Message(frm="S2", to="S1", cmd=AppendEntriesFailed(term=2))) 140 | assert s.matchIndex["S2"] == 2 141 | assert s.nextIndex["S2"] == 2 142 | 143 | s.handle_message(Message(frm="S2", to="S1", cmd=AppendEntriesFailed(term=2))) 144 | assert s.matchIndex["S2"] == 2 145 | assert s.nextIndex["S2"] == 2 # do we care? 146 | assert s.outbox == [ 147 | Message( 148 | frm="S1", 149 | to="S2", 150 | cmd=AppendEntries( 151 | term=2, 152 | leaderId="S1", 153 | prevLogIndex=0, 154 | prevLogTerm=0, 155 | leaderCommit=0, 156 | entries=[old_entries[0]], 157 | ), 158 | ) 159 | ] 160 | 161 | def test_successful_appendentries_response_adds_AppendEntries_if_matchIndex_lower_than_lastIndex(): 162 | peers = ["S1", "S2", "S3", "S4", "S5"] 163 | old_entries = [Entry(term=1, cmd="old=1"), Entry(term=2, cmd="old=2")] 164 | log = InMemoryLog(old_entries) 165 | s = Leader(name="S1", now=1, log=log, peers=peers, currentTerm=2, votedFor=None) 166 | s.matchIndex["S2"] == 0 # arbitrarily 167 | s.handle_message( 168 | Message(frm="S2", to="S1", cmd=AppendEntriesSucceeded(matchIndex=1)) 169 | ) 170 | assert s.matchIndex["S2"] == 1 171 | assert s.outbox == [ 172 | Message( 173 | frm="S1", 174 | to="S2", 175 | cmd=AppendEntries( 176 | term=2, 177 | leaderId="S1", 178 | prevLogIndex=1, 179 | prevLogTerm=1, 180 | leaderCommit=0, 181 | entries=[old_entries[1]], 182 | ), 183 | ) 184 | ] 185 | 186 | 187 | def test_clock_tick_gives_first_heartbeat(): 188 | peers = ["S1", "S2", "S3", "S4", "S5"] 189 | old_entries = [Entry(term=1, cmd="old=1"), Entry(term=2, cmd="old=2")] 190 | log = InMemoryLog(old_entries) 191 | s = Leader(name="S1", now=1, log=log, peers=peers, currentTerm=2, votedFor=None) 192 | s.clock_tick(now=2) 193 | assert 2 - 1 > HEARTBEAT_FREQUENCY 194 | 195 | expected_appendentries = AppendEntries( 196 | term=2, 197 | leaderId="S1", 198 | prevLogIndex=2, 199 | prevLogTerm=2, 200 | leaderCommit=0, 201 | entries=[], 202 | ) 203 | assert s.outbox == [ 204 | Message(frm="S1", to=s, cmd=expected_appendentries) for s in peers if s != "S1" 205 | ] 206 | 207 | 208 | def test_heartbeat_is_custom_for_each_follower_based_on_nextIndex(): 209 | peers = ["S1", "S2", "S3", "S4", "S5"] 210 | old_entries = [Entry(term=1, cmd="old=1"), Entry(term=2, cmd="old=2")] 211 | log = InMemoryLog(old_entries) 212 | s = Leader(name="S1", now=1, log=log, peers=peers, currentTerm=2, votedFor=None) 213 | s.nextIndex['S2'] = 1 214 | s.nextIndex['S3'] = 2 215 | s.nextIndex['S4'] = 3 216 | s.nextIndex['S5'] = 3 217 | s.clock_tick(now=2) 218 | assert 2 - 1 > HEARTBEAT_FREQUENCY 219 | 220 | assert s.outbox == [ 221 | Message( 222 | frm='S1', to='S2', cmd=AppendEntries( 223 | term=2, 224 | leaderId="S1", 225 | prevLogIndex=0, 226 | prevLogTerm=0, 227 | leaderCommit=0, 228 | entries=[], 229 | )), 230 | Message( 231 | frm='S1', to='S3', cmd=AppendEntries( 232 | term=2, 233 | leaderId="S1", 234 | prevLogIndex=1, 235 | prevLogTerm=1, 236 | leaderCommit=0, 237 | entries=[], 238 | )), 239 | Message( 240 | frm='S1', to='S4', cmd=AppendEntries( 241 | term=2, 242 | leaderId="S1", 243 | prevLogIndex=2, 244 | prevLogTerm=2, 245 | leaderCommit=0, 246 | entries=[], 247 | )), 248 | Message( 249 | frm='S1', to='S5', cmd=AppendEntries( 250 | term=2, 251 | leaderId="S1", 252 | prevLogIndex=2, 253 | prevLogTerm=2, 254 | leaderCommit=0, 255 | entries=[], 256 | )), 257 | ] 258 | 259 | def test_heartbeat_only_appears_once_per_interval(): 260 | peers = ["S1", "S2", "S3", "S4", "S5"] 261 | old_entries = [Entry(term=1, cmd="old=1"), Entry(term=2, cmd="old=2")] 262 | log = InMemoryLog(old_entries) 263 | s = Leader(name="S1", now=1, log=log, peers=peers, currentTerm=2, votedFor=None) 264 | s.clock_tick(1) 265 | s.outbox[:] = [] # clear outbox 266 | too_soon = 1 + HEARTBEAT_FREQUENCY / 2.0 267 | s.clock_tick(too_soon) 268 | assert s.outbox == [] 269 | stilL_too_soon = 1 + HEARTBEAT_FREQUENCY - 0.001 270 | s.clock_tick(stilL_too_soon) 271 | assert s.outbox == [] 272 | just_after = 1 + HEARTBEAT_FREQUENCY + 0.001 273 | s.clock_tick(just_after) 274 | assert len(s.outbox) == 4 275 | two_heartbeats_in_theory = 1 + HEARTBEAT_FREQUENCY * 2 276 | 277 | # test we track time since last heartbeat, rather than from t=0 278 | assert two_heartbeats_in_theory < (just_after + HEARTBEAT_FREQUENCY) 279 | s.clock_tick(two_heartbeats_in_theory) 280 | assert len(s.outbox) == 4 281 | 282 | next_one = just_after + HEARTBEAT_FREQUENCY + 0.001 283 | s.clock_tick(next_one) 284 | assert len(s.outbox) == 8 285 | 286 | 287 | def test_becoming_follower_should_reset_matchindex_and_nextIndex(): 288 | s = Leader(name="S1", now=1, log=InMemoryLog([]), peers=["S1", "S2"], currentTerm=1, votedFor=None) 289 | s.matchIndex["S2"] = 99 290 | s.nextIndex["S2"] = 99 291 | s._become_follower() 292 | assert s.matchIndex == {} 293 | assert s.nextIndex == {} 294 | 295 | 296 | def test_updates_commitIndex_on_quorum_AppendEntriesSucceeded(): 297 | peers = ["S1", "S2", "S3", "S4", "S5"] 298 | old_entries = [Entry(term=1, cmd="old=1"), Entry(term=2, cmd="old=2")] 299 | log = InMemoryLog(old_entries) 300 | s = Leader(name="S1", now=1, log=log, peers=peers, currentTerm=1, votedFor=None) 301 | assert s.commitIndex == 0 302 | s.handle_message( 303 | Message(frm="S2", to="S1", cmd=AppendEntriesSucceeded(matchIndex=1)) 304 | ) 305 | assert s.commitIndex == 0 306 | s.handle_message( 307 | Message(frm="S2", to="S1", cmd=AppendEntriesSucceeded(matchIndex=2)) 308 | ) 309 | assert s.commitIndex == 0 310 | s.handle_message( 311 | Message(frm="S3", to="S1", cmd=AppendEntriesSucceeded(matchIndex=1)) 312 | ) 313 | assert s.commitIndex == 1 314 | -------------------------------------------------------------------------------- /src/raft/server.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Dict, List, Optional 3 | from raft.log import Log, Entry 4 | from raft.messages import ( 5 | Message, 6 | AppendEntries, 7 | AppendEntriesSucceeded, 8 | AppendEntriesFailed, 9 | ClientSetCommand, 10 | ClientSetSucceeded, 11 | RequestVote, 12 | VoteGranted, 13 | VoteDenied, 14 | ) 15 | 16 | HEARTBEAT_FREQUENCY = 0.02 17 | MIN_ELECTION_TIMEOUT = 0.15 18 | ELECTION_TIMEOUT_JITTER = 0.15 19 | 20 | 21 | class Server: 22 | def __init__( 23 | self, 24 | name: str, 25 | peers: List[str], 26 | now: float, 27 | log: Log, 28 | currentTerm: int, 29 | votedFor: Optional[str], 30 | ): 31 | self.name = name 32 | self.peers = peers 33 | self.now = now 34 | self._last_heartbeat = 0 # type: float 35 | self._reset_election_timeout() 36 | self.outbox = [] # type: List[Message] 37 | 38 | # Raft persistent state 39 | self.log = log 40 | self.currentTerm = currentTerm 41 | self.votedFor = votedFor 42 | 43 | # Raft volatile state 44 | self.commitIndex = 0 45 | self.lastApplied = 0 46 | 47 | def __repr__(self): 48 | return f"<{self.__class__.__name__}: term={self.currentTerm}, lastLogIndex={self.log.lastLogIndex}>" 49 | 50 | def _reset_election_timeout(self) -> None: 51 | jitter = random.randint(0, int(ELECTION_TIMEOUT_JITTER * 1000)) / 1000.0 52 | self._election_timeout = self.now + MIN_ELECTION_TIMEOUT + jitter 53 | 54 | def handle_message(self, msg: Message) -> None: 55 | print(f"{self.name} handling {msg}") 56 | if hasattr(msg.cmd, "term") and msg.cmd.term > self.currentTerm: 57 | self.currentTerm = msg.cmd.term 58 | self.votedFor = None 59 | self._become_follower() 60 | self._handle_message(msg) 61 | 62 | def _handle_message(self, msg: Message) -> None: 63 | raise NotImplementedError 64 | 65 | def clock_tick(self, now: float): 66 | raise NotImplementedError 67 | 68 | def _become_follower(self) -> None: 69 | print(f"** {self.name} is becoming a Follower **") 70 | self.__class__ = Follower 71 | 72 | 73 | class Leader(Server): 74 | def __init__( 75 | self, 76 | name: str, 77 | peers: List[str], 78 | now: float, 79 | log: Log, 80 | currentTerm: int, 81 | votedFor: Optional[str], 82 | ): 83 | super().__init__(name, peers, now, log, currentTerm, votedFor) 84 | self._setup_follower_tracking_indexes() 85 | 86 | def _become_follower(self) -> None: 87 | super()._become_follower() 88 | self.matchIndex.clear() 89 | self.nextIndex.clear() 90 | 91 | def _setup_follower_tracking_indexes(self) -> None: 92 | # Raft leader volatile state 93 | self.nextIndex = { 94 | server_name: self.log.lastLogIndex + 1 95 | for server_name in self.peers 96 | if server_name != self.name 97 | } # type: Dict[str, int] 98 | self.matchIndex = { 99 | server_name: 0 for server_name in self.peers if server_name != self.name 100 | } # type: Dict[str, int] 101 | 102 | def clock_tick(self, now: float) -> None: 103 | self.now = now 104 | if self.now > (self._last_heartbeat + HEARTBEAT_FREQUENCY): 105 | self._last_heartbeat = self.now 106 | self.outbox.extend( 107 | Message(frm=self.name, to=s, cmd=self._heartbeat_for(s)) 108 | for s in self.peers 109 | if s != self.name 110 | ) 111 | 112 | def _handle_message(self, msg: Message) -> None: 113 | if isinstance(msg.cmd, ClientSetCommand): 114 | self._handleClientSetCommand(frm=msg.frm, cmd=msg.cmd) 115 | 116 | if isinstance(msg.cmd, AppendEntriesSucceeded): 117 | self._handleAppendEntriesSucceeded(frm=msg.frm, cmd=msg.cmd) 118 | 119 | if isinstance(msg.cmd, AppendEntriesFailed): 120 | self._handleAppendEntriesFailed(frm=msg.frm) 121 | 122 | def _handleClientSetCommand(self, frm: str, cmd: ClientSetCommand): 123 | prevLogIndex = self.log.lastLogIndex 124 | prevLogTerm = self.log.last_log_term 125 | new_entry = Entry(term=self.currentTerm, cmd=cmd.cmd) 126 | assert self.log.add_entry( 127 | entry=new_entry, 128 | prevLogIndex=prevLogIndex, 129 | prevLogTerm=prevLogTerm, 130 | leaderCommit=1, 131 | ) 132 | print(f"server added {cmd} at position {prevLogIndex + 1}") 133 | ae = AppendEntries( 134 | term=self.currentTerm, 135 | leaderId=self.name, 136 | prevLogIndex=prevLogIndex, 137 | prevLogTerm=prevLogTerm, 138 | leaderCommit=0, 139 | entries=[new_entry], 140 | ) 141 | self.outbox.extend( 142 | Message(frm=self.name, to=s, cmd=ae) for s in self.peers if s != self.name 143 | ) 144 | if False: # TODO: 145 | self.outbox.append( 146 | Message(frm=self.name, to=frm, cmd=ClientSetSucceeded(guid=cmd.guid)) 147 | ) 148 | 149 | def _handleAppendEntriesSucceeded(self, frm: str, cmd: AppendEntriesSucceeded): 150 | self.matchIndex[frm] = cmd.matchIndex 151 | self.nextIndex[frm] = cmd.matchIndex + 1 152 | print(self.matchIndex) 153 | if self.matchIndex[frm] < self.log.lastLogIndex: 154 | self._send_next_entry(frm) 155 | self._commit_if_possible(cmd.matchIndex) 156 | 157 | def _handleAppendEntriesFailed(self, frm: str): 158 | self.nextIndex[frm] = max(self.nextIndex[frm] - 1, 1) 159 | index_to_resend = self.nextIndex[frm] 160 | print(f"{frm} failed, resending entry at {index_to_resend}") 161 | prevLogIndex = index_to_resend - 1 162 | prevLogTerm = self.log.entry_term(prevLogIndex) 163 | self.outbox.append( 164 | Message( 165 | frm=self.name, 166 | to=frm, 167 | cmd=AppendEntries( 168 | term=self.currentTerm, 169 | leaderId=self.name, 170 | prevLogIndex=prevLogIndex, 171 | prevLogTerm=prevLogTerm, 172 | leaderCommit=0, 173 | entries=[self.log.entry_at(index_to_resend)], 174 | ), 175 | ) 176 | ) 177 | 178 | def _heartbeat_for(self, follower) -> AppendEntries: 179 | print(f"making heartbeat for {follower}") 180 | prevLogIndex = self.nextIndex[follower] - 1 181 | prevLogTerm = self.log.entry_term(prevLogIndex) 182 | return AppendEntries( 183 | term=self.currentTerm, 184 | leaderId=self.name, 185 | prevLogIndex=prevLogIndex, 186 | prevLogTerm=prevLogTerm, 187 | leaderCommit=0, 188 | entries=[], 189 | ) 190 | 191 | def _next_entry_for(self, follower) -> AppendEntries: 192 | prevLogIndex = self.nextIndex[follower] - 1 193 | prevLogTerm = self.log.entry_term(prevLogIndex) 194 | entry = self.log.entry_at(self.nextIndex[follower]) 195 | return AppendEntries( 196 | term=self.currentTerm, 197 | leaderId=self.name, 198 | prevLogIndex=prevLogIndex, 199 | prevLogTerm=prevLogTerm, 200 | leaderCommit=0, 201 | entries=[entry], 202 | ) 203 | 204 | 205 | def _commit_if_possible(self, matchIndex: int): 206 | if self.commitIndex > matchIndex: 207 | return 208 | if self._have_quorum_at(matchIndex): 209 | self.commitIndex += 1 210 | # TODOs: 211 | # self.log.apply_state_machine_up_to(self.commitIndex) 212 | # self._send_any_pending_client_responses() 213 | 214 | def _have_quorum_at(self, matchIndex) -> bool: 215 | quorum = len(self.peers) // 2 216 | matching_follwers = len( 217 | [f for f, ix in self.matchIndex.items() if ix >= matchIndex] 218 | ) 219 | me = 1 220 | return (matching_follwers + me) > quorum 221 | 222 | def _send_next_entry(self, follower: str): 223 | next_to_send = self.log.entry_at(self.matchIndex[follower] + 1) 224 | prevLogIndex = self.matchIndex[follower] 225 | prevLogTerm = self.log.entry_term(prevLogIndex) 226 | self.outbox.append( 227 | Message( 228 | frm=self.name, 229 | to=follower, 230 | cmd=AppendEntries( 231 | term=self.currentTerm, 232 | leaderId=self.name, 233 | prevLogIndex=prevLogIndex, 234 | prevLogTerm=prevLogTerm, 235 | leaderCommit=0, 236 | entries=[next_to_send], 237 | ), 238 | ) 239 | ) 240 | 241 | 242 | class Follower(Server): 243 | def clock_tick(self, now: float): 244 | self.now = now 245 | if self.now > self._election_timeout: 246 | print( 247 | f"election timeout! {self.now} was greater than {self._election_timeout}" 248 | ) 249 | self._reset_election_timeout() 250 | self._become_candidate() 251 | 252 | def _handle_message(self, msg: Message) -> None: 253 | if isinstance(msg.cmd, AppendEntries): 254 | kvcmd = msg.cmd.entries[0].cmd if msg.cmd.entries else "HeArtBeAt" 255 | self._handle_AppendEntries(frm=msg.frm, cmd=msg.cmd) 256 | 257 | if isinstance(msg.cmd, RequestVote): 258 | self._handle_RequestVote(frm=msg.frm, cmd=msg.cmd) 259 | 260 | def _handle_RequestVote(self, frm: str, cmd: RequestVote) -> None: 261 | assert frm == cmd.candidateId 262 | if self._should_grant_vote(cmd): 263 | self.outbox.append(Message(frm=self.name, to=frm, cmd=VoteGranted())) 264 | else: 265 | self.outbox.append( 266 | Message(frm=self.name, to=frm, cmd=VoteDenied(term=self.currentTerm),) 267 | ) 268 | 269 | def _should_grant_vote(self, cmd: RequestVote) -> bool: 270 | if cmd.term < self.currentTerm: 271 | return False 272 | if cmd.lastLogTerm < self.log.last_log_term: 273 | return False 274 | if cmd.lastLogIndex < self.log.lastLogIndex: 275 | return False 276 | if self.votedFor and self.votedFor != cmd.candidateId: 277 | return False 278 | return True 279 | 280 | def _handle_AppendEntries(self, frm: str, cmd: AppendEntries) -> None: 281 | # TODO: this log.check_log() is rough. 282 | # lets convert to appendentries taking a list. 283 | if not self.log.check_log(cmd.prevLogIndex, cmd.prevLogTerm): 284 | self.outbox.append( 285 | Message( 286 | frm=self.name, 287 | to=frm, 288 | cmd=AppendEntriesFailed(term=self.currentTerm), 289 | ) 290 | ) 291 | return 292 | self._reset_election_timeout() 293 | matchIndex = cmd.prevLogIndex 294 | for entry in cmd.entries: 295 | assert self.log.add_entry( 296 | entry, cmd.prevLogIndex, cmd.prevLogTerm, cmd.leaderCommit 297 | ) 298 | matchIndex += 1 299 | self.outbox.append( 300 | Message( 301 | frm=self.name, 302 | to=frm, 303 | cmd=AppendEntriesSucceeded(matchIndex=matchIndex), 304 | ) 305 | ) 306 | 307 | def _become_candidate(self) -> None: 308 | print(f"** {self.name} is becoming Candidate **") 309 | self.__class__ = Candidate 310 | self._call_election() # pylint: disable=no-member 311 | 312 | 313 | class Candidate(Server): 314 | def clock_tick(self, now: float) -> None: 315 | self.now = now 316 | 317 | def _handle_message(self, msg: Message) -> None: 318 | if isinstance(msg.cmd, VoteGranted): 319 | self._votes.add(msg.frm) 320 | if len(self._votes) > len(self.peers) / 2: 321 | self._become_leader() 322 | 323 | def _call_election(self): 324 | self.currentTerm += 1 325 | self.votedFor = self.name 326 | self._votes = set([self.votedFor]) 327 | self.outbox.extend( 328 | Message( 329 | frm=self.name, 330 | to=p, 331 | cmd=RequestVote( 332 | term=self.currentTerm, 333 | candidateId=self.name, 334 | lastLogIndex=self.log.lastLogIndex, 335 | lastLogTerm=self.log.last_log_term, 336 | ), 337 | ) 338 | for p in self.peers 339 | if p != self.name 340 | ) 341 | 342 | def _become_leader(self) -> None: 343 | print(f"** {self.name} is becoming Leader **") 344 | self.__class__ = Leader 345 | self._setup_follower_tracking_indexes() 346 | --------------------------------------------------------------------------------